diff --git a/Cargo.toml b/Cargo.toml index 17b89af..f7197f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,11 +7,16 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +futures = "0.3.8" +log = "0.4" [dependencies.tokio] version = "1" features = ["sync"] +[dev-dependencies] +env_logger = "0.8" + [dev-dependencies.tokio] version = "1" -features = ["rt", "macros"] +features = ["rt", "rt-multi-thread", "macros"] diff --git a/src/broker_task.rs b/src/broker_task.rs new file mode 100644 index 0000000..f0e1fcb --- /dev/null +++ b/src/broker_task.rs @@ -0,0 +1,194 @@ +use log::trace; +use std::{ + any::{Any, TypeId}, + collections::HashMap, + marker::PhantomData, +}; +use tokio::sync::{broadcast, mpsc, oneshot}; + +use crate::{message::Message, publication::Publication, subscription::Subscription}; + +pub type ErasedSender = Box; + +pub trait MessageType { + fn message_type_id(&self) -> TypeId; + + fn create_sender(&self) -> ErasedSender; +} + +pub struct BasicMessageType { + _phantom: PhantomData, +} + +impl Default for BasicMessageType { + fn default() -> Self { + BasicMessageType { + _phantom: PhantomData, + } + } +} + +impl MessageType for BasicMessageType { + fn message_type_id(&self) -> TypeId { + TypeId::of::() + } + + fn create_sender(&self) -> ErasedSender { + trace!( + "Creating sender for {} ({:?})", + std::any::type_name::(), + MessageType::type_id(self) + ); + // TODO: configurable queue size (per message?) + let (sender, _) = broadcast::channel::(8); + let sender: ErasedSender = Box::new(sender); + sender + } +} + +pub trait SubscribeRequest { + fn message_type(&self) -> &dyn MessageType; + + /// `sender` must be `tokio::sync::broadcast::Sender` where + /// `MessageType::get_message_type` returns the `TypeId` of `T`. + unsafe fn send_subscribe_response(self: Box, sender: &ErasedSender); +} + +pub type SubscribeRequestBox = Box; +pub type SubscribeRequestSender = mpsc::Sender; + +pub struct BasicSubscribeRequest { + msg_type: BasicMessageType, + response_tx: oneshot::Sender>, +} + +impl SubscribeRequest for BasicSubscribeRequest { + fn message_type(&self) -> &dyn MessageType { + &self.msg_type + } + + unsafe fn send_subscribe_response(self: Box, sender: &ErasedSender) { + let sender = &*(&**sender as *const dyn Any as *const broadcast::Sender); + // let sender = (**sender).downcast_ref::>().unwrap(); + let receiver = sender.subscribe(); + let subscription = Subscription::new(receiver); + let _ = self.response_tx.send(subscription); + } +} + +impl BasicSubscribeRequest { + pub(crate) fn new() -> (Self, oneshot::Receiver>) { + let (response_tx, response_rx) = oneshot::channel(); + ( + Self { + msg_type: Default::default(), + response_tx, + }, + response_rx, + ) + } +} + +pub trait AdvertiseRequest { + fn message_type(&self) -> &dyn MessageType; + + /// `sender` must be `tokio::sync::broadcast::Sender` where + /// `MessageType::get_message_type` returns the `TypeId` of `T`. + unsafe fn send_advertise_response(self: Box, sender: &ErasedSender); +} + +pub type AdvertiseRequestBox = Box; +pub type AdvertiseRequestSender = mpsc::Sender; + +pub struct BasicAdvertiseRequest { + msg_type: BasicMessageType, + response_tx: oneshot::Sender>, +} + +impl AdvertiseRequest for BasicAdvertiseRequest { + fn message_type(&self) -> &dyn MessageType { + &self.msg_type + } + + unsafe fn send_advertise_response(self: Box, sender: &ErasedSender) { + let sender = &*(&**sender as *const dyn Any as *const broadcast::Sender); + // let sender = (**sender).downcast_ref::>().unwrap(); + let publication = Publication::new(sender.clone()); + let _ = self.response_tx.send(publication); + } +} + +impl BasicAdvertiseRequest { + pub(crate) fn new() -> (Self, oneshot::Receiver>) { + let (response_tx, response_rx) = oneshot::channel(); + ( + Self { + msg_type: Default::default(), + response_tx, + }, + response_rx, + ) + } +} + +struct Registry { + senders: HashMap, +} + +impl Default for Registry { + fn default() -> Self { + Registry { + senders: HashMap::new(), + } + } +} + +impl Registry { + fn get_sender_for_type(&mut self, message_type: &dyn MessageType) -> &ErasedSender { + let type_id = message_type.message_type_id(); + let sender_entry = self.senders.entry(type_id); + sender_entry.or_insert_with(|| message_type.create_sender()) + } + + fn handle_subscribe(&mut self, subscribe_request: SubscribeRequestBox) { + let sender = self.get_sender_for_type(subscribe_request.message_type()); + unsafe { subscribe_request.send_subscribe_response(sender) } + } + + fn handle_advertise(&mut self, advertise_request: AdvertiseRequestBox) { + let sender = self.get_sender_for_type(advertise_request.message_type()); + unsafe { advertise_request.send_advertise_response(sender) } + } +} + +pub struct BrokerTask { + pub(crate) subscribe_rx: mpsc::Receiver, + pub(crate) advertise_rx: mpsc::Receiver, +} + +impl BrokerTask { + pub(crate) async fn run(mut self) { + trace!("BrokerTask starting"); + + let mut registry = Registry::default(); + + loop { + tokio::select! { + Some(subscribe_req) = self.subscribe_rx.recv() => { + registry.handle_subscribe(subscribe_req) + } + Some(advertise_req) = self.advertise_rx.recv() => { + registry.handle_advertise(advertise_req) + } + } + } + + // trace!("BrokerTask exiting"); + } +} + +impl Drop for BrokerTask { + fn drop(&mut self) { + trace!("BrokerTask dropped"); + } +} diff --git a/src/lib.rs b/src/lib.rs index ea8cd08..ec627fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,40 +1,79 @@ +mod broker_task; mod message; mod publication; mod subscription; +use broker_task::{ + AdvertiseRequestSender, BasicAdvertiseRequest, BasicSubscribeRequest, BrokerTask, + SubscribeRequestSender, +}; pub use message::Message; pub use publication::{Publication, SendError}; pub use subscription::{RecvError, Subscription}; -#[derive(Clone, Debug, PartialEq)] -#[non_exhaustive] -pub enum SubscribeError {} +use futures::executor::block_on; +use tokio::sync::mpsc; #[derive(Clone, Debug, PartialEq)] #[non_exhaustive] -pub enum AdvertiseError {} +pub enum OrsbError { + BrokerClosed, + NoResponse, +} -pub struct Orsb {} +#[derive(Debug, Clone)] +pub struct Orsb { + subscribe_tx: SubscribeRequestSender, + advertise_tx: AdvertiseRequestSender, +} impl Orsb { - pub fn new() -> Self { - todo!() + pub fn start_new() -> Self { + Self::start_new2().0 + } + + pub(crate) fn start_new2() -> (Self, tokio::task::JoinHandle<()>) { + let (subscribe_tx, subscribe_rx) = mpsc::channel(8); + let (advertise_tx, advertise_rx) = mpsc::channel(8); + + let broker = BrokerTask { + subscribe_rx, + advertise_rx, + }; + let join_handle = tokio::spawn(broker.run()); + + let orsb = Orsb { + subscribe_tx, + advertise_tx, + }; + + (orsb, join_handle) } - pub async fn subscribe(&mut self) -> Result, SubscribeError> { - todo!() + pub async fn subscribe(&mut self) -> Result, OrsbError> { + let (subscribe_request, response_rx) = BasicSubscribeRequest::::new(); + self.subscribe_tx + .send(Box::new(subscribe_request)) + .await + .or(Err(OrsbError::BrokerClosed))?; + response_rx.await.or(Err(OrsbError::NoResponse)) } - pub async fn advertise(&mut self) -> Result, AdvertiseError> { - todo!() + pub async fn advertise(&mut self) -> Result, OrsbError> { + let (advertise_request, response_rx) = BasicAdvertiseRequest::::new(); + self.advertise_tx + .send(Box::new(advertise_request)) + .await + .or(Err(OrsbError::BrokerClosed))?; + response_rx.await.or(Err(OrsbError::NoResponse)) } - pub fn subscribe_blocking(&mut self) -> Result, SubscribeError> { - todo!() + pub fn subscribe_blocking(&mut self) -> Result, OrsbError> { + block_on(self.subscribe::()) } - pub fn advertise_blocking(&mut self) -> Result, AdvertiseError> { - todo!() + pub fn advertise_blocking(&mut self) -> Result, OrsbError> { + block_on(self.advertise::()) } } @@ -45,20 +84,49 @@ mod test { #[derive(Clone, Debug, PartialEq)] struct TestMsg(u8); - #[test] - fn test_sync() { - let mut orsb = Orsb::new(); + fn init_logger() { + let _ = env_logger::try_init(); + } + + #[tokio::test] + async fn test_sync() { + init_logger(); + + let mut orsb = Orsb::start_new(); + + tokio::task::spawn_blocking(move || { + let mut sub = orsb.subscribe_blocking::().unwrap(); + let mut publ = orsb.advertise_blocking::().unwrap(); + + publ.send(TestMsg(10)).unwrap(); + publ.send(TestMsg(20)).unwrap(); + publ.send(TestMsg(30)).unwrap(); + + assert_eq!(sub.try_recv(), Ok(TestMsg(10))); + assert_eq!(sub.try_recv(), Ok(TestMsg(20))); + assert_eq!(sub.try_recv(), Ok(TestMsg(30))); + assert_eq!(sub.try_recv(), Err(RecvError::Empty)); + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_async() { + init_logger(); + + let mut orsb = Orsb::start_new(); - let mut sub = orsb.subscribe_blocking::().unwrap(); - let mut publ = orsb.advertise_blocking::().unwrap(); + let mut sub = orsb.subscribe::().await.unwrap(); + let mut publ = orsb.advertise::().await.unwrap(); publ.send(TestMsg(10)).unwrap(); publ.send(TestMsg(20)).unwrap(); publ.send(TestMsg(30)).unwrap(); - assert_eq!(sub.recv_blocking(), Ok(TestMsg(10))); - assert_eq!(sub.recv_blocking(), Ok(TestMsg(20))); - assert_eq!(sub.recv_blocking(), Ok(TestMsg(30))); + assert_eq!(sub.recv().await, Ok(TestMsg(10))); + assert_eq!(sub.recv().await, Ok(TestMsg(20))); + assert_eq!(sub.recv().await, Ok(TestMsg(30))); assert_eq!(sub.try_recv(), Err(RecvError::Empty)); } } diff --git a/src/message.rs b/src/message.rs index 093d7fa..0028b2d 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,3 +1,5 @@ -pub trait Message: 'static + Clone + Send + Sync {} +use std::any::Any; + +pub trait Message: 'static + Any + Clone + Send + Sync {} impl Message for T where T: 'static + Clone + Send + Sync {} diff --git a/src/publication.rs b/src/publication.rs index dc1ccaa..50c5ca5 100644 --- a/src/publication.rs +++ b/src/publication.rs @@ -1,20 +1,25 @@ -use std::marker::PhantomData; +use tokio::sync::broadcast; use crate::message::Message; #[derive(Clone, Debug, PartialEq)] #[non_exhaustive] -pub enum SendError { - NoListeners, -} +pub enum SendError {} +#[derive(Debug)] pub struct Publication { - _phantom: PhantomData, + sender: broadcast::Sender, } impl Publication { - pub fn send(&mut self, message: T) -> Result<(), SendError> { - let _ = message; - todo!() + pub(crate) fn new(sender: broadcast::Sender) -> Self { + Publication { sender } + } + + pub fn send(&mut self, message: T) -> Result { + match self.sender.send(message) { + Ok(subscribers) => Ok(subscribers), + Err(_) => Ok(0), + } } } diff --git a/src/subscription.rs b/src/subscription.rs index 6c338e9..692b1fb 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -1,4 +1,5 @@ -use std::marker::PhantomData; +use futures::executor::block_on; +use tokio::sync::broadcast; use crate::message::Message; @@ -10,20 +11,46 @@ pub enum RecvError { Lagged, } +impl From for RecvError { + fn from(err: broadcast::error::RecvError) -> Self { + match err { + broadcast::error::RecvError::Closed => RecvError::Closed, + broadcast::error::RecvError::Lagged(_) => RecvError::Lagged, + } + } +} + +impl From for RecvError { + fn from(err: broadcast::error::TryRecvError) -> Self { + match err { + broadcast::error::TryRecvError::Empty => RecvError::Empty, + broadcast::error::TryRecvError::Closed => RecvError::Closed, + broadcast::error::TryRecvError::Lagged(_) => RecvError::Lagged, + } + } +} + +#[derive(Debug)] pub struct Subscription { - _phantom: PhantomData, + receiver: broadcast::Receiver, +} + +impl Subscription { + pub(crate) fn new(receiver: broadcast::Receiver) -> Self { + Subscription { receiver } + } } impl Subscription { pub async fn recv(&mut self) -> Result { - todo!() + Ok(self.receiver.recv().await?) } pub fn recv_blocking(&mut self) -> Result { - todo!() + block_on(self.recv()) } pub fn try_recv(&mut self) -> Result { - todo!() + Ok(self.receiver.try_recv()?) } }