diff --git a/bench-vortex/src/data_downloads.rs b/bench-vortex/src/data_downloads.rs index 2cdeea796d..be19d4dd43 100644 --- a/bench-vortex/src/data_downloads.rs +++ b/bench-vortex/src/data_downloads.rs @@ -15,7 +15,7 @@ use vortex::{Array, IntoArray}; use vortex_dtype::DType; use vortex_error::{VortexError, VortexResult}; use vortex_serde::io::TokioAdapter; -use vortex_serde::writer::ArrayWriter; +use vortex_serde::stream_writer::StreamArrayWriter; use crate::idempotent; use crate::reader::BATCH_SIZE; @@ -57,7 +57,7 @@ pub fn data_vortex_uncompressed(fname_out: &str, downloaded_data: PathBuf) -> Pa .unwrap() .block_on(async move { let write = tokio::fs::File::create(path).await.unwrap(); - ArrayWriter::new(TokioAdapter(write)) + StreamArrayWriter::new(TokioAdapter(write)) .write_array(array) .await .unwrap(); diff --git a/bench-vortex/src/reader.rs b/bench-vortex/src/reader.rs index bef82d4e57..c9af638039 100644 --- a/bench-vortex/src/reader.rs +++ b/bench-vortex/src/reader.rs @@ -33,8 +33,9 @@ use vortex_error::{vortex_err, VortexResult}; use vortex_sampling_compressor::SamplingCompressor; use vortex_serde::chunked_reader::ChunkedArrayReader; use vortex_serde::io::{ObjectStoreExt, TokioAdapter, VortexReadAt, VortexWrite}; -use vortex_serde::writer::ArrayWriter; -use vortex_serde::MessageReader; +use vortex_serde::stream_reader::StreamArrayReader; +use vortex_serde::stream_writer::StreamArrayWriter; +use vortex_serde::DTypeReader; use crate::{COMPRESSORS, CTX}; @@ -49,10 +50,12 @@ pub struct VortexFooter { pub async fn open_vortex(path: &Path) -> VortexResult { let file = tokio::fs::File::open(path).await.unwrap(); - let mut msgs = MessageReader::try_new(TokioAdapter(file)).await.unwrap(); - msgs.array_stream_from_messages(CTX.clone()) - .await - .unwrap() + let reader = StreamArrayReader::try_new(TokioAdapter(file), CTX.clone()) + .await? + .load_dtype() + .await?; + reader + .into_array_stream() .collect_chunked() .await .map(|a| a.into_array()) @@ -64,7 +67,7 @@ pub async fn rewrite_parquet_as_vortex( ) -> VortexResult<()> { let chunked = compress_parquet_to_vortex(parquet_path.as_path())?; - let written = ArrayWriter::new(write) + let written = StreamArrayWriter::new(write) .write_array_stream(chunked.array_stream()) .await?; @@ -146,8 +149,7 @@ pub async fn read_vortex_footer_format( buf.reserve(header_len - buf.len()); unsafe { buf.set_len(header_len) } buf = reader.read_at_into(footer.dtype_range.start, buf).await?; - let mut header_reader = MessageReader::try_new(buf).await?; - let dtype = header_reader.read_dtype().await?; + let dtype = DTypeReader::new(buf).await?.read_dtype().await?; ChunkedArrayReader::try_new( reader, diff --git a/bench-vortex/src/tpch/mod.rs b/bench-vortex/src/tpch/mod.rs index b922aae959..021b6ea8a2 100644 --- a/bench-vortex/src/tpch/mod.rs +++ b/bench-vortex/src/tpch/mod.rs @@ -22,7 +22,7 @@ use vortex_datafusion::persistent::config::{VortexFile, VortexTableOptions}; use vortex_datafusion::SessionContextExt; use vortex_dtype::DType; use vortex_sampling_compressor::SamplingCompressor; -use vortex_serde::layouts::writer::LayoutWriter; +use vortex_serde::layouts::LayoutWriter; use crate::idempotent_async; diff --git a/vortex-array/src/array/chunked/compute/take.rs b/vortex-array/src/array/chunked/compute/take.rs index 09e4f01468..e23bbaff7c 100644 --- a/vortex-array/src/array/chunked/compute/take.rs +++ b/vortex-array/src/array/chunked/compute/take.rs @@ -4,11 +4,10 @@ use vortex_error::{vortex_err, VortexResult}; use vortex_scalar::Scalar; use crate::array::chunked::ChunkedArray; -use crate::array::primitive::PrimitiveArray; use crate::compute::unary::{scalar_at, subtract_scalar, try_cast}; use crate::compute::{search_sorted, slice, take, SearchSortedSide, TakeFn}; use crate::stats::ArrayStatistics; -use crate::{Array, ArrayDType, IntoArray, ToArray}; +use crate::{Array, ArrayDType, IntoArray, IntoArrayVariant, ToArray}; impl TakeFn for ChunkedArray { fn take(&self, indices: &Array) -> VortexResult { @@ -25,8 +24,7 @@ impl TakeFn for ChunkedArray { return take_strict_sorted(self, indices); } - // FIXME(ngates): this is wrong, need to canonicalise - let indices = PrimitiveArray::try_from(try_cast(indices, PType::U64.into())?)?; + let indices = try_cast(indices, PType::U64.into())?.into_primitive()?; // While the chunk idx remains the same, accumulate a list of chunk indices. let mut chunks = Vec::new(); @@ -79,8 +77,8 @@ fn take_strict_sorted(chunked: &ChunkedArray, indices: &Array) -> VortexResult VortexResult Option { - let chunk_start = usize::try_from(&scalar_at(&self.chunk_ends(), idx).ok()?).ok()?; - let chunk_end = usize::try_from(&scalar_at(&self.chunk_ends(), idx + 1).ok()?).ok()?; + let chunk_start = usize::try_from(&scalar_at(&self.chunk_offsets(), idx).ok()?).ok()?; + let chunk_end = usize::try_from(&scalar_at(&self.chunk_offsets(), idx + 1).ok()?).ok()?; // Offset the index since chunk_ends is child 0. self.array() @@ -83,7 +84,7 @@ impl ChunkedArray { } #[inline] - pub fn chunk_ends(&self) -> Array { + pub fn chunk_offsets(&self) -> Array { self.array() .child(0, &Self::ENDS_DTYPE, self.nchunks() + 1) .expect("missing chunk ends") @@ -92,7 +93,7 @@ impl ChunkedArray { pub fn find_chunk_idx(&self, index: usize) -> (usize, usize) { assert!(index <= self.len(), "Index out of bounds of the array"); - let search_result = search_sorted(&self.chunk_ends(), index, SearchSortedSide::Left) + let search_result = search_sorted(&self.chunk_offsets(), index, SearchSortedSide::Left) .unwrap_or_else(|err| { panic!("Search sorted failed in find_chunk_idx: {}", err); }); @@ -106,7 +107,7 @@ impl ChunkedArray { } SearchResult::NotFound(i) => i - 1, }; - let chunk_start = &scalar_at(&self.chunk_ends(), index_chunk) + let chunk_start = &scalar_at(&self.chunk_offsets(), index_chunk) .and_then(|s| usize::try_from(&s)) .unwrap_or_else(|err| { panic!("Failed to find chunk start in find_chunk_idx: {}", err); @@ -154,7 +155,7 @@ impl FromIterator for ChunkedArray { impl AcceptArrayVisitor for ChunkedArray { fn accept(&self, visitor: &mut dyn ArrayVisitor) -> VortexResult<()> { - visitor.visit_child("chunk_ends", &self.chunk_ends())?; + visitor.visit_child("chunk_ends", &self.chunk_offsets())?; for (idx, chunk) in self.chunks().enumerate() { visitor.visit_child(format!("[{}]", idx).as_str(), &chunk)?; } diff --git a/vortex-array/src/array/struct_/mod.rs b/vortex-array/src/array/struct_/mod.rs index 9a453c12f3..d2e6e05d9c 100644 --- a/vortex-array/src/array/struct_/mod.rs +++ b/vortex-array/src/array/struct_/mod.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use vortex_dtype::field::Field; use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType}; use vortex_error::{vortex_bail, vortex_err, VortexResult}; @@ -96,16 +97,25 @@ impl StructArray { /// # Panics /// This function will panic an error if the projection references columns not within the /// schema boundaries. - pub fn project(&self, projection: &[usize]) -> VortexResult { + pub fn project(&self, projection: &[Field]) -> VortexResult { let mut children = Vec::with_capacity(projection.len()); let mut names = Vec::with_capacity(projection.len()); - for &column_idx in projection { + for field in projection.iter() { + let idx = match field { + Field::Name(n) => self + .names() + .iter() + .position(|name| name.as_ref() == n) + .ok_or_else(|| vortex_err!("Unknown field {n}"))?, + Field::Index(i) => *i, + }; + + names.push(self.names()[idx].clone()); children.push( - self.field(column_idx) - .ok_or(vortex_err!(OutOfBounds: column_idx, 0, self.dtypes().len()))?, + self.field(idx) + .ok_or_else(|| vortex_err!(OutOfBounds: idx, 0, self.dtypes().len()))?, ); - names.push(self.names()[column_idx].clone()); } StructArray::try_new( @@ -166,6 +176,7 @@ impl ArrayStatisticsCompute for StructArray {} #[cfg(test)] mod test { + use vortex_dtype::field::Field; use vortex_dtype::{DType, FieldName, FieldNames, Nullability}; use crate::array::primitive::PrimitiveArray; @@ -193,7 +204,9 @@ mod test { ) .unwrap(); - let struct_b = struct_a.project(&[2usize, 0]).unwrap(); + let struct_b = struct_a + .project(&[Field::from(2usize), Field::from(0)]) + .unwrap(); assert_eq!( struct_b.names().as_ref(), [FieldName::from("zs"), FieldName::from("xs")], diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index d53265849d..ebb62b44ef 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -29,7 +29,7 @@ impl Display for Operator { Operator::Lt => "<", Operator::Lte => "<=", }; - write!(f, "{display}") + Display::fmt(display, f) } } diff --git a/vortex-datafusion/examples/table_provider.rs b/vortex-datafusion/examples/table_provider.rs index 693a40f5c1..fbcb63c547 100644 --- a/vortex-datafusion/examples/table_provider.rs +++ b/vortex-datafusion/examples/table_provider.rs @@ -14,7 +14,7 @@ use vortex::validity::Validity; use vortex::{Context, IntoArray}; use vortex_datafusion::persistent::config::{VortexFile, VortexTableOptions}; use vortex_datafusion::persistent::provider::VortexFileTableProvider; -use vortex_serde::layouts::writer::LayoutWriter; +use vortex_serde::layouts::LayoutWriter; #[tokio::main] async fn main() -> anyhow::Result<()> { diff --git a/vortex-datafusion/src/lib.rs b/vortex-datafusion/src/lib.rs index 7c7142e87a..a2b9a154a2 100644 --- a/vortex-datafusion/src/lib.rs +++ b/vortex-datafusion/src/lib.rs @@ -22,6 +22,7 @@ use persistent::config::VortexTableOptions; use persistent::provider::VortexFileTableProvider; use vortex::array::ChunkedArray; use vortex::{Array, ArrayDType, IntoArrayVariant}; +use vortex_dtype::field::Field; use vortex_error::vortex_err; pub mod memory; @@ -190,9 +191,8 @@ impl Debug for VortexScanExec { } impl DisplayAs for VortexScanExec { - #[allow(clippy::use_debug)] fn fmt_as(&self, _display_type: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "{:?}", self) + Debug::fmt(self, f) } } @@ -203,7 +203,7 @@ pub(crate) struct VortexRecordBatchStream { num_chunks: usize, chunks: ChunkedArray, - projection: Vec, + projection: Vec, } impl Stream for VortexRecordBatchStream { @@ -227,12 +227,11 @@ impl Stream for VortexRecordBatchStream { .into_struct() .map_err(|vortex_error| DataFusionError::Execution(format!("{}", vortex_error)))?; - let projected_struct = - struct_array - .project(this.projection.as_slice()) - .map_err(|vortex_err| { - exec_datafusion_err!("projection pushdown to Vortex failed: {vortex_err}") - })?; + let projected_struct = struct_array + .project(&this.projection) + .map_err(|vortex_err| { + exec_datafusion_err!("projection pushdown to Vortex failed: {vortex_err}") + })?; Poll::Ready(Some(Ok(projected_struct.into()))) } @@ -284,7 +283,12 @@ impl ExecutionPlan for VortexScanExec { idx: 0, num_chunks: self.array.nchunks(), chunks: self.array.clone(), - projection: self.scan_projection.clone(), + projection: self + .scan_projection + .iter() + .copied() + .map(Field::from) + .collect(), })) } } diff --git a/vortex-datafusion/src/memory.rs b/vortex-datafusion/src/memory.rs index c0d4116c21..7520568b42 100644 --- a/vortex-datafusion/src/memory.rs +++ b/vortex-datafusion/src/memory.rs @@ -9,12 +9,13 @@ use datafusion::prelude::*; use datafusion_common::{Result as DFResult, ToDFSchema}; use datafusion_expr::utils::conjunction; use datafusion_expr::{TableProviderFilterPushDown, TableType}; -use datafusion_physical_expr::{create_physical_expr, EquivalenceProperties, PhysicalExpr}; +use datafusion_physical_expr::{create_physical_expr, EquivalenceProperties}; use datafusion_physical_plan::{ExecutionMode, ExecutionPlan, Partitioning, PlanProperties}; use itertools::Itertools; use vortex::array::ChunkedArray; use vortex::{Array, ArrayDType as _}; -use vortex_expr::datafusion::extract_columns_from_expr; +use vortex_expr::datafusion::convert_expr_to_vortex; +use vortex_expr::VortexExpr; use crate::datatype::infer_schema; use crate::plans::{RowSelectorExec, TakeRowsExec}; @@ -95,16 +96,11 @@ impl TableProvider for VortexMemTable { let df_schema = self.schema_ref.clone().to_dfschema()?; let filter_expr = create_physical_expr(&expr, &df_schema, state.execution_props())?; - - let filter_projection = - extract_columns_from_expr(Some(&filter_expr), self.schema_ref.clone())? - .into_iter() - .collect(); + let filter_expr = convert_expr_to_vortex(filter_expr)?; make_filter_then_take_plan( self.schema_ref.clone(), filter_expr, - filter_projection, self.array.clone(), output_projection.clone(), state, @@ -192,17 +188,12 @@ impl VortexMemTableOptions { /// columns. fn make_filter_then_take_plan( schema: SchemaRef, - filter_expr: Arc, - filter_projection: Vec, + filter_expr: Arc, chunked_array: ChunkedArray, output_projection: Vec, _session_state: &dyn Session, ) -> DFResult> { - let row_selector_op = Arc::new(RowSelectorExec::try_new( - filter_expr, - filter_projection, - &chunked_array, - )?); + let row_selector_op = Arc::new(RowSelectorExec::try_new(filter_expr, &chunked_array)?); Ok(Arc::new(TakeRowsExec::new( schema.clone(), diff --git a/vortex-datafusion/src/persistent/opener.rs b/vortex-datafusion/src/persistent/opener.rs index 4bc1a35639..19f200562c 100644 --- a/vortex-datafusion/src/persistent/opener.rs +++ b/vortex-datafusion/src/persistent/opener.rs @@ -6,15 +6,13 @@ use datafusion::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener use datafusion_common::Result as DFResult; use datafusion_physical_expr::PhysicalExpr; use futures::{FutureExt as _, TryStreamExt}; -use itertools::Itertools; use object_store::ObjectStore; use vortex::Context; -use vortex_expr::datafusion::{convert_expr_to_vortex, extract_columns_from_expr}; +use vortex_expr::datafusion::convert_expr_to_vortex; use vortex_serde::io::ObjectStoreReadAt; -use vortex_serde::layouts::reader::builder::LayoutReaderBuilder; -use vortex_serde::layouts::reader::context::{LayoutContext, LayoutDeserializer}; -use vortex_serde::layouts::reader::filtering::RowFilter; -use vortex_serde::layouts::reader::projections::Projection; +use vortex_serde::layouts::{ + LayoutContext, LayoutDeserializer, LayoutReaderBuilder, Projection, RowFilter, +}; pub struct VortexFileOpener { pub ctx: Arc, @@ -39,9 +37,6 @@ impl FileOpener for VortexFileOpener { builder = builder.with_batch_size(batch_size); } - let predicate_projection = - extract_columns_from_expr(self.predicate.as_ref(), self.arrow_schema.clone())?; - if let Some(predicate) = self .predicate .clone() @@ -52,34 +47,17 @@ impl FileOpener for VortexFileOpener { } if let Some(projection) = self.projection.as_ref() { - let mut projection = projection.clone(); - for col_idx in predicate_projection.into_iter() { - if !projection.contains(&col_idx) { - projection.push(col_idx); - } - } - builder = builder.with_projection(Projection::new(projection)) } - let original_projection_len = self.projection.as_ref().map(|v| v.len()); - - Ok(async move { - let reader = builder.build().await?; - - let stream = reader - .and_then(move |array| async move { - let rb = RecordBatch::from(array); - - // If we had a projection, we cut the record batch down to the desired columns - if let Some(len) = original_projection_len { - Ok(rb.project(&(0..len).collect_vec())?) - } else { - Ok(rb) - } - }) - .map_err(|e| e.into()); - Ok(Box::pin(stream) as _) + Ok(async { + Ok(Box::pin( + builder + .build() + .await? + .map_ok(RecordBatch::from) + .map_err(|e| e.into()), + ) as _) } .boxed()) } diff --git a/vortex-datafusion/src/plans.rs b/vortex-datafusion/src/plans.rs index 89dae336de..8ac6b3f852 100644 --- a/vortex-datafusion/src/plans.rs +++ b/vortex-datafusion/src/plans.rs @@ -9,10 +9,10 @@ use std::task::{Context, Poll}; use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions, UInt64Array}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result as DFResult}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, PlanProperties, }; @@ -23,8 +23,8 @@ use vortex::array::ChunkedArray; use vortex::arrow::FromArrowArray; use vortex::compute::take; use vortex::{Array, AsArray as _, IntoArray, IntoArrayVariant, IntoCanonical}; +use vortex_dtype::field::Field; use vortex_error::vortex_err; -use vortex_expr::datafusion::convert_expr_to_vortex; use vortex_expr::VortexExpr; /// Physical plan operator that applies a set of [filters][Expr] against the input, producing a @@ -32,7 +32,6 @@ use vortex_expr::VortexExpr; /// chunks but for different columns. pub(crate) struct RowSelectorExec { filter_expr: Arc, - filter_projection: Vec, /// cached PlanProperties object. We do not make use of this. cached_plan_props: PlanProperties, /// Full array. We only access partitions of this data. @@ -40,17 +39,17 @@ pub(crate) struct RowSelectorExec { } lazy_static! { - static ref ROW_SELECTOR_SCHEMA_REF: SchemaRef = Arc::new(Schema::new(vec![Field::new( - "row_idx", - DataType::UInt64, - false - )])); + static ref ROW_SELECTOR_SCHEMA_REF: SchemaRef = + Arc::new(Schema::new(vec![arrow_schema::Field::new( + "row_idx", + DataType::UInt64, + false + )])); } impl RowSelectorExec { pub(crate) fn try_new( - filter_expr: Arc, - filter_projection: Vec, + filter_expr: Arc, chunked_array: &ChunkedArray, ) -> DFResult { let cached_plan_props = PlanProperties::new( @@ -59,11 +58,8 @@ impl RowSelectorExec { ExecutionMode::Bounded, ); - let filter_expr = convert_expr_to_vortex(filter_expr)?; - Ok(Self { filter_expr, - filter_projection: filter_projection.clone(), chunked_array: chunked_array.clone(), cached_plan_props, }) @@ -130,7 +126,7 @@ impl ExecutionPlan for RowSelectorExec { Ok(Box::pin(RowIndicesStream { chunked_array: self.chunked_array.clone(), chunk_idx: 0, - filter_projection: self.filter_projection.clone(), + filter_projection: self.filter_expr.references().iter().cloned().collect(), conjunction_expr: self.filter_expr.clone(), })) } @@ -141,7 +137,7 @@ pub(crate) struct RowIndicesStream { chunked_array: ChunkedArray, chunk_idx: usize, conjunction_expr: Arc, - filter_projection: Vec, + filter_projection: Vec, } impl Stream for RowIndicesStream { @@ -166,7 +162,7 @@ impl Stream for RowIndicesStream { let vortex_struct = next_chunk .into_struct() .expect("chunks must be StructArray") - .project(this.filter_projection.as_slice()) + .project(&this.filter_projection) .expect("projection should succeed"); let selection = this @@ -203,7 +199,7 @@ pub(crate) struct TakeRowsExec { plan_properties: PlanProperties, // Array storing the indices used to take the plan nodes. - projection: Vec, + projection: Vec, // Input plan, a stream of indices on which we perform a take against the original dataset. input: Arc, @@ -230,7 +226,7 @@ impl TakeRowsExec { Self { plan_properties, - projection: projection.to_owned(), + projection: projection.iter().copied().map(Field::from).collect(), input: row_indices, output_schema: output_schema.clone(), table: table.clone(), @@ -309,7 +305,7 @@ pub(crate) struct TakeRowsStream { chunk_idx: usize, // Projection based on the schema here - output_projection: Vec, + output_projection: Vec, output_schema: SchemaRef, // The original Vortex array we're taking from @@ -402,6 +398,7 @@ mod test { use vortex::array::{BoolArray, ChunkedArray, PrimitiveArray, StructArray}; use vortex::validity::Validity; use vortex::{ArrayDType, IntoArray}; + use vortex_dtype::field::Field; use vortex_dtype::FieldName; use vortex_expr::datafusion::convert_expr_to_vortex; @@ -440,7 +437,7 @@ mod test { chunked_array: chunked_array.clone(), chunk_idx: 0, conjunction_expr: convert_expr_to_vortex(df_expr).unwrap(), - filter_projection: vec![0, 1], + filter_projection: vec![Field::from(0), Field::from(1)], }; let rows: Vec = futures::executor::block_on_stream(filtering_stream) diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 68aa9e95d8..c513604962 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -3,9 +3,10 @@ use std::hash::Hash; use std::sync::Arc; use itertools::Itertools; -use vortex_error::{vortex_bail, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexResult}; use DType::*; +use crate::field::Field; use crate::nullability::Nullability; use crate::{ExtDType, PType}; @@ -180,14 +181,22 @@ impl StructDType { &self.dtypes } - pub fn project(&self, indices: &[usize]) -> VortexResult { - let mut names = vec![]; - let mut dtypes = vec![]; - - for &idx in indices.iter() { - if idx > self.names.len() { - vortex_bail!("Projection column is out of bounds"); - } + pub fn project(&self, projection: &[Field]) -> VortexResult { + let mut names = Vec::with_capacity(projection.len()); + let mut dtypes = Vec::with_capacity(projection.len()); + + for field in projection.iter() { + let idx = match field { + Field::Name(n) => self + .find_name(n.as_ref()) + .ok_or_else(|| vortex_err!("Unknown field {n}"))?, + Field::Index(i) => { + if *i > self.names.len() { + vortex_bail!("Projection column is out of bounds"); + } + *i + } + }; names.push(self.names[idx].clone()); dtypes.push(self.dtypes[idx].clone()); diff --git a/vortex-dtype/src/field.rs b/vortex-dtype/src/field.rs index 80afc0c76e..af18979bd4 100644 --- a/vortex-dtype/src/field.rs +++ b/vortex-dtype/src/field.rs @@ -1,11 +1,13 @@ use core::fmt; use std::fmt::{Display, Formatter}; +use itertools::Itertools; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum Field { Name(String), - Index(i32), + Index(usize), } impl From<&str> for Field { @@ -14,8 +16,14 @@ impl From<&str> for Field { } } -impl From for Field { - fn from(value: i32) -> Self { +impl From for Field { + fn from(value: String) -> Self { + Field::Name(value) + } +} + +impl From for Field { + fn from(value: usize) -> Self { Field::Index(value) } } @@ -38,8 +46,8 @@ impl FieldPath { Self(vec![]) } - pub fn from_name(name: &str) -> Self { - Self(vec![Field::from(name)]) + pub fn from_name>(name: F) -> Self { + Self(vec![name.into()]) } pub fn path(&self) -> &[Field] { @@ -75,12 +83,6 @@ impl From> for FieldPath { impl Display for FieldPath { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let formatted = self - .0 - .iter() - .map(|fid| format!("{fid}")) - .collect::>() - .join("."); - write!(f, "{}", formatted) + Display::fmt(&self.0.iter().format("."), f) } } diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index 59a81d2909..424760c4e0 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -4,12 +4,14 @@ pub use dtype::*; pub use extension::*; pub use half; pub use nullability::*; +pub use project::*; pub use ptype::*; mod dtype; mod extension; pub mod field; mod nullability; +mod project; mod ptype; mod serde; diff --git a/vortex-dtype/src/project.rs b/vortex-dtype/src/project.rs new file mode 100644 index 0000000000..8f696b1a5f --- /dev/null +++ b/vortex-dtype/src/project.rs @@ -0,0 +1,61 @@ +use std::sync::Arc; + +use vortex_error::{vortex_err, VortexResult}; + +use crate::field::Field; +use crate::{flatbuffers as fb, DType, StructDType}; + +/// Convert name references in projection list into index references. +/// +/// This is mostly useful if you want to deduplicate multiple projections against serialized schema. +pub fn resolve_field_references<'a, 'b: 'a>( + fb: fb::Struct_<'b>, + projection: &'a [Field], +) -> impl Iterator> + 'a { + projection.iter().map(move |field| match field { + Field::Name(n) => { + let names = fb + .names() + .ok_or_else(|| vortex_err!("Missing field names"))?; + names + .iter() + .position(|name| name == n) + .ok_or_else(|| vortex_err!("Unknown field name {n}")) + } + Field::Index(i) => Ok(*i), + }) +} + +/// Deserialize flatbuffer schema selecting only columns defined by projection +pub fn deserialize_and_project(fb: fb::DType<'_>, projection: &[Field]) -> VortexResult { + let fb_struct = fb + .type__as_struct_() + .ok_or_else(|| vortex_err!("The top-level type should be a struct"))?; + let nullability = fb_struct.nullable().into(); + + let (names, dtypes): (Vec>, Vec) = + resolve_field_references(fb_struct, projection) + .map(|idx| idx.and_then(|i| read_field(fb_struct, i))) + .collect::>>()? + .into_iter() + .unzip(); + + Ok(DType::Struct( + StructDType::new(names.into(), dtypes), + nullability, + )) +} + +fn read_field(fb_struct: fb::Struct_, idx: usize) -> VortexResult<(Arc, DType)> { + let name = fb_struct + .names() + .ok_or_else(|| vortex_err!("Missing field names"))? + .get(idx); + let fb_dtype = fb_struct + .dtypes() + .ok_or_else(|| vortex_err!("Missing field dtypes"))? + .get(idx); + let dtype = DType::try_from(fb_dtype)?; + + Ok((name.into(), dtype)) +} diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs index 839235474c..6ffa8ae89c 100644 --- a/vortex-dtype/src/serde/proto.rs +++ b/vortex-dtype/src/serde/proto.rs @@ -143,7 +143,7 @@ impl TryFrom<&pb::FieldPath> for FieldPath { .ok_or_else(|| vortex_err!(InvalidSerde: "FieldPath part missing type"))? { FieldType::Name(name) => path.push(Field::from(name.as_str())), - FieldType::Index(idx) => path.push(Field::from(*idx)), + FieldType::Index(idx) => path.push(Field::from(*idx as usize)), } } Ok(FieldPath::from(path)) diff --git a/vortex-expr/src/datafusion.rs b/vortex-expr/src/datafusion.rs index 5bd9d915fc..764a9b6404 100644 --- a/vortex-expr/src/datafusion.rs +++ b/vortex-expr/src/datafusion.rs @@ -1,11 +1,7 @@ #![cfg(feature = "datafusion")] -use std::collections::HashSet; use std::sync::Arc; -use datafusion_common::arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{DataFusionError, Result as DFResult}; use datafusion_expr::Operator as DFOperator; use datafusion_physical_expr::PhysicalExpr; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; @@ -56,38 +52,6 @@ pub fn convert_expr_to_vortex( vortex_bail!("Couldn't convert DataFusion physical expression to a vortex expression") } -/// Extract all indexes of all columns referenced by the physical expressions from the schema -pub fn extract_columns_from_expr( - expr: Option<&Arc>, - schema_ref: SchemaRef, -) -> DFResult> { - let mut predicate_projection = HashSet::new(); - - if let Some(expr) = expr { - expr.apply(|expr| { - if let Some(column) = expr - .as_any() - .downcast_ref::() - { - match schema_ref.column_with_name(column.name()) { - Some(_) => { - predicate_projection.insert(column.index()); - } - None => { - return Err(DataFusionError::External( - format!("Could not find expected column {} in schema", column.name()) - .into(), - )) - } - } - } - Ok(TreeNodeRecursion::Continue) - })?; - } - - Ok(predicate_projection) -} - impl TryFrom for Operator { type Error = VortexError; diff --git a/vortex-expr/src/display.rs b/vortex-expr/src/display.rs deleted file mode 100644 index d4883a7c5b..0000000000 --- a/vortex-expr/src/display.rs +++ /dev/null @@ -1,19 +0,0 @@ -use core::fmt; -use std::fmt::{Display, Formatter}; - -use crate::expressions::{Predicate, Value}; - -impl Display for Predicate { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "({} {} {})", self.lhs, self.op, self.rhs) - } -} - -impl Display for Value { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match self { - Value::Field(field_path) => Display::fmt(field_path, f), - Value::Literal(scalar) => Display::fmt(&scalar, f), - } - } -} diff --git a/vortex-expr/src/expr.rs b/vortex-expr/src/expr.rs index cfaf4f5bad..b80b4dd6b9 100644 --- a/vortex-expr/src/expr.rs +++ b/vortex-expr/src/expr.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::fmt::Debug; use std::sync::Arc; @@ -5,6 +6,7 @@ use vortex::array::{ConstantArray, StructArray}; use vortex::compute::{compare, Operator as ArrayOperator}; use vortex::variants::StructArrayTrait; use vortex::{Array, IntoArray}; +use vortex_dtype::field::Field; use vortex_error::{vortex_bail, vortex_err, VortexResult}; use vortex_scalar::Scalar; @@ -12,6 +14,8 @@ use crate::Operator; pub trait VortexExpr: Debug + Send + Sync { fn evaluate(&self, array: &Array) -> VortexResult; + + fn references(&self) -> HashSet; } #[derive(Debug)] @@ -36,12 +40,14 @@ impl BinaryExpr { #[derive(Debug)] pub struct Column { - name: String, + field: Field, } impl Column { - pub fn new(name: String) -> Self { - Self { name } + pub fn new(field: String) -> Self { + Self { + field: Field::from(field), + } } } @@ -49,13 +55,17 @@ impl VortexExpr for Column { fn evaluate(&self, array: &Array) -> VortexResult { let s = StructArray::try_from(array)?; - let column = s.field_by_name(&self.name).ok_or(vortex_err!( - "Array doesn't contain child array of name {}", - self.name - ))?; - + let column = match &self.field { + Field::Name(n) => s.field_by_name(n), + Field::Index(i) => s.field(*i), + } + .ok_or_else(|| vortex_err!("Array doesn't contain child array {}", self.field))?; Ok(column) } + + fn references(&self) -> HashSet { + HashSet::from([self.field.clone()]) + } } #[derive(Debug)] @@ -73,6 +83,10 @@ impl VortexExpr for Literal { fn evaluate(&self, array: &Array) -> VortexResult { Ok(ConstantArray::new(self.value.clone(), array.len()).into_array()) } + + fn references(&self) -> HashSet { + HashSet::new() + } } impl VortexExpr for BinaryExpr { @@ -93,10 +107,20 @@ impl VortexExpr for BinaryExpr { Ok(array) } + + fn references(&self) -> HashSet { + let mut res = self.left.references(); + res.extend(self.right.references()); + res + } } impl VortexExpr for NoOp { fn evaluate(&self, _array: &Array) -> VortexResult { vortex_bail!("NoOp::evaluate() should not be called") } + + fn references(&self) -> HashSet { + HashSet::new() + } } diff --git a/vortex-expr/src/expressions.rs b/vortex-expr/src/expressions.rs deleted file mode 100644 index ae0dcd26c8..0000000000 --- a/vortex-expr/src/expressions.rs +++ /dev/null @@ -1,195 +0,0 @@ -use core::fmt; -use std::fmt::{Display, Formatter}; - -use vortex_dtype::field::FieldPath; -use vortex_scalar::Scalar; - -use crate::operators::Operator; - -#[derive(Clone, Debug, Default, PartialEq)] -#[cfg_attr( - feature = "serde", - derive(serde::Serialize, serde::Deserialize), - serde(transparent) -)] -pub struct Disjunction { - conjunctions: Vec, -} - -impl Disjunction { - pub fn iter(&self) -> impl Iterator { - self.conjunctions.iter() - } -} - -impl From for Disjunction { - fn from(value: Conjunction) -> Self { - Self { - conjunctions: vec![value], - } - } -} - -impl FromIterator for Disjunction { - fn from_iter>(iter: T) -> Self { - Self { - conjunctions: iter - .into_iter() - .map(|predicate| Conjunction::from_iter([predicate])) - .collect(), - } - } -} - -impl FromIterator for Disjunction { - fn from_iter>(iter: T) -> Self { - Self { - conjunctions: iter.into_iter().collect(), - } - } -} - -impl Display for Disjunction { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.conjunctions - .iter() - .map(|v| format!("{}", v)) - .intersperse("\nOR \n".to_string()) - .try_for_each(|s| write!(f, "{}", s)) - } -} - -#[derive(Clone, Debug, Default, PartialEq)] -#[cfg_attr( - feature = "serde", - derive(serde::Serialize, serde::Deserialize), - serde(transparent) -)] -pub struct Conjunction { - predicates: Vec, -} - -impl Conjunction { - pub fn iter(&self) -> impl Iterator { - self.predicates.iter() - } -} - -impl From for Conjunction { - fn from(value: Predicate) -> Self { - Self { - predicates: vec![value], - } - } -} - -impl FromIterator for Conjunction { - fn from_iter>(iter: T) -> Self { - Self { - predicates: iter.into_iter().collect(), - } - } -} - -impl Display for Conjunction { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.predicates - .iter() - .map(|v| format!("{}", v)) - .intersperse(" AND ".to_string()) - .try_for_each(|s| write!(f, "{}", s)) - } -} - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum Value { - /// A named reference to a qualified field in a dtype. - Field(FieldPath), - /// A constant scalar value. - Literal(Scalar), -} - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Predicate { - pub lhs: FieldPath, - pub op: Operator, - pub rhs: Value, -} - -pub fn lit>(n: T) -> Value { - Value::Literal(n.into()) -} - -impl Value { - // NB: We rewrite predicates to be Field-op-predicate, so these methods all must - // use the inverse operator. - pub fn equals(self, field: impl Into) -> Predicate { - Predicate { - lhs: field.into(), - op: Operator::Eq, - rhs: self, - } - } - - pub fn not_equals(self, field: impl Into) -> Predicate { - Predicate { - lhs: field.into(), - op: Operator::NotEq, - rhs: self, - } - } - - pub fn gt(self, field: impl Into) -> Predicate { - Predicate { - lhs: field.into(), - op: Operator::Gt.inverse().unwrap(), - rhs: self, - } - } - - pub fn gte(self, field: impl Into) -> Predicate { - Predicate { - lhs: field.into(), - op: Operator::Gte.inverse().unwrap(), - rhs: self, - } - } - - pub fn lt(self, field: impl Into) -> Predicate { - Predicate { - lhs: field.into(), - op: Operator::Lt.inverse().unwrap(), - rhs: self, - } - } - - pub fn lte(self, field: impl Into) -> Predicate { - Predicate { - lhs: field.into(), - op: Operator::Lte.inverse().unwrap(), - rhs: self, - } - } -} - -#[cfg(test)] -mod test { - use vortex_dtype::field::Field; - - use super::*; - - #[test] - fn test_lit() { - let scalar: Scalar = 1.into(); - let value: Value = lit(scalar); - let field = Field::from("id"); - let expr = Predicate { - lhs: FieldPath::from_iter([field]), - op: Operator::Eq, - rhs: value, - }; - assert_eq!(format!("{}", expr), "($id = 1)"); - } -} diff --git a/vortex-expr/src/field_paths.rs b/vortex-expr/src/field_paths.rs deleted file mode 100644 index a21267b02d..0000000000 --- a/vortex-expr/src/field_paths.rs +++ /dev/null @@ -1,64 +0,0 @@ -use vortex_dtype::field::FieldPath; - -use crate::expressions::{Predicate, Value}; -use crate::operators::Operator; - -pub trait FieldPathOperations { - fn equal(&self, other: Value) -> Predicate; - fn not_equal(&self, other: Value) -> Predicate; - fn gt(&self, other: Value) -> Predicate; - fn gte(&self, other: Value) -> Predicate; - fn lt(&self, other: Value) -> Predicate; - fn lte(&self, other: Value) -> Predicate; -} - -impl FieldPathOperations for FieldPath { - // comparisons - fn equal(&self, other: Value) -> Predicate { - Predicate { - lhs: self.clone(), - op: Operator::Eq, - rhs: other, - } - } - - fn not_equal(&self, other: Value) -> Predicate { - Predicate { - lhs: self.clone(), - op: Operator::NotEq, - rhs: other, - } - } - - fn gt(&self, other: Value) -> Predicate { - Predicate { - lhs: self.clone(), - op: Operator::Gt, - rhs: other, - } - } - - fn gte(&self, other: Value) -> Predicate { - Predicate { - lhs: self.clone(), - op: Operator::Gte, - rhs: other, - } - } - - fn lt(&self, other: Value) -> Predicate { - Predicate { - lhs: self.clone(), - op: Operator::Lt, - rhs: other, - } - } - - fn lte(&self, other: Value) -> Predicate { - Predicate { - lhs: self.clone(), - op: Operator::Lte, - rhs: other, - } - } -} diff --git a/vortex-expr/src/lib.rs b/vortex-expr/src/lib.rs index 4e158430bd..2b2174f83a 100644 --- a/vortex-expr/src/lib.rs +++ b/vortex-expr/src/lib.rs @@ -1,16 +1,8 @@ #![feature(iter_intersperse)] pub mod datafusion; -mod display; mod expr; -mod expressions; -mod field_paths; mod operators; -#[cfg(all(feature = "proto", feature = "serde"))] -mod serde_proto; - pub use expr::*; -pub use expressions::*; -pub use field_paths::*; pub use operators::*; diff --git a/vortex-expr/src/serde_proto.rs b/vortex-expr/src/serde_proto.rs deleted file mode 100644 index a2bdc9fe22..0000000000 --- a/vortex-expr/src/serde_proto.rs +++ /dev/null @@ -1,48 +0,0 @@ -#![cfg(feature = "proto")] - -use vortex_error::{vortex_bail, vortex_err, VortexError}; -use vortex_proto::expr as pb; -use vortex_proto::expr::predicate::Rhs; - -use crate::{Operator, Predicate, Value}; - -impl TryFrom<&pb::Predicate> for Predicate { - type Error = VortexError; - - fn try_from(value: &pb::Predicate) -> Result { - Ok(Predicate { - lhs: value - .lhs - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Lhs is missing"))? - .try_into()?, - op: value.op().try_into()?, - rhs: match value - .rhs - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Rhs is missing"))? - { - Rhs::Field(f) => Value::Field(f.try_into()?), - Rhs::Scalar(scalar) => Value::Literal(scalar.try_into()?), - }, - }) - } -} - -impl TryFrom for Operator { - type Error = VortexError; - - fn try_from(value: pb::Operator) -> Result { - match value { - pb::Operator::Unknown => { - vortex_bail!(InvalidSerde: "Unknown operator {}", value.as_str_name()) - } - pb::Operator::Eq => Ok(Self::Eq), - pb::Operator::Neq => Ok(Self::NotEq), - pb::Operator::Lt => Ok(Self::Lt), - pb::Operator::Lte => Ok(Self::Lte), - pb::Operator::Gt => Ok(Self::Gt), - pb::Operator::Gte => Ok(Self::Gte), - } - } -} diff --git a/vortex-flatbuffers/src/generated/array.rs b/vortex-flatbuffers/src/generated/array.rs index 857cc2aac9..7ba6ba97c7 100644 --- a/vortex-flatbuffers/src/generated/array.rs +++ b/vortex-flatbuffers/src/generated/array.rs @@ -54,7 +54,7 @@ impl<'a> flatbuffers::Follow<'a> for Version { type Inner = Self; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { - let b = unsafe { flatbuffers::read_scalar_at::(buf, loc) }; + let b = flatbuffers::read_scalar_at::(buf, loc); Self(b) } } @@ -63,7 +63,7 @@ impl flatbuffers::Push for Version { type Output = Version; #[inline] unsafe fn push(&self, dst: &mut [u8], _written_len: usize) { - unsafe { flatbuffers::emplace_scalar::(dst, self.0) }; + flatbuffers::emplace_scalar::(dst, self.0); } } @@ -103,7 +103,7 @@ impl<'a> flatbuffers::Follow<'a> for Array<'a> { type Inner = Array<'a>; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { - Self { _tab: unsafe { flatbuffers::Table::new(buf, loc) } } + Self { _tab: flatbuffers::Table::new(buf, loc) } } } @@ -285,7 +285,7 @@ impl<'a> flatbuffers::Follow<'a> for ArrayStats<'a> { type Inner = ArrayStats<'a>; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { - Self { _tab: unsafe { flatbuffers::Table::new(buf, loc) } } + Self { _tab: flatbuffers::Table::new(buf, loc) } } } @@ -575,17 +575,15 @@ pub fn size_prefixed_root_as_array_with_opts<'b, 'o>( /// # Safety /// Callers must trust the given bytes do indeed contain a valid `Array`. pub unsafe fn root_as_array_unchecked(buf: &[u8]) -> Array { - unsafe { flatbuffers::root_unchecked::(buf) } + flatbuffers::root_unchecked::(buf) } - #[inline] /// Assumes, without verification, that a buffer of bytes contains a size prefixed Array and returns it. /// # Safety /// Callers must trust the given bytes do indeed contain a valid size prefixed `Array`. pub unsafe fn size_prefixed_root_as_array_unchecked(buf: &[u8]) -> Array { - unsafe { flatbuffers::size_prefixed_root_unchecked::(buf) } + flatbuffers::size_prefixed_root_unchecked::(buf) } - #[inline] pub fn finish_array_buffer<'a, 'b, A: flatbuffers::Allocator + 'a>( fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, diff --git a/vortex-flatbuffers/src/lib.rs b/vortex-flatbuffers/src/lib.rs index 53f571a098..9f16b5fa5a 100644 --- a/vortex-flatbuffers/src/lib.rs +++ b/vortex-flatbuffers/src/lib.rs @@ -58,9 +58,6 @@ pub mod footer; #[path = "./generated/message.rs"] pub mod message; -use std::io; -use std::io::Write; - use flatbuffers::{root, FlatBufferBuilder, Follow, InvalidFlatbuffer, Verifiable, WIPOffset}; pub trait FlatBufferRoot {} @@ -103,34 +100,3 @@ impl FlatBufferToBytes for F { f(fbb.finished_data()) } } - -pub trait FlatBufferWriter { - // Write the given FlatBuffer message, appending padding until the total bytes written - // are a multiple of `alignment`. - fn write_message( - &mut self, - msg: &F, - alignment: usize, - ) -> io::Result<()>; -} - -impl FlatBufferWriter for W { - fn write_message( - &mut self, - msg: &F, - alignment: usize, - ) -> io::Result<()> { - let mut fbb = FlatBufferBuilder::new(); - let root = msg.write_flatbuffer(&mut fbb); - fbb.finish_minimal(root); - let fb_data = fbb.finished_data(); - let fb_size = fb_data.len(); - - let aligned_size = (fb_size + (alignment - 1)) & !(alignment - 1); - let padding_bytes = aligned_size - fb_size; - - self.write_all(&(aligned_size as u32).to_le_bytes())?; - self.write_all(fb_data)?; - self.write_all(&vec![0; padding_bytes]) - } -} diff --git a/vortex-proto/proto/dtype.proto b/vortex-proto/proto/dtype.proto index 1b3fab7dda..d4fe588428 100644 --- a/vortex-proto/proto/dtype.proto +++ b/vortex-proto/proto/dtype.proto @@ -75,7 +75,7 @@ message DType { message Field { oneof field_type { string name = 1; - int32 index = 2; + uint64 index = 2; } } diff --git a/vortex-proto/proto/expr.proto b/vortex-proto/proto/expr.proto deleted file mode 100644 index de5ec9ebd0..0000000000 --- a/vortex-proto/proto/expr.proto +++ /dev/null @@ -1,33 +0,0 @@ -syntax = "proto3"; - -package vortex.expr; - -import "dtype.proto"; -import "scalar.proto"; - -message Disjunction { - repeated Conjunction conjunctions = 1; -} - -message Conjunction { - repeated Predicate predicates = 1; -} - -message Predicate { - vortex.dtype.FieldPath lhs = 1; - Operator op = 2; - oneof rhs { - vortex.dtype.FieldPath field = 3; - vortex.scalar.Scalar scalar = 4; - } -} - -enum Operator { - UNKNOWN = 0; - EQ = 1; - NEQ = 2; - LT = 3; - LTE = 4; - GT = 5; - GTE = 6; -} diff --git a/vortex-proto/src/generated/vortex.dtype.rs b/vortex-proto/src/generated/vortex.dtype.rs index ff74a77624..9a3a956007 100644 --- a/vortex-proto/src/generated/vortex.dtype.rs +++ b/vortex-proto/src/generated/vortex.dtype.rs @@ -110,8 +110,8 @@ pub mod field { pub enum FieldType { #[prost(string, tag = "1")] Name(::prost::alloc::string::String), - #[prost(int32, tag = "2")] - Index(i32), + #[prost(uint64, tag = "2")] + Index(u64), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/vortex-serde/benches/ipc_array_reader_take.rs b/vortex-serde/benches/ipc_array_reader_take.rs index 98c3594807..7a649e72b0 100644 --- a/vortex-serde/benches/ipc_array_reader_take.rs +++ b/vortex-serde/benches/ipc_array_reader_take.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use criterion::async_executor::FuturesExecutor; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use futures_executor::block_on; +use futures_util::io::Cursor; use futures_util::{pin_mut, TryStreamExt}; use itertools::Itertools; use vortex::array::{ChunkedArray, PrimitiveArray}; @@ -10,8 +11,8 @@ use vortex::stream::ArrayStreamExt; use vortex::validity::Validity; use vortex::{Context, IntoArray}; use vortex_serde::io::FuturesAdapter; -use vortex_serde::writer::ArrayWriter; -use vortex_serde::MessageReader; +use vortex_serde::stream_reader::StreamArrayReader; +use vortex_serde::stream_writer::StreamArrayWriter; // 100 record batches, 100k rows each // take from the first 20 batches and last batch @@ -33,21 +34,22 @@ fn ipc_array_reader_take(c: &mut Criterion) { ) .into_array(); - let buffer = block_on(async { ArrayWriter::new(vec![]).write_array(array).await }) + let buffer = block_on(async { StreamArrayWriter::new(vec![]).write_array(array).await }) .unwrap() .into_inner(); let indices = indices.clone().into_array(); b.to_async(FuturesExecutor).iter(|| async { - let mut cursor = futures_util::io::Cursor::new(&buffer); - let mut msgs = MessageReader::try_new(FuturesAdapter(&mut cursor)) - .await - .unwrap(); - let stream = msgs - .array_stream_from_messages(ctx.clone()) - .await - .unwrap() + let stream_reader = + StreamArrayReader::try_new(FuturesAdapter(Cursor::new(&buffer)), ctx.clone()) + .await + .unwrap() + .load_dtype() + .await + .unwrap(); + let stream = stream_reader + .into_array_stream() .take_rows(indices.clone()) .unwrap(); pin_mut!(stream); diff --git a/vortex-serde/benches/ipc_take.rs b/vortex-serde/benches/ipc_take.rs index fc107f3946..37aaabfacd 100644 --- a/vortex-serde/benches/ipc_take.rs +++ b/vortex-serde/benches/ipc_take.rs @@ -17,8 +17,8 @@ use vortex::compute::take; use vortex::{Context, IntoArray}; use vortex_sampling_compressor::SamplingCompressor; use vortex_serde::io::FuturesAdapter; -use vortex_serde::writer::ArrayWriter; -use vortex_serde::MessageReader; +use vortex_serde::stream_reader::StreamArrayReader; +use vortex_serde::stream_writer::StreamArrayWriter; fn ipc_take(c: &mut Criterion) { let mut group = c.benchmark_group("ipc_take"); @@ -61,17 +61,22 @@ fn ipc_take(c: &mut Criterion) { let compressed = compressor.compress(&uncompressed).unwrap(); // Try running take over an ArrayView. - let buffer = block_on(async { ArrayWriter::new(vec![]).write_array(compressed).await }) - .unwrap() - .into_inner(); + let buffer = + block_on(async { StreamArrayWriter::new(vec![]).write_array(compressed).await }) + .unwrap() + .into_inner(); let ctx_ref = &Arc::new(ctx); let ro_buffer = buffer.as_slice(); let indices_ref = &indices; b.to_async(FuturesExecutor).iter(|| async move { - let mut msgs = MessageReader::try_new(FuturesAdapter(Cursor::new(ro_buffer))).await?; - let reader = msgs.array_stream_from_messages(ctx_ref.clone()).await?; + let stream_reader = + StreamArrayReader::try_new(FuturesAdapter(Cursor::new(ro_buffer)), ctx_ref.clone()) + .await? + .load_dtype() + .await?; + let reader = stream_reader.into_array_stream(); pin_mut!(reader); let array_view = reader.try_next().await?.unwrap(); black_box(take(&array_view, indices_ref)) diff --git a/vortex-serde/src/chunked_reader/take_rows.rs b/vortex-serde/src/chunked_reader/take_rows.rs index 1f83ccfce2..db911eba61 100644 --- a/vortex-serde/src/chunked_reader/take_rows.rs +++ b/vortex-serde/src/chunked_reader/take_rows.rs @@ -215,17 +215,17 @@ mod test { use vortex_error::VortexResult; use crate::chunked_reader::ChunkedArrayReader; - use crate::writer::ArrayWriter; + use crate::stream_writer::StreamArrayWriter; use crate::MessageReader; - fn chunked_array() -> VortexResult>> { + fn chunked_array() -> VortexResult>> { let c = ChunkedArray::try_new( vec![PrimitiveArray::from((0i32..1000).collect_vec()).into_array(); 10], PType::I32.into(), )? .into_array(); - block_on(async { ArrayWriter::new(vec![]).write_array(c).await }) + block_on(async { StreamArrayWriter::new(vec![]).write_array(c).await }) } #[test] diff --git a/vortex-serde/src/dtype_reader.rs b/vortex-serde/src/dtype_reader.rs new file mode 100644 index 0000000000..37ff0e45c7 --- /dev/null +++ b/vortex-serde/src/dtype_reader.rs @@ -0,0 +1,29 @@ +use vortex_dtype::DType; +use vortex_error::VortexResult; + +use crate::io::VortexRead; +use crate::message_reader::MessageReader; + +/// Reader for serialized dtype messages +pub struct DTypeReader { + msgs: MessageReader, +} + +impl DTypeReader { + /// Create new ['DTypeReader'] given readable contents + pub async fn new(read: R) -> VortexResult { + Ok(Self { + msgs: MessageReader::try_new(read).await?, + }) + } + + /// Deserialize dtype out of ipc serialized format + pub async fn read_dtype(&mut self) -> VortexResult { + self.msgs.read_dtype().await + } + + /// Deconstruct this reader into its underlying contents for further reuse + pub fn into_inner(self) -> R { + self.msgs.into_inner() + } +} diff --git a/vortex-serde/src/layouts/mod.rs b/vortex-serde/src/layouts/mod.rs index 9743e25b47..03d6b856f7 100644 --- a/vortex-serde/src/layouts/mod.rs +++ b/vortex-serde/src/layouts/mod.rs @@ -1,7 +1,10 @@ -pub mod reader; -pub mod writer; +mod read; +mod write; #[cfg(test)] mod tests; pub const MAGIC_BYTES: [u8; 4] = *b"VRX1"; + +pub use read::*; +pub use write::*; diff --git a/vortex-serde/src/layouts/reader/batch.rs b/vortex-serde/src/layouts/read/batch.rs similarity index 97% rename from vortex-serde/src/layouts/reader/batch.rs rename to vortex-serde/src/layouts/read/batch.rs index 4ebec12f0d..21bdf0f93f 100644 --- a/vortex-serde/src/layouts/reader/batch.rs +++ b/vortex-serde/src/layouts/read/batch.rs @@ -5,7 +5,7 @@ use vortex::array::StructArray; use vortex::{Array, IntoArray}; use vortex_error::{vortex_err, VortexResult}; -use crate::layouts::reader::{Layout, ReadResult}; +use crate::layouts::read::{Layout, ReadResult}; #[derive(Debug)] pub struct BatchReader { diff --git a/vortex-serde/src/layouts/reader/buffered.rs b/vortex-serde/src/layouts/read/buffered.rs similarity index 98% rename from vortex-serde/src/layouts/reader/buffered.rs rename to vortex-serde/src/layouts/read/buffered.rs index 883d4673aa..2f72ec74cd 100644 --- a/vortex-serde/src/layouts/reader/buffered.rs +++ b/vortex-serde/src/layouts/read/buffered.rs @@ -5,7 +5,7 @@ use vortex::compute::slice; use vortex::{Array, ArrayDType, IntoArray}; use vortex_error::VortexResult; -use crate::layouts::reader::{Layout, ReadResult}; +use crate::layouts::read::{Layout, ReadResult}; #[derive(Debug)] pub struct BufferedReader { diff --git a/vortex-serde/src/layouts/reader/builder.rs b/vortex-serde/src/layouts/read/builder.rs similarity index 65% rename from vortex-serde/src/layouts/reader/builder.rs rename to vortex-serde/src/layouts/read/builder.rs index 84f40e7647..7e452c7ddc 100644 --- a/vortex-serde/src/layouts/reader/builder.rs +++ b/vortex-serde/src/layouts/read/builder.rs @@ -2,16 +2,17 @@ use std::sync::{Arc, RwLock}; use bytes::BytesMut; use vortex::{Array, ArrayDType}; +use vortex_dtype::field::Field; use vortex_error::{vortex_bail, VortexResult}; use crate::io::VortexReadAt; -use crate::layouts::reader::cache::{LayoutMessageCache, RelativeLayoutCache}; -use crate::layouts::reader::context::LayoutDeserializer; -use crate::layouts::reader::filtering::RowFilter; -use crate::layouts::reader::footer::Footer; -use crate::layouts::reader::projections::Projection; -use crate::layouts::reader::stream::VortexLayoutBatchStream; -use crate::layouts::reader::{Scan, DEFAULT_BATCH_SIZE, FILE_POSTSCRIPT_SIZE, INITIAL_READ_SIZE}; +use crate::layouts::read::cache::{LayoutMessageCache, RelativeLayoutCache}; +use crate::layouts::read::context::LayoutDeserializer; +use crate::layouts::read::filtering::RowFilter; +use crate::layouts::read::footer::Footer; +use crate::layouts::read::projections::Projection; +use crate::layouts::read::stream::LayoutBatchStream; +use crate::layouts::read::{Scan, DEFAULT_BATCH_SIZE, FILE_POSTSCRIPT_SIZE, INITIAL_READ_SIZE}; use crate::layouts::MAGIC_BYTES; pub struct LayoutReaderBuilder { @@ -67,18 +68,42 @@ impl LayoutReaderBuilder { self } - pub async fn build(mut self) -> VortexResult> { + pub async fn build(mut self) -> VortexResult> { let footer = self.read_footer().await?; - let projection = self.projection.unwrap_or_default(); + + // TODO(robert): Don't leak filter references into read projection + let (read_projection, result_projection) = if let Some(filter_columns) = self + .row_filter + .as_ref() + .map(|f| f.filter.references()) + .filter(|refs| !refs.is_empty()) + .map(|refs| footer.resolve_references(&refs.into_iter().collect::>())) + .transpose()? + { + match self.projection.unwrap_or_default() { + Projection::All => (Projection::All, Projection::All), + Projection::Flat(mut v) => { + let original_len = v.len(); + v.extend(filter_columns.into_iter()); + ( + Projection::Flat(v), + Projection::Flat((0..original_len).map(Field::from).collect()), + ) + } + } + } else { + (self.projection.unwrap_or_default(), Projection::All) + }; + let batch_size = self.batch_size.unwrap_or(DEFAULT_BATCH_SIZE); - let projected_dtype = match &projection { + let projected_dtype = match &read_projection { Projection::All => footer.dtype()?, - Projection::Partial(projection) => footer.projected_dtype(projection)?, + Projection::Flat(projection) => footer.projected_dtype(projection)?, }; let scan = Scan { - projection, + projection: read_projection, indices: self.indices, filter: self.row_filter, batch_size, @@ -90,7 +115,14 @@ impl LayoutReaderBuilder { let layout = footer.layout(scan.clone(), layouts_cache)?; - VortexLayoutBatchStream::try_new(self.reader, layout, message_cache, projected_dtype, scan) + LayoutBatchStream::try_new( + self.reader, + layout, + message_cache, + projected_dtype, + scan, + result_projection, + ) } async fn len(&self) -> usize { diff --git a/vortex-serde/src/layouts/reader/cache.rs b/vortex-serde/src/layouts/read/cache.rs similarity index 97% rename from vortex-serde/src/layouts/reader/cache.rs rename to vortex-serde/src/layouts/read/cache.rs index 562e62bbb4..c3dca0fbaf 100644 --- a/vortex-serde/src/layouts/reader/cache.rs +++ b/vortex-serde/src/layouts/read/cache.rs @@ -4,7 +4,7 @@ use ahash::HashMap; use bytes::Bytes; use vortex_dtype::DType; -use crate::layouts::reader::{LayoutPartId, MessageId}; +use crate::layouts::read::{LayoutPartId, MessageId}; #[derive(Default, Debug)] pub struct LayoutMessageCache { diff --git a/vortex-serde/src/layouts/reader/context.rs b/vortex-serde/src/layouts/read/context.rs similarity index 94% rename from vortex-serde/src/layouts/reader/context.rs rename to vortex-serde/src/layouts/read/context.rs index ef3e5a4342..e30ab69256 100644 --- a/vortex-serde/src/layouts/reader/context.rs +++ b/vortex-serde/src/layouts/read/context.rs @@ -8,9 +8,9 @@ use vortex_error::{vortex_err, VortexResult}; use vortex_flatbuffers::footer as fb; use vortex_flatbuffers::footer::LayoutVariant; -use crate::layouts::reader::cache::RelativeLayoutCache; -use crate::layouts::reader::layouts::{ChunkedLayoutSpec, ColumnLayoutSpec, FlatLayout}; -use crate::layouts::reader::{Layout, Scan}; +use crate::layouts::read::cache::RelativeLayoutCache; +use crate::layouts::read::layouts::{ChunkedLayoutSpec, ColumnLayoutSpec, FlatLayout}; +use crate::layouts::read::{Layout, Scan}; #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] pub struct LayoutId(pub u16); diff --git a/vortex-serde/src/layouts/reader/filtering.rs b/vortex-serde/src/layouts/read/filtering.rs similarity index 100% rename from vortex-serde/src/layouts/reader/filtering.rs rename to vortex-serde/src/layouts/read/filtering.rs diff --git a/vortex-serde/src/layouts/read/footer.rs b/vortex-serde/src/layouts/read/footer.rs new file mode 100644 index 0000000000..71d4e8600d --- /dev/null +++ b/vortex-serde/src/layouts/read/footer.rs @@ -0,0 +1,113 @@ +use bytes::Bytes; +use flatbuffers::root; +use vortex_dtype::field::Field; +use vortex_dtype::{deserialize_and_project, resolve_field_references, DType}; +use vortex_error::{vortex_err, VortexResult}; +use vortex_flatbuffers::{message as fb, ReadFlatBuffer}; + +use crate::layouts::read::cache::RelativeLayoutCache; +use crate::layouts::read::context::LayoutDeserializer; +use crate::layouts::read::{Layout, Scan, FILE_POSTSCRIPT_SIZE}; +use crate::messages::IPCDType; +use crate::FLATBUFFER_SIZE_LENGTH; + +/// Wrapper around serialized file footer. Provides handle on file schema and +/// layout metadata to read the contents. +/// +/// # Footer format +/// ┌────────────────────────────┐ +/// │ │ +/// ... +/// ├────────────────────────────┤ +/// │ │ +/// │ Schema │ +/// │ │ +/// ├────────────────────────────┤ +/// │ │ +/// │ Layouts │ +/// │ │ +/// ├────────────────────────────┤ +/// │ Schema Offset (8 bytes) │ +/// ├────────────────────────────┤ +/// │ Layout Offset (8 bytes) │ +/// ├────────────────────────────┤ +/// │ Magic bytes (4 bytes) │ +/// └────────────────────────────┘ +/// +pub struct Footer { + pub(crate) schema_offset: u64, + /// This is actually layouts + pub(crate) footer_offset: u64, + pub(crate) leftovers: Bytes, + pub(crate) leftovers_offset: u64, + pub(crate) layout_serde: LayoutDeserializer, +} + +impl Footer { + pub fn leftovers_footer_offset(&self) -> usize { + (self.footer_offset - self.leftovers_offset) as usize + } + + pub fn leftovers_schema_offset(&self) -> usize { + (self.schema_offset - self.leftovers_offset) as usize + } + + pub fn layout( + &self, + scan: Scan, + message_cache: RelativeLayoutCache, + ) -> VortexResult> { + let start_offset = self.leftovers_footer_offset(); + let end_offset = self.leftovers.len() - FILE_POSTSCRIPT_SIZE; + let footer_bytes = self + .leftovers + .slice(start_offset + FLATBUFFER_SIZE_LENGTH..end_offset); + let fb_footer = root::(&footer_bytes)?; + + let fb_layout = fb_footer + .layout() + .ok_or_else(|| vortex_err!("Footer must contain a layout"))?; + let loc = fb_layout._tab.loc(); + self.layout_serde + .read_layout(footer_bytes, loc, scan, message_cache) + } + + pub fn dtype(&self) -> VortexResult { + Ok(IPCDType::read_flatbuffer(&self.fb_schema()?)?.0) + } + + pub fn projected_dtype(&self, projection: &[Field]) -> VortexResult { + let fb_dtype = self + .fb_schema()? + .dtype() + .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?; + deserialize_and_project(fb_dtype, projection) + } + + /// Convert all name based references to index based for sake of augmenting read projection + pub(crate) fn resolve_references(&self, projection: &[Field]) -> VortexResult> { + let dtype = self + .fb_schema()? + .dtype() + .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?; + let fb_struct = dtype + .type__as_struct_() + .ok_or_else(|| vortex_err!("The top-level type should be a struct"))?; + resolve_field_references(fb_struct, projection) + .map(|idx| idx.map(Field::from)) + .collect::>>() + } + + fn fb_schema(&self) -> VortexResult { + let start_offset = self.leftovers_schema_offset(); + let end_offset = self.leftovers_footer_offset(); + let dtype_bytes = &self.leftovers[start_offset + FLATBUFFER_SIZE_LENGTH..end_offset]; + + root::(dtype_bytes) + .map_err(|e| e.into()) + .and_then(|m| { + m.header_as_schema() + .ok_or_else(|| vortex_err!("Message was not a schema")) + }) + } +} diff --git a/vortex-serde/src/layouts/reader/layouts.rs b/vortex-serde/src/layouts/read/layouts.rs similarity index 89% rename from vortex-serde/src/layouts/reader/layouts.rs rename to vortex-serde/src/layouts/read/layouts.rs index b89738c477..4dd1c038d3 100644 --- a/vortex-serde/src/layouts/reader/layouts.rs +++ b/vortex-serde/src/layouts/read/layouts.rs @@ -4,17 +4,18 @@ use std::sync::Arc; use bytes::Bytes; use flatbuffers::{ForwardsUOffset, Vector}; use vortex::Context; +use vortex_dtype::field::Field; use vortex_dtype::DType; use vortex_error::{vortex_bail, vortex_err, VortexResult}; use vortex_flatbuffers::footer as fb; use super::projections::Projection; -use crate::layouts::reader::batch::BatchReader; -use crate::layouts::reader::buffered::BufferedReader; -use crate::layouts::reader::cache::RelativeLayoutCache; -use crate::layouts::reader::context::{LayoutDeserializer, LayoutId, LayoutSpec}; -use crate::layouts::reader::{Layout, ReadResult, Scan}; -use crate::writer::ByteRange; +use crate::layouts::read::batch::BatchReader; +use crate::layouts::read::buffered::BufferedReader; +use crate::layouts::read::cache::RelativeLayoutCache; +use crate::layouts::read::context::{LayoutDeserializer, LayoutId, LayoutSpec}; +use crate::layouts::read::{Layout, ReadResult, Scan}; +use crate::stream_writer::ByteRange; use crate::ArrayBufferReader; #[derive(Debug)] @@ -191,15 +192,17 @@ impl Layout for ColumnLayout { Projection::All => (0..fb_children.len()) .map(|idx| self.read_child(idx, fb_children, s.dtypes()[idx].clone())) .collect::>>()?, - Projection::Partial(ref v) => v + Projection::Flat(ref v) => v .iter() - .enumerate() - .map(|(position, &projection_idx)| { - self.read_child( - projection_idx, - fb_children, - s.dtypes()[position].clone(), - ) + .zip(s.dtypes().iter().cloned()) + .map(|(projected_field, dtype)| { + let child_idx = match projected_field { + Field::Name(n) => s.find_name(n.as_ref()).ok_or_else(|| { + vortex_err!("Invalid projection, trying to select {n}") + })?, + Field::Index(i) => *i, + }; + self.read_child(child_idx, fb_children, dtype) }) .collect::>>()?, }; diff --git a/vortex-serde/src/layouts/reader/mod.rs b/vortex-serde/src/layouts/read/mod.rs similarity index 83% rename from vortex-serde/src/layouts/reader/mod.rs rename to vortex-serde/src/layouts/read/mod.rs index f323462b9f..817620ad82 100644 --- a/vortex-serde/src/layouts/reader/mod.rs +++ b/vortex-serde/src/layouts/read/mod.rs @@ -1,24 +1,29 @@ use std::fmt::Debug; pub use layouts::{ChunkedLayoutSpec, ColumnLayoutSpec}; -use projections::Projection; use vortex::Array; use vortex_error::VortexResult; -use crate::layouts::reader::filtering::RowFilter; -use crate::writer::ByteRange; - -pub mod batch; -pub mod buffered; -pub mod builder; +mod batch; +mod buffered; +mod builder; mod cache; -pub mod context; -pub mod filtering; +mod context; +mod filtering; mod footer; mod layouts; -pub mod projections; -pub mod schema; -pub mod stream; +mod projections; +mod schema; +mod stream; + +pub use builder::LayoutReaderBuilder; +pub use context::*; +pub use filtering::RowFilter; +pub use projections::Projection; +pub use schema::Schema; +pub use stream::LayoutBatchStream; + +use crate::stream_writer::ByteRange; // Recommended read-size according to the AWS performance guide const INITIAL_READ_SIZE: usize = 8 * 1024 * 1024; diff --git a/vortex-serde/src/layouts/read/projections.rs b/vortex-serde/src/layouts/read/projections.rs new file mode 100644 index 0000000000..375da3d764 --- /dev/null +++ b/vortex-serde/src/layouts/read/projections.rs @@ -0,0 +1,28 @@ +use vortex_dtype::field::Field; + +// TODO(robert): Add ability to project nested columns. +// Until datafusion supports nested column pruning we should create a separate variant to implement it +#[derive(Debug, Clone, Default)] +pub enum Projection { + #[default] + All, + Flat(Vec), +} + +impl Projection { + pub fn new(indices: impl AsRef<[usize]>) -> Self { + Self::Flat(indices.as_ref().iter().copied().map(Field::from).collect()) + } +} + +impl From> for Projection { + fn from(indices: Vec) -> Self { + Self::Flat(indices) + } +} + +impl From> for Projection { + fn from(indices: Vec) -> Self { + Self::Flat(indices.into_iter().map(Field::from).collect()) + } +} diff --git a/vortex-serde/src/layouts/reader/schema.rs b/vortex-serde/src/layouts/read/schema.rs similarity index 93% rename from vortex-serde/src/layouts/reader/schema.rs rename to vortex-serde/src/layouts/read/schema.rs index 38c88dbc7e..6b5b0170bd 100644 --- a/vortex-serde/src/layouts/reader/schema.rs +++ b/vortex-serde/src/layouts/read/schema.rs @@ -10,7 +10,7 @@ impl Schema { pub fn project(&self, projection: Projection) -> VortexResult { match projection { Projection::All => Ok(self.clone()), - Projection::Partial(indices) => { + Projection::Flat(indices) => { let DType::Struct(s, n) = &self.0 else { vortex_bail!("Can't project non struct types") }; diff --git a/vortex-serde/src/layouts/reader/stream.rs b/vortex-serde/src/layouts/read/stream.rs similarity index 88% rename from vortex-serde/src/layouts/reader/stream.rs rename to vortex-serde/src/layouts/read/stream.rs index b482567ea4..02f7f053c5 100644 --- a/vortex-serde/src/layouts/reader/stream.rs +++ b/vortex-serde/src/layouts/read/stream.rs @@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut}; use futures::Stream; use futures_util::future::BoxFuture; use futures_util::{stream, FutureExt, StreamExt, TryStreamExt}; -use vortex::array::BoolArray; +use vortex::array::{BoolArray, StructArray}; use vortex::compute::unary::subtract_scalar; use vortex::compute::{filter, search_sorted, slice, take, SearchSortedSide}; use vortex::validity::Validity; @@ -17,12 +17,13 @@ use vortex_error::{vortex_err, VortexError, VortexResult}; use vortex_scalar::Scalar; use crate::io::VortexReadAt; -use crate::layouts::reader::cache::LayoutMessageCache; -use crate::layouts::reader::schema::Schema; -use crate::layouts::reader::{Layout, MessageId, ReadResult, Scan}; -use crate::writer::ByteRange; +use crate::layouts::read::cache::LayoutMessageCache; +use crate::layouts::read::schema::Schema; +use crate::layouts::read::{Layout, MessageId, ReadResult, Scan}; +use crate::layouts::Projection; +use crate::stream_writer::ByteRange; -pub struct VortexLayoutBatchStream { +pub struct LayoutBatchStream { reader: Option, layout: Box, scan: Scan, @@ -30,17 +31,19 @@ pub struct VortexLayoutBatchStream { state: StreamingState, dtype: DType, current_offset: usize, + result_projection: Projection, } -impl VortexLayoutBatchStream { +impl LayoutBatchStream { pub fn try_new( reader: R, layout: Box, messages_cache: Arc>, dtype: DType, scan: Scan, + result_projection: Projection, ) -> VortexResult { - Ok(VortexLayoutBatchStream { + Ok(LayoutBatchStream { reader: Some(reader), layout, scan, @@ -48,6 +51,7 @@ impl VortexLayoutBatchStream { state: Default::default(), dtype, current_offset: 0, + result_projection, }) } @@ -89,7 +93,7 @@ enum StreamingState { Error, } -impl Stream for VortexLayoutBatchStream { +impl Stream for LayoutBatchStream { type Item = VortexResult; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -122,6 +126,13 @@ impl Stream for VortexLayoutBatchStrea batch = filter(&batch, &filter_array)?; } + batch = match &self.result_projection { + Projection::All => batch, + Projection::Flat(v) => { + StructArray::try_from(batch)?.project(v)?.into_array() + } + }; + self.state = StreamingState::Init; return Poll::Ready(Some(Ok(batch))); } @@ -189,7 +200,7 @@ mod tests { use vortex::validity::Validity; use vortex::IntoArrayVariant; - use crate::layouts::reader::stream::null_as_false; + use crate::layouts::read::stream::null_as_false; #[test] fn coerces_nulls() { diff --git a/vortex-serde/src/layouts/reader/footer.rs b/vortex-serde/src/layouts/reader/footer.rs deleted file mode 100644 index cbcd199775..0000000000 --- a/vortex-serde/src/layouts/reader/footer.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::sync::Arc; - -use bytes::Bytes; -use flatbuffers::root; -use vortex_dtype::{DType, StructDType}; -use vortex_error::{vortex_err, VortexResult}; -use vortex_flatbuffers::dtype::Struct_; -use vortex_flatbuffers::ReadFlatBuffer; - -use crate::layouts::reader::cache::RelativeLayoutCache; -use crate::layouts::reader::context::LayoutDeserializer; -use crate::layouts::reader::{Layout, Scan, FILE_POSTSCRIPT_SIZE}; -use crate::messages::IPCDType; - -pub struct Footer { - pub(crate) schema_offset: u64, - /// This is actually layouts - pub(crate) footer_offset: u64, - pub(crate) leftovers: Bytes, - pub(crate) leftovers_offset: u64, - pub(crate) layout_serde: LayoutDeserializer, -} - -impl Footer { - pub fn leftovers_footer_offset(&self) -> usize { - (self.footer_offset - self.leftovers_offset) as usize - } - - pub fn leftovers_schema_offset(&self) -> usize { - (self.schema_offset - self.leftovers_offset) as usize - } - - pub fn layout( - &self, - scan: Scan, - message_cache: RelativeLayoutCache, - ) -> VortexResult> { - let start_offset = self.leftovers_footer_offset(); - let end_offset = self.leftovers.len() - FILE_POSTSCRIPT_SIZE; - let footer_bytes = self.leftovers.slice(start_offset..end_offset); - let fb_footer = root::(&footer_bytes)?; - - let fb_layout = fb_footer - .layout() - .ok_or_else(|| vortex_err!("Footer must contain a layout"))?; - let loc = fb_layout._tab.loc(); - self.layout_serde - .read_layout(footer_bytes, loc, scan, message_cache) - } - - pub fn dtype(&self) -> VortexResult { - let start_offset = self.leftovers_schema_offset(); - let end_offset = self.leftovers_footer_offset(); - let dtype_bytes = &self.leftovers[start_offset..end_offset]; - - Ok( - IPCDType::read_flatbuffer(&root::(dtype_bytes)?)? - .0, - ) - } - - pub fn projected_dtype(&self, projection: &[usize]) -> VortexResult { - let start_offset = self.leftovers_schema_offset(); - let end_offset = self.leftovers_footer_offset(); - let dtype_bytes = &self.leftovers[start_offset..end_offset]; - - let fb_schema = root::(dtype_bytes)?; - let fb_dtype = fb_schema - .dtype() - .ok_or_else(|| vortex_err!(InvalidSerde: "Schema missing DType"))?; - - let fb_struct = fb_dtype - .type__as_struct_() - .ok_or_else(|| vortex_err!("The top-level type should be a struct"))?; - let nullability = fb_struct.nullable().into(); - - let (names, dtypes): (Vec>, Vec) = projection - .iter() - .map(|idx| Self::read_field(fb_struct, *idx)) - .collect::>>()? - .into_iter() - .unzip(); - - Ok(DType::Struct( - StructDType::new(names.into(), dtypes), - nullability, - )) - } - - fn read_field(fb_struct: Struct_, idx: usize) -> VortexResult<(Arc, DType)> { - let name = fb_struct - .names() - .ok_or_else(|| vortex_err!("Missing field names"))? - .get(idx); - let fb_dtype = fb_struct - .dtypes() - .ok_or_else(|| vortex_err!("Missing field dtypes"))? - .get(idx); - let dtype = DType::try_from(fb_dtype)?; - - Ok((name.into(), dtype)) - } -} diff --git a/vortex-serde/src/layouts/reader/projections.rs b/vortex-serde/src/layouts/reader/projections.rs deleted file mode 100644 index 7b45f1166d..0000000000 --- a/vortex-serde/src/layouts/reader/projections.rs +++ /dev/null @@ -1,18 +0,0 @@ -#[derive(Debug, Clone, Default)] -pub enum Projection { - #[default] - All, - Partial(Vec), -} - -impl Projection { - pub fn new(indices: impl AsRef<[usize]>) -> Self { - Self::Partial(Vec::from(indices.as_ref())) - } -} - -impl From> for Projection { - fn from(indices: Vec) -> Self { - Self::Partial(indices) - } -} diff --git a/vortex-serde/src/layouts/tests.rs b/vortex-serde/src/layouts/tests.rs index 8ab8b7b48e..9122da488f 100644 --- a/vortex-serde/src/layouts/tests.rs +++ b/vortex-serde/src/layouts/tests.rs @@ -3,10 +3,8 @@ use vortex::array::{ChunkedArray, PrimitiveArray, StructArray, VarBinArray}; use vortex::{ArrayDType, IntoArray, IntoArrayVariant}; use vortex_dtype::PType; -use crate::layouts::reader::builder::LayoutReaderBuilder; -use crate::layouts::reader::context::LayoutDeserializer; -use crate::layouts::reader::projections::Projection; -use crate::layouts::writer::LayoutWriter; +use crate::layouts::write::LayoutWriter; +use crate::layouts::{LayoutDeserializer, LayoutReaderBuilder, Projection}; #[tokio::test] #[cfg_attr(miri, ignore)] diff --git a/vortex-serde/src/layouts/writer/footer.rs b/vortex-serde/src/layouts/write/footer.rs similarity index 93% rename from vortex-serde/src/layouts/writer/footer.rs rename to vortex-serde/src/layouts/write/footer.rs index dd78f7983e..c7d61ea9f5 100644 --- a/vortex-serde/src/layouts/writer/footer.rs +++ b/vortex-serde/src/layouts/write/footer.rs @@ -1,7 +1,7 @@ use flatbuffers::{FlatBufferBuilder, WIPOffset}; use vortex_flatbuffers::{footer as fb, WriteFlatBuffer}; -use crate::layouts::writer::layouts::Layout; +use crate::layouts::write::layouts::Layout; #[derive(Debug)] pub struct Footer { diff --git a/vortex-serde/src/layouts/writer/layouts.rs b/vortex-serde/src/layouts/write/layouts.rs similarity index 96% rename from vortex-serde/src/layouts/writer/layouts.rs rename to vortex-serde/src/layouts/write/layouts.rs index a56ebc5841..2049682b84 100644 --- a/vortex-serde/src/layouts/writer/layouts.rs +++ b/vortex-serde/src/layouts/write/layouts.rs @@ -3,8 +3,8 @@ use std::collections::VecDeque; use flatbuffers::{FlatBufferBuilder, WIPOffset}; use vortex_flatbuffers::{footer as fb, WriteFlatBuffer}; -use crate::layouts::reader::context::LayoutId; -use crate::writer::ByteRange; +use crate::layouts::LayoutId; +use crate::stream_writer::ByteRange; #[derive(Debug, Clone)] pub enum Layout { diff --git a/vortex-serde/src/layouts/write/mod.rs b/vortex-serde/src/layouts/write/mod.rs new file mode 100644 index 0000000000..2956adda4b --- /dev/null +++ b/vortex-serde/src/layouts/write/mod.rs @@ -0,0 +1,5 @@ +pub use writer::LayoutWriter; + +mod footer; +mod layouts; +mod writer; diff --git a/vortex-serde/src/layouts/writer/layout_writer.rs b/vortex-serde/src/layouts/write/writer.rs similarity index 60% rename from vortex-serde/src/layouts/writer/layout_writer.rs rename to vortex-serde/src/layouts/write/writer.rs index cdfcb3fec0..70dd49c60f 100644 --- a/vortex-serde/src/layouts/writer/layout_writer.rs +++ b/vortex-serde/src/layouts/write/writer.rs @@ -1,25 +1,20 @@ use std::collections::VecDeque; use std::mem; -use flatbuffers::FlatBufferBuilder; use futures::{Stream, TryStreamExt}; -use itertools::Itertools; use vortex::array::{ChunkedArray, StructArray}; use vortex::stream::ArrayStream; use vortex::validity::Validity; use vortex::{Array, ArrayDType, IntoArray}; -use vortex_buffer::io_buf::IoBuf; use vortex_dtype::DType; use vortex_error::{vortex_bail, VortexResult}; -use vortex_flatbuffers::WriteFlatBuffer; use crate::io::VortexWrite; -use crate::layouts::reader::{ChunkedLayoutSpec, ColumnLayoutSpec}; -use crate::layouts::writer::footer::Footer; -use crate::layouts::writer::layouts::{FlatLayout, Layout, NestedLayout}; +use crate::layouts::read::{ChunkedLayoutSpec, ColumnLayoutSpec}; +use crate::layouts::write::footer::Footer; +use crate::layouts::write::layouts::{FlatLayout, Layout, NestedLayout}; use crate::layouts::MAGIC_BYTES; -use crate::messages::IPCSchema; -use crate::writer::ChunkOffsets; +use crate::stream_writer::ChunkOffsets; use crate::MessageWriter; pub struct LayoutWriter { @@ -68,26 +63,20 @@ impl LayoutWriter { while let Some(columns) = array_stream.try_next().await? { let st = StructArray::try_from(&columns)?; for (i, field) in st.children().enumerate() { - let chunk_pos = if let Ok(chunked_array) = ChunkedArray::try_from(field.clone()) { + if let Ok(chunked_array) = ChunkedArray::try_from(field.clone()) { self.write_column_chunks(chunked_array.array_stream(), i) .await? } else { self.write_column_chunks(field.into_array_stream(), i) .await? - }; - - self.merge_chunk_offsets(i, chunk_pos); + } } } Ok(self) } - async fn write_column_chunks( - &mut self, - mut stream: S, - column_idx: usize, - ) -> VortexResult + async fn write_column_chunks(&mut self, mut stream: S, column_idx: usize) -> VortexResult<()> where S: Stream> + Unpin, { @@ -111,70 +100,40 @@ impl LayoutWriter { byte_offsets.push(self.msgs.tell()); } - Ok(ChunkOffsets { - byte_offsets, - row_offsets, - }) - } - - fn merge_chunk_offsets(&mut self, column_idx: usize, chunk_pos: ChunkOffsets) { if let Some(chunk) = self.column_chunks.get_mut(column_idx) { - chunk.byte_offsets.extend(chunk_pos.byte_offsets); - chunk.row_offsets.extend(chunk_pos.row_offsets); + // Remove last entry from the list as it would be the same as first entry of next chunk + byte_offsets.truncate(byte_offsets.len() - 1); + row_offsets.truncate(row_offsets.len() - 1); + + chunk.byte_offsets.extend(byte_offsets); + chunk.row_offsets.extend(row_offsets); } else { - self.column_chunks.push(chunk_pos); + self.column_chunks + .push(ChunkOffsets::new(byte_offsets, row_offsets)); } + + Ok(()) } async fn write_metadata_arrays(&mut self) -> VortexResult { - let DType::Struct(..) = self.dtype.as_ref().expect("Should have written values") else { - unreachable!("Values are a structarray") - }; - let mut column_layouts = VecDeque::with_capacity(self.column_chunks.len()); for mut chunk in mem::take(&mut self.column_chunks) { - let mut chunks = VecDeque::new(); - let len = chunk.byte_offsets.len() - 1; - let byte_counts = chunk + let mut chunks: VecDeque = chunk .byte_offsets .iter() - .skip(1) - .zip(chunk.byte_offsets.iter()) - .map(|(a, b)| a - b) - .collect_vec(); - - chunks.extend( - chunk - .byte_offsets - .iter() - .zip(chunk.byte_offsets.iter().skip(1)) - .map(|(begin, end)| Layout::Flat(FlatLayout::new(*begin, *end))), - ); - let row_counts = chunk - .row_offsets - .iter() - .skip(1) - .zip(chunk.row_offsets.iter()) - .map(|(a, b)| a - b) - .collect_vec(); + .zip(chunk.byte_offsets.iter().skip(1)) + .map(|(begin, end)| Layout::Flat(FlatLayout::new(*begin, *end))) + .collect(); chunk.byte_offsets.truncate(len); chunk.row_offsets.truncate(len); let metadata_array = StructArray::try_new( - [ - "byte_offset".into(), - "byte_count".into(), - "row_offset".into(), - "row_count".into(), - ] - .into(), + ["byte_offset".into(), "row_offset".into()].into(), vec![ chunk.byte_offsets.into_array(), - byte_counts.into_array(), chunk.row_offsets.into_array(), - row_counts.into_array(), ], len, Validity::NonNullable, @@ -196,43 +155,27 @@ impl LayoutWriter { Ok(NestedLayout::new(column_layouts, ColumnLayoutSpec::ID)) } - async fn write_file_trailer(self, footer: Footer) -> VortexResult { - let schema_offset = self.msgs.tell(); - let mut w = self.msgs.into_inner(); - - let dtype_len = Self::write_flatbuffer( - &mut w, - &IPCSchema(&self.dtype.expect("Needed a schema at this point")), - ) - .await?; - let _ = Self::write_flatbuffer(&mut w, &footer).await?; - - w.write_all(schema_offset.to_le_bytes()).await?; - w.write_all((schema_offset + dtype_len).to_le_bytes()) + async fn write_footer(&mut self, footer: Footer) -> VortexResult<(u64, u64)> { + let dtype_offset = self.msgs.tell(); + self.msgs + .write_dtype(&self.dtype.take().expect("Needed a schema at this point")) .await?; - w.write_all(MAGIC_BYTES).await?; - Ok(w) - } - - // TODO(robert): Remove this once messagewriter/reader can write non length prefixed messages - async fn write_flatbuffer(write: &mut W, fb: &F) -> VortexResult { - let mut fbb = FlatBufferBuilder::new(); - let fb_offset = fb.write_flatbuffer(&mut fbb); - fbb.finish_minimal(fb_offset); - - let (buffer, buffer_begin) = fbb.collapse(); - let buffer_end = buffer.len(); - let sliced_buf = buffer.slice_owned(buffer_begin..buffer_end); - let buf_len = sliced_buf.as_slice().len() as u64; - - write.write_all(sliced_buf).await?; - Ok(buf_len) + let footer_offset = self.msgs.tell(); + self.msgs.write_message(footer).await?; + Ok((dtype_offset, footer_offset)) } pub async fn finalize(mut self) -> VortexResult { let top_level_layout = self.write_metadata_arrays().await?; - self.write_file_trailer(Footer::new(Layout::Nested(top_level_layout))) - .await + let (dtype_offset, footer_offset) = self + .write_footer(Footer::new(Layout::Nested(top_level_layout))) + .await?; + let mut w = self.msgs.into_inner(); + + w.write_all(dtype_offset.to_le_bytes()).await?; + w.write_all(footer_offset.to_le_bytes()).await?; + w.write_all(MAGIC_BYTES).await?; + Ok(w) } } @@ -243,7 +186,7 @@ mod tests { use vortex::validity::Validity; use vortex::IntoArray; - use crate::layouts::writer::layout_writer::LayoutWriter; + use crate::layouts::LayoutWriter; #[test] fn write_columns() { diff --git a/vortex-serde/src/layouts/writer/mod.rs b/vortex-serde/src/layouts/writer/mod.rs deleted file mode 100644 index 8409c16e31..0000000000 --- a/vortex-serde/src/layouts/writer/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub use layout_writer::LayoutWriter; - -mod footer; -mod layout_writer; -mod layouts; diff --git a/vortex-serde/src/lib.rs b/vortex-serde/src/lib.rs index cb6359433d..2849972333 100644 --- a/vortex-serde/src/lib.rs +++ b/vortex-serde/src/lib.rs @@ -1,14 +1,16 @@ -pub use message_reader::*; -pub use message_writer::*; +use message_reader::*; +use message_writer::*; pub mod chunked_reader; +mod dtype_reader; pub mod io; pub mod layouts; mod message_reader; mod message_writer; mod messages; pub mod stream_reader; -pub mod writer; +pub mod stream_writer; +pub use dtype_reader::*; pub const ALIGNMENT: usize = 64; @@ -28,12 +30,12 @@ mod test { use vortex_error::VortexResult; use crate::io::FuturesAdapter; - use crate::writer::ArrayWriter; - use crate::MessageReader; + use crate::stream_reader::StreamArrayReader; + use crate::stream_writer::StreamArrayWriter; fn write_ipc(array: A) -> Vec { block_on(async { - ArrayWriter::new(vec![]) + StreamArrayWriter::new(vec![]) .write_array(array.into_array()) .await .unwrap() @@ -50,14 +52,17 @@ mod test { let indices = PrimitiveArray::from(vec![1, 2, 10]).into_array(); let ctx = Arc::new(Context::default()); - let mut messages = block_on(async { - MessageReader::try_new(FuturesAdapter(Cursor::new(buffer))) + let stream_reader = block_on(async { + StreamArrayReader::try_new(FuturesAdapter(Cursor::new(buffer)), ctx) + .await + .unwrap() + .load_dtype() .await .unwrap() }); - let reader = block_on(async { messages.array_stream_from_messages(ctx).await })?; + let reader = stream_reader.into_array_stream(); - let result_iter = reader.take_rows(indices).unwrap(); + let result_iter = reader.take_rows(indices)?; pin_mut!(result_iter); let _result = block_on(async { result_iter.next().await.unwrap().unwrap() }); @@ -79,12 +84,17 @@ mod test { let chunked = ChunkedArray::try_new(vec![data.clone(), data2], data.dtype().clone())?; let buffer = write_ipc(chunked); - let mut messages = - block_on(async { MessageReader::try_new(FuturesAdapter(Cursor::new(buffer))).await })?; - let ctx = Arc::new(Context::default()); - let take_iter = block_on(async { messages.array_stream_from_messages(ctx).await })? - .take_rows(indices)?; + let stream_reader = block_on(async { + StreamArrayReader::try_new(FuturesAdapter(Cursor::new(buffer)), ctx) + .await + .unwrap() + .load_dtype() + .await + .unwrap() + }); + + let take_iter = stream_reader.into_array_stream().take_rows(indices)?; pin_mut!(take_iter); let next = block_on(async { take_iter.try_next().await })?.expect("Expected a chunk"); diff --git a/vortex-serde/src/message_reader.rs b/vortex-serde/src/message_reader.rs index 592769087f..bff55f431d 100644 --- a/vortex-serde/src/message_reader.rs +++ b/vortex-serde/src/message_reader.rs @@ -14,7 +14,7 @@ use vortex_flatbuffers::{message as fb, ReadFlatBuffer}; use crate::io::VortexRead; use crate::messages::IPCDType; -const FLATBUFFER_SIZE_LENGTH: usize = 4; +pub const FLATBUFFER_SIZE_LENGTH: usize = 4; pub struct MessageReader { read: R, @@ -126,15 +126,6 @@ impl MessageReader { array_reader.into_array(ctx, dtype).map(Some) } - /// Construct an ArrayStream pulling the DType from the stream. - pub async fn array_stream_from_messages( - &mut self, - ctx: Arc, - ) -> VortexResult { - let dtype = self.read_dtype().await?; - Ok(self.array_stream(ctx, dtype)) - } - pub fn array_stream(&mut self, ctx: Arc, dtype: DType) -> impl ArrayStream + '_ { struct State<'a, R: VortexRead> { msgs: &'a mut MessageReader, @@ -207,6 +198,10 @@ impl MessageReader { let _ = self.next().await?; page_buffer } + + pub fn into_inner(self) -> R { + self.read + } } pub enum ReadState { diff --git a/vortex-serde/src/message_writer.rs b/vortex-serde/src/message_writer.rs index d39b0fe6b2..2dfe4c6c38 100644 --- a/vortex-serde/src/message_writer.rs +++ b/vortex-serde/src/message_writer.rs @@ -85,7 +85,7 @@ impl MessageWriter { Ok(()) } - async fn write_message(&mut self, flatbuffer: F) -> io::Result<()> { + pub async fn write_message(&mut self, flatbuffer: F) -> io::Result<()> { // We reuse the scratch buffer each time and then replace it at the end. // The scratch buffer may be missing if a previous write failed. We could use scopeguard // or similar here if it becomes a problem in practice. diff --git a/vortex-serde/src/writer.rs b/vortex-serde/src/stream_writer/mod.rs similarity index 89% rename from vortex-serde/src/writer.rs rename to vortex-serde/src/stream_writer/mod.rs index fc1535e974..bca208ba62 100644 --- a/vortex-serde/src/writer.rs +++ b/vortex-serde/src/stream_writer/mod.rs @@ -9,14 +9,14 @@ use vortex_error::VortexResult; use crate::io::VortexWrite; use crate::MessageWriter; -pub struct ArrayWriter { +pub struct StreamArrayWriter { msgs: MessageWriter, array_layouts: Vec, page_ranges: Vec, } -impl ArrayWriter { +impl StreamArrayWriter { pub fn new(write: W) -> Self { Self { msgs: MessageWriter::new(write), @@ -59,10 +59,7 @@ impl ArrayWriter { byte_offsets.push(self.msgs.tell()); } - Ok(ChunkOffsets { - byte_offsets, - row_offsets, - }) + Ok(ChunkOffsets::new(byte_offsets, row_offsets)) } pub async fn write_array_stream( @@ -103,6 +100,10 @@ pub struct ByteRange { #[allow(clippy::len_without_is_empty)] impl ByteRange { + pub fn new(begin: u64, end: u64) -> Self { + Self { begin, end } + } + pub fn len(&self) -> usize { (self.end - self.begin) as usize } @@ -119,3 +120,12 @@ pub struct ChunkOffsets { pub byte_offsets: Vec, pub row_offsets: Vec, } + +impl ChunkOffsets { + pub fn new(byte_offsets: Vec, row_offsets: Vec) -> Self { + Self { + byte_offsets, + row_offsets, + } + } +}