diff --git a/src/trace_listeners.rs b/src/trace_listeners.rs index 71ede93..2455455 100644 --- a/src/trace_listeners.rs +++ b/src/trace_listeners.rs @@ -9,13 +9,18 @@ use tracing_subscriber::{ registry::{LookupSpan, SpanRef}, }; +#[derive(Clone, Debug)] +struct FieldFilter { + name: String, + value: Option, +} + #[derive(Clone, Debug)] pub struct Filters { filter_target: Option, filter_name: Option, filter_message: Option, - filter_field: Option, - filter_field_value: Option, + filter_fields: Vec, } impl Filters { @@ -24,8 +29,7 @@ impl Filters { filter_target: None, filter_name: None, filter_message: None, - filter_field: None, - filter_field_value: None, + filter_fields: Vec::new(), } } @@ -45,13 +49,18 @@ impl Filters { } pub fn field(mut self, field: impl ToString) -> Self { - self.filter_field = Some(field.to_string()); + self.filter_fields.push(FieldFilter { + name: field.to_string(), + value: None, + }); self } pub fn field_value(mut self, field: impl ToString, value: impl ToString) -> Self { - self.filter_field = Some(field.to_string()); - self.filter_field_value = Some(value.to_string()); + self.filter_fields.push(FieldFilter { + name: field.to_string(), + value: Some(value.to_string()), + }); self } @@ -73,7 +82,7 @@ impl Filters { struct TraceListenerVisitor<'a> { filters: &'a Filters, right_message: bool, - right_field: bool, + right_fields: Vec, } impl<'a> TraceListenerVisitor<'a> { @@ -81,13 +90,13 @@ impl<'a> TraceListenerVisitor<'a> { Self { filters, right_message: false, - right_field: false, + right_fields: vec![false; filters.filter_fields.len()], } } fn did_match(&self) -> bool { (self.filters.filter_message.is_none() || self.right_message) - && (self.filters.filter_field.is_none() || self.right_field) + && self.right_fields.iter().all(|right| *right) } } impl<'a> Visit for TraceListenerVisitor<'a> { @@ -100,12 +109,17 @@ impl<'a> Visit for TraceListenerVisitor<'a> { self.right_message = true; } } - if let Some(filter_field) = &self.filters.filter_field { - if field.name() == filter_field { - if let Some(filter_field_value) = &self.filters.filter_field_value { - self.right_field = &value_str == filter_field_value; + for (filter_field, right) in self + .filters + .filter_fields + .iter() + .zip(self.right_fields.iter_mut()) + { + if field.name() == filter_field.name { + if let Some(filter_field_value) = &filter_field.value { + *right = &value_str == filter_field_value; } else { - self.right_field = true; + *right = true; } } } @@ -254,8 +268,8 @@ impl LookupSpan<'lookup>> Layer for SpanListener mod test { use super::*; - use tracing_subscriber::prelude::*; use tracing::info; + use tracing_subscriber::prelude::*; #[test] fn test_event_listener() { @@ -263,8 +277,7 @@ mod test { let all_listener = EventListener::new(Filters::new().target(std::module_path!())); let msg_listener = EventListener::new(Filters::new().message("filter message")); let field_listener = EventListener::new(Filters::new().field("field")); - let field_value_listener = - EventListener::new(Filters::new().field_value("field", 1234)); + let field_value_listener = EventListener::new(Filters::new().field_value("field", 1234)); let msg_field_value_listener = EventListener::new( Filters::new() .message("filter message") @@ -326,4 +339,42 @@ mod test { assert_eq!(field_value_listener.get_count(), 2); assert_eq!(msg_field_value_listener.get_count(), 1); } + + #[test] + fn test_filter_fields() { + let no_fields = EventListener::new(Filters::new()); + let one_field = EventListener::new(Filters::new().field("field_one")); + let two_fields = EventListener::new( + Filters::new() + .field("field_one") + .field_value("field_two", "123"), + ); + let subscriber = tracing_subscriber::registry() + .with(no_fields.clone()) + .with(one_field.clone()) + .with(two_fields.clone()); + let _sub = tracing::subscriber::set_default(subscriber); + + assert_eq!(no_fields.get_count(), 0); + assert_eq!(one_field.get_count(), 0); + assert_eq!(two_fields.get_count(), 0); + + info!("no fields message"); + + assert_eq!(no_fields.get_count(), 1); + assert_eq!(one_field.get_count(), 0); + assert_eq!(two_fields.get_count(), 0); + + info!(field_one = 123, "one field message"); + + assert_eq!(no_fields.get_count(), 2); + assert_eq!(one_field.get_count(), 1); + assert_eq!(two_fields.get_count(), 0); + + info!(field_one = 123, field_two = 123, "two field message"); + + assert_eq!(no_fields.get_count(), 3); + assert_eq!(one_field.get_count(), 2); + assert_eq!(two_fields.get_count(), 1); + } }