diff --git a/src/section_runner.rs b/src/section_runner.rs index b30c572..557a1e3 100644 --- a/src/section_runner.rs +++ b/src/section_runner.rs @@ -10,7 +10,11 @@ use std::{ time::Duration, }; use thiserror::Error; -use tokio::{spawn, sync::mpsc, time::{delay_for, Instant}}; +use tokio::{ + spawn, + sync::mpsc, + time::{delay_for, Instant}, +}; use tracing::{debug, trace, trace_span}; #[derive(Debug, Clone)] @@ -44,44 +48,45 @@ struct SecRun { } mod option_future { -use pin_project::pin_project; -use std::{pin::Pin, future::Future, task::{Poll, Context}, ops::Deref}; + use pin_project::pin_project; + use std::{ + future::Future, + ops::Deref, + pin::Pin, + task::{Context, Poll}, + }; -#[pin_project] -#[derive(Debug, Clone)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct OptionFuture(#[pin] Option); - -impl Future for OptionFuture { - type Output = Option; - - fn poll( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll { - match self.project().0.as_pin_mut() { - Some(x) => x.poll(cx).map(Some), - None => Poll::Ready(None), + #[pin_project] + #[derive(Debug, Clone)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct OptionFuture(#[pin] Option); + + impl Future for OptionFuture { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().0.as_pin_mut() { + Some(x) => x.poll(cx).map(Some), + None => Poll::Ready(None), + } } } -} -impl Deref for OptionFuture { - type Target = Option; + impl Deref for OptionFuture { + type Target = Option; - fn deref(&self) -> &Self::Target { - &self.0 + fn deref(&self) -> &Self::Target { + &self.0 + } } -} -impl From> for OptionFuture { - fn from(option: Option) -> Self { - OptionFuture(option) + impl From> for OptionFuture { + fn from(option: Option) -> Self { + OptionFuture(option) + } } } -} - use option_future::OptionFuture; async fn runner_task( @@ -214,7 +219,7 @@ mod test { model::Section, trace_listeners::{EventListener, Filters, SpanFilters, SpanListener}, }; - use tokio::time::{advance, pause, resume}; + use tokio::time::{pause, resume}; use tracing_subscriber::prelude::*; #[tokio::test] @@ -245,9 +250,7 @@ mod test { assert_eq!(task_span.get_exit_count(), 1); } - #[tokio::test] - async fn test_queue() { - env_logger::builder().filter_level(log::LevelFilter::Trace).init(); + fn make_sections_and_interface() -> (Vec, Arc) { let interface = Arc::new(MockSectionInterface::new(2)); let sections: Vec = vec![ Arc::new(Section { @@ -261,10 +264,32 @@ mod test { interface_id: 1, }), ]; + (sections, interface) + } + + fn assert_section_states(interface: &MockSectionInterface, states: &[bool]) { + for (id, state) in states.iter().enumerate() { + assert_eq!( + interface.get_section_state(id as u32), + *state, + "section interface id {} did not match", + id + ); + } + } + + async fn advance(dur: Duration) { + // HACK: advance should really be enough, but we need another yield_now + tokio::time::advance(Duration::from_secs(10)).await; + tokio::task::yield_now().await; + } + + #[tokio::test] + async fn test_queue() { + let (sections, interface) = make_sections_and_interface(); let mut runner = SectionRunner::new(interface.clone()); - assert_eq!(interface.get_section_state(0), false); - assert_eq!(interface.get_section_state(1), false); + assert_section_states(&interface, &[false, false]); // Queue single section, make sure it runs runner @@ -277,18 +302,13 @@ mod test { pause(); advance(Duration::from_secs(1)).await; - assert_eq!(interface.get_section_state(0), true); - assert_eq!(interface.get_section_state(1), false); + assert_section_states(&interface, &[true, false]); - // HACK: advance should really be enough, but we need another yield_now advance(Duration::from_secs(10)).await; - tokio::task::yield_now().await; - assert_eq!(interface.get_section_state(0), false); - assert_eq!(interface.get_section_state(1), false); + assert_section_states(&interface, &[false, false]); // Queue two sections, make sure they run one at a time - runner .queue_run(sections[1].clone(), Duration::from_secs(10)) .await @@ -301,20 +321,60 @@ mod test { advance(Duration::from_secs(1)).await; - assert_eq!(interface.get_section_state(0), false); - assert_eq!(interface.get_section_state(1), true); + assert_section_states(&interface, &[false, true]); advance(Duration::from_secs(10)).await; - tokio::task::yield_now().await; - assert_eq!(interface.get_section_state(0), true); - assert_eq!(interface.get_section_state(1), false); + assert_section_states(&interface, &[true, false]); advance(Duration::from_secs(10)).await; + + assert_section_states(&interface, &[false, false]); + + resume(); + + runner.quit().await.unwrap(); + tokio::task::yield_now().await; + } + + #[tokio::test] + async fn test_cancel_run() { + env_logger::builder() + .filter_level(log::LevelFilter::Trace) + .init(); + let (sections, interface) = make_sections_and_interface(); + let mut runner = SectionRunner::new(interface.clone()); + + let run1 = runner + .queue_run(sections[1].clone(), Duration::from_secs(10)) + .await + .unwrap(); + + let run2 = runner + .queue_run(sections[0].clone(), Duration::from_secs(10)) + .await + .unwrap(); + + let run3 = runner + .queue_run(sections[1].clone(), Duration::from_secs(10)) + .await + .unwrap(); + + pause(); + + advance(Duration::from_secs(1)).await; + + assert_section_states(&interface, &[false, true]); + + runner.cancel_run(run1).await.unwrap(); tokio::task::yield_now().await; - assert_eq!(interface.get_section_state(0), false); - assert_eq!(interface.get_section_state(1), false); + assert_section_states(&interface, &[true, false]); + + runner.cancel_run(run3).await.unwrap(); + advance(Duration::from_secs(10)).await; + + assert_section_states(&interface, &[false, false]); resume();