diff --git a/src/main.rs b/src/main.rs index 793d90b..0b671df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -63,6 +63,11 @@ async fn main() -> Result<()> { }; let mut mqtt_interface = mqtt_interface::MqttInterfaceTask::start(mqtt_options).await?; + update_listener::UpdateListener::start( + section_runner.subscribe().await?, + mqtt_interface.clone(), + ); + program_runner.update_sections(sections.clone()).await?; mqtt_interface.publish_sections(§ions).await?; for section_id in sections.keys() { diff --git a/src/program_runner.rs b/src/program_runner.rs index ee11fe5..49143df 100644 --- a/src/program_runner.rs +++ b/src/program_runner.rs @@ -274,7 +274,7 @@ impl RunnerTask { let sec_event = sec_event.wrap_err("failed to receive section event")?; #[allow(clippy::single_match)] match sec_event { - SectionEvent::RunFinish(finished_run) => { + SectionEvent::RunFinish(finished_run, _) => { self.handle_finished_run(finished_run, run_queue); } _ => {} @@ -525,15 +525,15 @@ mod test { ProgramEvent::RunStart(prog) if prog.id == 1 ); - assert_matches!(sec_events.try_recv().unwrap(), SectionEvent::RunStart(_)); + assert_matches!(sec_events.try_recv().unwrap(), SectionEvent::RunStart(_, _)); assert_eq!(interface.get_section_state(0), true); tokio::time::pause(); - assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunFinish(_)); - assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunStart(_)); + assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunFinish(_, _)); + assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunStart(_, _)); assert_eq!(interface.get_section_state(0), false); assert_eq!(interface.get_section_state(1), true); - assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunFinish(_)); + assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunFinish(_, _)); assert_matches!( prog_events.recv().await.unwrap(), ProgramEvent::RunFinish(_) @@ -715,7 +715,7 @@ mod test { ProgramEvent::RunStart(prog) if prog.id == 1 ); - assert_matches!(sec_events.try_recv().unwrap(), SectionEvent::RunStart(_)); + assert_matches!(sec_events.try_recv().unwrap(), SectionEvent::RunStart(_, _)); runner.cancel_program(program.id).await.unwrap(); yield_now().await; @@ -724,7 +724,7 @@ mod test { ProgramEvent::RunCancel(prog) if prog.id == 1 ); - assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunCancel(_)); + assert_matches!(sec_events.recv().await.unwrap(), SectionEvent::RunCancel(_, _)); runner.quit().await.unwrap(); sec_runner.quit().await.unwrap(); diff --git a/src/section_runner.rs b/src/section_runner.rs index 62c0737..0daf90b 100644 --- a/src/section_runner.rs +++ b/src/section_runner.rs @@ -17,7 +17,7 @@ use tokio::{ }; use tracing::{debug, trace, trace_span, warn}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SectionRunHandle(i32); #[derive(Debug)] @@ -44,13 +44,13 @@ enum RunnerMsg { Subscribe(oneshot::Sender), } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub enum SectionEvent { - RunStart(SectionRunHandle), - RunFinish(SectionRunHandle), - RunPause(SectionRunHandle), - RunUnpause(SectionRunHandle), - RunCancel(SectionRunHandle), + RunStart(SectionRunHandle, SectionRef), + RunFinish(SectionRunHandle, SectionRef), + RunPause(SectionRunHandle, SectionRef), + RunUnpause(SectionRunHandle, SectionRef), + RunCancel(SectionRunHandle, SectionRef), RunnerPause, RunnerUnpause, } @@ -176,7 +176,10 @@ impl RunnerTask { run.state = Running { start_time: Instant::now(), }; - self.send_event(SectionEvent::RunStart(run.handle.clone())); + self.send_event(SectionEvent::RunStart( + run.handle.clone(), + run.section.clone(), + )); } fn finish_run(&mut self, run: &mut Arc) { @@ -186,7 +189,10 @@ impl RunnerTask { self.interface .set_section_state(run.section.interface_id, false); run.state = SecRunState::Finished; - self.send_event(SectionEvent::RunFinish(run.handle.clone())); + self.send_event(SectionEvent::RunFinish( + run.handle.clone(), + run.section.clone(), + )); } else { warn!( section_id = run.section.id, @@ -204,7 +210,10 @@ impl RunnerTask { .set_section_state(run.section.interface_id, false); } run.state = SecRunState::Cancelled; - self.send_event(SectionEvent::RunCancel(run.handle.clone())); + self.send_event(SectionEvent::RunCancel( + run.handle.clone(), + run.section.clone(), + )); } fn pause_run(&mut self, run: &mut Arc) { @@ -232,7 +241,10 @@ impl RunnerTask { } }; run.state = new_state; - self.send_event(SectionEvent::RunPause(run.handle.clone())); + self.send_event(SectionEvent::RunPause( + run.handle.clone(), + run.section.clone(), + )); } fn unpause_run(&mut self, run: &mut Arc) { @@ -251,7 +263,10 @@ impl RunnerTask { }; let ran_for = pause_time - start_time; run.duration -= ran_for; - self.send_event(SectionEvent::RunUnpause(run.handle.clone())); + self.send_event(SectionEvent::RunUnpause( + run.handle.clone(), + run.section.clone(), + )); } Waiting | Finished | Cancelled | Running { .. } => { warn!( @@ -467,6 +482,7 @@ mod test { model::{Section, Sections}, trace_listeners::{EventListener, Filters, SpanFilters, SpanListener}, }; + use assert_matches::assert_matches; use im::ordmap; use tracing_subscriber::prelude::*; @@ -732,59 +748,39 @@ mod test { .await .unwrap(); - assert_eq!( + assert_matches!( event_recv.recv().await, - Ok(SectionEvent::RunStart(run1.clone())) + Ok(SectionEvent::RunStart(handle, _)) + if handle == run1 ); runner.pause().await.unwrap(); - assert_eq!(event_recv.recv().await, Ok(SectionEvent::RunnerPause)); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunPause(run1.clone())) - ); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunnerPause)); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunPause(handle, _)) if handle == run1); runner.unpause().await.unwrap(); - assert_eq!(event_recv.recv().await, Ok(SectionEvent::RunnerUnpause)); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunUnpause(run1.clone())) - ); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunnerUnpause)); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunUnpause(handle, _)) if handle == run1); advance(Duration::from_secs(11)).await; - assert_eq!(event_recv.recv().await, Ok(SectionEvent::RunFinish(run1))); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunStart(run2.clone())) - ); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunFinish(handle, _)) if handle == run1); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunStart(handle, _)) if handle == run2); runner.pause().await.unwrap(); - assert_eq!(event_recv.recv().await, Ok(SectionEvent::RunnerPause)); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunPause(run2.clone())) - ); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunnerPause)); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunPause(handle, _)) if handle == run2); // cancel paused run runner.cancel_run(run2.clone()).await.unwrap(); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunCancel(run2.clone())) - ); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunPause(run3.clone())) - ); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunCancel(handle, _)) if handle == run2); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunPause(handle, _)) if handle == run3); runner.unpause().await.unwrap(); - assert_eq!(event_recv.recv().await, Ok(SectionEvent::RunnerUnpause)); - assert_eq!( - event_recv.recv().await, - Ok(SectionEvent::RunUnpause(run3.clone())) - ); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunnerUnpause)); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunUnpause(handle, _)) if handle == run3); advance(Duration::from_secs(11)).await; - assert_eq!(event_recv.recv().await, Ok(SectionEvent::RunFinish(run3))); + assert_matches!(event_recv.recv().await, Ok(SectionEvent::RunFinish(handle, _)) if handle == run3); runner.quit().await.unwrap(); } diff --git a/src/update_listener.rs b/src/update_listener.rs new file mode 100644 index 0000000..8ddb434 --- /dev/null +++ b/src/update_listener.rs @@ -0,0 +1,63 @@ +use crate::{ + mqtt_interface::MqttInterface, + section_runner::{SectionEvent, SectionEventRecv}, +}; +use tokio::{select, sync::broadcast, task::JoinHandle}; +use tracing::trace; + +pub struct UpdateListener { + mqtt_interface: MqttInterface, + running: bool, +} + +impl UpdateListener { + pub fn start( + section_events: SectionEventRecv, + mqtt_interface: MqttInterface, + ) -> JoinHandle<()> { + let update_listener = UpdateListener { + mqtt_interface, + running: true, + }; + tokio::spawn(update_listener.run(section_events)) + } + + async fn handle_section_event( + &mut self, + event: Result, + ) -> eyre::Result<()> { + let event = match event { + Err(broadcast::RecvError::Closed) => { + trace!("section events channel closed"); + self.running = false; + return Ok(()); + } + e => e, + }?; + if let Some((sec_id, state)) = match event { + SectionEvent::RunStart(_, sec) | SectionEvent::RunUnpause(_, sec) => { + Some((sec.id, true)) + } + SectionEvent::RunFinish(_, sec) + | SectionEvent::RunPause(_, sec) + | SectionEvent::RunCancel(_, sec) => Some((sec.id, false)), + SectionEvent::RunnerPause | SectionEvent::RunnerUnpause => None, + } { + self.mqtt_interface + .publish_section_state(sec_id, state) + .await?; + } + Ok(()) + } + + pub async fn run(mut self, mut section_events: SectionEventRecv) { + while self.running { + let result = select! { + section_event = section_events.recv() => { + self.handle_section_event(section_event).await + } + }; + result.expect("error in update_listener task"); + } + } +}