Browse Source

Add MPSC channels

master
Alex Mikhalev 4 years ago
parent
commit
1611572c13
  1. 4
      src/broker_task/mod.rs
  2. 101
      src/broker_task/registry.rs
  3. 119
      src/broker_task/request.rs
  4. 147
      src/lib.rs
  5. 5
      src/mpsc/mod.rs
  6. 53
      src/mpsc/receiver.rs
  7. 51
      src/mpsc/sender.rs

4
src/broker_task/mod.rs

@ -8,7 +8,9 @@ mod request;
pub(crate) use error::{BrokerError, BrokerResult}; pub(crate) use error::{BrokerError, BrokerResult};
use registry::Registry; use registry::Registry;
pub(crate) use request::{ pub(crate) use request::{
AdvertiseRequest, BrokerRequestBox, BrokerRequestSender, SubscribeRequest, AdvertiseRequest, BrokerRequestBox, BrokerRequestSender, ClaimReceiverError,
ClaimReceiverRequest, MakeBrokerRequest, SenderRequestNoWait, SenderRequestWait,
SubscribeRequest,
}; };
pub(crate) struct BrokerTask { pub(crate) struct BrokerTask {

101
src/broker_task/registry.rs

@ -1,4 +1,5 @@
use crate::Message; use super::{BrokerError, BrokerResult};
use crate::{Message, Sender};
use log::trace; use log::trace;
use std::{ use std::{
any::{Any, TypeId}, any::{Any, TypeId},
@ -6,8 +7,7 @@ use std::{
collections::HashMap, collections::HashMap,
marker::PhantomData, marker::PhantomData,
}; };
use tokio::sync::broadcast; use tokio::sync::{broadcast, mpsc, oneshot};
use super::{BrokerError, BrokerResult};
pub trait MessageType { pub trait MessageType {
fn message_type_id(&self) -> TypeId; fn message_type_id(&self) -> TypeId;
@ -15,6 +15,8 @@ pub trait MessageType {
fn message_type_name(&self) -> &'static str; fn message_type_name(&self) -> &'static str;
fn create_broadcast_sender(&self) -> ErasedSender; fn create_broadcast_sender(&self) -> ErasedSender;
fn create_mpsc_sender(&self) -> ErasedSender;
} }
pub struct BasicMessageType<T> { pub struct BasicMessageType<T> {
@ -40,7 +42,7 @@ impl<T: Message> MessageType for BasicMessageType<T> {
fn create_broadcast_sender(&self) -> ErasedSender { fn create_broadcast_sender(&self) -> ErasedSender {
trace!( trace!(
"Creating sender for {} ({:?})", "Creating broadcast sender for {} ({:?})",
std::any::type_name::<T>(), std::any::type_name::<T>(),
MessageType::type_id(self) MessageType::type_id(self)
); );
@ -49,10 +51,64 @@ impl<T: Message> MessageType for BasicMessageType<T> {
let sender: ErasedSender = Box::new(sender); let sender: ErasedSender = Box::new(sender);
sender sender
} }
fn create_mpsc_sender(&self) -> ErasedSender {
trace!(
"Creating mpsc sender for {} ({:?})",
std::any::type_name::<T>(),
MessageType::type_id(self)
);
Box::new(MpscSender::<T>::new())
}
} }
pub type ErasedSender = Box<dyn Any + Send + Sync>; pub type ErasedSender = Box<dyn Any + Send + Sync>;
pub struct MpscSender<T> {
sender: mpsc::Sender<T>,
waiting: Option<MpscWaiting<T>>,
}
struct MpscWaiting<T> {
receiver: mpsc::Receiver<T>,
waiting: Vec<oneshot::Sender<BrokerResult<Sender<T>>>>,
}
impl<T: Message> MpscSender<T> {
fn new() -> Self {
// TODO: configurable queue size (per message?)
let (sender, receiver) = mpsc::channel(8);
MpscSender {
sender,
waiting: Some(MpscWaiting {
receiver,
waiting: Vec::default(),
}),
}
}
pub fn clone_sender_or_wait(&mut self, wait: oneshot::Sender<BrokerResult<Sender<T>>>) {
if let Some(waiting) = &mut self.waiting {
waiting.waiting.push(wait);
} else {
let _ = wait.send(Ok(self.clone_sender()));
}
}
pub fn clone_sender(&self) -> Sender<T> {
Sender::new(self.sender.clone())
}
pub fn claim_receiver(&mut self) -> Option<mpsc::Receiver<T>> {
self.waiting.take().map(|w| {
for response_tx in w.waiting.into_iter() {
let _ = response_tx.send(Ok(self.clone_sender()));
}
w.receiver
})
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct TopicEntry { pub struct TopicEntry {
message_type_id: TypeId, message_type_id: TypeId,
@ -74,25 +130,44 @@ impl Default for Registry {
} }
impl Registry { impl Registry {
fn get_erased_sender_for_type(&mut self, message_type: &dyn MessageType) -> &ErasedSender { fn get_erased_sender_for_type<'a, 'b>(
&'a mut self,
message_type: &'b dyn MessageType,
create_sender: fn(&'b dyn MessageType) -> ErasedSender,
) -> &'a mut ErasedSender {
let type_id = message_type.message_type_id(); let type_id = message_type.message_type_id();
let type_name = message_type.message_type_name(); let type_name = message_type.message_type_name();
let topic_entry = self.topics.entry(type_name.into()); let topic_entry = self.topics.entry(type_name.into());
let topic_entry = topic_entry.or_insert_with(|| TopicEntry { let topic_entry = topic_entry.or_insert_with(|| TopicEntry {
message_type_id: type_id, message_type_id: type_id,
message_type_name: type_name.to_string(), message_type_name: type_name.to_string(),
sender: message_type.create_broadcast_sender(), sender: create_sender(message_type),
}); });
&topic_entry.sender &mut topic_entry.sender
} }
pub fn get_sender_for_type<T: Message>( fn get_sender_for_type<'a, 'b, T: 'static>(
&mut self, &'a mut self,
message_type: &dyn MessageType, message_type: &'b dyn MessageType,
) -> BrokerResult<&broadcast::Sender<T>> { create_sender: fn(&'b dyn MessageType) -> ErasedSender,
let erased = self.get_erased_sender_for_type(message_type); ) -> BrokerResult<&'a mut T> {
let erased = self.get_erased_sender_for_type(message_type, create_sender);
(**erased) (**erased)
.downcast_ref::<broadcast::Sender<T>>() .downcast_mut::<T>()
.ok_or(BrokerError::MismatchedType) .ok_or(BrokerError::MismatchedType)
} }
pub fn get_broadcast_sender<T: Message>(
&mut self,
message_type: &dyn MessageType,
) -> BrokerResult<&mut broadcast::Sender<T>> {
self.get_sender_for_type(message_type, MessageType::create_broadcast_sender)
}
pub fn get_mpsc_sender<T: Message>(
&mut self,
message_type: &dyn MessageType,
) -> BrokerResult<&mut MpscSender<T>> {
self.get_sender_for_type(message_type, MessageType::create_mpsc_sender)
}
} }

119
src/broker_task/request.rs

@ -1,18 +1,30 @@
use std::marker::PhantomData;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use super::{ use super::{
registry::{BasicMessageType, Registry}, registry::{BasicMessageType, Registry},
BrokerResult, BrokerError, BrokerResult,
}; };
use crate::{Message, Publication, Subscription}; use crate::{Message, Publication, Receiver, Sender, Subscription};
pub trait BrokerRequestInternal { pub trait BrokerRequestInternal: Send {
fn run_request(self: Box<Self>, registry: &mut Registry); fn run_request(self: Box<Self>, registry: &mut Registry);
} }
pub trait RequestHandler { pub trait MakeBrokerRequest: BrokerRequestInternal {
type Response; type Response;
fn create(response_tx: oneshot::Sender<Self::Response>) -> Box<Self>;
fn new() -> (Box<Self>, oneshot::Receiver<Self::Response>) {
let (response_tx, response_rx) = oneshot::channel();
(Self::create(response_tx), response_rx)
}
}
pub trait RequestHandler: Send {
type Response: Send;
fn handle(self, registry: &mut Registry) -> Self::Response; fn handle(self, registry: &mut Registry) -> Self::Response;
} }
@ -28,21 +40,18 @@ impl<H: RequestHandler> BrokerRequestInternal for BrokerRequest<H> {
} }
} }
impl<H: RequestHandler + Default> BrokerRequest<H> { impl<H: RequestHandler + Default> MakeBrokerRequest for BrokerRequest<H> {
type Response = H::Response;
fn create(response_tx: oneshot::Sender<H::Response>) -> Box<Self> { fn create(response_tx: oneshot::Sender<H::Response>) -> Box<Self> {
Box::new(Self { Box::new(Self {
handler: H::default(), handler: H::default(),
response_tx, response_tx,
}) })
} }
pub(crate) fn new() -> (Box<Self>, oneshot::Receiver<H::Response>) {
let (response_tx, response_rx) = oneshot::channel();
(Self::create(response_tx), response_rx)
}
} }
pub type BrokerRequestBox = Box<dyn BrokerRequestInternal + Send + Sync>; pub type BrokerRequestBox = Box<dyn BrokerRequestInternal>;
pub type BrokerRequestSender = mpsc::Sender<BrokerRequestBox>; pub type BrokerRequestSender = mpsc::Sender<BrokerRequestBox>;
pub type SubscribeRequest<T> = BrokerRequest<Subscribe<T>>; pub type SubscribeRequest<T> = BrokerRequest<Subscribe<T>>;
@ -62,7 +71,7 @@ impl<T: Message> RequestHandler for Subscribe<T> {
type Response = BrokerResult<Subscription<T>>; type Response = BrokerResult<Subscription<T>>;
fn handle(self, registry: &mut Registry) -> Self::Response { fn handle(self, registry: &mut Registry) -> Self::Response {
let sender = registry.get_sender_for_type::<T>(&self.msg_type); let sender = registry.get_broadcast_sender::<T>(&self.msg_type);
sender.map(|sender| { sender.map(|sender| {
let receiver = sender.subscribe(); let receiver = sender.subscribe();
Subscription::new(receiver) Subscription::new(receiver)
@ -87,7 +96,91 @@ impl<T: Message> RequestHandler for Advertise<T> {
type Response = BrokerResult<Publication<T>>; type Response = BrokerResult<Publication<T>>;
fn handle(self, registry: &mut Registry) -> Self::Response { fn handle(self, registry: &mut Registry) -> Self::Response {
let sender = registry.get_sender_for_type::<T>(&self.msg_type); let sender = registry.get_broadcast_sender::<T>(&self.msg_type);
sender.map(|sender| Publication::new(sender.clone())) sender.map(|sender| Publication::new(sender.clone()))
} }
} }
pub type ClaimReceiverRequest<T> = BrokerRequest<ClaimReceiver<T>>;
pub enum ClaimReceiverError {
AlreadyClaimed,
Broker(BrokerError),
}
pub struct ClaimReceiver<T: Message> {
msg_type: BasicMessageType<T>,
}
impl<T: Message> Default for ClaimReceiver<T> {
fn default() -> Self {
Self {
msg_type: BasicMessageType::<T>::default(),
}
}
}
impl<T: Message> RequestHandler for ClaimReceiver<T> {
type Response = Result<Receiver<T>, ClaimReceiverError>;
fn handle(self, registry: &mut Registry) -> Self::Response {
registry
.get_mpsc_sender::<T>(&self.msg_type)
.map_err(ClaimReceiverError::Broker)
.and_then(|sender| {
sender
.claim_receiver()
.ok_or(ClaimReceiverError::AlreadyClaimed)
})
.map(Receiver::new)
}
}
pub struct SenderRequest<T, W> {
msg_type: BasicMessageType<T>,
response_tx: oneshot::Sender<BrokerResult<Sender<T>>>,
_wait: PhantomData<W>,
}
pub trait GetSenderWait: Send {
const WAIT: bool;
}
pub struct NoWait;
impl GetSenderWait for NoWait {
const WAIT: bool = false;
}
pub type SenderRequestNoWait<T> = SenderRequest<T, NoWait>;
pub struct Wait;
impl GetSenderWait for Wait {
const WAIT: bool = true;
}
pub type SenderRequestWait<T> = SenderRequest<T, Wait>;
impl<T: Message, W: GetSenderWait> MakeBrokerRequest for SenderRequest<T, W> {
type Response = BrokerResult<Sender<T>>;
fn create(response_tx: oneshot::Sender<Self::Response>) -> Box<Self> {
Box::new(Self {
msg_type: BasicMessageType::default(),
response_tx,
_wait: PhantomData,
})
}
}
impl<T: Message, W: GetSenderWait> BrokerRequestInternal for SenderRequest<T, W> {
fn run_request(self: Box<Self>, registry: &mut Registry) {
let response = match registry.get_mpsc_sender::<T>(&self.msg_type) {
Ok(sender) => Ok(if W::WAIT {
sender.clone_sender_or_wait(self.response_tx);
return;
} else {
sender.clone_sender()
}),
Err(err) => Err(err),
};
let _ = self.response_tx.send(response);
}
}

147
src/lib.rs

@ -1,15 +1,16 @@
#![warn(clippy::all)]
mod broker_task; mod broker_task;
mod message; mod message;
mod pub_sub; pub mod mpsc;
pub mod pub_sub;
use broker_task::{ use broker_task::{BrokerError, BrokerRequestSender, BrokerResult, BrokerTask};
AdvertiseRequest, BrokerError, BrokerRequestSender, BrokerResult, BrokerTask, SubscribeRequest,
};
pub use message::Message; pub use message::Message;
pub use mpsc::*;
pub use pub_sub::*; pub use pub_sub::*;
use futures::executor::block_on; use futures::executor::block_on;
use tokio::sync::oneshot;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
#[non_exhaustive] #[non_exhaustive]
@ -27,9 +28,7 @@ impl From<BrokerError> for OrsbError {
} }
} }
fn map_broker_response<T>( fn map_broker_response<T>(result: Result<BrokerResult<T>, OrsbError>) -> Result<T, OrsbError> {
result: Result<BrokerResult<T>, oneshot::error::RecvError>,
) -> Result<T, OrsbError> {
match result { match result {
Ok(Ok(value)) => Ok(value), Ok(Ok(value)) => Ok(value),
Ok(Err(err)) => Err(err.into()), Ok(Err(err)) => Err(err.into()),
@ -37,6 +36,22 @@ fn map_broker_response<T>(
} }
} }
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum ClaimReceiverError {
AlreadyClaimed,
Other(OrsbError),
}
impl From<broker_task::ClaimReceiverError> for ClaimReceiverError {
fn from(err: broker_task::ClaimReceiverError) -> Self {
match err {
broker_task::ClaimReceiverError::AlreadyClaimed => ClaimReceiverError::AlreadyClaimed,
broker_task::ClaimReceiverError::Broker(err) => ClaimReceiverError::Other(err.into()),
}
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Orsb { pub struct Orsb {
request_tx: BrokerRequestSender, request_tx: BrokerRequestSender,
@ -55,22 +70,29 @@ impl Orsb {
(orsb, join_handle) (orsb, join_handle)
} }
pub async fn subscribe<T: Message>(&mut self) -> Result<Subscription<T>, OrsbError> { async fn make_request<R: broker_task::MakeBrokerRequest + 'static>(
let (subscribe_request, response_rx) = SubscribeRequest::<T>::new(); &mut self,
) -> Result<R::Response, OrsbError> {
let (req, response_rx) = R::new();
self.request_tx self.request_tx
.send(subscribe_request) .send(req)
.await .await
.or(Err(OrsbError::BrokerClosed))?; .or(Err(OrsbError::BrokerClosed))?;
map_broker_response(response_rx.await) response_rx.await.map_err(|_| OrsbError::NoResponse)
}
pub async fn subscribe<T: Message>(&mut self) -> Result<Subscription<T>, OrsbError> {
map_broker_response(
self.make_request::<broker_task::SubscribeRequest<T>>()
.await,
)
} }
pub async fn advertise<T: Message>(&mut self) -> Result<Publication<T>, OrsbError> { pub async fn advertise<T: Message>(&mut self) -> Result<Publication<T>, OrsbError> {
let (advertise_request, response_rx) = AdvertiseRequest::<T>::new(); map_broker_response(
self.request_tx self.make_request::<broker_task::AdvertiseRequest<T>>()
.send(advertise_request) .await,
.await )
.or(Err(OrsbError::BrokerClosed))?;
map_broker_response(response_rx.await)
} }
pub fn subscribe_blocking<T: Message>(&mut self) -> Result<Subscription<T>, OrsbError> { pub fn subscribe_blocking<T: Message>(&mut self) -> Result<Subscription<T>, OrsbError> {
@ -80,6 +102,41 @@ impl Orsb {
pub fn advertise_blocking<T: Message>(&mut self) -> Result<Publication<T>, OrsbError> { pub fn advertise_blocking<T: Message>(&mut self) -> Result<Publication<T>, OrsbError> {
block_on(self.advertise::<T>()) block_on(self.advertise::<T>())
} }
pub async fn claim_receiver<T: Message>(&mut self) -> Result<Receiver<T>, ClaimReceiverError> {
self.make_request::<broker_task::ClaimReceiverRequest<T>>()
.await
.map_err(ClaimReceiverError::Other)
.and_then(|inner_result| inner_result.map_err(ClaimReceiverError::from))
}
pub async fn sender<T: Message>(&mut self) -> Result<Sender<T>, OrsbError> {
map_broker_response(
self.make_request::<broker_task::SenderRequestNoWait<T>>()
.await,
)
}
pub async fn wait_sender<T: Message>(&mut self) -> Result<Sender<T>, OrsbError> {
map_broker_response(
self.make_request::<broker_task::SenderRequestWait<T>>()
.await,
)
}
pub fn claim_receiver_blocking<T: Message>(
&mut self,
) -> Result<Receiver<T>, ClaimReceiverError> {
block_on(self.claim_receiver::<T>())
}
pub fn sender_blocking<T: Message>(&mut self) -> Result<Sender<T>, OrsbError> {
block_on(self.sender::<T>())
}
pub fn wait_sender_blocking<T: Message>(&mut self) -> Result<Sender<T>, OrsbError> {
block_on(self.wait_sender::<T>())
}
} }
#[cfg(test)] #[cfg(test)]
@ -94,7 +151,7 @@ mod test {
} }
#[tokio::test] #[tokio::test]
async fn test_sync() { async fn test_pub_sub_sync() {
init_logger(); init_logger();
let mut orsb = Orsb::start_new(); let mut orsb = Orsb::start_new();
@ -117,7 +174,7 @@ mod test {
} }
#[tokio::test] #[tokio::test]
async fn test_async() { async fn test_pub_sub_async() {
init_logger(); init_logger();
let mut orsb = Orsb::start_new(); let mut orsb = Orsb::start_new();
@ -134,4 +191,54 @@ mod test {
assert_eq!(sub.recv().await, Ok(TestMsg(30))); assert_eq!(sub.recv().await, Ok(TestMsg(30)));
assert_eq!(sub.try_recv(), Err(SubscriptionError::Empty)); assert_eq!(sub.try_recv(), Err(SubscriptionError::Empty));
} }
#[tokio::test]
async fn test_mpsc_sync() {
init_logger();
let mut orsb = Orsb::start_new();
tokio::task::spawn_blocking(move || {
let mut recv = orsb.claim_receiver_blocking::<TestMsg>().unwrap();
assert_eq!(
orsb.claim_receiver_blocking::<TestMsg>().unwrap_err(),
ClaimReceiverError::AlreadyClaimed
);
let mut send = orsb.sender_blocking::<TestMsg>().unwrap();
send.send_blocking(TestMsg(10)).unwrap();
send.send_blocking(TestMsg(20)).unwrap();
send.send_blocking(TestMsg(30)).unwrap();
assert_eq!(recv.recv_blocking(), Ok(TestMsg(10)));
assert_eq!(recv.recv_blocking(), Ok(TestMsg(20)));
assert_eq!(recv.recv_blocking(), Ok(TestMsg(30)));
assert_eq!(recv.try_recv(), Err(ReceiverError::Empty));
})
.await
.unwrap();
}
#[tokio::test]
async fn test_mpsc_async() {
init_logger();
let mut orsb = Orsb::start_new();
let mut recv = orsb.claim_receiver::<TestMsg>().await.unwrap();
assert_eq!(
orsb.claim_receiver::<TestMsg>().await.unwrap_err(),
ClaimReceiverError::AlreadyClaimed
);
let mut send = orsb.sender::<TestMsg>().await.unwrap();
send.send(TestMsg(10)).await.unwrap();
send.send(TestMsg(20)).await.unwrap();
send.send(TestMsg(30)).await.unwrap();
assert_eq!(recv.recv().await, Ok(TestMsg(10)));
assert_eq!(recv.recv().await, Ok(TestMsg(20)));
assert_eq!(recv.recv().await, Ok(TestMsg(30)));
assert_eq!(recv.try_recv(), Err(ReceiverError::Empty));
}
} }

5
src/mpsc/mod.rs

@ -0,0 +1,5 @@
mod receiver;
mod sender;
pub use receiver::{Receiver, ReceiverError};
pub use sender::{Sender, SenderError};

53
src/mpsc/receiver.rs

@ -0,0 +1,53 @@
use futures::FutureExt;
use tokio::sync::mpsc;
use crate::message::Message;
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum ReceiverError {
Empty,
Closed,
}
impl From<mpsc::error::RecvError> for ReceiverError {
fn from(_: mpsc::error::RecvError) -> Self {
ReceiverError::Closed
}
}
// impl From<mpsc::error::TryRecvError> for ReceiverError {
// fn from(err: mpsc::error::TryRecvError) -> Self {
// match err {
// mpsc::error::TryRecvError::Empty => ReceiverError::Empty,
// mpsc::error::TryRecvError::Closed => ReceiverError::Closed,
// mpsc::error::TryRecvError::Lagged(_) => ReceiverError::Lagged,
// }
// }
// }
#[derive(Debug)]
pub struct Receiver<T> {
receiver: mpsc::Receiver<T>,
}
impl<T: Message> Receiver<T> {
pub(crate) fn new(receiver: mpsc::Receiver<T>) -> Self {
Receiver { receiver }
}
pub async fn recv(&mut self) -> Result<T, ReceiverError> {
self.receiver.recv().await.ok_or(ReceiverError::Closed)
}
pub fn recv_blocking(&mut self) -> Result<T, ReceiverError> {
self.receiver.blocking_recv().ok_or(ReceiverError::Closed)
}
pub fn try_recv(&mut self) -> Result<T, ReceiverError> {
// Ok(self.receiver.try_recv()?)
self.recv()
.now_or_never()
.unwrap_or(Err(ReceiverError::Empty))
}
}

51
src/mpsc/sender.rs

@ -0,0 +1,51 @@
use tokio::sync::mpsc;
use crate::message::Message;
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum SenderError {
Closed,
Full
}
impl<T> From<mpsc::error::SendError<T>> for SenderError {
fn from(_: mpsc::error::SendError<T>) -> Self {
SenderError::Closed
}
}
impl<T> From<mpsc::error::TrySendError<T>> for SenderError {
fn from(err: mpsc::error::TrySendError<T>) -> Self {
match err {
mpsc::error::TrySendError::Full(_) => SenderError::Full,
mpsc::error::TrySendError::Closed(_) => SenderError::Closed
}
}
}
#[derive(Clone, Debug)]
pub struct Sender<T> {
sender: mpsc::Sender<T>,
}
impl<T: Message> Sender<T> {
pub(crate) fn new(sender: mpsc::Sender<T>) -> Self {
Sender { sender }
}
pub async fn send(&mut self, message: T) -> Result<(), SenderError> {
self.sender.send(message).await?;
Ok(())
}
pub fn send_blocking(&mut self, message: T) -> Result<(), SenderError> {
self.sender.blocking_send(message)?;
Ok(())
}
pub fn try_send(&mut self, message: T) -> Result<(), SenderError> {
self.sender.try_send(message)?;
Ok(())
}
}
Loading…
Cancel
Save