diff --git a/Cargo.toml b/Cargo.toml index 8ca179a..19d47ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ chrono = { version = "0.4.15" } assert_matches = "1.3.0" serde = { version = "1.0.116", features = ["derive"] } serde_json = "1.0.57" +actix = "0.10.0" +actix-rt = "1.1.1" [dependencies.rumqttc] git = "https://github.com/bytebeamio/rumqtt.git" diff --git a/src/main.rs b/src/main.rs index d02f780..765c5e9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,7 @@ use tracing::{debug, info}; use tracing_subscriber::EnvFilter; use update_listener::UpdateListener; -#[tokio::main] +#[actix_rt::main] async fn main() -> Result<()> { tracing_subscriber::fmt() .with_ansi(true) diff --git a/src/program_runner.rs b/src/program_runner.rs index bbf933a..cd8e067 100644 --- a/src/program_runner.rs +++ b/src/program_runner.rs @@ -428,7 +428,7 @@ mod test { use tokio::task::yield_now; use tracing_subscriber::prelude::*; - #[tokio::test] + #[actix_rt::test] async fn test_quit() { let quit_msg = EventListener::new( Filters::new() @@ -496,7 +496,7 @@ mod test { .into() } - #[tokio::test] + #[actix_rt::test] async fn test_run_program() { let (sections, mut sec_runner, interface) = make_sections_and_runner(); let mut sec_events = sec_runner.subscribe().await.unwrap(); @@ -522,31 +522,31 @@ mod test { runner.run_program(program).await.unwrap(); yield_now().await; assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunStart(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunStart(prog)) if prog.id == 1 ); - assert_matches!(sec_events.try_recv().unwrap(), SectionEvent::RunStart(_, _)); + assert_matches!(sec_events.try_recv(), Ok(SectionEvent::RunStart(_, _))); assert_eq!(interface.get_section_state(0), true); tokio::time::pause(); assert_matches!( - sec_events.recv().await.unwrap(), - SectionEvent::RunFinish(_, _) + sec_events.recv().await, + Ok(SectionEvent::RunFinish(_, _)) ); assert_matches!( - sec_events.recv().await.unwrap(), - SectionEvent::RunStart(_, _) + sec_events.recv().await, + Ok(SectionEvent::RunStart(_, _)) ); assert_eq!(interface.get_section_state(0), false); assert_eq!(interface.get_section_state(1), true); assert_matches!( - sec_events.recv().await.unwrap(), - SectionEvent::RunFinish(_, _) + sec_events.recv().await, + Ok(SectionEvent::RunFinish(_, _)) ); assert_matches!( - prog_events.recv().await.unwrap(), - ProgramEvent::RunFinish(_) + prog_events.recv().await, + Ok(ProgramEvent::RunFinish(_)) ); runner.quit().await.unwrap(); @@ -554,7 +554,7 @@ mod test { yield_now().await; } - #[tokio::test] + #[actix_rt::test] async fn test_run_nonexistant_section() { let (sections, mut sec_runner, _) = make_sections_and_runner(); let mut runner = ProgramRunner::new(sec_runner.clone()); @@ -581,13 +581,13 @@ mod test { // Should immediately start and finish running program // due to nonexistant section assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunStart(prog) + prog_events.try_recv(), + Ok(ProgramEvent::RunStart(prog)) if prog.id == 1 ); assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunFinish(prog) + prog_events.try_recv(), + Ok(ProgramEvent::RunFinish(prog)) if prog.id == 1 ); @@ -595,14 +595,14 @@ mod test { yield_now().await; // Should run right away since last program should be done assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunStart(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunStart(prog)) if prog.id == 2 ); tokio::time::pause(); assert_matches!( - prog_events.recv().await.unwrap(), - ProgramEvent::RunFinish(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunFinish(prog)) if prog.id == 2 ); @@ -610,7 +610,7 @@ mod test { sec_runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_close_event_chan() { let (sections, mut sec_runner, _) = make_sections_and_runner(); let mut runner = ProgramRunner::new(sec_runner.clone()); @@ -633,7 +633,7 @@ mod test { sec_runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_run_program_id() { let (sections, mut sec_runner, _) = make_sections_and_runner(); let mut runner = ProgramRunner::new(sec_runner.clone()); @@ -666,28 +666,28 @@ mod test { runner.run_program_id(1).await.unwrap(); yield_now().await; assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunStart(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunStart(prog)) if prog.id == 1 ); tokio::time::pause(); assert_matches!( - prog_events.recv().await.unwrap(), - ProgramEvent::RunFinish(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunFinish(prog)) if prog.id == 1 ); runner.run_program_id(1).await.unwrap(); yield_now().await; assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunStart(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunStart(prog)) if prog.id == 1 ); assert_matches!( - prog_events.recv().await.unwrap(), - ProgramEvent::RunFinish(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunFinish(prog)) if prog.id == 1 ); @@ -695,7 +695,7 @@ mod test { sec_runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_cancel_program() { let (sections, mut sec_runner, _) = make_sections_and_runner(); let mut sec_events = sec_runner.subscribe().await.unwrap(); @@ -721,8 +721,8 @@ mod test { runner.run_program(program.clone()).await.unwrap(); yield_now().await; assert_matches!( - prog_events.try_recv().unwrap(), - ProgramEvent::RunStart(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunStart(prog)) if prog.id == 1 ); assert_matches!(sec_events.try_recv().unwrap(), SectionEvent::RunStart(_, _)); @@ -730,20 +730,20 @@ mod test { runner.cancel_program(program.id).await.unwrap(); yield_now().await; assert_matches!( - prog_events.recv().await.unwrap(), - ProgramEvent::RunCancel(prog) + prog_events.recv().await, + Ok(ProgramEvent::RunCancel(prog)) if prog.id == 1 ); assert_matches!( - sec_events.recv().await.unwrap(), - SectionEvent::RunCancel(_, _) + sec_events.recv().await, + Ok(SectionEvent::RunCancel(_, _)) ); runner.quit().await.unwrap(); sec_runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_scheduled_run() { tracing_subscriber::fmt().init(); let (sections, mut sec_runner, _) = make_sections_and_runner(); diff --git a/src/section_runner.rs b/src/section_runner.rs index d26a1a8..8e34f55 100644 --- a/src/section_runner.rs +++ b/src/section_runner.rs @@ -1,22 +1,15 @@ use crate::model::SectionRef; -use crate::option_future::OptionFuture; use crate::section_interface::SectionInterface; -use std::{ - mem::swap, - sync::{ - atomic::{AtomicI32, Ordering}, - Arc, - }, - time::Duration, +use actix::{ + Actor, ActorContext, Addr, AsyncContext, Handler, Message, MessageResult, SpawnHandle, }; +use std::{mem::swap, sync::Arc, time::Duration}; use thiserror::Error; use tokio::{ - spawn, - sync::{broadcast, mpsc, oneshot, watch}, - time::{delay_for, Instant}, + sync::{broadcast, watch}, + time::Instant, }; -use tracing::{debug, trace, trace_span, warn}; -use tracing_futures::Instrument; +use tracing::{debug, trace, warn}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SectionRunHandle(i32); @@ -27,30 +20,6 @@ impl SectionRunHandle { } } -#[derive(Debug)] -struct SectionRunnerInner { - next_run_id: AtomicI32, -} - -impl SectionRunnerInner { - fn new() -> Self { - Self { - next_run_id: AtomicI32::new(1), - } - } -} - -#[derive(Debug)] -enum RunnerMsg { - Quit(oneshot::Sender<()>), - QueueRun(SectionRunHandle, SectionRef, Duration), - CancelRun(SectionRunHandle), - CancelAll, - Pause, - Unpause, - Subscribe(oneshot::Sender), -} - #[derive(Clone, Debug)] pub enum SectionEvent { RunStart(SectionRunHandle, SectionRef), @@ -130,35 +99,15 @@ impl Default for SecRunnerState { pub type SecRunnerStateRecv = watch::Receiver; -struct RunnerTask { +struct SectionRunnerInner { interface: Arc, - msg_recv: mpsc::Receiver, - running: bool, - delay_future: OptionFuture, event_send: Option, state_send: watch::Sender, - quit_tx: Option>, + delay_future: Option, did_change: bool, } -impl RunnerTask { - fn new( - interface: Arc, - msg_recv: mpsc::Receiver, - state_send: watch::Sender, - ) -> Self { - Self { - interface, - msg_recv, - running: true, - delay_future: None.into(), - event_send: None, - state_send, - quit_tx: None, - did_change: false, - } - } - +impl SectionRunnerInner { fn send_event(&mut self, event: SectionEvent) { if let Some(event_send) = &mut self.event_send { match event_send.send(event) { @@ -297,35 +246,212 @@ impl RunnerTask { self.did_change = true; } - fn process_queue(&mut self, state: &mut SecRunnerState) { + fn process_after_delay( + &mut self, + after: Duration, + ctx: &mut ::Context, + ) { + let delay_future = ctx.notify_later(Process, after); + if let Some(old_future) = self.delay_future.replace(delay_future) { + ctx.cancel_future(old_future); + } + } + + fn cancel_process(&mut self, ctx: &mut ::Context) { + if let Some(old_future) = self.delay_future.take() { + ctx.cancel_future(old_future); + } + } +} + +struct SectionRunnerActor { + state: SecRunnerState, + next_run_id: i32, + inner: SectionRunnerInner, +} + +impl Actor for SectionRunnerActor { + type Context = actix::Context; + + fn started(&mut self, _ctx: &mut Self::Context) { + trace!("section_runner starting"); + } + + fn stopped(&mut self, _ctx: &mut Self::Context) { + trace!("section_runner stopped"); + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +struct Quit; + +impl Handler for SectionRunnerActor { + type Result = (); + + fn handle(&mut self, _msg: Quit, ctx: &mut Self::Context) -> Self::Result { + ctx.stop(); + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "SectionRunHandle")] +struct QueueRun(SectionRef, Duration); + +impl Handler for SectionRunnerActor { + type Result = MessageResult; + + fn handle(&mut self, msg: QueueRun, ctx: &mut Self::Context) -> Self::Result { + let QueueRun(section, duration) = msg; + + let run_id = self.next_run_id; + self.next_run_id += 1; + let handle = SectionRunHandle(run_id); + + let run: Arc = SecRun::new(handle.clone(), section, duration).into(); + self.state.run_queue.push_back(run); + self.inner.did_change = true; + + ctx.notify(Process); + + MessageResult(handle) + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +struct CancelRun(SectionRunHandle); + +impl Handler for SectionRunnerActor { + type Result = (); + + fn handle(&mut self, msg: CancelRun, ctx: &mut Self::Context) -> Self::Result { + let CancelRun(handle) = msg; + for run in self.state.run_queue.iter_mut() { + if run.handle != handle { + continue; + } + trace!(handle = handle.0, "cancelling run by handle"); + self.inner.cancel_run(run); + } + ctx.notify(Process); + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +struct CancelAll; + +impl Handler for SectionRunnerActor { + type Result = (); + + fn handle(&mut self, _msg: CancelAll, ctx: &mut Self::Context) -> Self::Result { + let mut old_runs = SecRunQueue::new(); + swap(&mut old_runs, &mut self.state.run_queue); + trace!(count = old_runs.len(), "cancelling all runs"); + for mut run in old_runs { + self.inner.cancel_run(&mut run); + } + ctx.notify(Process); + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +struct SetPaused(bool); + +impl Handler for SectionRunnerActor { + type Result = (); + + fn handle(&mut self, msg: SetPaused, ctx: &mut Self::Context) -> Self::Result { + let SetPaused(pause) = msg; + if pause != self.state.paused { + if pause { + self.state.paused = true; + self.inner.send_event(SectionEvent::RunnerPause); + } else { + self.state.paused = false; + self.inner.send_event(SectionEvent::RunnerUnpause); + } + self.inner.did_change = true; + ctx.notify(Process); + } + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "SectionEventRecv")] +struct Subscribe; + +impl Handler for SectionRunnerActor { + type Result = MessageResult; + + fn handle(&mut self, _msg: Subscribe, _ctx: &mut Self::Context) -> Self::Result { + let event_recv = self.inner.subscribe_event(); + MessageResult(event_recv) + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +struct Process; + +impl Handler for SectionRunnerActor { + type Result = (); + + fn handle(&mut self, _msg: Process, ctx: &mut Self::Context) -> Self::Result { + self.process(ctx) + } +} + +impl SectionRunnerActor { + fn new( + interface: Arc, + state_send: watch::Sender, + ) -> Self { + Self { + state: SecRunnerState::default(), + inner: SectionRunnerInner { + interface, + event_send: None, + state_send, + delay_future: None, + did_change: false, + }, + next_run_id: 1, + } + } + + fn process_queue(&mut self, ctx: &mut actix::Context) { use SecRunState::*; + let state = &mut self.state; while let Some(current_run) = state.run_queue.front_mut() { let run_finished = match (¤t_run.state, state.paused) { (Waiting, false) => { - self.start_run(current_run); - self.delay_future = Some(delay_for(current_run.duration)).into(); + self.inner.start_run(current_run); + self.inner.process_after_delay(current_run.duration, ctx); false } (Running { start_time }, false) => { let time_to_finish = start_time.elapsed() >= current_run.duration; if time_to_finish { - self.finish_run(current_run); - self.delay_future = None.into(); + self.inner.finish_run(current_run); + self.inner.cancel_process(ctx); } time_to_finish } (Paused { .. }, false) => { - self.unpause_run(current_run); - self.delay_future = Some(delay_for(current_run.duration)).into(); + self.inner.unpause_run(current_run); + self.inner.process_after_delay(current_run.duration, ctx); false } (Waiting, true) => { - self.pause_run(current_run); + self.inner.pause_run(current_run); false } (Running { .. }, true) => { - self.pause_run(current_run); - self.delay_future = None.into(); + self.inner.pause_run(current_run); + self.inner.cancel_process(ctx); false } (Paused { .. }, true) => false, @@ -340,140 +466,38 @@ impl RunnerTask { } } - fn handle_msg(&mut self, msg: Option, state: &mut SecRunnerState) { - let msg = msg.expect("SectionRunner channel closed"); - use RunnerMsg::*; - trace!(msg = debug(&msg), "runner_task recv"); - match msg { - Quit(quit_tx) => { - self.quit_tx = Some(quit_tx); - self.running = false; - } - QueueRun(handle, section, duration) => { - state - .run_queue - .push_back(Arc::new(SecRun::new(handle, section, duration))); - self.did_change = true; - } - CancelRun(handle) => { - for run in state.run_queue.iter_mut() { - if run.handle != handle { - continue; - } - trace!(handle = handle.0, "cancelling run by handle"); - self.cancel_run(run); - } - } - CancelAll => { - let mut old_runs = SecRunQueue::new(); - swap(&mut old_runs, &mut state.run_queue); - trace!(count = old_runs.len(), "cancelling all runs"); - for mut run in old_runs { - self.cancel_run(&mut run); - } - } - Pause => { - state.paused = true; - self.send_event(SectionEvent::RunnerPause); - self.did_change = true; - } - Unpause => { - state.paused = false; - self.send_event(SectionEvent::RunnerUnpause); - self.did_change = true; - } - Subscribe(res_send) => { - let event_recv = self.subscribe_event(); - // Ignore error if channel closed - let _ = res_send.send(event_recv); - } - } - } - - async fn run_impl(mut self) { - let mut state = SecRunnerState::default(); - - while self.running { - // Process all pending messages - // This is so if there are many pending messages, the state - // is only broadcast once - while let Ok(msg) = self.msg_recv.try_recv() { - self.handle_msg(Some(msg), &mut state); - } - - self.process_queue(&mut state); - - // If a change was made to state, broadcast it - if self.did_change { - let _ = self.state_send.broadcast(state.clone()); - self.did_change = false; - } - - let delay_done = || { - trace!("delay done"); - }; - tokio::select! { - msg = self.msg_recv.recv() => { - self.handle_msg(msg, &mut state) - }, - _ = &mut self.delay_future, if self.delay_future.is_some() => delay_done() - }; - } + fn process(&mut self, ctx: &mut actix::Context) { + self.process_queue(ctx); - if let Some(quit_tx) = self.quit_tx.take() { - let _ = quit_tx.send(()); + if self.inner.did_change { + let _ = self.inner.state_send.broadcast(self.state.clone()); + self.inner.did_change = false; } } - - async fn run(self) { - let span = trace_span!("section_runner task"); - - self.run_impl().instrument(span).await; - } } #[derive(Debug, Clone, Error)] -#[error("the SectionRunner channel is closed")] -pub struct ChannelClosed; - -pub type Result = std::result::Result; - -impl From> for ChannelClosed { - fn from(_: mpsc::error::SendError) -> Self { - Self - } -} +#[error("error communicating with SectionRunner: {0}")] +pub struct Error(#[from] actix::MailboxError); -impl From for ChannelClosed { - fn from(_: oneshot::error::RecvError) -> Self { - Self - } -} +pub type Result = std::result::Result; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct SectionRunner { - inner: Arc, - msg_send: mpsc::Sender, state_recv: SecRunnerStateRecv, + addr: Addr, } #[allow(dead_code)] impl SectionRunner { pub fn new(interface: Arc) -> Self { - let (msg_send, msg_recv) = mpsc::channel(32); let (state_send, state_recv) = watch::channel(SecRunnerState::default()); - spawn(RunnerTask::new(interface, msg_recv, state_send).run()); - Self { - inner: Arc::new(SectionRunnerInner::new()), - msg_send, - state_recv, - } + let addr = SectionRunnerActor::new(interface, state_send).start(); + Self { state_recv, addr } } pub async fn quit(&mut self) -> Result<()> { - let (quit_tx, quit_rx) = oneshot::channel(); - self.msg_send.send(RunnerMsg::Quit(quit_tx)).await?; - quit_rx.await?; + self.addr.send(Quit).await?; Ok(()) } @@ -482,38 +506,32 @@ impl SectionRunner { section: SectionRef, duration: Duration, ) -> Result { - let run_id = self.inner.next_run_id.fetch_add(1, Ordering::Relaxed); - let handle = SectionRunHandle(run_id); - self.msg_send - .send(RunnerMsg::QueueRun(handle.clone(), section, duration)) - .await?; + let handle = self.addr.send(QueueRun(section, duration)).await?; Ok(handle) } pub async fn cancel_run(&mut self, handle: SectionRunHandle) -> Result<()> { - self.msg_send.send(RunnerMsg::CancelRun(handle)).await?; + self.addr.send(CancelRun(handle)).await?; Ok(()) } pub async fn cancel_all(&mut self) -> Result<()> { - self.msg_send.send(RunnerMsg::CancelAll).await?; + self.addr.send(CancelAll).await?; Ok(()) } pub async fn pause(&mut self) -> Result<()> { - self.msg_send.send(RunnerMsg::Pause).await?; + self.addr.send(SetPaused(true)).await?; Ok(()) } pub async fn unpause(&mut self) -> Result<()> { - self.msg_send.send(RunnerMsg::Unpause).await?; + self.addr.send(SetPaused(false)).await?; Ok(()) } pub async fn subscribe(&mut self) -> Result { - let (res_send, res_recv) = oneshot::channel(); - self.msg_send.send(RunnerMsg::Subscribe(res_send)).await?; - let event_recv = res_recv.await?; + let event_recv = self.addr.send(Subscribe).await?; Ok(event_recv) } @@ -528,28 +546,21 @@ mod test { use crate::section_interface::MockSectionInterface; use crate::{ model::{Section, Sections}, - trace_listeners::{EventListener, Filters, SpanFilters, SpanListener}, + trace_listeners::{EventListener, Filters}, }; use assert_matches::assert_matches; use im::ordmap; use tracing_subscriber::prelude::*; - #[tokio::test] + #[actix_rt::test] async fn test_quit() { let quit_msg = EventListener::new( Filters::new() .target("sprinklers_rs::section_runner") - .message("runner_task recv") - .field_value("msg", "Quit"), - ); - let task_span = SpanListener::new( - SpanFilters::new() - .target("sprinklers_rs::section_runner") - .name("section_runner task"), + .message("section_runner stopped"), ); let subscriber = tracing_subscriber::registry() - .with(quit_msg.clone()) - .with(task_span.clone()); + .with(quit_msg.clone()); let _sub = tracing::subscriber::set_default(subscriber); let interface = MockSectionInterface::new(6); @@ -558,7 +569,6 @@ mod test { runner.quit().await.unwrap(); assert_eq!(quit_msg.get_count(), 1); - assert!(task_span.get_exit_count() > 1); } fn make_sections_and_interface() -> (Sections, Arc) { @@ -597,7 +607,7 @@ mod test { tokio::time::resume(); } - #[tokio::test] + #[actix_rt::test] async fn test_queue() { let (sections, interface) = make_sections_and_interface(); let mut runner = SectionRunner::new(interface.clone()); @@ -644,7 +654,7 @@ mod test { runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_cancel_run() { let (sections, interface) = make_sections_and_interface(); let mut runner = SectionRunner::new(interface.clone()); @@ -681,7 +691,7 @@ mod test { runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_cancel_all() { let (sections, interface) = make_sections_and_interface(); let mut runner = SectionRunner::new(interface.clone()); @@ -715,7 +725,7 @@ mod test { runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_pause() { let (sections, interface) = make_sections_and_interface(); let mut runner = SectionRunner::new(interface.clone()); @@ -774,7 +784,7 @@ mod test { runner.quit().await.unwrap(); } - #[tokio::test] + #[actix_rt::test] async fn test_event() { let (sections, interface) = make_sections_and_interface(); let mut runner = SectionRunner::new(interface.clone());