diff --git a/Cargo.toml b/Cargo.toml index 10c6f61..1234682 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,8 @@ env_logger = "0.7.1" color-eyre = "0.5.1" eyre = "0.6.0" thiserror = "1.0.20" -tokio = { version = "0.2.22", features = ["rt-core", "rt-threaded", "time", "sync", "macros"] } +tokio = { version = "0.2.22", features = ["rt-core", "time", "sync", "macros", "test-util"] } tracing = { version = "0.1.19", features = ["log"] } tracing-futures = "0.2.4" tracing-subscriber = { version = "0.2.11", features = ["registry"] } +pin-project = "0.4.23" diff --git a/src/section_runner.rs b/src/section_runner.rs index b2e64a5..b30c572 100644 --- a/src/section_runner.rs +++ b/src/section_runner.rs @@ -2,6 +2,7 @@ use crate::model::SectionRef; use crate::section_interface::SectionInterface; use mpsc::error::SendError; use std::{ + collections::VecDeque, sync::{ atomic::{AtomicI32, Ordering}, Arc, @@ -9,8 +10,8 @@ use std::{ time::Duration, }; use thiserror::Error; -use tokio::{spawn, sync::mpsc}; -use tracing::{trace, trace_span}; +use tokio::{spawn, sync::mpsc, time::{delay_for, Instant}}; +use tracing::{debug, trace, trace_span}; #[derive(Debug, Clone)] pub struct RunHandle(i32); @@ -34,19 +35,111 @@ enum RunnerMsg { QueueRun(RunHandle, SectionRef, Duration), } +#[derive(Debug)] +struct SecRun { + handle: RunHandle, + section: SectionRef, + duration: Duration, + start_time: Option, +} + +mod option_future { +use pin_project::pin_project; +use std::{pin::Pin, future::Future, task::{Poll, Context}, ops::Deref}; + +#[pin_project] +#[derive(Debug, Clone)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct OptionFuture(#[pin] Option); + +impl Future for OptionFuture { + type Output = Option; + + fn poll( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll { + match self.project().0.as_pin_mut() { + Some(x) => x.poll(cx).map(Some), + None => Poll::Ready(None), + } + } +} + +impl Deref for OptionFuture { + type Target = Option; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From> for OptionFuture { + fn from(option: Option) -> Self { + OptionFuture(option) + } +} + +} + +use option_future::OptionFuture; + async fn runner_task( - interface: Box, + interface: Arc, mut msg_recv: mpsc::Receiver, ) { let span = trace_span!("runner_task"); let _enter = span.enter(); - while let Some(msg) = msg_recv.recv().await { - use RunnerMsg::*; - trace!(msg = debug(&msg), "runner_task recv"); - match msg { - Quit => return, - RunnerMsg::QueueRun(_, _, _) => todo!(), + + let mut running = true; + let mut run_queue: VecDeque = VecDeque::new(); + let mut delay_future: OptionFuture<_> = None.into(); + while running { + if let Some(current_run) = run_queue.front_mut() { + let current_sec = ¤t_run.section; + let done = if let Some(start_time) = ¤t_run.start_time { + let elapsed = Instant::now() - *start_time; + elapsed >= current_run.duration + } else { + debug!(section_id = current_sec.id, "starting running section"); + interface.set_section_state(current_sec.interface_id, true); + current_run.start_time = Some(Instant::now()); + delay_future = Some(delay_for(current_run.duration)).into(); + false + }; + + if done { + debug!(section_id = current_sec.id, "finished running section"); + interface.set_section_state(current_sec.interface_id, false); + run_queue.pop_front(); + delay_future = None.into(); + continue; + } } + + let mut handle_msg = |msg: Option| { + let msg = msg.expect("SectionRunner channel closed"); + use RunnerMsg::*; + trace!(msg = debug(&msg), "runner_task recv"); + match msg { + Quit => running = false, + QueueRun(handle, section, duration) => { + run_queue.push_back(SecRun { + handle, + section, + duration, + start_time: None, + }); + } + } + }; + let delay_done = || { + trace!("delay done"); + }; + tokio::select! { + msg = msg_recv.recv() => handle_msg(msg), + _ = &mut delay_future, if delay_future.is_some() => delay_done() + }; } } @@ -69,7 +162,7 @@ pub struct SectionRunner { } impl SectionRunner { - pub fn new(interface: Box) -> Self { + pub fn new(interface: Arc) -> Self { let (msg_send, msg_recv) = mpsc::channel(8); spawn(runner_task(interface, msg_recv)); Self { @@ -117,7 +210,11 @@ impl SectionRunner { mod test { use super::*; use crate::section_interface::MockSectionInterface; - use crate::trace_listeners::{EventListener, Filters, SpanFilters, SpanListener}; + use crate::{ + model::Section, + trace_listeners::{EventListener, Filters, SpanFilters, SpanListener}, + }; + use tokio::time::{advance, pause, resume}; use tracing_subscriber::prelude::*; #[tokio::test] @@ -139,7 +236,7 @@ mod test { let _sub = tracing::subscriber::set_default(subscriber); let interface = MockSectionInterface::new(6); - let mut runner = SectionRunner::new(Box::new(interface)); + let mut runner = SectionRunner::new(Arc::new(interface)); tokio::task::yield_now().await; runner.quit().await.unwrap(); tokio::task::yield_now().await; @@ -147,4 +244,81 @@ mod test { assert_eq!(quit_msg.get_count(), 1); assert_eq!(task_span.get_exit_count(), 1); } + + #[tokio::test] + async fn test_queue() { + env_logger::builder().filter_level(log::LevelFilter::Trace).init(); + let interface = Arc::new(MockSectionInterface::new(2)); + let sections: Vec = vec![ + Arc::new(Section { + id: 1, + name: "Section 1".into(), + interface_id: 0, + }), + Arc::new(Section { + id: 2, + name: "Section 2".into(), + interface_id: 1, + }), + ]; + let mut runner = SectionRunner::new(interface.clone()); + + assert_eq!(interface.get_section_state(0), false); + assert_eq!(interface.get_section_state(1), false); + + // Queue single section, make sure it runs + runner + .queue_run(sections[0].clone(), Duration::from_secs(10)) + .await + .unwrap(); + + tokio::task::yield_now().await; + + pause(); + advance(Duration::from_secs(1)).await; + + assert_eq!(interface.get_section_state(0), true); + assert_eq!(interface.get_section_state(1), false); + + // HACK: advance should really be enough, but we need another yield_now + advance(Duration::from_secs(10)).await; + tokio::task::yield_now().await; + + assert_eq!(interface.get_section_state(0), false); + assert_eq!(interface.get_section_state(1), false); + + // Queue two sections, make sure they run one at a time + + runner + .queue_run(sections[1].clone(), Duration::from_secs(10)) + .await + .unwrap(); + + runner + .queue_run(sections[0].clone(), Duration::from_secs(10)) + .await + .unwrap(); + + advance(Duration::from_secs(1)).await; + + assert_eq!(interface.get_section_state(0), false); + assert_eq!(interface.get_section_state(1), true); + + advance(Duration::from_secs(10)).await; + tokio::task::yield_now().await; + + assert_eq!(interface.get_section_state(0), true); + assert_eq!(interface.get_section_state(1), false); + + advance(Duration::from_secs(10)).await; + tokio::task::yield_now().await; + + assert_eq!(interface.get_section_state(0), false); + assert_eq!(interface.get_section_state(1), false); + + resume(); + + runner.quit().await.unwrap(); + tokio::task::yield_now().await; + } }