From 1cc4caae608a48afe544f794f585e3791e68de72 Mon Sep 17 00:00:00 2001 From: Alex Mikhalev Date: Thu, 1 Oct 2020 23:06:33 -0600 Subject: [PATCH] Implement updating program from MQTT --- sprinklers_actors/src/state_manager.rs | 20 ++++- sprinklers_database/src/program.rs | 18 ++++- sprinklers_mqtt/Cargo.toml | 2 +- sprinklers_mqtt/src/actor.rs | 2 +- sprinklers_mqtt/src/request/mod.rs | 14 ++-- sprinklers_mqtt/src/request/programs.rs | 50 ++++++++++-- sprinklers_mqtt/src/request/sections.rs | 11 ++- sprinklers_mqtt/src/update_listener.rs | 66 +++++++++++++-- sprinklers_rs/Cargo.toml | 1 + sprinklers_rs/src/main.rs | 21 ++--- sprinklers_rs/src/state_manager.rs | 103 ++++++++++++++++++++++++ 11 files changed, 264 insertions(+), 44 deletions(-) create mode 100644 sprinklers_rs/src/state_manager.rs diff --git a/sprinklers_actors/src/state_manager.rs b/sprinklers_actors/src/state_manager.rs index 805f10e..e942716 100644 --- a/sprinklers_actors/src/state_manager.rs +++ b/sprinklers_actors/src/state_manager.rs @@ -1,5 +1,5 @@ -use eyre::Result; use sprinklers_core::model::{ProgramId, ProgramRef, ProgramUpdateData, Programs}; +use thiserror::Error; use tokio::sync::{mpsc, oneshot, watch}; #[derive(Debug)] @@ -17,6 +17,20 @@ pub struct StateManager { programs_watch: watch::Receiver, } +#[derive(Debug, Error)] +pub enum StateError { + #[error("no such program: {0}")] + NoSuchProgram(ProgramId), + #[error("internal error: {0}")] + Other( + #[from] + #[source] + eyre::Report, + ), +} + +pub type Result = std::result::Result; + impl StateManager { pub fn new( request_tx: mpsc::Sender, @@ -40,8 +54,8 @@ impl StateManager { update, resp_tx, }) - .await?; - resp_rx.await? + .await.map_err(eyre::Report::from)?; + resp_rx.await.map_err(eyre::Report::from)? } pub fn get_programs(&self) -> watch::Receiver { diff --git a/sprinklers_database/src/program.rs b/sprinklers_database/src/program.rs index a81213f..5d5d826 100644 --- a/sprinklers_database/src/program.rs +++ b/sprinklers_database/src/program.rs @@ -9,6 +9,7 @@ use sprinklers_core::{ use eyre::Result; use rusqlite::{params, Row, ToSql, Transaction, NO_PARAMS}; +use thiserror::Error; type SqlProgramSequence = SqlJson; type SqlSchedule = SqlJson; @@ -99,6 +100,10 @@ fn sequence_as_sql<'a>( .map(move |(seq_num, item)| item_as_sql(item, program_id, seq_num)) } +#[derive(Clone, Debug, Error)] +#[error("no such program id: {0}")] +pub struct NoSuchProgram(pub ProgramId); + pub fn query_programs(conn: &DbConn) -> Result { let query_sql = "\ SELECT p.id, p.name, p.enabled, p.schedule, ps.sequence @@ -121,7 +126,12 @@ FROM programs AS p INNER JOIN program_sequences AS ps ON ps.program_id = p.id WHERE p.id = ?1;"; let mut statement = conn.prepare_cached(query_sql)?; - Ok(statement.query_row(params![id], from_sql)?) + statement + .query_row(params![id], from_sql) + .map_err(|err| match err { + rusqlite::Error::QueryReturnedNoRows => NoSuchProgram(id).into(), + e => e.into(), + }) } pub fn update_program( @@ -137,8 +147,12 @@ UPDATE programs enabled = ifnull(?3, enabled), schedule = ifnull(?4, schedule) WHERE id = ?1;"; - conn.prepare_cached(update_sql)? + let updated = conn + .prepare_cached(update_sql)? .execute(&update_as_sql(id, prog))?; + if updated == 0 { + return Err(NoSuchProgram(id).into()); + } if let Some(sequence) = &prog.sequence { let clear_seq_sql = "\ DELETE diff --git a/sprinklers_mqtt/Cargo.toml b/sprinklers_mqtt/Cargo.toml index bbe6027..db40c23 100644 --- a/sprinklers_mqtt/Cargo.toml +++ b/sprinklers_mqtt/Cargo.toml @@ -14,7 +14,7 @@ actix = { version = "0.10.0", default-features = false } eyre = "0.6.0" rumqttc = "0.1.0" tracing = "0.1.19" -serde = { version = "1.0.116", features = ["derive"] } +serde = { version = "1.0.116", features = ["derive", "rc"] } serde_json = "1.0.57" chrono = "0.4.15" num-traits = "0.2.12" diff --git a/sprinklers_mqtt/src/actor.rs b/sprinklers_mqtt/src/actor.rs index 65d8558..8df6871 100644 --- a/sprinklers_mqtt/src/actor.rs +++ b/sprinklers_mqtt/src/actor.rs @@ -36,7 +36,7 @@ impl MqttActor { }; let rid = request_value.rid; let request_fut = - serde_json::from_value::(request_value.rest).map(|mut request| { + serde_json::from_value::(request_value.rest).map(|request| { debug!(rid, "about to execute request: {:?}", request); request.execute(&mut self.request_context) }); diff --git a/sprinklers_mqtt/src/request/mod.rs b/sprinklers_mqtt/src/request/mod.rs index 23829f5..4532519 100644 --- a/sprinklers_mqtt/src/request/mod.rs +++ b/sprinklers_mqtt/src/request/mod.rs @@ -1,4 +1,4 @@ -use sprinklers_actors::{program_runner::ProgramRunner, section_runner::SectionRunner}; +use sprinklers_actors::{ProgramRunner, SectionRunner, StateManager}; use sprinklers_core::model::Sections; use futures_util::{ready, FutureExt}; @@ -13,6 +13,7 @@ pub struct RequestContext { pub sections: Sections, pub section_runner: SectionRunner, pub program_runner: ProgramRunner, + pub state_manager: StateManager, } type BoxFuture = Pin>>; @@ -190,11 +191,12 @@ type RequestFuture = BoxFuture>; trait IRequest { type Response: Serialize; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture; + fn exec(self, ctx: &mut RequestContext) -> RequestFuture; - fn exec_erased(&mut self, ctx: &mut RequestContext) -> RequestFuture + fn exec_erased(self, ctx: &mut RequestContext) -> RequestFuture where Self::Response: 'static, + Self: Sized, { // TODO: figure out how to get rid of this nested box Box::pin(ErasedRequestFuture(self.exec(ctx))) @@ -263,12 +265,13 @@ pub enum Request { PauseSectionRunner(sections::PauseSectionRunnerRequest), RunProgram(programs::RunProgramRequest), CancelProgram(programs::CancelProgramRequest), + UpdateProgram(programs::UpdateProgramRequest), } impl IRequest for Request { type Response = ResponseValue; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { match self { Request::RunSection(req) => req.exec_erased(ctx), Request::CancelSection(req) => req.exec_erased(ctx), @@ -276,12 +279,13 @@ impl IRequest for Request { Request::PauseSectionRunner(req) => req.exec_erased(ctx), Request::RunProgram(req) => req.exec_erased(ctx), Request::CancelProgram(req) => req.exec_erased(ctx), + Request::UpdateProgram(req) => req.exec_erased(ctx), } } } impl Request { - pub fn execute(&mut self, ctx: &mut RequestContext) -> impl Future { + pub fn execute(self, ctx: &mut RequestContext) -> impl Future { self.exec(ctx).map(Response::from) } } diff --git a/sprinklers_mqtt/src/request/programs.rs b/sprinklers_mqtt/src/request/programs.rs index 49db41e..aad7d92 100644 --- a/sprinklers_mqtt/src/request/programs.rs +++ b/sprinklers_mqtt/src/request/programs.rs @@ -1,6 +1,6 @@ use super::*; -use sprinklers_actors::program_runner::ProgramRunnerError; -use sprinklers_core::model::ProgramId; +use sprinklers_actors::{program_runner::Error, state_manager::StateError}; +use sprinklers_core::model::{ProgramId, ProgramRef, ProgramUpdateData}; use eyre::WrapErr; @@ -13,7 +13,7 @@ pub struct RunProgramRequest { impl IRequest for RunProgramRequest { type Response = ResponseMessage; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut program_runner = ctx.program_runner.clone(); let program_id = self.program_id; Box::pin(async move { @@ -22,7 +22,7 @@ impl IRequest for RunProgramRequest { "running program '{}'", program.name ))), - Err(e @ ProgramRunnerError::InvalidProgramId(_)) => Err(RequestError::with_name( + Err(e @ Error::InvalidProgramId(_)) => Err(RequestError::with_name( ErrorCode::NoSuchProgram, e, "program", @@ -42,7 +42,7 @@ pub struct CancelProgramRequest { impl IRequest for CancelProgramRequest { type Response = ResponseMessage; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut program_runner = ctx.program_runner.clone(); let program_id = self.program_id; Box::pin(async move { @@ -64,3 +64,43 @@ impl IRequest for CancelProgramRequest { }) } } + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UpdateProgramRequest { + program_id: ProgramId, + data: ProgramUpdateData, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UpdateProgramResponse { + message: String, + data: ProgramRef, +} + +impl IRequest for UpdateProgramRequest { + type Response = UpdateProgramResponse; + + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { + let mut state_manager = ctx.state_manager.clone(); + Box::pin(async move { + let new_program = state_manager + .update_program(self.program_id, self.data) + .await + .map_err(|err| match err { + e @ StateError::NoSuchProgram(_) => RequestError::with_name_and_cause( + ErrorCode::NoSuchProgram, + "could not update program", + "program", + e, + ), + e => RequestError::from(eyre::Report::from(e)), + })?; + Ok(UpdateProgramResponse { + message: format!("updated program '{}'", new_program.name), + data: new_program, + }) + }) + } +} diff --git a/sprinklers_mqtt/src/request/sections.rs b/sprinklers_mqtt/src/request/sections.rs index 23dcf41..f299321 100644 --- a/sprinklers_mqtt/src/request/sections.rs +++ b/sprinklers_mqtt/src/request/sections.rs @@ -36,7 +36,7 @@ pub struct RunSectionResponse { impl IRequest for RunSectionRequest { type Response = RunSectionResponse; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut section_runner = ctx.section_runner.clone(); let section = self.section_id.get_section(&ctx.sections); let duration = self.duration; @@ -69,7 +69,7 @@ pub struct CancelSectionResponse { impl IRequest for CancelSectionRequest { type Response = CancelSectionResponse; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut section_runner = ctx.section_runner.clone(); let section = self.section_id.get_section(&ctx.sections); Box::pin(async move { @@ -104,12 +104,11 @@ pub struct CancelSectionRunIdResponse { impl IRequest for CancelSectionRunIdRequest { type Response = ResponseMessage; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut section_runner = ctx.section_runner.clone(); - let run_id = self.run_id.clone(); Box::pin(async move { let cancelled = section_runner - .cancel_run(run_id) + .cancel_run(self.run_id) .await .wrap_err("could not cancel section run")?; if cancelled { @@ -140,7 +139,7 @@ pub struct PauseSectionRunnerResponse { impl IRequest for PauseSectionRunnerRequest { type Response = PauseSectionRunnerResponse; - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + fn exec(self, ctx: &mut RequestContext) -> RequestFuture { let mut section_runner = ctx.section_runner.clone(); let paused = self.paused; Box::pin(async move { diff --git a/sprinklers_mqtt/src/update_listener.rs b/sprinklers_mqtt/src/update_listener.rs index 39fef84..9850be4 100644 --- a/sprinklers_mqtt/src/update_listener.rs +++ b/sprinklers_mqtt/src/update_listener.rs @@ -5,11 +5,22 @@ use sprinklers_actors::{ }; use actix::{fut::wrap_future, Actor, ActorContext, Addr, AsyncContext, Handler, StreamHandler}; -use tokio::sync::broadcast; +use sprinklers_core::model::Programs; +use tokio::sync::{broadcast, watch}; use tracing::{trace, warn}; struct UpdateListenerActor { mqtt_interface: MqttInterface, + has_published_program_states: bool, +} + +impl UpdateListenerActor { + fn new(mqtt_interface: MqttInterface) -> Self { + Self { + mqtt_interface, + has_published_program_states: false, + } + } } impl Actor for UpdateListenerActor { @@ -101,6 +112,35 @@ impl StreamHandler for UpdateListenerActor { } } +impl StreamHandler for UpdateListenerActor { + fn handle(&mut self, programs: Programs, ctx: &mut Self::Context) { + let mut mqtt_interface = self.mqtt_interface.clone(); + + let has_published_program_states = self.has_published_program_states; + self.has_published_program_states = true; + + let fut = async move { + if let Err(err) = mqtt_interface.publish_programs(&programs).await { + warn!("could not publish programs: {:?}", err); + } + // Some what of a hack + // Initialize program running states to false the first time we + // receive programs + if !has_published_program_states { + for program_id in programs.keys() { + if let Err(err) = mqtt_interface + .publish_program_running(*program_id, false) + .await + { + warn!("could not publish program running: {:?}", err); + } + } + } + }; + ctx.spawn(wrap_future(fut)); + } +} + #[derive(actix::Message)] #[rtype(result = "()")] struct Quit; @@ -143,13 +183,19 @@ impl Listenable for SectionEventRecv { } } -impl Listenable for ProgramEventRecv { +impl Listenable for SecRunnerStateRecv { fn listen(self, ctx: &mut ::Context) { ctx.add_stream(self); } } -impl Listenable for SecRunnerStateRecv { +impl Listenable for watch::Receiver { + fn listen(self, ctx: &mut ::Context) { + ctx.add_stream(self); + } +} + +impl Listenable for ProgramEventRecv { fn listen(self, ctx: &mut ::Context) { ctx.add_stream(self); } @@ -161,7 +207,7 @@ pub struct UpdateListener { impl UpdateListener { pub fn start(mqtt_interface: MqttInterface) -> Self { - let addr = UpdateListenerActor { mqtt_interface }.start(); + let addr = UpdateListenerActor::new(mqtt_interface).start(); Self { addr } } @@ -176,14 +222,18 @@ impl UpdateListener { self.listen(section_events); } - pub fn listen_program_events(&mut self, program_events: ProgramEventRecv) { - self.listen(program_events); - } - pub fn listen_section_runner(&mut self, sec_runner_state_recv: SecRunnerStateRecv) { self.listen(sec_runner_state_recv); } + pub fn listen_programs(&mut self, programs: watch::Receiver) { + self.listen(programs); + } + + pub fn listen_program_events(&mut self, program_events: ProgramEventRecv) { + self.listen(program_events); + } + pub async fn quit(self) -> eyre::Result<()> { Ok(self.addr.send(Quit).await?) } diff --git a/sprinklers_rs/Cargo.toml b/sprinklers_rs/Cargo.toml index 0cf05b7..daaf812 100644 --- a/sprinklers_rs/Cargo.toml +++ b/sprinklers_rs/Cargo.toml @@ -18,6 +18,7 @@ tokio = "0.2.22" tracing = { version = "0.1.19", features = ["log"] } actix = { version = "0.10.0", default-features = false } actix-rt = "1.1.1" +chrono = "0.4.19" [dependencies.tracing-subscriber] version = "0.2.11" diff --git a/sprinklers_rs/src/main.rs b/sprinklers_rs/src/main.rs index 67d3b74..28d4ffc 100644 --- a/sprinklers_rs/src/main.rs +++ b/sprinklers_rs/src/main.rs @@ -2,6 +2,7 @@ #![warn(clippy::print_stdout)] // mod option_future; +mod state_manager; use sprinklers_actors as actors; use sprinklers_core::section_interface::MockSectionInterface; @@ -24,9 +25,9 @@ async fn main() -> Result<()> { info!("Starting sprinklers_rs..."); color_eyre::install()?; - let conn = database::setup_db()?; + let db_conn = database::setup_db()?; - let sections = database::query_sections(&conn)?; + let sections = database::query_sections(&db_conn)?; for sec in sections.values() { debug!(section = debug(&sec), "read section"); } @@ -36,11 +37,9 @@ async fn main() -> Result<()> { let mut section_runner = actors::SectionRunner::new(section_interface); let mut program_runner = actors::ProgramRunner::new(section_runner.clone()); - let programs = database::query_programs(&conn)?; + let state_manager = crate::state_manager::StateManagerThread::start(db_conn); - for prog in programs.values() { - debug!(program = debug(&prog), "read program"); - } + program_runner.listen_programs(state_manager.get_programs()); let mqtt_options = mqtt::Options { broker_host: "localhost".into(), @@ -53,12 +52,14 @@ async fn main() -> Result<()> { sections: sections.clone(), section_runner: section_runner.clone(), program_runner: program_runner.clone(), + state_manager: state_manager.clone(), }; let mut mqtt_interface = mqtt::MqttInterfaceTask::start(mqtt_options, request_context); let mut update_listener = mqtt::UpdateListener::start(mqtt_interface.clone()); update_listener.listen_section_events(section_runner.subscribe().await?); update_listener.listen_section_runner(section_runner.get_state_recv()); + update_listener.listen_programs(state_manager.get_programs()); update_listener.listen_program_events(program_runner.subscribe().await?); program_runner.update_sections(sections.clone()).await?; @@ -69,13 +70,6 @@ async fn main() -> Result<()> { .publish_section_state(*section_id, false) .await?; } - program_runner.update_programs(programs.clone()).await?; - for program_id in programs.keys() { - mqtt_interface - .publish_program_running(*program_id, false) - .await?; - } - mqtt_interface.publish_programs(&programs).await?; info!("sprinklers_rs initialized"); @@ -84,6 +78,7 @@ async fn main() -> Result<()> { update_listener.quit().await?; mqtt_interface.quit().await?; + drop(state_manager); program_runner.quit().await?; section_runner.quit().await?; actix::System::current().stop(); diff --git a/sprinklers_rs/src/state_manager.rs b/sprinklers_rs/src/state_manager.rs new file mode 100644 index 0000000..80f4aea --- /dev/null +++ b/sprinklers_rs/src/state_manager.rs @@ -0,0 +1,103 @@ +use sprinklers_actors::{state_manager, StateManager}; +use sprinklers_database::{self as database, DbConn}; + +use eyre::{eyre, WrapErr}; +use sprinklers_core::model::{ProgramRef, Programs}; +use tokio::{ + runtime, + sync::{mpsc, watch}, +}; +use tracing::warn; + +pub struct StateManagerThread { + db_conn: DbConn, + request_rx: mpsc::Receiver, + programs_tx: watch::Sender, +} + +struct State { + programs: Programs, +} + +impl StateManagerThread { + pub fn start(db_conn: DbConn) -> StateManager { + let (request_tx, request_rx) = mpsc::channel(8); + let (programs_tx, programs_rx) = watch::channel(Programs::default()); + let task = StateManagerThread { + db_conn, + request_rx, + programs_tx, + }; + let runtime_handle = runtime::Handle::current(); + std::thread::Builder::new() + .name("sprinklers_rs::state_manager".into()) + .spawn(move || task.run(runtime_handle)) + .expect("could not start state_manager thread"); + StateManager::new(request_tx, programs_rx) + } + + fn broadcast_programs(&mut self, programs: Programs) { + if let Err(err) = self.programs_tx.broadcast(programs) { + warn!("could not broadcast programs: {}", err); + } + } + + fn handle_request( + &mut self, + request: state_manager::Request, + state: &mut State, + ) -> eyre::Result<()> { + use state_manager::Request; + + match request { + Request::UpdateProgram { + id, + update, + resp_tx, + } => { + // HACK: would really like stable try notation + let res = (|| -> state_manager::Result { + let mut trans = self + .db_conn + .transaction() + .wrap_err("failed to start transaction")?; + database::update_program(&mut trans, id, &update).map_err(|err| { + if let Some(e) = err.downcast_ref::() { + state_manager::StateError::NoSuchProgram(e.0) + } else { + err.into() + } + })?; + let new_program: ProgramRef = database::query_program_by_id(&trans, id)?.into(); + state.programs.insert(new_program.id, new_program.clone()); + trans.commit().wrap_err("could not commit transaction")?; + self.broadcast_programs(state.programs.clone()); + Ok(new_program) + })(); + resp_tx + .send(res) + .map_err(|_| eyre!("could not respond to UpdateProgram"))?; + } + } + Ok(()) + } + + fn load_state(&mut self) -> eyre::Result { + let programs = + database::query_programs(&self.db_conn).wrap_err("could not query programs")?; + + Ok(State { programs }) + } + + fn run(mut self, runtime_handle: runtime::Handle) { + let mut state = self.load_state().expect("could not load initial state"); + + self.broadcast_programs(state.programs.clone()); + + while let Some(request) = runtime_handle.block_on(self.request_rx.recv()) { + if let Err(err) = self.handle_request(request, &mut state) { + warn!("error handling request: {:?}", err); + } + } + } +}