diff --git a/vortex-flatbuffers/src/lib.rs b/vortex-flatbuffers/src/lib.rs index 8278d68274..5c6b214356 100644 --- a/vortex-flatbuffers/src/lib.rs +++ b/vortex-flatbuffers/src/lib.rs @@ -1,8 +1,7 @@ use std::io; -use std::io::{Read, Write}; +use std::io::Write; -use flatbuffers::{root, FlatBufferBuilder, Follow, Verifiable, WIPOffset}; -use vortex_error::{vortex_err, VortexResult}; +use flatbuffers::{FlatBufferBuilder, WIPOffset}; pub trait FlatBufferRoot {} @@ -35,35 +34,6 @@ impl FlatBufferToBytes for F { } } -pub trait FlatBufferReader { - /// Returns Ok(None) if the reader has reached EOF. - fn read_message<'a, F>(&mut self, buffer: &'a mut Vec) -> VortexResult> - where - F: 'a + Follow<'a, Inner = F> + Verifiable; -} - -impl FlatBufferReader for R { - fn read_message<'a, F>(&mut self, buffer: &'a mut Vec) -> VortexResult> - where - F: 'a + Follow<'a, Inner = F> + Verifiable, - { - let mut msg_size: [u8; 4] = [0; 4]; - if let Err(e) = self.read_exact(&mut msg_size) { - return match e.kind() { - io::ErrorKind::UnexpectedEof => Ok(None), - _ => Err(vortex_err!(IOError: e)), - }; - } - let msg_size = u32::from_le_bytes(msg_size) as u64; - if msg_size == 0 { - // FIXME(ngates): I think this is wrong. - return Ok(None); - } - self.take(msg_size).read_to_end(buffer)?; - Ok(Some(root::(buffer)?)) - } -} - pub trait FlatBufferWriter { // Write the given FlatBuffer message, appending padding until the total bytes written // are a multiple of `alignment`. diff --git a/vortex-ipc/src/lib.rs b/vortex-ipc/src/lib.rs index 04c31298d1..f5a502b18a 100644 --- a/vortex-ipc/src/lib.rs +++ b/vortex-ipc/src/lib.rs @@ -70,8 +70,10 @@ mod tests { let mut cursor = Cursor::new(Vec::new()); let ctx = SerdeContext::default(); - let mut writer = StreamWriter::try_new_unbuffered(&mut cursor, ctx).unwrap(); - writer.write_array(&arr).unwrap(); + { + let mut writer = StreamWriter::try_new_unbuffered(&mut cursor, ctx).unwrap(); + writer.write_array(&arr).unwrap(); + } cursor.flush().unwrap(); cursor.set_position(0); diff --git a/vortex-ipc/src/reader.rs b/vortex-ipc/src/reader.rs index 64b3c8e854..35f5cd7ebe 100644 --- a/vortex-ipc/src/reader.rs +++ b/vortex-ipc/src/reader.rs @@ -1,7 +1,9 @@ use std::io; use std::io::{BufReader, Read}; +use std::marker::PhantomData; use arrow_buffer::Buffer as ArrowBuffer; +use flatbuffers::{root, root_unchecked}; use nougat::gat; use vortex::array::chunked::ChunkedArray; use vortex::array::composite::VORTEX_COMPOSITE_EXTENSIONS; @@ -9,7 +11,7 @@ use vortex::buffer::Buffer; use vortex::stats::{ArrayStatistics, Stat}; use vortex::{Array, ArrayView, IntoArray, OwnedArray, SerdeContext, ToArray, ToStatic}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; -use vortex_flatbuffers::{FlatBufferReader, ReadFlatBuffer}; +use vortex_flatbuffers::ReadFlatBuffer; use vortex_schema::{DType, DTypeSerdeContext}; use crate::flatbuffers::ipc::Message; @@ -18,9 +20,8 @@ use crate::iter::{FallibleLendingIterator, FallibleLendingIteratorāļžItem}; #[allow(dead_code)] pub struct StreamReader { read: R, - - pub(crate) ctx: SerdeContext, - // Optionally take a projection? + messages: StreamMessageReader, + ctx: SerdeContext, } impl StreamReader> { @@ -31,16 +32,27 @@ impl StreamReader> { impl StreamReader { pub fn try_new_unbuffered(mut read: R) -> VortexResult { - let mut msg_vec = Vec::new(); - let fb_msg = read - .read_message::(&mut msg_vec)? - .ok_or_else(|| vortex_err!(InvalidSerde: "Unexpected EOF reading IPC format"))?; - let fb_ctx = fb_msg.header_as_context().ok_or_else( - || vortex_err!(InvalidSerde: "Expected IPC Context as first message in stream"), - )?; - let ctx: SerdeContext = fb_ctx.try_into()?; - - Ok(Self { read, ctx }) + let mut messages = StreamMessageReader::try_new(&mut read)?; + match messages.peek() { + None => vortex_bail!("IPC stream is empty"), + Some(msg) => { + if msg.header_as_context().is_none() { + vortex_bail!(InvalidSerde: "Expected IPC Context as first message in stream") + } + } + } + + let ctx: SerdeContext = messages + .next(&mut read)? + .header_as_context() + .unwrap() + .try_into()?; + + Ok(Self { + read, + messages, + ctx, + }) } /// Read a single array from the IPC stream. @@ -69,48 +81,49 @@ impl FallibleLendingIterator for StreamReader { type Item<'next> = StreamArrayReader<'next, R> where Self: 'next; fn next(&mut self) -> Result>, Self::Error> { - let mut fb_vec = Vec::new(); - let msg = self.read.read_message::(&mut fb_vec)?; - if msg.is_none() { - // End of the stream + if self + .messages + .peek() + .and_then(|msg| msg.header_as_schema()) + .is_none() + { return Ok(None); } - let msg = msg.unwrap(); - // FIXME(ngates): parse the schema? - let schema = msg + let schema_msg = self + .messages + .next(&mut self.read)? .header_as_schema() - .ok_or_else(|| vortex_err!(InvalidSerde: "Expected IPC Schema message"))?; + .unwrap(); // TODO(ngates): construct this from the SerdeContext. let dtype_ctx = DTypeSerdeContext::new(VORTEX_COMPOSITE_EXTENSIONS.iter().map(|e| e.id()).collect()); let dtype = DType::read_flatbuffer( &dtype_ctx, - &schema + &schema_msg .dtype() .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?, ) .map_err(|e| vortex_err!(InvalidSerde: "Failed to parse DType: {}", e))?; - // Figure out how many columns we have and therefore how many buffers there? Ok(Some(StreamArrayReader { - read: &mut self.read, ctx: &self.ctx, + read: &mut self.read, + messages: &mut self.messages, dtype, buffers: vec![], - column_msg_buffer: vec![], })) } } #[allow(dead_code)] pub struct StreamArrayReader<'a, R: Read> { - read: &'a mut R, ctx: &'a SerdeContext, + read: &'a mut R, + messages: &'a mut StreamMessageReader, dtype: DType, buffers: Vec>, - column_msg_buffer: Vec, } impl<'a, R: Read> StreamArrayReader<'a, R> { @@ -136,55 +149,37 @@ impl<'iter, R: Read> FallibleLendingIterator for StreamArrayReader<'iter, R> { type Item<'next> = Array<'next> where Self: 'next; fn next(&mut self) -> Result>, Self::Error> { - self.column_msg_buffer.clear(); - let msg = self - .read - .read_message::(&mut self.column_msg_buffer)?; - if msg.is_none() { - // End of the stream + let Some(chunk_msg) = self.messages.peek().and_then(|msg| msg.header_as_chunk()) else { return Ok(None); - } - let msg = msg.unwrap(); - - let chunk_msg = msg - .header_as_chunk() - .ok_or_else(|| vortex_err!(InvalidSerde: "Expected IPC Chunk message")) - .unwrap(); - let col_array = chunk_msg - .array() - .ok_or_else(|| vortex_err!(InvalidSerde: "Chunk column missing Array")) - .unwrap(); + }; // Read all the column's buffers - // TODO(ngates): read into a single buffer, then Arc::clone and slice self.buffers.clear(); let mut offset = 0; for buffer in chunk_msg.buffers().unwrap_or_default().iter() { - let to_kill = buffer.offset() - offset; - io::copy(&mut self.read.take(to_kill), &mut io::sink()).unwrap(); - - let buffer_length = buffer.length(); - let mut bytes = Vec::with_capacity(buffer_length as usize); - let bytes_read = self - .read - .take(buffer.length()) - .read_to_end(&mut bytes) - .unwrap(); - if bytes_read < buffer_length as usize { - return Err(vortex_err!(InvalidSerde: "Unexpected EOF reading buffer")); - } + let _skip = buffer.offset() - offset; + self.read.skip(buffer.offset() - offset)?; + // TODO(ngates): read into a single buffer, then Arc::clone and slice + let mut bytes = Vec::with_capacity(buffer.length() as usize); + self.read.read_into(buffer.length(), &mut bytes)?; let arrow_buffer = ArrowBuffer::from_vec(bytes); - assert_eq!(arrow_buffer.len(), buffer_length as usize); self.buffers.push(Buffer::Owned(arrow_buffer)); offset = buffer.offset() + buffer.length(); } // Consume any remaining padding after the final buffer. - let to_kill = chunk_msg.buffer_size() - offset; - io::copy(&mut self.read.take(to_kill), &mut io::sink()).unwrap(); + self.read.skip(chunk_msg.buffer_size() - offset)?; + // After reading the buffers we're now able to load the next message. + let col_array = self + .messages + .next(self.read)? + .header_as_chunk() + .unwrap() + .array() + .unwrap(); let view = ArrayView::try_new(self.ctx, &self.dtype, col_array, self.buffers.as_slice())?; // Validate it @@ -193,3 +188,135 @@ impl<'iter, R: Read> FallibleLendingIterator for StreamArrayReader<'iter, R> { Ok(Some(view.into_array())) } } + +pub trait ReadExtensions: Read { + /// Skip n bytes in the stream. + fn skip(&mut self, nbytes: u64) -> io::Result<()> { + io::copy(&mut self.take(nbytes), &mut io::sink())?; + Ok(()) + } + + /// Read exactly nbytes into the buffer. + fn read_into(&mut self, nbytes: u64, buffer: &mut Vec) -> VortexResult<()> { + buffer.reserve_exact(nbytes as usize); + if self.take(nbytes).read_to_end(buffer)? != nbytes as usize { + vortex_bail!(InvalidSerde: "Failed to read all bytes") + } + Ok(()) + } +} + +impl ReadExtensions for R {} + +struct StreamMessageReader { + message: Vec, + prev_message: Vec, + finished: bool, + phantom: PhantomData, +} + +impl StreamMessageReader { + pub fn try_new(read: &mut R) -> VortexResult { + let mut reader = Self { + message: Vec::new(), + prev_message: Vec::new(), + finished: false, + phantom: PhantomData, + }; + reader.load_next_message(read)?; + Ok(reader) + } + + pub fn peek(&self) -> Option { + if self.finished { + return None; + } + // The message has been validated by the next() call. + Some(unsafe { root_unchecked::(&self.message) }) + } + + pub fn next(&mut self, read: &mut R) -> VortexResult { + if self.finished { + panic!("StreamMessageReader is finished - should've checked peek!"); + } + std::mem::swap(&mut self.prev_message, &mut self.message); + if !self.load_next_message(read)? { + self.finished = true; + } + Ok(unsafe { root_unchecked::(&self.prev_message) }) + } + + fn load_next_message(&mut self, read: &mut R) -> VortexResult { + let mut len_buf = [0u8; 4]; + match read.read_exact(&mut len_buf) { + Ok(_) => {} + Err(e) => { + return match e.kind() { + io::ErrorKind::UnexpectedEof => Ok(false), + _ => Err(e.into()), + } + } + } + + let len = u32::from_le_bytes(len_buf); + if len == u32::MAX { + // Marker for no more messages. + return Ok(false); + } + + self.message.clear(); + self.message.reserve(len as usize); + if read.take(len as u64).read_to_end(&mut self.message)? != len as usize { + vortex_bail!(InvalidSerde: "Failed to read all bytes") + } + + std::hint::black_box(root::(&self.message)?); + Ok(true) + } +} + +#[cfg(test)] +mod tests { + use std::io::{Cursor, Read, Write}; + + use vortex::array::chunked::{Chunked, ChunkedArray}; + use vortex::array::primitive::{Primitive, PrimitiveArray}; + use vortex::{ArrayDType, ArrayDef, IntoArray, SerdeContext}; + + use crate::reader::StreamReader; + use crate::writer::StreamWriter; + + #[test] + fn test_read_write() { + let array = PrimitiveArray::from(vec![0, 1, 2]).into_array(); + let chunked_array = + ChunkedArray::try_new(vec![array.clone(), array.clone()], array.dtype().clone()) + .unwrap() + .into_array(); + + let mut buffer = vec![]; + let mut cursor = Cursor::new(&mut buffer); + { + let mut writer = StreamWriter::try_new(&mut cursor, SerdeContext::default()).unwrap(); + writer.write_array(&array).unwrap(); + writer.write_array(&chunked_array).unwrap(); + } + // Push some extra bytes to test that the reader is well-behaved and doesn't read past the + // end of the stream. + let _ = cursor.write(b"hello").unwrap(); + + cursor.set_position(0); + { + let mut reader = StreamReader::try_new_unbuffered(&mut cursor).unwrap(); + let first = reader.read_array().unwrap(); + assert_eq!(first.encoding().id(), Primitive::ID); + let second = reader.read_array().unwrap(); + assert_eq!(second.encoding().id(), Chunked::ID); + } + let _pos = cursor.position(); + // Test our termination bytes exist + let mut terminator = [0u8; 5]; + cursor.read_exact(&mut terminator).unwrap(); + assert_eq!(&terminator, b"hello"); + } +} diff --git a/vortex-ipc/src/writer.rs b/vortex-ipc/src/writer.rs index e8ddeb7229..a12aec61c6 100644 --- a/vortex-ipc/src/writer.rs +++ b/vortex-ipc/src/writer.rs @@ -74,3 +74,10 @@ impl StreamWriter { Ok(()) } } + +impl Drop for StreamWriter { + fn drop(&mut self) { + // Terminate the stream + let _ = self.write.write_all(&[u8::MAX; 4]); + } +}