diff --git a/src/broker_task/mod.rs b/src/broker_task/mod.rs index 54fdf6f..219cf7b 100644 --- a/src/broker_task/mod.rs +++ b/src/broker_task/mod.rs @@ -8,7 +8,9 @@ mod request; pub(crate) use error::{BrokerError, BrokerResult}; use registry::Registry; pub(crate) use request::{ - AdvertiseRequest, BrokerRequestBox, BrokerRequestSender, SubscribeRequest, + AdvertiseRequest, BrokerRequestBox, BrokerRequestSender, ClaimReceiverError, + ClaimReceiverRequest, MakeBrokerRequest, SenderRequestNoWait, SenderRequestWait, + SubscribeRequest, }; pub(crate) struct BrokerTask { diff --git a/src/broker_task/registry.rs b/src/broker_task/registry.rs index 9efff2a..afc62d7 100644 --- a/src/broker_task/registry.rs +++ b/src/broker_task/registry.rs @@ -1,4 +1,5 @@ -use crate::Message; +use super::{BrokerError, BrokerResult}; +use crate::{Message, Sender}; use log::trace; use std::{ any::{Any, TypeId}, @@ -6,8 +7,7 @@ use std::{ collections::HashMap, marker::PhantomData, }; -use tokio::sync::broadcast; -use super::{BrokerError, BrokerResult}; +use tokio::sync::{broadcast, mpsc, oneshot}; pub trait MessageType { fn message_type_id(&self) -> TypeId; @@ -15,6 +15,8 @@ pub trait MessageType { fn message_type_name(&self) -> &'static str; fn create_broadcast_sender(&self) -> ErasedSender; + + fn create_mpsc_sender(&self) -> ErasedSender; } pub struct BasicMessageType { @@ -40,7 +42,7 @@ impl MessageType for BasicMessageType { fn create_broadcast_sender(&self) -> ErasedSender { trace!( - "Creating sender for {} ({:?})", + "Creating broadcast sender for {} ({:?})", std::any::type_name::(), MessageType::type_id(self) ); @@ -49,10 +51,64 @@ impl MessageType for BasicMessageType { let sender: ErasedSender = Box::new(sender); sender } + + fn create_mpsc_sender(&self) -> ErasedSender { + trace!( + "Creating mpsc sender for {} ({:?})", + std::any::type_name::(), + MessageType::type_id(self) + ); + Box::new(MpscSender::::new()) + } } pub type ErasedSender = Box; +pub struct MpscSender { + sender: mpsc::Sender, + waiting: Option>, +} + +struct MpscWaiting { + receiver: mpsc::Receiver, + waiting: Vec>>>, +} + +impl MpscSender { + fn new() -> Self { + // TODO: configurable queue size (per message?) + let (sender, receiver) = mpsc::channel(8); + MpscSender { + sender, + waiting: Some(MpscWaiting { + receiver, + waiting: Vec::default(), + }), + } + } + + pub fn clone_sender_or_wait(&mut self, wait: oneshot::Sender>>) { + if let Some(waiting) = &mut self.waiting { + waiting.waiting.push(wait); + } else { + let _ = wait.send(Ok(self.clone_sender())); + } + } + + pub fn clone_sender(&self) -> Sender { + Sender::new(self.sender.clone()) + } + + pub fn claim_receiver(&mut self) -> Option> { + self.waiting.take().map(|w| { + for response_tx in w.waiting.into_iter() { + let _ = response_tx.send(Ok(self.clone_sender())); + } + w.receiver + }) + } +} + #[derive(Debug)] pub struct TopicEntry { message_type_id: TypeId, @@ -74,25 +130,44 @@ impl Default for Registry { } impl Registry { - fn get_erased_sender_for_type(&mut self, message_type: &dyn MessageType) -> &ErasedSender { + fn get_erased_sender_for_type<'a, 'b>( + &'a mut self, + message_type: &'b dyn MessageType, + create_sender: fn(&'b dyn MessageType) -> ErasedSender, + ) -> &'a mut ErasedSender { let type_id = message_type.message_type_id(); let type_name = message_type.message_type_name(); let topic_entry = self.topics.entry(type_name.into()); let topic_entry = topic_entry.or_insert_with(|| TopicEntry { message_type_id: type_id, message_type_name: type_name.to_string(), - sender: message_type.create_broadcast_sender(), + sender: create_sender(message_type), }); - &topic_entry.sender + &mut topic_entry.sender } - pub fn get_sender_for_type( - &mut self, - message_type: &dyn MessageType, - ) -> BrokerResult<&broadcast::Sender> { - let erased = self.get_erased_sender_for_type(message_type); + fn get_sender_for_type<'a, 'b, T: 'static>( + &'a mut self, + message_type: &'b dyn MessageType, + create_sender: fn(&'b dyn MessageType) -> ErasedSender, + ) -> BrokerResult<&'a mut T> { + let erased = self.get_erased_sender_for_type(message_type, create_sender); (**erased) - .downcast_ref::>() + .downcast_mut::() .ok_or(BrokerError::MismatchedType) } + + pub fn get_broadcast_sender( + &mut self, + message_type: &dyn MessageType, + ) -> BrokerResult<&mut broadcast::Sender> { + self.get_sender_for_type(message_type, MessageType::create_broadcast_sender) + } + + pub fn get_mpsc_sender( + &mut self, + message_type: &dyn MessageType, + ) -> BrokerResult<&mut MpscSender> { + self.get_sender_for_type(message_type, MessageType::create_mpsc_sender) + } } diff --git a/src/broker_task/request.rs b/src/broker_task/request.rs index a7d9c0a..7451098 100644 --- a/src/broker_task/request.rs +++ b/src/broker_task/request.rs @@ -1,18 +1,30 @@ +use std::marker::PhantomData; use tokio::sync::{mpsc, oneshot}; use super::{ registry::{BasicMessageType, Registry}, - BrokerResult, + BrokerError, BrokerResult, }; -use crate::{Message, Publication, Subscription}; +use crate::{Message, Publication, Receiver, Sender, Subscription}; -pub trait BrokerRequestInternal { +pub trait BrokerRequestInternal: Send { fn run_request(self: Box, registry: &mut Registry); } -pub trait RequestHandler { +pub trait MakeBrokerRequest: BrokerRequestInternal { type Response; + fn create(response_tx: oneshot::Sender) -> Box; + + fn new() -> (Box, oneshot::Receiver) { + let (response_tx, response_rx) = oneshot::channel(); + (Self::create(response_tx), response_rx) + } +} + +pub trait RequestHandler: Send { + type Response: Send; + fn handle(self, registry: &mut Registry) -> Self::Response; } @@ -28,21 +40,18 @@ impl BrokerRequestInternal for BrokerRequest { } } -impl BrokerRequest { +impl MakeBrokerRequest for BrokerRequest { + type Response = H::Response; + fn create(response_tx: oneshot::Sender) -> Box { Box::new(Self { handler: H::default(), response_tx, }) } - - pub(crate) fn new() -> (Box, oneshot::Receiver) { - let (response_tx, response_rx) = oneshot::channel(); - (Self::create(response_tx), response_rx) - } } -pub type BrokerRequestBox = Box; +pub type BrokerRequestBox = Box; pub type BrokerRequestSender = mpsc::Sender; pub type SubscribeRequest = BrokerRequest>; @@ -62,7 +71,7 @@ impl RequestHandler for Subscribe { type Response = BrokerResult>; fn handle(self, registry: &mut Registry) -> Self::Response { - let sender = registry.get_sender_for_type::(&self.msg_type); + let sender = registry.get_broadcast_sender::(&self.msg_type); sender.map(|sender| { let receiver = sender.subscribe(); Subscription::new(receiver) @@ -87,7 +96,91 @@ impl RequestHandler for Advertise { type Response = BrokerResult>; fn handle(self, registry: &mut Registry) -> Self::Response { - let sender = registry.get_sender_for_type::(&self.msg_type); + let sender = registry.get_broadcast_sender::(&self.msg_type); sender.map(|sender| Publication::new(sender.clone())) } } + +pub type ClaimReceiverRequest = BrokerRequest>; + +pub enum ClaimReceiverError { + AlreadyClaimed, + Broker(BrokerError), +} + +pub struct ClaimReceiver { + msg_type: BasicMessageType, +} + +impl Default for ClaimReceiver { + fn default() -> Self { + Self { + msg_type: BasicMessageType::::default(), + } + } +} + +impl RequestHandler for ClaimReceiver { + type Response = Result, ClaimReceiverError>; + + fn handle(self, registry: &mut Registry) -> Self::Response { + registry + .get_mpsc_sender::(&self.msg_type) + .map_err(ClaimReceiverError::Broker) + .and_then(|sender| { + sender + .claim_receiver() + .ok_or(ClaimReceiverError::AlreadyClaimed) + }) + .map(Receiver::new) + } +} + +pub struct SenderRequest { + msg_type: BasicMessageType, + response_tx: oneshot::Sender>>, + _wait: PhantomData, +} + +pub trait GetSenderWait: Send { + const WAIT: bool; +} + +pub struct NoWait; +impl GetSenderWait for NoWait { + const WAIT: bool = false; +} +pub type SenderRequestNoWait = SenderRequest; + +pub struct Wait; +impl GetSenderWait for Wait { + const WAIT: bool = true; +} +pub type SenderRequestWait = SenderRequest; + +impl MakeBrokerRequest for SenderRequest { + type Response = BrokerResult>; + + fn create(response_tx: oneshot::Sender) -> Box { + Box::new(Self { + msg_type: BasicMessageType::default(), + response_tx, + _wait: PhantomData, + }) + } +} + +impl BrokerRequestInternal for SenderRequest { + fn run_request(self: Box, registry: &mut Registry) { + let response = match registry.get_mpsc_sender::(&self.msg_type) { + Ok(sender) => Ok(if W::WAIT { + sender.clone_sender_or_wait(self.response_tx); + return; + } else { + sender.clone_sender() + }), + Err(err) => Err(err), + }; + let _ = self.response_tx.send(response); + } +} diff --git a/src/lib.rs b/src/lib.rs index d14705a..7b6a8dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,16 @@ +#![warn(clippy::all)] + mod broker_task; mod message; -mod pub_sub; +pub mod mpsc; +pub mod pub_sub; -use broker_task::{ - AdvertiseRequest, BrokerError, BrokerRequestSender, BrokerResult, BrokerTask, SubscribeRequest, -}; +use broker_task::{BrokerError, BrokerRequestSender, BrokerResult, BrokerTask}; pub use message::Message; +pub use mpsc::*; pub use pub_sub::*; use futures::executor::block_on; -use tokio::sync::oneshot; #[derive(Clone, Debug, PartialEq)] #[non_exhaustive] @@ -27,9 +28,7 @@ impl From for OrsbError { } } -fn map_broker_response( - result: Result, oneshot::error::RecvError>, -) -> Result { +fn map_broker_response(result: Result, OrsbError>) -> Result { match result { Ok(Ok(value)) => Ok(value), Ok(Err(err)) => Err(err.into()), @@ -37,6 +36,22 @@ fn map_broker_response( } } +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum ClaimReceiverError { + AlreadyClaimed, + Other(OrsbError), +} + +impl From for ClaimReceiverError { + fn from(err: broker_task::ClaimReceiverError) -> Self { + match err { + broker_task::ClaimReceiverError::AlreadyClaimed => ClaimReceiverError::AlreadyClaimed, + broker_task::ClaimReceiverError::Broker(err) => ClaimReceiverError::Other(err.into()), + } + } +} + #[derive(Debug, Clone)] pub struct Orsb { request_tx: BrokerRequestSender, @@ -55,22 +70,29 @@ impl Orsb { (orsb, join_handle) } - pub async fn subscribe(&mut self) -> Result, OrsbError> { - let (subscribe_request, response_rx) = SubscribeRequest::::new(); + async fn make_request( + &mut self, + ) -> Result { + let (req, response_rx) = R::new(); self.request_tx - .send(subscribe_request) + .send(req) .await .or(Err(OrsbError::BrokerClosed))?; - map_broker_response(response_rx.await) + response_rx.await.map_err(|_| OrsbError::NoResponse) + } + + pub async fn subscribe(&mut self) -> Result, OrsbError> { + map_broker_response( + self.make_request::>() + .await, + ) } pub async fn advertise(&mut self) -> Result, OrsbError> { - let (advertise_request, response_rx) = AdvertiseRequest::::new(); - self.request_tx - .send(advertise_request) - .await - .or(Err(OrsbError::BrokerClosed))?; - map_broker_response(response_rx.await) + map_broker_response( + self.make_request::>() + .await, + ) } pub fn subscribe_blocking(&mut self) -> Result, OrsbError> { @@ -80,6 +102,41 @@ impl Orsb { pub fn advertise_blocking(&mut self) -> Result, OrsbError> { block_on(self.advertise::()) } + + pub async fn claim_receiver(&mut self) -> Result, ClaimReceiverError> { + self.make_request::>() + .await + .map_err(ClaimReceiverError::Other) + .and_then(|inner_result| inner_result.map_err(ClaimReceiverError::from)) + } + + pub async fn sender(&mut self) -> Result, OrsbError> { + map_broker_response( + self.make_request::>() + .await, + ) + } + + pub async fn wait_sender(&mut self) -> Result, OrsbError> { + map_broker_response( + self.make_request::>() + .await, + ) + } + + pub fn claim_receiver_blocking( + &mut self, + ) -> Result, ClaimReceiverError> { + block_on(self.claim_receiver::()) + } + + pub fn sender_blocking(&mut self) -> Result, OrsbError> { + block_on(self.sender::()) + } + + pub fn wait_sender_blocking(&mut self) -> Result, OrsbError> { + block_on(self.wait_sender::()) + } } #[cfg(test)] @@ -94,7 +151,7 @@ mod test { } #[tokio::test] - async fn test_sync() { + async fn test_pub_sub_sync() { init_logger(); let mut orsb = Orsb::start_new(); @@ -117,7 +174,7 @@ mod test { } #[tokio::test] - async fn test_async() { + async fn test_pub_sub_async() { init_logger(); let mut orsb = Orsb::start_new(); @@ -134,4 +191,54 @@ mod test { assert_eq!(sub.recv().await, Ok(TestMsg(30))); assert_eq!(sub.try_recv(), Err(SubscriptionError::Empty)); } + + #[tokio::test] + async fn test_mpsc_sync() { + init_logger(); + + let mut orsb = Orsb::start_new(); + + tokio::task::spawn_blocking(move || { + let mut recv = orsb.claim_receiver_blocking::().unwrap(); + assert_eq!( + orsb.claim_receiver_blocking::().unwrap_err(), + ClaimReceiverError::AlreadyClaimed + ); + let mut send = orsb.sender_blocking::().unwrap(); + + send.send_blocking(TestMsg(10)).unwrap(); + send.send_blocking(TestMsg(20)).unwrap(); + send.send_blocking(TestMsg(30)).unwrap(); + + assert_eq!(recv.recv_blocking(), Ok(TestMsg(10))); + assert_eq!(recv.recv_blocking(), Ok(TestMsg(20))); + assert_eq!(recv.recv_blocking(), Ok(TestMsg(30))); + assert_eq!(recv.try_recv(), Err(ReceiverError::Empty)); + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_mpsc_async() { + init_logger(); + + let mut orsb = Orsb::start_new(); + + let mut recv = orsb.claim_receiver::().await.unwrap(); + assert_eq!( + orsb.claim_receiver::().await.unwrap_err(), + ClaimReceiverError::AlreadyClaimed + ); + let mut send = orsb.sender::().await.unwrap(); + + send.send(TestMsg(10)).await.unwrap(); + send.send(TestMsg(20)).await.unwrap(); + send.send(TestMsg(30)).await.unwrap(); + + assert_eq!(recv.recv().await, Ok(TestMsg(10))); + assert_eq!(recv.recv().await, Ok(TestMsg(20))); + assert_eq!(recv.recv().await, Ok(TestMsg(30))); + assert_eq!(recv.try_recv(), Err(ReceiverError::Empty)); + } } diff --git a/src/mpsc/mod.rs b/src/mpsc/mod.rs new file mode 100644 index 0000000..d970be3 --- /dev/null +++ b/src/mpsc/mod.rs @@ -0,0 +1,5 @@ +mod receiver; +mod sender; + +pub use receiver::{Receiver, ReceiverError}; +pub use sender::{Sender, SenderError}; diff --git a/src/mpsc/receiver.rs b/src/mpsc/receiver.rs new file mode 100644 index 0000000..f58f666 --- /dev/null +++ b/src/mpsc/receiver.rs @@ -0,0 +1,53 @@ +use futures::FutureExt; +use tokio::sync::mpsc; + +use crate::message::Message; + +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum ReceiverError { + Empty, + Closed, +} + +impl From for ReceiverError { + fn from(_: mpsc::error::RecvError) -> Self { + ReceiverError::Closed + } +} + +// impl From for ReceiverError { +// fn from(err: mpsc::error::TryRecvError) -> Self { +// match err { +// mpsc::error::TryRecvError::Empty => ReceiverError::Empty, +// mpsc::error::TryRecvError::Closed => ReceiverError::Closed, +// mpsc::error::TryRecvError::Lagged(_) => ReceiverError::Lagged, +// } +// } +// } + +#[derive(Debug)] +pub struct Receiver { + receiver: mpsc::Receiver, +} + +impl Receiver { + pub(crate) fn new(receiver: mpsc::Receiver) -> Self { + Receiver { receiver } + } + + pub async fn recv(&mut self) -> Result { + self.receiver.recv().await.ok_or(ReceiverError::Closed) + } + + pub fn recv_blocking(&mut self) -> Result { + self.receiver.blocking_recv().ok_or(ReceiverError::Closed) + } + + pub fn try_recv(&mut self) -> Result { + // Ok(self.receiver.try_recv()?) + self.recv() + .now_or_never() + .unwrap_or(Err(ReceiverError::Empty)) + } +} diff --git a/src/mpsc/sender.rs b/src/mpsc/sender.rs new file mode 100644 index 0000000..1b19c71 --- /dev/null +++ b/src/mpsc/sender.rs @@ -0,0 +1,51 @@ +use tokio::sync::mpsc; + +use crate::message::Message; + +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum SenderError { + Closed, + Full +} + +impl From> for SenderError { + fn from(_: mpsc::error::SendError) -> Self { + SenderError::Closed + } +} + +impl From> for SenderError { + fn from(err: mpsc::error::TrySendError) -> Self { + match err { + mpsc::error::TrySendError::Full(_) => SenderError::Full, + mpsc::error::TrySendError::Closed(_) => SenderError::Closed + } + } +} + +#[derive(Clone, Debug)] +pub struct Sender { + sender: mpsc::Sender, +} + +impl Sender { + pub(crate) fn new(sender: mpsc::Sender) -> Self { + Sender { sender } + } + + pub async fn send(&mut self, message: T) -> Result<(), SenderError> { + self.sender.send(message).await?; + Ok(()) + } + + pub fn send_blocking(&mut self, message: T) -> Result<(), SenderError> { + self.sender.blocking_send(message)?; + Ok(()) + } + + pub fn try_send(&mut self, message: T) -> Result<(), SenderError> { + self.sender.try_send(message)?; + Ok(()) + } +}