diff --git a/src/bag.rs b/src/bag.rs index 2cafb7e..6132351 100644 --- a/src/bag.rs +++ b/src/bag.rs @@ -2,8 +2,16 @@ use std::{fs::File, path::Path}; use eyre::Context; use log::debug; +use rayon::iter::ParallelIterator; -use crate::{index::BagIndex, info::BagInfo, message::compute_layout, reader::MmapReader, Result}; +use crate::{ + chunk::{read_chunks_messages, MessageData}, + index::BagIndex, + info::BagInfo, + message::compute_layout, + reader::MmapReader, + Result, +}; pub struct Bag { reader: MmapReader, @@ -29,6 +37,12 @@ impl Bag { } pub fn compute_info(&mut self) -> Result { - BagInfo::compute(&mut self.reader, &self.index) + let reader = self.reader.clone(); + BagInfo::compute(|| reader.clone(), &self.index) + } + + pub fn read_messages(&mut self) -> impl ParallelIterator> + '_ { + let reader = self.reader.clone(); + read_chunks_messages(move || reader.clone(), &self.index.chunks) } } diff --git a/src/chunk.rs b/src/chunk.rs index e5c2172..170bd7a 100644 --- a/src/chunk.rs +++ b/src/chunk.rs @@ -1,13 +1,19 @@ -use std::{io, str::FromStr}; +use std::{ + io::{self, Read, SeekFrom}, + mem, + str::FromStr, +}; -use eyre::bail; +use bytes::Bytes; +use eyre::{bail, eyre, Context}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use crate::{ error, + index::ChunkInfo, parse::{Header, Op}, - reader::BagReader, - Error, Result, - Time + reader::{BagReader, BytesReader}, + Error, Result, Time, }; #[derive(Clone, Copy, Debug)] @@ -31,12 +37,18 @@ impl FromStr for Compression { } impl Compression { - pub fn decompress_stream<'a, R: io::Read + 'a>(self, read: R) -> Box { + pub fn decompress(self, mut read: R, decompressed_size: usize) -> Result> { + let mut decompressed = Vec::with_capacity(decompressed_size); match self { - Compression::None => Box::new(read), + Compression::None => { + read.read_to_end(&mut decompressed)?; + } Compression::Bz2 => todo!("bz2 decompression"), - Compression::Lz4 => Box::new(lz4_flex::frame::FrameDecoder::new(read)), + Compression::Lz4 => { + lz4_flex::frame::FrameDecoder::new(read).read_to_end(&mut decompressed)?; + } } + Ok(decompressed) } } @@ -60,19 +72,41 @@ impl ChunkHeader { } } -pub fn read_chunk(bag_reader: &mut R, pos: u64) -> Result> { - bag_reader.seek(io::SeekFrom::Start(pos))?; - +pub fn read_chunk_data(bag_reader: &mut R) -> Result> { let chunk_header = ChunkHeader::read(bag_reader)?; let compressed_data = bag_reader.read_data()?; - let mut data = Vec::with_capacity(chunk_header.uncompressed_size as usize); - let mut decompresor = chunk_header.compression.decompress_stream(compressed_data); - io::copy(&mut decompresor, &mut data)?; + let data = chunk_header + .compression + .decompress(compressed_data, chunk_header.uncompressed_size as usize)?; Ok(data) } +pub fn read_chunk_data_at( + bag_reader: &mut R, + pos: u64, +) -> Result> { + bag_reader.seek(SeekFrom::Start(pos as u64))?; + read_chunk_data(bag_reader).wrap_err_with(|| eyre!("failed to read chunk at offset {}", pos)) +} + +pub fn read_chunks_data<'a, R, F, C>( + make_reader: F, + chunks: C, +) -> impl ParallelIterator>> + 'a +where + R: BagReader + io::Seek, + F: Fn() -> R + Send + Sync + 'a, + C: IntoParallelIterator + 'a, +{ + chunks.into_par_iter().map(move |chunk| { + let mut reader = make_reader(); + read_chunk_data_at(&mut reader, chunk.pos) + }) +} + +#[derive(Debug)] pub struct MessageDataHeader { pub conn_id: u32, pub time: Time, @@ -91,3 +125,66 @@ impl MessageDataHeader { }) } } + +#[derive(Debug)] +pub struct MessageData { + pub header: MessageDataHeader, + pub data: Bytes, +} + +pub fn read_chunks_messages<'a, R, F, C>( + make_reader: F, + chunks: C, +) -> impl ParallelIterator> + 'a +where + R: BagReader + io::Seek + 'a, + F: Fn() -> R + Send + Sync + 'a, + C: IntoParallelIterator + 'a, +{ + read_chunks_data(make_reader, chunks).flat_map_iter(move |data| ChunkMessageIterator { + reader: data.map(|data| BytesReader::from(Bytes::from(data))), + }) +} + +pub struct ChunkMessageIterator { + reader: Result, +} + +impl ChunkMessageIterator { + fn next_impl(&mut self) -> Result> { + let reader = match &mut self.reader { + Ok(reader) => reader, + Err(err) => { + // workaround for eyre::Report not being clone + let mut new_err = eyre!("original error already consumed"); + mem::swap(err, &mut new_err); + return Err(new_err); + } + }; + while reader.remaining() > 0 { + let header = reader.read_header()?; + let op = header.read_op()?; + + match op { + Op::MsgData => { + let header = MessageDataHeader::from_header(header)?; + let data = reader.read_data_bytes()?; + return Ok(Some(MessageData { header, data })); + } + Op::Connection => { + reader.skip_data()?; + } + _ => bail!("unexpected op in chunk: {:?}", op), + } + } + Ok(None) + } +} + +impl Iterator for ChunkMessageIterator { + type Item = Result; + + fn next(&mut self) -> Option { + self.next_impl().transpose() + } +} diff --git a/src/index.rs b/src/index.rs index 1af4d88..cae0979 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,4 +1,4 @@ -use std::io::SeekFrom; +use std::io; use eyre::bail; use log::trace; @@ -95,7 +95,7 @@ pub struct BagIndex { } impl BagIndex { - fn read_v2(reader: &mut R) -> Result { + fn read_v2(reader: &mut R) -> Result { let file_header = reader.read_header_op(Op::FileHeader)?; let data_length = reader.read_data_length()?; @@ -112,7 +112,7 @@ impl BagIndex { bail!(error::UnsupportedEncryptor(encryptor)); } - reader.seek(SeekFrom::Start(index_pos))?; + reader.seek(io::SeekFrom::Start(index_pos))?; let conn_count = file_header.read_u32(b"conn_count")?; trace!("connection count: {}", conn_count); @@ -141,7 +141,7 @@ impl BagIndex { }) } - pub fn read(reader: &mut R) -> Result { + pub fn read(reader: &mut R) -> Result { let version = reader.read_version()?; trace!("bag version: {}", version); if (version.major, version.minor) == (2, 0) { diff --git a/src/info.rs b/src/info.rs index 97f95aa..01f01ec 100644 --- a/src/info.rs +++ b/src/info.rs @@ -1,13 +1,11 @@ use std::{collections::HashMap, io}; -use eyre::{bail, Context}; use rayon::prelude::*; use crate::{ - chunk::{read_chunk, ChunkHeader, MessageDataHeader}, + chunk::{read_chunks_messages, ChunkHeader}, index::{BagIndex, IndexData}, - parse::Op, - reader::{BagReader, SliceReader}, + reader::BagReader, Result, }; @@ -26,15 +24,16 @@ impl BagInfo { self } - pub fn compute( - reader: &mut R, - index: &BagIndex, - ) -> Result { + pub fn compute(make_reader: F, index: &BagIndex) -> Result + where + R: BagReader + io::Seek, + F: Fn() -> R + Send + Sync, + { index .chunks .par_iter() - .try_fold(BagInfo::default, |mut info, chunk| -> Result<_> { - let mut reader = reader.clone(); + .try_fold(BagInfo::default, move |mut info, chunk| -> Result<_> { + let mut reader = make_reader(); reader.seek(io::SeekFrom::Start(chunk.pos))?; let chunk_header = ChunkHeader::read(&mut reader)?; info.total_uncompressed += chunk_header.uncompressed_size as u64; @@ -49,33 +48,17 @@ impl BagInfo { .try_reduce(BagInfo::default, |a, b| Ok(a.combine(b))) } - pub fn compute_without_index( - reader: &mut R, - index: &BagIndex, - ) -> Result { - index - .chunks - .par_iter() + pub fn compute_without_index(make_reader: F, index: &BagIndex) -> Result + where + R: BagReader + io::Seek, + F: Fn() -> R + Send + Sync, + { + read_chunks_messages(make_reader, &index.chunks) .try_fold(BagInfo::default, |mut info, chunk| -> Result<_> { - let mut reader = reader.clone(); - let data = read_chunk(&mut reader, chunk.pos) - .wrap_err_with(|| format!("failed to read chunk: {:#?}", chunk))?; - info.total_uncompressed += data.len() as u64; - let mut chunk_reader = SliceReader::from(data); - while chunk_reader.remaining() > 0 { - let header = chunk_reader.read_header()?; - let op = header.read_op()?; - match op { - Op::MsgData => { - let header = MessageDataHeader::from_header(header)?; - let count = info.per_connection.entry(header.conn_id).or_insert(0); - *count += 1; - chunk_reader.skip_data()?; - } - Op::Connection => chunk_reader.skip_data()?, - _ => bail!("unexpected op in chunk: {:?}", op), - } - } + let data = chunk?; + info.total_uncompressed += data.data.len() as u64; + let count = info.per_connection.entry(data.header.conn_id).or_insert(0); + *count += 1; Ok(info) }) .try_reduce(BagInfo::default, |a, b| Ok(a.combine(b))) diff --git a/src/reader.rs b/src/reader.rs index d8a7b45..72eb0f3 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,5 +1,4 @@ -use std::io::SeekFrom; - +use ::bytes::Bytes; use nom::multi::length_data; use nom::number::streaming::le_u32; @@ -11,18 +10,19 @@ mod io; #[cfg(feature = "mmap")] mod mmap; mod slice; +mod bytes; #[cfg(feature = "mmap")] pub use self::mmap::MmapReader; -pub use self::{io::IoReader, slice::SliceReader}; +pub use self::{io::IoReader, slice::SliceReader, bytes::BytesReader}; pub trait BagReader { + fn skip(&mut self, amount: usize) -> Result<()>; + fn read_parser<'a, O: 'a, P>(&'a mut self, parser: P) -> Result where P: nom::Parser<&'a [u8], O, parse::Error<&'a [u8]>>; - fn seek(&mut self, pos: SeekFrom) -> Result<()>; - fn read_version(&mut self) -> Result { self.read_parser(Version::parse) } @@ -46,11 +46,15 @@ pub trait BagReader { fn skip_data(&mut self) -> Result<()> { let data_length = self.read_data_length()?; - self.seek(SeekFrom::Current(data_length as i64))?; + self.skip(data_length as usize)?; Ok(()) } fn read_data(&mut self) -> Result<&[u8]> { self.read_parser(length_data(le_u32)) } + + fn read_data_bytes(&mut self) -> Result { + Ok(Bytes::copy_from_slice(self.read_data()?)) + } } diff --git a/src/reader/bytes.rs b/src/reader/bytes.rs new file mode 100644 index 0000000..c0ee702 --- /dev/null +++ b/src/reader/bytes.rs @@ -0,0 +1,77 @@ +use bytes::{Buf, Bytes}; +use eyre::bail; + +use super::{error::UnexpectedEof, BagReader}; +use crate::{parse, Result}; + +// This is like a SliceReader, except a byte is reference +// counted to read_data_bytes can give out unlimited-lifetime references +// to the same buffer. +#[derive(Clone)] +pub struct BytesReader { + current: Bytes, + last: Option, +} + +impl BytesReader { + pub fn into_inner(self) -> Bytes { + self.current + } + + pub fn remaining(&self) -> usize { + self.as_ref().remaining() + } +} + +impl AsRef for BytesReader { + fn as_ref(&self) -> &Bytes { + &self.current + } +} + +impl AsMut for BytesReader { + fn as_mut(&mut self) -> &mut Bytes { + &mut self.current + } +} + +impl From for BytesReader { + fn from(bytes: Bytes) -> Self { + Self { + current: bytes, + last: None, + } + } +} + +impl BagReader for BytesReader { + fn skip(&mut self, amount: usize) -> Result<()> { + self.current.advance(amount); + Ok(()) + } + + fn read_parser<'a, O: 'a, P>(&'a mut self, mut parser: P) -> Result + where + P: nom::Parser<&'a [u8], O, parse::Error<&'a [u8]>>, + { + // Store the current buffer to last so we can return a reference + // to it. + self.last = Some(self.current.clone()); + let buf = self.last.as_ref().unwrap(); + match parser.parse(buf.as_ref()) { + Ok((rest, output)) => { + // Modify current to refer to rest. + // This is why parse last, as otherwise current would be borrowed here. + self.current = buf.slice_ref(rest); + Ok(output) + } + Err(nom::Err::Incomplete(_)) => bail!(UnexpectedEof), + Err(nom::Err::Error(e) | nom::Err::Failure(e)) => Err(e.into_owned().into()), + } + } + + fn read_data_bytes(&mut self) -> Result { + let len = self.read_data_length()?; + Ok(self.current.split_to(len as usize)) + } +} diff --git a/src/reader/io.rs b/src/reader/io.rs index 70a1f5c..c594dbf 100644 --- a/src/reader/io.rs +++ b/src/reader/io.rs @@ -1,4 +1,4 @@ -use std::io; +use std::io::{self, Seek}; use bytes::{Buf, BytesMut}; @@ -13,7 +13,7 @@ pub struct IoReader { consumed: usize, } -impl IoReader { +impl IoReader { pub fn new(read: R) -> Self { Self { read, @@ -38,6 +38,11 @@ impl IoReader { } impl BagReader for IoReader { + fn skip(&mut self, amount: usize) -> Result<()> { + self.seek(io::SeekFrom::Current(amount as i64))?; + Ok(()) + } + fn read_parser<'a, O: 'a, P>(&'a mut self, mut parser: P) -> Result where P: nom::Parser<&'a [u8], O, parse::Error<&'a [u8]>>, @@ -68,8 +73,10 @@ impl BagReader for IoReader { } } } +} - fn seek(&mut self, mut pos: io::SeekFrom) -> Result<()> { +impl io::Seek for IoReader { + fn seek(&mut self, mut pos: io::SeekFrom) -> io::Result { if let io::SeekFrom::Current(pos) = &mut pos { // If seeking relative to current position, subtract data // read from the file but not yet consumed. @@ -78,14 +85,14 @@ impl BagReader for IoReader { if *pos >= 0 && new_pos < 0 { // The new position is within the already read data, just consume more data self.consumed += *pos as usize; - return Ok(()); + // TODO: compute this correctly + return Ok(0); } *pos = new_pos; } self.buffer.clear(); self.consumed = 0; - self.read.seek(pos)?; - Ok(()) + self.read.seek(pos) } } diff --git a/src/reader/slice.rs b/src/reader/slice.rs index 26e9af4..08bdf54 100644 --- a/src/reader/slice.rs +++ b/src/reader/slice.rs @@ -48,6 +48,12 @@ where T: Deref, U: AsRef<[u8]> + ?Sized + 'static, { + fn skip(&mut self, amount: usize) -> Result<()> { + // TODO: bounds checking + self.pos += amount; + Ok(()) + } + fn read_parser<'a, O: 'a, P>(&'a mut self, mut parser: P) -> Result where P: nom::Parser<&'a [u8], O, parse::Error<&'a [u8]>>, @@ -63,8 +69,15 @@ where Err(nom::Err::Error(e) | nom::Err::Failure(e)) => Err(e.into_owned().into()), } } +} - fn seek(&mut self, pos: io::SeekFrom) -> Result<()> { +impl io::Seek for SliceReader +where + T: Deref, + U: AsRef<[u8]> + ?Sized + 'static, +{ + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + // TODO: bounds checking match pos { io::SeekFrom::Start(pos) => { self.pos = pos as usize; @@ -76,6 +89,6 @@ where self.pos = ((self.pos as i64) + pos) as usize; } } - Ok(()) + Ok(self.pos as u64) } }