diff --git a/src/mqtt/request/mod.rs b/src/mqtt/request/mod.rs index 929572c..1be71b6 100644 --- a/src/mqtt/request/mod.rs +++ b/src/mqtt/request/mod.rs @@ -1,11 +1,12 @@ use crate::{model::Sections, section_runner::SectionRunner}; +use futures_util::ready; use futures_util::FutureExt; use num_derive::FromPrimitive; use serde::{Deserialize, Serialize}; -use std::{fmt, future::Future, pin::Pin}; +use std::{fmt, future::Future, pin::Pin, task::Poll}; -mod run_section; +mod sections; pub struct RequestContext { pub sections: Sections, @@ -14,8 +15,6 @@ pub struct RequestContext { type BoxFuture = Pin>>; -pub type ResponseValue = serde_json::Value; - #[derive(Copy, Clone, Debug, PartialEq, Eq, FromPrimitive)] #[repr(u16)] pub enum ErrorCode { @@ -156,17 +155,69 @@ impl RequestError { } } -type RequestResult = Result; -type RequestFuture = BoxFuture; +#[derive(Debug, Serialize, Deserialize)] +struct ResponseMessage(#[serde(rename = "message")] String); + +impl ResponseMessage { + fn new(message: M) -> Self + where + M: ToString, + { + ResponseMessage(message.to_string()) + } +} + +impl From for ResponseMessage { + fn from(message: String) -> Self { + ResponseMessage(message) + } +} + +pub type ResponseValue = serde_json::Value; + +type RequestResult = Result; +type RequestFuture = BoxFuture>; trait IRequest { - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture; + type Response: Serialize; + + fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture; + + fn exec_erased(&mut self, ctx: &mut RequestContext) -> RequestFuture + where + Self::Response: 'static, + { + // TODO: figure out how to get rid of this nested box + Box::pin(ErasedRequestFuture(self.exec(ctx))) + } +} + +struct ErasedRequestFuture(RequestFuture) +where + Response: Serialize; + +impl Future for ErasedRequestFuture +where + Response: Serialize, +{ + type Output = RequestResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + use eyre::WrapErr; + let response = ready!(self.as_mut().0.poll_unpin(cx)); + Poll::Ready(response.and_then(|res| { + serde_json::to_value(res) + .wrap_err("could not serialize response") + .map_err(RequestError::from) + })) + } } #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase", tag = "type")] pub enum Request { - RunSection(run_section::RequestData), + RunSection(sections::RunSectionRequest), + CancelSection(sections::CancelSectionRequest), } #[derive(Debug, Deserialize, Serialize)] @@ -192,9 +243,12 @@ impl From for Response { } impl IRequest for Request { - fn exec(&mut self, ctx: &mut RequestContext) -> BoxFuture { + type Response = ResponseValue; + + fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { match self { - Request::RunSection(req) => req.exec(ctx), + Request::RunSection(req) => req.exec_erased(ctx), + Request::CancelSection(req) => req.exec_erased(ctx), } } } diff --git a/src/mqtt/request/run_section.rs b/src/mqtt/request/run_section.rs deleted file mode 100644 index 2ee2c8f..0000000 --- a/src/mqtt/request/run_section.rs +++ /dev/null @@ -1,43 +0,0 @@ -use super::*; -use crate::{model::SectionId, section_runner::SectionRunHandle}; -use eyre::WrapErr; -use serde::{Deserialize, Serialize}; -use std::time::Duration; - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct RequestData { - pub section_id: SectionId, - #[serde(with = "crate::serde::duration")] - pub duration: Duration, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ResponseData { - pub message: String, - pub run_id: SectionRunHandle, -} - -impl IRequest for RequestData { - fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { - let mut section_runner = ctx.section_runner.clone(); - let section = ctx.sections.get(&self.section_id).cloned(); - let duration = self.duration; - Box::pin(async move { - let section = section.ok_or_else(|| { - RequestError::with_name(ErrorCode::NotFound, "section not found", "section") - })?; - let handle = section_runner - .queue_run(section.clone(), duration) - .await - .wrap_err("could not queue run")?; - let res = ResponseData { - message: format!("running section '{}' for {:?}", §ion.name, duration), - run_id: handle, - }; - let res_value = serde_json::to_value(res).wrap_err("could not serialize response")?; - Ok(res_value) - }) - } -} diff --git a/src/mqtt/request/sections.rs b/src/mqtt/request/sections.rs new file mode 100644 index 0000000..3ce50dc --- /dev/null +++ b/src/mqtt/request/sections.rs @@ -0,0 +1,87 @@ +use super::*; +use crate::{model::SectionRef, section_runner::SectionRunHandle}; +use eyre::WrapErr; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +#[serde(transparent)] +pub struct SectionId(pub crate::model::SectionId); + +impl SectionId { + fn get_section(self, sections: &Sections) -> Result { + sections.get(&self.0).cloned().ok_or_else(|| { + RequestError::with_name(ErrorCode::NotFound, "section not found", "section") + }) + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RunSectionRequest { + pub section_id: SectionId, + #[serde(with = "crate::serde::duration")] + pub duration: Duration, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RunSectionResponse { + pub message: String, + pub run_id: SectionRunHandle, +} + +impl IRequest for RunSectionRequest { + type Response = RunSectionResponse; + fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + let mut section_runner = ctx.section_runner.clone(); + let section = self.section_id.get_section(&ctx.sections); + let duration = self.duration; + Box::pin(async move { + let section = section?; + let handle = section_runner + .queue_run(section.clone(), duration) + .await + .wrap_err("could not queue run")?; + Ok(RunSectionResponse { + message: format!("running section '{}' for {:?}", §ion.name, duration), + run_id: handle, + }) + }) + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelSectionRequest { + pub section_id: SectionId, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelSectionResponse { + pub message: String, + pub cancelled: usize, +} + +impl IRequest for CancelSectionRequest { + type Response = CancelSectionResponse; + fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + let mut section_runner = ctx.section_runner.clone(); + let section = self.section_id.get_section(&ctx.sections); + Box::pin(async move { + let section = section?; + let cancelled = section_runner + .cancel_by_section(section.id) + .await + .wrap_err("could not cancel section")?; + Ok(CancelSectionResponse { + message: format!( + "cancelled {} runs for section '{}'", + cancelled, section.name + ), + cancelled, + }) + }) + } +} diff --git a/src/section_runner.rs b/src/section_runner.rs index b11c263..771e04d 100644 --- a/src/section_runner.rs +++ b/src/section_runner.rs @@ -1,4 +1,4 @@ -use crate::model::SectionRef; +use crate::model::{SectionId, SectionRef}; use crate::section_interface::SectionInterface; use actix::{ Actor, ActorContext, Addr, AsyncContext, Handler, Message, MessageResult, SpawnHandle, @@ -20,8 +20,7 @@ use tokio::{ }; use tracing::{debug, trace, warn}; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[derive(serde::Deserialize, serde::Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] pub struct SectionRunHandle(i32); impl SectionRunHandle { @@ -177,19 +176,24 @@ impl SectionRunnerInner { } } - fn cancel_run(&mut self, run: &mut Arc) { + fn cancel_run(&mut self, run: &mut Arc) -> bool { let run = Arc::make_mut(run); if run.is_running() { debug!(section_id = run.section.id, "cancelling running section"); self.interface .set_section_state(run.section.interface_id, false); } - run.state = SecRunState::Cancelled; - self.send_event(SectionEvent::RunCancel( - run.handle.clone(), - run.section.clone(), - )); - self.did_change = true; + if run.state != SecRunState::Cancelled { + run.state = SecRunState::Cancelled; + self.send_event(SectionEvent::RunCancel( + run.handle.clone(), + run.section.clone(), + )); + self.did_change = true; + true + } else { + false + } } fn pause_run(&mut self, run: &mut Arc) { @@ -322,40 +326,78 @@ impl Handler for SectionRunnerActor { } #[derive(Message, Debug, Clone)] -#[rtype(result = "()")] +#[rtype(result = "bool")] struct CancelRun(SectionRunHandle); impl Handler for SectionRunnerActor { - type Result = (); + type Result = bool; fn handle(&mut self, msg: CancelRun, ctx: &mut Self::Context) -> Self::Result { let CancelRun(handle) = msg; - for run in self.state.run_queue.iter_mut() { - if run.handle != handle { - continue; - } + let mut cancelled = false; + for run in self + .state + .run_queue + .iter_mut() + .filter(|run| run.handle == handle) + { trace!(handle = handle.0, "cancelling run by handle"); - self.inner.cancel_run(run); + cancelled = self.inner.cancel_run(run); } ctx.notify(Process); + cancelled } } #[derive(Message, Debug, Clone)] -#[rtype(result = "()")] +#[rtype(result = "usize")] +struct CancelBySection(SectionId); + +impl Handler for SectionRunnerActor { + type Result = usize; + + fn handle(&mut self, msg: CancelBySection, ctx: &mut Self::Context) -> Self::Result { + let CancelBySection(section_id) = msg; + let mut count = 0_usize; + for run in self + .state + .run_queue + .iter_mut() + .filter(|run| run.section.id == section_id) + { + trace!( + handle = run.handle.0, + section_id, + "cancelling run by section" + ); + if self.inner.cancel_run(run) { + count += 1; + } + } + ctx.notify(Process); + count + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "usize")] struct CancelAll; impl Handler for SectionRunnerActor { - type Result = (); + type Result = usize; fn handle(&mut self, _msg: CancelAll, ctx: &mut Self::Context) -> Self::Result { let mut old_runs = SecRunQueue::new(); swap(&mut old_runs, &mut self.state.run_queue); trace!(count = old_runs.len(), "cancelling all runs"); + let mut count = 0usize; for mut run in old_runs { - self.inner.cancel_run(&mut run); + if self.inner.cancel_run(&mut run) { + count += 1; + } } ctx.notify(Process); + count } } @@ -517,11 +559,7 @@ impl SectionRunner { (QueueRun(handle.clone(), section, duration), handle) } - pub fn do_queue_run( - &mut self, - section: SectionRef, - duration: Duration, - ) -> SectionRunHandle { + pub fn do_queue_run(&mut self, section: SectionRef, duration: Duration) -> SectionRunHandle { let (queue_run, handle) = self.queue_run_inner(section, duration); self.addr.do_send(queue_run); handle @@ -543,11 +581,20 @@ impl SectionRunner { self.addr.do_send(CancelRun(handle)) } - pub fn cancel_run(&mut self, handle: SectionRunHandle) -> impl Future> { + pub fn cancel_run(&mut self, handle: SectionRunHandle) -> impl Future> { self.addr.send(CancelRun(handle)).map_err(From::from) } - pub fn cancel_all(&mut self) -> impl Future> { + pub fn cancel_by_section( + &mut self, + section_id: SectionId, + ) -> impl Future> { + self.addr + .send(CancelBySection(section_id)) + .map_err(From::from) + } + + pub fn cancel_all(&mut self) -> impl Future> { self.addr.send(CancelAll).map_err(From::from) }