diff --git a/src/main.rs b/src/main.rs index 765c5e9..6c61ce5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,7 +63,7 @@ async fn main() -> Result<()> { device_id: "sprinklers_rs-0001".into(), client_id: "sprinklers_rs-0001".into(), }; - let mut mqtt_interface = mqtt_interface::MqttInterfaceTask::start(mqtt_options).await?; + let mut mqtt_interface = mqtt_interface::MqttInterfaceTask::start(mqtt_options); let update_listener = { let section_events = section_runner.subscribe().await?; @@ -101,7 +101,7 @@ async fn main() -> Result<()> { mqtt_interface.quit().await?; program_runner.quit().await?; section_runner.quit().await?; - tokio::task::yield_now().await; + actix::System::current().stop(); Ok(()) } diff --git a/src/mqtt_interface.rs b/src/mqtt_interface.rs index 1926f16..d775b9c 100644 --- a/src/mqtt_interface.rs +++ b/src/mqtt_interface.rs @@ -3,15 +3,17 @@ use crate::{ section_runner::SecRunnerState, section_runner_json::SecRunnerStateJson, }; +use actix::{Actor, ActorContext, ActorFuture, Addr, AsyncContext, Handler, WrapFuture}; use eyre::WrapErr; -use rumqttc::{LastWill, MqttOptions, QoS}; +use rumqttc::{LastWill, MqttOptions, Packet, QoS}; use std::{ + collections::HashSet, ops::{Deref, DerefMut}, sync::Arc, time::Duration, }; -use tokio::task::JoinHandle; -use tracing::{debug, info, trace, warn}; +use tokio::sync::oneshot; +use tracing::{debug, error, info, trace, warn}; #[derive(Clone, Debug)] struct Topics @@ -60,56 +62,98 @@ where fn section_runner(&self) -> String { format!("{}/section_runner", self.prefix.as_ref()) } + + fn requests(&self) -> String { + format!("{}/requests", self.prefix.as_ref()) + } + + fn responses(&self) -> String { + format!("{}/responses", self.prefix.as_ref()) + } } -#[derive(Clone, Debug)] -pub struct Options { - pub broker_host: String, - pub broker_port: u16, - pub device_id: String, - pub client_id: String, +struct EventLoopTask { + event_loop: rumqttc::EventLoop, + mqtt_addr: Addr, + quit_tx: oneshot::Sender<()>, + unreleased_pubs: HashSet, } -async fn event_loop_task( - interface: MqttInterface, - mut event_loop: rumqttc::EventLoop, -) -> eyre::Result<()> { - use rumqttc::{ConnectionError, Event}; - let reconnect_timeout = Duration::from_secs(5); - event_loop.set_reconnection_delay(reconnect_timeout); - loop { - match event_loop.poll().await { - Ok(Event::Incoming(incoming)) => { - debug!(incoming = debug(&incoming), "MQTT incoming message"); - #[allow(clippy::single_match)] - match incoming { - rumqttc::Packet::ConnAck(_) => { - info!("MQTT connected"); - { - // HACK: this really should just be await - // but that can sometimes deadlock if the publish channel is full - let mut interface = interface.clone(); - let fut = async move { interface.publish_connected(true).await }; - tokio::spawn(fut); - } - //.await?; +impl EventLoopTask { + fn new( + event_loop: rumqttc::EventLoop, + mqtt_addr: Addr, + quit_tx: oneshot::Sender<()>, + ) -> Self { + Self { + event_loop, + mqtt_addr, + quit_tx, + unreleased_pubs: HashSet::default(), + } + } + + fn handle_incoming(&mut self, incoming: Packet) { + trace!(incoming = debug(&incoming), "MQTT incoming message"); + #[allow(clippy::single_match)] + match incoming { + Packet::ConnAck(_) => { + self.mqtt_addr.do_send(Connected); + } + Packet::Publish(publish) => { + // Only deliver QoS 2 packets once + let deliver = if publish.qos == QoS::ExactlyOnce { + if self.unreleased_pubs.contains(&publish.pkid) { + false + } else { + self.unreleased_pubs.insert(publish.pkid); + true } - _ => {} + } else { + true + }; + if deliver { + self.mqtt_addr.do_send(PubRecieve(publish)); } } - Ok(Event::Outgoing(outgoing)) => { - trace!(outgoing = debug(&outgoing), "MQTT outgoing message"); - } - Err(ConnectionError::Cancel) => { - debug!("MQTT disconnecting"); - break; + Packet::PubRel(pubrel) => { + self.unreleased_pubs.remove(&pubrel.pkid); } - Err(err) => { - warn!("MQTT error, reconnecting: {}", err); + _ => {} + } + } + + async fn run(mut self) { + use rumqttc::{ConnectionError, Event}; + let reconnect_timeout = Duration::from_secs(5); + self.event_loop.set_reconnection_delay(reconnect_timeout); + loop { + match self.event_loop.poll().await { + Ok(Event::Incoming(incoming)) => { + self.handle_incoming(incoming); + } + Ok(Event::Outgoing(outgoing)) => { + trace!(outgoing = debug(&outgoing), "MQTT outgoing message"); + } + Err(ConnectionError::Cancel) => { + debug!("MQTT disconnecting"); + break; + } + Err(err) => { + warn!("MQTT error, reconnecting: {}", err); + } } } + let _ = self.quit_tx.send(()); } - Ok(()) +} + +#[derive(Clone, Debug)] +pub struct Options { + pub broker_host: String, + pub broker_port: u16, + pub device_id: String, + pub client_id: String, } #[derive(Clone)] @@ -217,35 +261,118 @@ impl MqttInterface { .await .wrap_err("failed to publish section runner") } + + pub async fn subscribe_requests(&mut self) -> eyre::Result<()> { + self.client + .subscribe(self.topics.requests(), QoS::ExactlyOnce) + .await?; + Ok(()) + } +} + +struct MqttActor { + interface: MqttInterface, + event_loop: Option, + quit_rx: Option>, +} + +impl MqttActor { + fn new(interface: MqttInterface, event_loop: rumqttc::EventLoop) -> Self { + Self { + interface, + event_loop: Some(event_loop), + quit_rx: None, + } + } +} + +impl Actor for MqttActor { + type Context = actix::Context; + + fn started(&mut self, ctx: &mut Self::Context) { + trace!("MqttActor starting"); + let event_loop = self.event_loop.take().expect("MqttActor already started"); + let (quit_tx, quit_rx) = oneshot::channel(); + ctx.spawn( + EventLoopTask::new(event_loop, ctx.address(), quit_tx) + .run() + .into_actor(self), + ); + self.quit_rx = Some(quit_rx); + } +} + +#[derive(actix::Message)] +#[rtype(result = "()")] +struct Quit; + +impl Handler for MqttActor { + type Result = actix::ResponseActFuture; + fn handle(&mut self, _msg: Quit, _ctx: &mut Self::Context) -> Self::Result { + let mut interface = self.interface.clone(); + let quit_rx = self.quit_rx.take().expect("MqttActor has already quit!"); + let fut = async move { + interface + .cancel() + .await + .expect("could not cancel MQTT client"); + let _ = quit_rx.await; + } + .into_actor(self) + .map(|_, _, ctx| ctx.stop()); + Box::pin(fut) + } +} + +#[derive(actix::Message)] +#[rtype(result = "()")] +struct Connected; + +impl Handler for MqttActor { + type Result = (); + + fn handle(&mut self, _msg: Connected, ctx: &mut Self::Context) -> Self::Result { + info!("MQTT connected"); + let mut interface = self.interface.clone(); + let fut = async move { + let res = interface.publish_connected(true).await; + let res = res.and(interface.subscribe_requests().await); + if let Err(err) = res { + error!("error in connection setup: {}", err); + } + }; + ctx.spawn(fut.into_actor(self)); + } +} + +#[derive(actix::Message)] +#[rtype(result = "()")] +struct PubRecieve(rumqttc::Publish); + +impl Handler for MqttActor { + type Result = (); + + fn handle(&mut self, msg: PubRecieve, _ctx: &mut Self::Context) -> Self::Result { + debug!("received MQTT pub: {:?}", msg.0); + } } pub struct MqttInterfaceTask { interface: MqttInterface, - join_handle: JoinHandle<()>, + addr: Addr, } impl MqttInterfaceTask { - pub async fn start(options: Options) -> eyre::Result { + pub fn start(options: Options) -> Self { let (interface, event_loop) = MqttInterface::new(options); - let join_handle = tokio::spawn({ - let interface = interface.clone(); - async move { - event_loop_task(interface, event_loop) - .await - .expect("error in event loop task") - } - }); + let addr = MqttActor::new(interface.clone(), event_loop).start(); - Ok(Self { - interface, - join_handle, - }) + Self { interface, addr } } - pub async fn quit(mut self) -> eyre::Result<()> { - self.interface.cancel().await?; - self.join_handle.await?; + pub async fn quit(self) -> eyre::Result<()> { + self.addr.send(Quit).await?; Ok(()) } }