diff --git a/src/broker_task.rs b/src/broker_task.rs index ec1c59b..5993275 100644 --- a/src/broker_task.rs +++ b/src/broker_task.rs @@ -9,8 +9,20 @@ use tokio::sync::{broadcast, mpsc, oneshot}; use crate::{message::Message, publication::Publication, subscription::Subscription}; +pub enum BrokerError { + MismatchedType, +} + +pub type BrokerResult = Result; + pub type ErasedSender = Box; +fn downcast_sender_ref(erased: &ErasedSender) -> BrokerResult<&broadcast::Sender> { + (**erased) + .downcast_ref::>() + .ok_or(BrokerError::MismatchedType) +} + pub trait MessageType { fn message_type_id(&self) -> TypeId; @@ -58,7 +70,7 @@ pub trait SubscribeRequest { /// `sender` must be `tokio::sync::broadcast::Sender` where /// `MessageType::get_message_type` returns the `TypeId` of `T`. - unsafe fn send_subscribe_response(self: Box, sender: &ErasedSender); + fn send_subscribe_response(self: Box, sender: &ErasedSender); } pub type SubscribeRequestBox = Box; @@ -66,7 +78,7 @@ pub type SubscribeRequestSender = mpsc::Sender; pub struct BasicSubscribeRequest { msg_type: BasicMessageType, - response_tx: oneshot::Sender>, + response_tx: oneshot::Sender>>, } impl SubscribeRequest for BasicSubscribeRequest { @@ -74,17 +86,17 @@ impl SubscribeRequest for BasicSubscribeRequest { &self.msg_type } - unsafe fn send_subscribe_response(self: Box, sender: &ErasedSender) { - let sender = &*(&**sender as *const dyn Any as *const broadcast::Sender); - // let sender = (**sender).downcast_ref::>().unwrap(); - let receiver = sender.subscribe(); - let subscription = Subscription::new(receiver); + fn send_subscribe_response(self: Box, sender: &ErasedSender) { + let subscription = downcast_sender_ref::(sender).map(|sender| { + let receiver = sender.subscribe(); + Subscription::new(receiver) + }); let _ = self.response_tx.send(subscription); } } impl BasicSubscribeRequest { - pub(crate) fn new() -> (Self, oneshot::Receiver>) { + pub(crate) fn new() -> (Self, oneshot::Receiver>>) { let (response_tx, response_rx) = oneshot::channel(); ( Self { @@ -101,7 +113,7 @@ pub trait AdvertiseRequest { /// `sender` must be `tokio::sync::broadcast::Sender` where /// `MessageType::get_message_type` returns the `TypeId` of `T`. - unsafe fn send_advertise_response(self: Box, sender: &ErasedSender); + fn send_advertise_response(self: Box, sender: &ErasedSender); } pub type AdvertiseRequestBox = Box; @@ -109,7 +121,7 @@ pub type AdvertiseRequestSender = mpsc::Sender; pub struct BasicAdvertiseRequest { msg_type: BasicMessageType, - response_tx: oneshot::Sender>, + response_tx: oneshot::Sender>>, } impl AdvertiseRequest for BasicAdvertiseRequest { @@ -117,16 +129,15 @@ impl AdvertiseRequest for BasicAdvertiseRequest { &self.msg_type } - unsafe fn send_advertise_response(self: Box, sender: &ErasedSender) { - let sender = &*(&**sender as *const dyn Any as *const broadcast::Sender); - // let sender = (**sender).downcast_ref::>().unwrap(); - let publication = Publication::new(sender.clone()); + fn send_advertise_response(self: Box, sender: &ErasedSender) { + let publication = + downcast_sender_ref::(sender).map(|sender| Publication::new(sender.clone())); let _ = self.response_tx.send(publication); } } impl BasicAdvertiseRequest { - pub(crate) fn new() -> (Self, oneshot::Receiver>) { + pub(crate) fn new() -> (Self, oneshot::Receiver>>) { let (response_tx, response_rx) = oneshot::channel(); ( Self { @@ -173,12 +184,12 @@ impl Registry { fn handle_subscribe(&mut self, subscribe_request: SubscribeRequestBox) { let sender = self.get_sender_for_type(subscribe_request.message_type()); - unsafe { subscribe_request.send_subscribe_response(sender) } + subscribe_request.send_subscribe_response(sender) } fn handle_advertise(&mut self, advertise_request: AdvertiseRequestBox) { let sender = self.get_sender_for_type(advertise_request.message_type()); - unsafe { advertise_request.send_advertise_response(sender) } + advertise_request.send_advertise_response(sender) } } diff --git a/src/lib.rs b/src/lib.rs index 3d13eee..4e5655c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,21 +4,40 @@ mod publication; mod subscription; use broker_task::{ - AdvertiseRequestSender, BasicAdvertiseRequest, BasicSubscribeRequest, BrokerTask, - SubscribeRequestSender, + AdvertiseRequestSender, BasicAdvertiseRequest, BasicSubscribeRequest, BrokerError, + BrokerResult, BrokerTask, SubscribeRequestSender, }; pub use message::Message; pub use publication::{Publication, PublishError}; pub use subscription::{Subscription, SubscriptionError}; use futures::executor::block_on; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; #[derive(Clone, Debug, PartialEq)] #[non_exhaustive] pub enum OrsbError { BrokerClosed, NoResponse, + MismatchedType, +} + +impl From for OrsbError { + fn from(err: BrokerError) -> Self { + match err { + BrokerError::MismatchedType => OrsbError::MismatchedType, + } + } +} + +fn map_broker_response( + result: Result, oneshot::error::RecvError>, +) -> Result { + match result { + Ok(Ok(value)) => Ok(value), + Ok(Err(err)) => Err(err.into()), + Err(_) => Err(OrsbError::NoResponse), + } } #[derive(Debug, Clone)] @@ -56,7 +75,7 @@ impl Orsb { .send(Box::new(subscribe_request)) .await .or(Err(OrsbError::BrokerClosed))?; - response_rx.await.or(Err(OrsbError::NoResponse)) + map_broker_response(response_rx.await) } pub async fn advertise(&mut self) -> Result, OrsbError> { @@ -65,7 +84,7 @@ impl Orsb { .send(Box::new(advertise_request)) .await .or(Err(OrsbError::BrokerClosed))?; - response_rx.await.or(Err(OrsbError::NoResponse)) + map_broker_response(response_rx.await) } pub fn subscribe_blocking(&mut self) -> Result, OrsbError> {