diff --git a/sprinklers_actors/src/program_runner.rs b/sprinklers_actors/src/program_runner.rs index d3ab723..8c165a3 100644 --- a/sprinklers_actors/src/program_runner.rs +++ b/sprinklers_actors/src/program_runner.rs @@ -316,6 +316,24 @@ impl Handler for ProgramRunnerActor { } } +impl StreamHandler for ProgramRunnerActor { + fn handle(&mut self, item: Zones, ctx: &mut Self::Context) { + ctx.notify(UpdateZones(item)) + } +} + +#[derive(Message)] +#[rtype(result = "()")] +struct ListenZones(watch::Receiver); + +impl Handler for ProgramRunnerActor { + type Result = (); + + fn handle(&mut self, msg: ListenZones, ctx: &mut Self::Context) -> Self::Result { + ctx.add_stream(msg.0); + } +} + impl StreamHandler for ProgramRunnerActor { fn handle(&mut self, item: Programs, ctx: &mut Self::Context) { ctx.notify(UpdatePrograms(item)) @@ -498,6 +516,11 @@ impl ProgramRunner { Ok(event_recv) } + pub fn listen_zones(&mut self, zones_watch: watch::Receiver) { + // TODO: should this adopt a similar pattern to update_listener? + self.addr.do_send(ListenZones(zones_watch)) + } + pub fn listen_programs(&mut self, programs_watch: watch::Receiver) { self.addr.do_send(ListenPrograms(programs_watch)) } diff --git a/sprinklers_actors/src/state_manager.rs b/sprinklers_actors/src/state_manager.rs index 0a5ecaf..5f8c955 100644 --- a/sprinklers_actors/src/state_manager.rs +++ b/sprinklers_actors/src/state_manager.rs @@ -1,4 +1,4 @@ -use sprinklers_core::model::{ProgramId, ProgramRef, ProgramUpdateData, Programs}; +use sprinklers_core::model::{ProgramId, ProgramRef, ProgramUpdateData, Programs, Zones}; use thiserror::Error; use tokio::sync::{mpsc, oneshot, watch}; @@ -14,6 +14,7 @@ pub enum Request { #[derive(Clone)] pub struct StateManager { request_tx: mpsc::Sender, + zones_watch: watch::Receiver, programs_watch: watch::Receiver, } @@ -34,10 +35,12 @@ pub type Result = std::result::Result; impl StateManager { pub fn new( request_tx: mpsc::Sender, + zones_watch: watch::Receiver, programs_watch: watch::Receiver, ) -> Self { Self { request_tx, + zones_watch, programs_watch, } } @@ -59,6 +62,10 @@ impl StateManager { resp_rx.await.map_err(eyre::Report::from)? } + pub fn get_zones(&self) -> watch::Receiver { + self.zones_watch.clone() + } + pub fn get_programs(&self) -> watch::Receiver { self.programs_watch.clone() } diff --git a/sprinklers_mqtt/src/lib.rs b/sprinklers_mqtt/src/lib.rs index 582f9b4..3ffb0cb 100644 --- a/sprinklers_mqtt/src/lib.rs +++ b/sprinklers_mqtt/src/lib.rs @@ -75,15 +75,42 @@ impl MqttInterface { pub async fn publish_zones(&mut self, zones: &Zones) -> eyre::Result<()> { let zone_ids: Vec<_> = zones.keys().cloned().collect(); - self.publish_data(self.topics.zones(), &zone_ids) - .await - .wrap_err("failed to publish zone ids")?; + self.publish_zone_ids(&zone_ids).await?; for zone in zones.values() { self.publish_zone(zone).await?; } Ok(()) } + // TODO: figure out how to share logic with publish_programs_diff and publish_zones + pub async fn publish_zones_diff( + &mut self, + old_zones: &Zones, + zones: &Zones, + ) -> eyre::Result<()> { + for (id, zone) in zones { + let publish = match old_zones.get(id) { + Some(old_zone) => !Arc::ptr_eq(old_zone, zone), + None => { + let zone_ids: Vec<_> = zones.keys().cloned().collect(); + self.publish_zone_ids(&zone_ids).await?; + true + } + }; + if publish { + self.publish_zone(zone).await?; + } + } + Ok(()) + } + + pub async fn publish_zone_ids(&mut self, zone_ids: &[ZoneId]) -> eyre::Result<()> { + self.publish_data(self.topics.zones(), &zone_ids) + .await + .wrap_err("failed to publish zone ids")?; + Ok(()) + } + pub async fn publish_zone(&mut self, zone: &Zone) -> eyre::Result<()> { self.publish_data(self.topics.zone_data(zone.id), zone) .await diff --git a/sprinklers_mqtt/src/request/mod.rs b/sprinklers_mqtt/src/request/mod.rs index 430acae..4cb338f 100644 --- a/sprinklers_mqtt/src/request/mod.rs +++ b/sprinklers_mqtt/src/request/mod.rs @@ -5,12 +5,13 @@ use futures_util::{ready, FutureExt}; use num_derive::FromPrimitive; use serde::{Deserialize, Serialize}; use std::{fmt, future::Future, pin::Pin, task::Poll}; +use tokio::sync::watch; mod programs; mod zones; pub struct RequestContext { - pub zones: Zones, + pub zones: watch::Receiver, pub zone_runner: ZoneRunner, pub program_runner: ProgramRunner, pub state_manager: StateManager, diff --git a/sprinklers_mqtt/src/request/zones.rs b/sprinklers_mqtt/src/request/zones.rs index 958a6ae..88d3876 100644 --- a/sprinklers_mqtt/src/request/zones.rs +++ b/sprinklers_mqtt/src/request/zones.rs @@ -41,7 +41,7 @@ impl IRequest for RunZoneRequest { type Response = RunZoneResponse; fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut zone_runner = ctx.zone_runner.clone(); - let zone = self.zone_id.get_zone(&ctx.zones); + let zone = self.zone_id.get_zone(&*ctx.zones.borrow()); let duration = self.duration; Box::pin(async move { let zone = zone?; @@ -76,7 +76,7 @@ impl IRequest for CancelZoneRequest { type Response = CancelZoneResponse; fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut zone_runner = ctx.zone_runner.clone(); - let zone = self.zone_id.get_zone(&ctx.zones); + let zone = self.zone_id.get_zone(&*ctx.zones.borrow()); Box::pin(async move { let zone = zone?; let cancelled = zone_runner diff --git a/sprinklers_mqtt/src/update_listener.rs b/sprinklers_mqtt/src/update_listener.rs index 67d3c0f..c2b282e 100644 --- a/sprinklers_mqtt/src/update_listener.rs +++ b/sprinklers_mqtt/src/update_listener.rs @@ -6,12 +6,13 @@ use sprinklers_actors::{ use actix::{fut::wrap_future, Actor, ActorContext, Addr, AsyncContext, Handler, StreamHandler}; use futures_util::TryFutureExt; -use sprinklers_core::model::Programs; +use sprinklers_core::model::{Programs, Zones}; use tokio::sync::{broadcast, watch}; use tracing::{trace, warn}; struct UpdateListenerActor { mqtt_interface: MqttInterface, + old_zones: Option, old_programs: Option, } @@ -19,6 +20,7 @@ impl UpdateListenerActor { fn new(mqtt_interface: MqttInterface) -> Self { Self { mqtt_interface, + old_zones: None, old_programs: None, } } @@ -36,6 +38,42 @@ impl Actor for UpdateListenerActor { } } +impl StreamHandler for UpdateListenerActor { + fn handle(&mut self, zones: Zones, ctx: &mut Self::Context) { + let mut mqtt_interface = self.mqtt_interface.clone(); + + let old_zones = self.old_zones.replace(zones.clone()); + + let fut = async move { + mqtt_interface.publish_zones(&zones).await?; + for zone_id in zones.keys() { + mqtt_interface.publish_zone_state(*zone_id, false).await?; + } + + match old_zones { + None => { + mqtt_interface.publish_zones(&zones).await?; + + // Some what of a hack + // Initialize zone running states to false the first time we + // receive zones + for zone_id in zones.keys() { + mqtt_interface.publish_zone_state(*zone_id, false).await?; + } + } + Some(old_zones) => { + mqtt_interface + .publish_zones_diff(&old_zones, &zones) + .await?; + } + } + Ok(()) + } + .unwrap_or_else(|err: eyre::Report| warn!("could not publish programs: {:?}", err)); + ctx.spawn(wrap_future(fut)); + } +} + impl StreamHandler> for UpdateListenerActor { fn handle(&mut self, event: Result, ctx: &mut Self::Context) { let event = match event { @@ -196,6 +234,12 @@ where } } +impl Listenable for watch::Receiver { + fn listen(self, ctx: &mut ::Context) { + ctx.add_stream(self); + } +} + impl Listenable for ZoneEventRecv { fn listen(self, ctx: &mut ::Context) { ctx.add_stream(self); @@ -237,6 +281,10 @@ impl UpdateListener { self.addr.do_send(Listen(listener)); } + pub fn listen_zones(&mut self, zones: watch::Receiver) { + self.listen(zones); + } + pub fn listen_zone_events(&mut self, zone_events: ZoneEventRecv) { self.listen(zone_events); } diff --git a/sprinklers_rs/src/main.rs b/sprinklers_rs/src/main.rs index fea0316..d2eb578 100644 --- a/sprinklers_rs/src/main.rs +++ b/sprinklers_rs/src/main.rs @@ -12,7 +12,7 @@ use sprinklers_mqtt as mqtt; use eyre::{Result, WrapErr}; use settings::Settings; -use tracing::{debug, info}; +use tracing::info; use tracing_subscriber::EnvFilter; #[actix_rt::main] @@ -31,11 +31,6 @@ async fn main() -> Result<()> { let db_conn = database::setup_db()?; - let zones = database::query_zones(&db_conn)?; - for zone in zones.values() { - debug!(zone = debug(&zone), "read zone"); - } - let zone_interface = settings.zone_interface.build()?; let mut zone_runner = actors::ZoneRunner::new(zone_interface); let mut program_runner = actors::ProgramRunner::new(zone_runner.clone()); @@ -45,29 +40,24 @@ async fn main() -> Result<()> { let mqtt_options = settings.mqtt; // TODO: have ability to update zones / other data let request_context = mqtt::RequestContext { - zones: zones.clone(), + zones: state_manager.get_zones(), zone_runner: zone_runner.clone(), program_runner: program_runner.clone(), state_manager: state_manager.clone(), }; - let mut mqtt_interface = mqtt::MqttInterfaceTask::start(mqtt_options, request_context); + let mqtt_interface = mqtt::MqttInterfaceTask::start(mqtt_options, request_context); let mut update_listener = mqtt::UpdateListener::start(mqtt_interface.clone()); + update_listener.listen_zones(state_manager.get_zones()); update_listener.listen_zone_events(zone_runner.subscribe().await?); update_listener.listen_zone_runner(zone_runner.get_state_recv()); update_listener.listen_programs(state_manager.get_programs()); update_listener.listen_program_events(program_runner.subscribe().await?); // Only listen to programs now so above subscriptions get events + program_runner.listen_zones(state_manager.get_zones()); program_runner.listen_programs(state_manager.get_programs()); - program_runner.update_zones(zones.clone()).await?; - // TODO: update listener should probably do this - mqtt_interface.publish_zones(&zones).await?; - for zone_id in zones.keys() { - mqtt_interface.publish_zone_state(*zone_id, false).await?; - } - info!("sprinklers_rs initialized"); tokio::signal::ctrl_c().await?; diff --git a/sprinklers_rs/src/state_manager.rs b/sprinklers_rs/src/state_manager.rs index 80f4aea..5f9e2de 100644 --- a/sprinklers_rs/src/state_manager.rs +++ b/sprinklers_rs/src/state_manager.rs @@ -2,30 +2,34 @@ use sprinklers_actors::{state_manager, StateManager}; use sprinklers_database::{self as database, DbConn}; use eyre::{eyre, WrapErr}; -use sprinklers_core::model::{ProgramRef, Programs}; +use sprinklers_core::model::{ProgramRef, Programs, Zones}; use tokio::{ runtime, sync::{mpsc, watch}, }; -use tracing::warn; +use tracing::{trace, warn}; pub struct StateManagerThread { db_conn: DbConn, request_rx: mpsc::Receiver, + zones_tx: watch::Sender, programs_tx: watch::Sender, } struct State { + zones: Zones, programs: Programs, } impl StateManagerThread { pub fn start(db_conn: DbConn) -> StateManager { let (request_tx, request_rx) = mpsc::channel(8); + let (zones_tx, zones_rx) = watch::channel(Zones::default()); let (programs_tx, programs_rx) = watch::channel(Programs::default()); let task = StateManagerThread { db_conn, request_rx, + zones_tx, programs_tx, }; let runtime_handle = runtime::Handle::current(); @@ -33,7 +37,13 @@ impl StateManagerThread { .name("sprinklers_rs::state_manager".into()) .spawn(move || task.run(runtime_handle)) .expect("could not start state_manager thread"); - StateManager::new(request_tx, programs_rx) + StateManager::new(request_tx, zones_rx, programs_rx) + } + + fn broadcast_zones(&mut self, zones: Zones) { + if let Err(err) = self.zones_tx.broadcast(zones) { + warn!("could not broadcast zones: {}", err); + } } fn broadcast_programs(&mut self, programs: Programs) { @@ -83,15 +93,22 @@ impl StateManagerThread { } fn load_state(&mut self) -> eyre::Result { + let zones = database::query_zones(&self.db_conn)?; + + for zone in zones.values() { + trace!(zone = debug(&zone), "read zone"); + } + let programs = database::query_programs(&self.db_conn).wrap_err("could not query programs")?; - Ok(State { programs }) + Ok(State { zones, programs }) } fn run(mut self, runtime_handle: runtime::Handle) { let mut state = self.load_state().expect("could not load initial state"); + self.broadcast_zones(state.zones.clone()); self.broadcast_programs(state.programs.clone()); while let Some(request) = runtime_handle.block_on(self.request_rx.recv()) {