diff --git a/Cargo.toml b/Cargo.toml index 7799c93..10c6f61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,7 @@ env_logger = "0.7.1" color-eyre = "0.5.1" eyre = "0.6.0" thiserror = "1.0.20" +tokio = { version = "0.2.22", features = ["rt-core", "rt-threaded", "time", "sync", "macros"] } +tracing = { version = "0.1.19", features = ["log"] } +tracing-futures = "0.2.4" +tracing-subscriber = { version = "0.2.11", features = ["registry"] } diff --git a/src/main.rs b/src/main.rs index 5fb38ac..8ae767c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,15 @@ use color_eyre::eyre::Result; use rusqlite::Connection as DbConnection; use rusqlite::NO_PARAMS; +use tracing::info; mod db; mod migrations; mod model; mod section_interface; +mod section_runner; +#[cfg(test)] +mod trace_listeners; use model::Section; @@ -30,9 +34,9 @@ fn main() -> Result<()> { color_eyre::install()?; let conn = setup_db()?; - let sections = query_sections(&conn); + let sections = query_sections(&conn)?; for sec in sections { - println!("section: {:?}", sec); + info!("section: {:?}", sec); } Ok(()) diff --git a/src/migrations.rs b/src/migrations.rs index 865ed69..4bebe32 100644 --- a/src/migrations.rs +++ b/src/migrations.rs @@ -1,9 +1,9 @@ -use log::debug; use rusqlite::NO_PARAMS; use rusqlite::{params, Connection}; use std::collections::BTreeMap; use std::ops::Bound::{Excluded, Unbounded}; use thiserror::Error; +use tracing::{debug, trace}; #[derive(Debug, Error)] pub enum MigrationError { @@ -122,12 +122,16 @@ impl Migrations { } pub fn add(&mut self, migration: Box) { - assert!(migration.version() != NO_MIGRATIONS, "migration has bad vesion"); + assert!( + migration.version() != NO_MIGRATIONS, + "migration has bad vesion" + ); self.migrations.insert(migration.version(), migration); } pub fn apply(&self, conn: &mut Connection) -> MigrationResult<()> { let db_version = get_db_version(conn)?; + trace!(db_version, "read db_version"); if db_version != 0 && !self.migrations.contains_key(&db_version) { return Err(MigrationError::VersionTooNew(db_version)); } @@ -135,7 +139,7 @@ impl Migrations { let mut trans = conn.transaction()?; let mut last_ver: MigrationVersion = NO_MIGRATIONS; for (ver, mig) in mig_range { - debug!("applying migration version {}", ver); + debug!(version = ver, "applying migration version"); mig.up(&mut trans)?; last_ver = *ver; } diff --git a/src/model/mod.rs b/src/model/mod.rs index cabb658..6880e5d 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,3 +1,3 @@ mod section; -pub use section::Section; +pub use section::{Section, SectionRef}; diff --git a/src/model/section.rs b/src/model/section.rs index cbb4141..5f1be98 100644 --- a/src/model/section.rs +++ b/src/model/section.rs @@ -1,5 +1,6 @@ use crate::section_interface::SectionId; -use rusqlite::{Row as SqlRow, Error as SqlError, ToSql}; +use rusqlite::{Error as SqlError, Row as SqlRow, ToSql}; +use std::sync::Arc; #[derive(Debug, Clone)] pub struct Section { @@ -21,3 +22,5 @@ impl Section { vec![&self.id, &self.name, &self.interface_id] } } + +pub type SectionRef = Arc
; diff --git a/src/section_interface.rs b/src/section_interface.rs index 12eb77e..b004594 100644 --- a/src/section_interface.rs +++ b/src/section_interface.rs @@ -1,10 +1,64 @@ +use std::iter::repeat_with; +use std::sync::atomic::{AtomicBool, Ordering}; +use tracing::debug; + pub type SectionId = u32; -pub trait SectionInterface { - fn num_sections() -> SectionId; - fn set_section(id: SectionId, running: bool); - fn get_section(id: SectionId) -> bool; +pub trait SectionInterface: Send { + fn num_sections(&self) -> SectionId; + fn set_section_state(&self, id: SectionId, running: bool); + fn get_section_state(&self, id: SectionId) -> bool; } +pub struct MockSectionInterface { + states: Vec, +} +impl MockSectionInterface { + pub fn new(num_sections: SectionId) -> Self { + Self { + states: repeat_with(|| AtomicBool::new(false)) + .take(num_sections as usize) + .collect(), + } + } +} +impl SectionInterface for MockSectionInterface { + fn num_sections(&self) -> SectionId { + self.states.len() as SectionId + } + fn set_section_state(&self, id: SectionId, running: bool) { + debug!(id, running, "setting section"); + self.states[id as usize].store(running, Ordering::SeqCst); + } + fn get_section_state(&self, id: SectionId) -> bool { + self.states[id as usize].load(Ordering::SeqCst) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_mock_section_interface() { + let iface = MockSectionInterface::new(6); + assert_eq!(iface.num_sections(), 6); + for i in 0..6u32 { + assert_eq!(iface.get_section_state(i), false); + } + for i in 0..6u32 { + iface.set_section_state(i, true); + } + for i in 0..6u32 { + assert_eq!(iface.get_section_state(i), true); + } + for i in 0..6u32 { + iface.set_section_state(i, false); + } + for i in 0..6u32 { + assert_eq!(iface.get_section_state(i), false); + } + } +} diff --git a/src/section_runner.rs b/src/section_runner.rs new file mode 100644 index 0000000..2ad07b4 --- /dev/null +++ b/src/section_runner.rs @@ -0,0 +1,145 @@ +use crate::model::SectionRef; +use crate::section_interface::SectionInterface; +use mpsc::error::SendError; +use std::{ + sync::{ + atomic::{AtomicI32, Ordering}, + Arc, + }, + time::Duration, +}; +use thiserror::Error; +use tokio::{spawn, sync::mpsc}; +use tracing::{trace, trace_span}; + +#[derive(Debug, Clone)] +pub struct RunHandle(i32); + +#[derive(Debug)] +struct SectionRunnerInner { + next_run_id: AtomicI32, +} + +impl SectionRunnerInner { + fn new() -> Self { + Self { + next_run_id: AtomicI32::new(1), + } + } +} + +#[derive(Clone, Debug)] +enum RunnerMsg { + Quit, + QueueRun(RunHandle, SectionRef, Duration), +} + +async fn runner_task( + interface: Box, + mut msg_recv: mpsc::Receiver, +) { + let span = trace_span!("runner_task"); + let _enter = span.enter(); + while let Some(msg) = msg_recv.recv().await { + use RunnerMsg::*; + trace!(msg = debug(&msg), "runner_task recv"); + match msg { + Quit => return, + RunnerMsg::QueueRun(_, _, _) => todo!(), + } + } +} + +#[derive(Debug, Clone, Error)] +#[error("the SectionRunner channel is closed")] +pub struct ChannelClosed; + +pub type Result = std::result::Result; + +impl From> for ChannelClosed { + fn from(_: SendError) -> Self { + Self + } +} + +#[derive(Clone, Debug)] +pub struct SectionRunner { + inner: Arc, + msg_send: mpsc::Sender, +} + +impl SectionRunner { + pub fn new(interface: Box) -> Self { + let (msg_send, msg_recv) = mpsc::channel(8); + spawn(runner_task(interface, msg_recv)); + Self { + inner: Arc::new(SectionRunnerInner::new()), + msg_send, + } + } + + pub async fn quit(&mut self) -> Result<()> { + self.msg_send.send(RunnerMsg::Quit).await?; + Ok(()) + } + + pub async fn queue_run( + &mut self, + section: SectionRef, + duration: Duration, + ) -> Result { + let run_id = self.inner.next_run_id.fetch_add(1, Ordering::Relaxed); + let handle = RunHandle(run_id); + self.msg_send + .send(RunnerMsg::QueueRun(handle.clone(), section, duration)) + .await?; + Ok(handle) + } + + pub async fn cancel_run(&mut self, handle: RunHandle) -> Result<()> { + todo!() + } + + pub async fn cancel_all(&mut self) -> Result<()> { + todo!() + } + + pub async fn pause(&mut self) -> Result<()> { + todo!() + } + + pub async fn unpause(&mut self) -> Result<()> { + todo!() + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::section_interface::MockSectionInterface; + use tracing_subscriber::prelude::*; + use crate::trace_listeners::{EventListener, Filters, SpanFilters, SpanListener}; + + #[tokio::test] + async fn test_quit() { + let quit_msg = EventListener::new( + Filters::new() + .filter_message("runner_task recv") + .filter_field_value("msg", "Quit"), + ); + let task_span = SpanListener::new(SpanFilters::new().filter_name("runner_task")); + let subscriber = tracing_subscriber::registry() + .with(quit_msg.clone()) + .with(task_span.clone()); + let _sub = tracing::subscriber::set_default(subscriber); + + let interface = MockSectionInterface::new(6); + let mut runner = SectionRunner::new(Box::new(interface)); + tokio::task::yield_now().await; + runner.quit().await.unwrap(); + tokio::task::yield_now().await; + + assert_eq!(quit_msg.get_count(), 1); + assert_eq!(task_span.get_exit_count(), 1); + } +} diff --git a/src/trace_listeners.rs b/src/trace_listeners.rs new file mode 100644 index 0000000..b9f3e0a --- /dev/null +++ b/src/trace_listeners.rs @@ -0,0 +1,279 @@ +use std::sync::atomic::{AtomicU32, Ordering}; +use std::{fmt::Debug, sync::Arc}; +use tracing::{ + field::{Field, Visit}, + Subscriber, +}; +use tracing_subscriber::{ + layer::{Context, Layer}, + registry::{LookupSpan, SpanRef}, +}; + +#[derive(Clone, Debug)] +pub struct Filters { + filter_message: Option, + filter_field: Option, + filter_field_value: Option, +} + +impl Filters { + pub fn new() -> Self { + Self { + filter_message: None, + filter_field: None, + filter_field_value: None, + } + } + + pub fn filter_message(mut self, message: impl ToString) -> Self { + self.filter_message = Some(message.to_string()); + self + } + + pub fn filter_field(mut self, field: impl ToString) -> Self { + self.filter_field = Some(field.to_string()); + self + } + + pub fn filter_field_value(mut self, field: impl ToString, value: impl ToString) -> Self { + self.filter_field = Some(field.to_string()); + self.filter_field_value = Some(value.to_string()); + self + } +} + +struct TraceListenerVisitor<'a> { + filters: &'a Filters, + right_message: bool, + right_field: bool, +} + +impl<'a> TraceListenerVisitor<'a> { + fn new(filters: &'a Filters) -> Self { + Self { + filters, + right_message: false, + right_field: false, + } + } + + fn did_match(&self) -> bool { + (self.filters.filter_message.is_none() || self.right_message) + && (self.filters.filter_field.is_none() || self.right_field) + } +} +impl<'a> Visit for TraceListenerVisitor<'a> { + fn record_debug(&mut self, field: &Field, value: &dyn Debug) { + use std::fmt::Write; + let mut value_str = String::new(); + write!(value_str, "{:?}", value).unwrap(); + if let Some(message) = &self.filters.filter_message { + if field.name() == "message" && &value_str == message { + self.right_message = true; + } + } + if let Some(filter_field) = &self.filters.filter_field { + if field.name() == filter_field { + if let Some(filter_field_value) = &self.filters.filter_field_value { + self.right_field = &value_str == filter_field_value; + } else { + self.right_field = true; + } + } + } + } +} + +#[derive(Clone, Debug)] +pub struct EventListener { + count: Arc, + filters: Filters, +} + +impl EventListener { + pub fn new(filters: Filters) -> Self { + Self { + count: Arc::new(AtomicU32::new(0)), + filters, + } + } + + pub fn get_count(&self) -> u32 { + self.count.load(Ordering::SeqCst) + } +} +impl Layer for EventListener { + fn on_event(&self, ev: &tracing::Event, _: Context) { + let mut visit = TraceListenerVisitor::new(&self.filters); + ev.record(&mut visit); + if visit.did_match() { + self.count.fetch_add(1, Ordering::SeqCst); + } + } +} + +#[derive(Clone, Debug)] +pub struct SpanFilters { + filters: Filters, + filter_name: Option, +} + +impl SpanFilters { + pub fn new() -> Self { + Self { + filters: Filters::new(), + filter_name: None, + } + } + + pub fn filter_name(mut self, name: impl ToString) -> Self { + self.filter_name = Some(name.to_string()); + self + } +} + +impl SpanFilters { + fn span_matches(&self, span: &SpanRef) -> bool + where + S: Subscriber + for<'lookup> LookupSpan<'lookup>, + { + if let Some(name) = &self.filter_name { + if span.name() != name { + return false; + } + } + true + } +} + +#[derive(Debug)] +struct SpanListenerData { + enter_count: AtomicU32, + exit_count: AtomicU32, +} +impl SpanListenerData { + fn new() -> Self { + Self { + enter_count: AtomicU32::new(0), + exit_count: AtomicU32::new(0), + } + } +} +#[derive(Clone, Debug)] +pub struct SpanListener { + data: Arc, + filters: SpanFilters, +} + +impl SpanListener { + pub fn new(filters: SpanFilters) -> Self { + Self { + data: Arc::new(SpanListenerData::new()), + filters, + } + } + + pub fn get_enter_count(&self) -> u32 { + self.data.enter_count.load(Ordering::SeqCst) + } + + pub fn get_exit_count(&self) -> u32 { + self.data.enter_count.load(Ordering::SeqCst) + } +} + +impl LookupSpan<'lookup>> Layer for SpanListener { + fn on_enter(&self, id: &tracing::span::Id, ctx: Context) { + let span = match ctx.span(id) { + Some(span) => span, + None => return, + }; + if !self.filters.span_matches(&span) { + return; + } + self.data.enter_count.fetch_add(1, Ordering::SeqCst); + } + fn on_exit(&self, id: &tracing::span::Id, ctx: Context) { + let span = match ctx.span(id) { + Some(span) => span, + None => return, + }; + if !self.filters.span_matches(&span) { + return; + } + self.data.exit_count.fetch_add(1, Ordering::SeqCst); + } +} + +mod test { + use super::*; + use tracing_subscriber::prelude::*; + use tracing::info; + + #[test] + fn test_event_listener() { + let all_listener = EventListener::new(Filters::new()); + let msg_listener = EventListener::new(Filters::new().filter_message("filter message")); + let field_listener = EventListener::new(Filters::new().filter_field("field")); + let field_value_listener = + EventListener::new(Filters::new().filter_field_value("field", 1234)); + let msg_field_value_listener = EventListener::new( + Filters::new() + .filter_message("filter message") + .filter_field_value("field", 1234), + ); + let subscriber = tracing_subscriber::registry() + .with(all_listener.clone()) + .with(msg_listener.clone()) + .with(field_listener.clone()) + .with(field_value_listener.clone()) + .with(msg_field_value_listener.clone()); + let _sub = tracing::subscriber::set_default(subscriber); + + assert_eq!(all_listener.get_count(), 0); + assert_eq!(msg_listener.get_count(), 0); + assert_eq!(field_listener.get_count(), 0); + assert_eq!(field_value_listener.get_count(), 0); + assert_eq!(msg_field_value_listener.get_count(), 0); + + info!("not filter message"); + + assert_eq!(all_listener.get_count(), 1); + assert_eq!(msg_listener.get_count(), 0); + assert_eq!(field_listener.get_count(), 0); + assert_eq!(field_value_listener.get_count(), 0); + assert_eq!(msg_field_value_listener.get_count(), 0); + + info!("filter message"); + + assert_eq!(all_listener.get_count(), 2); + assert_eq!(msg_listener.get_count(), 1); + assert_eq!(field_listener.get_count(), 0); + assert_eq!(field_value_listener.get_count(), 0); + assert_eq!(msg_field_value_listener.get_count(), 0); + + info!(field = 1, "not filter message, field not value"); + + assert_eq!(all_listener.get_count(), 3); + assert_eq!(msg_listener.get_count(), 1); + assert_eq!(field_listener.get_count(), 1); + assert_eq!(field_value_listener.get_count(), 0); + assert_eq!(msg_field_value_listener.get_count(), 0); + + info!(field = 1234, "not filter message, field and value"); + + assert_eq!(all_listener.get_count(), 4); + assert_eq!(msg_listener.get_count(), 1); + assert_eq!(field_listener.get_count(), 2); + assert_eq!(field_value_listener.get_count(), 1); + assert_eq!(msg_field_value_listener.get_count(), 0); + + info!(field = 1234, "filter message"); + + assert_eq!(all_listener.get_count(), 5); + assert_eq!(msg_listener.get_count(), 2); + assert_eq!(field_listener.get_count(), 3); + assert_eq!(field_value_listener.get_count(), 2); + assert_eq!(msg_field_value_listener.get_count(), 1); + } +}