diff --git a/Cargo.toml b/Cargo.toml index e3bf94e..b54b2ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ serde_json = "1.0.57" actix = "0.10.0" actix-rt = "1.1.1" futures-util = { version = "0.3.5", default-features = false, features = ["std", "async-await"] } +num-traits = "0.2.12" +num-derive = "0.3.2" [dependencies.rumqttc] version = "0.1.0" diff --git a/src/main.rs b/src/main.rs index 3a4a7f6..6e1a598 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ mod schedule; mod section_interface; mod section_runner; mod section_runner_json; +mod serde; #[cfg(test)] mod trace_listeners; mod update_listener; @@ -63,7 +64,12 @@ async fn main() -> Result<()> { device_id: "sprinklers_rs-0001".into(), client_id: "sprinklers_rs-0001".into(), }; - let mut mqtt_interface = mqtt::MqttInterfaceTask::start(mqtt_options); + // TODO: have ability to update sections / other data + let request_context = mqtt::RequestContext { + sections: sections.clone(), + section_runner: section_runner.clone(), + }; + let mut mqtt_interface = mqtt::MqttInterfaceTask::start(mqtt_options, request_context); let update_listener = { let section_events = section_runner.subscribe().await?; diff --git a/src/model/program.rs b/src/model/program.rs index 112da50..8ce2a4f 100644 --- a/src/model/program.rs +++ b/src/model/program.rs @@ -7,33 +7,10 @@ use std::{sync::Arc, time::Duration}; #[serde(rename_all = "camelCase")] pub struct ProgramItem { pub section_id: SectionId, - #[serde( - serialize_with = "ser::serialize_duration", - deserialize_with = "ser::deserialize_duration" - )] + #[serde(with = "crate::serde::duration")] pub duration: Duration, } -mod ser { - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - use std::time::Duration; - - pub fn serialize_duration(duration: &Duration, serializer: S) -> Result - where - S: Serializer, - { - duration.as_secs_f64().serialize(serializer) - } - - pub fn deserialize_duration<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let secs: f64 = Deserialize::deserialize(deserializer)?; - Ok(Duration::from_secs_f64(secs)) - } -} - pub type ProgramSequence = Vec; pub type ProgramId = u32; diff --git a/src/mqtt/actor.rs b/src/mqtt/actor.rs index e07c4e1..bc39002 100644 --- a/src/mqtt/actor.rs +++ b/src/mqtt/actor.rs @@ -1,5 +1,6 @@ -use super::{event_loop::EventLoopTask, MqttInterface}; +use super::{event_loop::EventLoopTask, request, MqttInterface}; use actix::{Actor, ActorContext, ActorFuture, AsyncContext, Handler, WrapFuture}; +use request::{ErrorCode, RequestContext, RequestError, WithRequestId}; use tokio::sync::oneshot; use tracing::{debug, error, info, trace, warn}; @@ -7,16 +8,61 @@ pub(super) struct MqttActor { interface: MqttInterface, event_loop: Option, quit_rx: Option>, + request_context: RequestContext, } impl MqttActor { - pub(super) fn new(interface: MqttInterface, event_loop: rumqttc::EventLoop) -> Self { + pub(super) fn new( + interface: MqttInterface, + event_loop: rumqttc::EventLoop, + request_context: RequestContext, + ) -> Self { Self { interface, event_loop: Some(event_loop), quit_rx: None, + request_context, } } + + fn handle_request(&mut self, payload: &[u8], ctx: &mut ::Context) { + let request_value = + match serde_json::from_slice::>(payload) { + Ok(r) => r, + Err(err) => { + warn!("could not deserialize request: {}", err); + return; + } + }; + let rid = request_value.rid; + let request_fut = + serde_json::from_value::(request_value.rest).map(|mut request| { + trace!("deserialized request: {:?}", request); + request.execute(&mut self.request_context) + }); + let mut interface = self.interface.clone(); + let fut = async move { + let response = match request_fut { + Ok(request_fut) => request_fut.await, + Err(deser_err) => RequestError::with_name_and_cause( + ErrorCode::Parse, + "could not parse request", + "request", + deser_err, + ) + .into(), + }; + let resp_with_id = WithRequestId:: { + rid, + rest: response, + }; + trace!("sending request response: {:?}", resp_with_id); + if let Err(err) = interface.publish_response(resp_with_id).await { + error!("could not publish request response: {}", err); + } + }; + ctx.spawn(fut.into_actor(self)); + } } impl Actor for MqttActor { @@ -85,10 +131,11 @@ pub(super) struct PubRecieve(pub(super) rumqttc::Publish); impl Handler for MqttActor { type Result = (); - fn handle(&mut self, msg: PubRecieve, _ctx: &mut Self::Context) -> Self::Result { + fn handle(&mut self, msg: PubRecieve, ctx: &mut Self::Context) -> Self::Result { let topic = &msg.0.topic; if topic == &self.interface.topics.requests() { debug!("received request: {:?}", msg.0); + self.handle_request(msg.0.payload.as_ref(), ctx); } else { warn!("received on unknown topic: {:?}", topic); } diff --git a/src/mqtt/mod.rs b/src/mqtt/mod.rs index 091f22c..7cc2bfd 100644 --- a/src/mqtt/mod.rs +++ b/src/mqtt/mod.rs @@ -1,7 +1,10 @@ mod actor; mod event_loop; +mod request; mod topics; +pub use request::RequestContext; + use self::topics::Topics; use crate::{ model::{Program, ProgramId, Programs, Section, SectionId, Sections}, @@ -130,6 +133,17 @@ impl MqttInterface { .wrap_err("failed to publish section runner") } + pub async fn publish_response(&mut self, resp: request::ResponseWithId) -> eyre::Result<()> { + let payload_vec = + serde_json::to_vec(&resp).wrap_err("failed to serialize request response")?; + // TODO: if couldn't serialize, just in case can have a static response + self.client + .publish(self.topics.responses(), QoS::AtMostOnce, false, payload_vec) + .await + .wrap_err("failed to publish request response")?; + Ok(()) + } + pub async fn subscribe_requests(&mut self) -> eyre::Result<()> { self.client .subscribe(self.topics.requests(), QoS::ExactlyOnce) @@ -144,10 +158,10 @@ pub struct MqttInterfaceTask { } impl MqttInterfaceTask { - pub fn start(options: Options) -> Self { + pub fn start(options: Options, request_context: RequestContext) -> Self { let (interface, event_loop) = MqttInterface::new(options); - let addr = actor::MqttActor::new(interface.clone(), event_loop).start(); + let addr = actor::MqttActor::new(interface.clone(), event_loop, request_context).start(); Self { interface, addr } } diff --git a/src/mqtt/request.rs b/src/mqtt/request.rs new file mode 100644 index 0000000..2aaccb3 --- /dev/null +++ b/src/mqtt/request.rs @@ -0,0 +1,261 @@ +use crate::{ + model::{SectionId, Sections}, + section_runner::SectionRunner, +}; +use eyre::WrapErr; +use futures_util::FutureExt; +use num_derive::FromPrimitive; +use serde::{Deserialize, Serialize}; +use std::{fmt, future::Future, pin::Pin, time::Duration}; + +pub struct RequestContext { + pub sections: Sections, + pub section_runner: SectionRunner, +} + +type BoxFuture = Pin>>; + +pub type ResponseValue = serde_json::Value; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, FromPrimitive)] +#[repr(u16)] +pub enum ErrorCode { + BadRequest = 100, + NotSpecified = 101, + Parse = 102, + Range = 103, + InvalidData = 104, + BadToken = 105, + Unauthorized = 106, + NoPermission = 107, + NotFound = 109, + // NotUnique = 110, + Internal = 200, + NotImplemented = 201, + Timeout = 300, + // ServerDisconnected = 301, + // BrokerDisconnected = 302, +} + +mod ser { + use super::ErrorCode; + use num_traits::FromPrimitive; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + impl Serialize for ErrorCode { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_u16(*self as u16) + } + } + + impl<'de> Deserialize<'de> for ErrorCode { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let prim = u16::deserialize(deserializer)?; + ErrorCode::from_u16(prim) + .ok_or_else(|| ::custom("invalid ErrorCode")) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase", tag = "result")] +pub struct RequestError { + code: ErrorCode, + message: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + cause: Option, +} + +impl fmt::Display for RequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "request error (code {:?}", self.code)?; + if let Some(name) = &self.name { + write!(f, "on {}", name)?; + } + write!(f, "): {}", self.message)?; + if let Some(cause) = &self.cause { + write!(f, ", caused by {}", cause)?; + } + Ok(()) + } +} + +impl std::error::Error for RequestError {} + +impl From for RequestError { + fn from(report: eyre::Report) -> Self { + let mut chain = report.chain(); + let message = match chain.next() { + Some(a) => a.to_string(), + None => "unknown error".to_string(), + }; + let cause = chain.fold(None, |cause, err| match cause { + Some(cause) => Some(format!("{}: {}", cause, err)), + None => Some(err.to_string()), + }); + RequestError::new(ErrorCode::Internal, message, None, cause) + } +} + +#[allow(dead_code)] +impl RequestError { + pub fn new(code: ErrorCode, message: M, name: Option, cause: Option) -> Self + where + M: ToString, + { + Self { + code, + message: message.to_string(), + name, + cause, + } + } + + pub fn simple(code: ErrorCode, message: M) -> Self + where + M: ToString, + { + Self::new(code, message, None, None) + } + + pub fn with_name(code: ErrorCode, message: M, name: N) -> Self + where + M: ToString, + N: ToString, + { + Self::new(code, message, Some(name.to_string()), None) + } + + pub fn with_cause(code: ErrorCode, message: M, cause: C) -> Self + where + M: ToString, + C: ToString, + { + Self::new(code, message, None, Some(cause.to_string())) + } + + pub fn with_name_and_cause(code: ErrorCode, message: M, name: N, cause: C) -> Self + where + M: ToString, + N: ToString, + C: ToString, + { + Self::new( + code, + message, + Some(name.to_string()), + Some(cause.to_string()), + ) + } +} + +type RequestResult = Result; +type RequestFuture = BoxFuture; + +trait IRequest { + fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture; +} + +mod run_section { + use super::*; + use crate::section_runner::SectionRunHandle; + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct RequestData { + pub section_id: SectionId, + #[serde(with = "crate::serde::duration")] + pub duration: Duration, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase")] + pub struct ResponseData { + pub message: String, + pub run_id: SectionRunHandle, + } + + impl IRequest for RequestData { + fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture { + let mut section_runner = ctx.section_runner.clone(); + let section = ctx.sections.get(&self.section_id).cloned(); + let duration = self.duration; + Box::pin(async move { + let section = section.ok_or_else(|| { + RequestError::with_name(ErrorCode::NotFound, "section not found", "section") + })?; + let handle = section_runner + .queue_run(section.clone(), duration) + .await + .wrap_err("could not queue run")?; + let res = ResponseData { + message: format!("running section '{}' for {:?}", §ion.name, duration), + run_id: handle, + }; + let res_value = + serde_json::to_value(res).wrap_err("could not serialize response")?; + Ok(res_value) + }) + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase", tag = "type")] +pub enum Request { + RunSection(run_section::RequestData), +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase", tag = "result")] +pub enum Response { + Success(ResponseValue), + Error(RequestError), +} + +impl From for Response { + fn from(res: RequestResult) -> Self { + match res { + Ok(value) => Response::Success(value), + Err(error) => Response::Error(error), + } + } +} + +impl From for Response { + fn from(error: RequestError) -> Self { + Response::Error(error) + } +} + +impl IRequest for Request { + fn exec(&mut self, ctx: &mut RequestContext) -> BoxFuture { + match self { + Request::RunSection(req) => req.exec(ctx), + } + } +} + +impl Request { + pub fn execute(&mut self, ctx: &mut RequestContext) -> impl Future { + self.exec(ctx).map(Response::from) + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WithRequestId { + pub rid: i32, + #[serde(flatten)] + pub rest: T, +} + +pub type ResponseWithId = WithRequestId; diff --git a/src/section_runner.rs b/src/section_runner.rs index 8d45b13..b11c263 100644 --- a/src/section_runner.rs +++ b/src/section_runner.rs @@ -21,6 +21,7 @@ use tokio::{ use tracing::{debug, trace, warn}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(serde::Deserialize, serde::Serialize)] pub struct SectionRunHandle(i32); impl SectionRunHandle { diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..92c3ac6 --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,21 @@ +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +pub mod duration { + use super::*; + use std::time::Duration; + + pub fn serialize(duration: &Duration, serializer: S) -> Result + where + S: Serializer, + { + duration.as_secs_f64().serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let secs: f64 = Deserialize::deserialize(deserializer)?; + Ok(Duration::from_secs_f64(secs)) + } +}