diff --git a/src/program_runner.rs b/src/program_runner.rs index 1ea4e77..4c08e5c 100644 --- a/src/program_runner.rs +++ b/src/program_runner.rs @@ -23,6 +23,7 @@ enum RunnerMsg { pub enum ProgramEvent { RunStart(ProgramRef), RunFinish(ProgramRef), + RunCancel(ProgramRef), } pub type ProgramEventRecv = broadcast::Receiver; @@ -35,6 +36,7 @@ enum RunState { Waiting, Running, Finished, + Cancelled, } #[derive(Debug)] @@ -141,6 +143,18 @@ impl RunnerTask { } } + async fn cancel_program_run(&mut self, run: &mut ProgRun) { + for handle in run.sec_run_handles.drain(..) { + if let Err(_closed) = self.section_runner.cancel_run(handle).await { + error!("section runner channel closed"); + self.running = false; + return; + } + } + debug!(program_id = run.program.id, "program run is cancelled"); + self.send_event(ProgramEvent::RunCancel(run.program.clone())); + } + async fn process_queue(&mut self, run_queue: &mut RunQueue) { while let Some(current_run) = run_queue.front_mut() { let run_finished = match current_run.state { @@ -150,6 +164,10 @@ impl RunnerTask { } RunState::Running => false, RunState::Finished => true, + RunState::Cancelled => { + self.cancel_program_run(current_run).await; + true + } }; if run_finished { run_queue.pop_front(); @@ -189,7 +207,13 @@ impl RunnerTask { RunProgram(program) => { run_queue.push_back(ProgRun::new(program)); } - RunnerMsg::CancelProgram(_) => todo!(), + RunnerMsg::CancelProgram(program_id) => { + for run in run_queue { + if run.program.id == program_id { + run.state = RunState::Cancelled; + } + } + } } } @@ -352,8 +376,8 @@ mod test { }; use im::ordmap; use std::{sync::Arc, time::Duration}; - use tracing_subscriber::prelude::*; use tokio::task::yield_now; + use tracing_subscriber::prelude::*; #[tokio::test] async fn test_quit() { @@ -475,23 +499,19 @@ mod test { let program1: ProgramRef = Program { id: 1, name: "Program 1".into(), - sequence: vec![ - ProgramItem { - section_id: 3, - duration: Duration::from_secs(10), - }, - ], + sequence: vec![ProgramItem { + section_id: 3, + duration: Duration::from_secs(10), + }], } .into(); let program2: ProgramRef = Program { id: 2, name: "Program 2".into(), - sequence: vec![ - ProgramItem { - section_id: 1, - duration: Duration::from_secs(10), - }, - ], + sequence: vec![ProgramItem { + section_id: 1, + duration: Duration::from_secs(10), + }], } .into(); @@ -570,23 +590,19 @@ mod test { let program1: ProgramRef = Program { id: 1, name: "Program 1".into(), - sequence: vec![ - ProgramItem { - section_id: 2, - duration: Duration::from_secs(10), - }, - ], + sequence: vec![ProgramItem { + section_id: 2, + duration: Duration::from_secs(10), + }], } .into(); let program2: ProgramRef = Program { id: 2, name: "Program 2".into(), - sequence: vec![ - ProgramItem { - section_id: 2, - duration: Duration::from_secs(10), - }, - ], + sequence: vec![ProgramItem { + section_id: 2, + duration: Duration::from_secs(10), + }], } .into(); let programs = ordmap![ 1 => program1, 2 => program2 ]; @@ -634,4 +650,58 @@ mod test { sec_runner.quit().await.unwrap(); yield_now().await; } + + #[tokio::test] + async fn test_cancel_program() { + let (sections, mut sec_runner, _) = make_sections_and_runner(); + let mut sec_events = sec_runner.subscribe().await.unwrap(); + let mut runner = ProgramRunner::new(sec_runner.clone()); + let mut prog_events = runner.subscribe().await.unwrap(); + + let program: ProgramRef = Program { + id: 1, + name: "Program 1".into(), + sequence: vec![ + ProgramItem { + section_id: 1, + duration: Duration::from_secs(10), + }, + ProgramItem { + section_id: 2, + duration: Duration::from_secs(10), + }, + ], + } + .into(); + + runner.update_sections(sections.clone()).await.unwrap(); + + runner.run_program(program.clone()).await.unwrap(); + yield_now().await; + assert!(matches!( + prog_events.try_recv().unwrap(), + ProgramEvent::RunStart(prog) + if prog.id == 1 + )); + assert!(matches!( + sec_events.try_recv().unwrap(), + SectionEvent::RunStart(_) + )); + + runner.cancel_program(program.id).await.unwrap(); + yield_now().await; + assert!(matches!( + prog_events.recv().await.unwrap(), + ProgramEvent::RunCancel(prog) + if prog.id == 1 + )); + assert!(matches!( + sec_events.recv().await.unwrap(), + SectionEvent::RunCancel(_) + )); + + runner.quit().await.unwrap(); + sec_runner.quit().await.unwrap(); + yield_now().await; + } }