diff --git a/Cargo.lock b/Cargo.lock index cb633d353da2..dc08fa610164 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -807,6 +807,9 @@ name = "bytes" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +dependencies = [ + "serde", +] [[package]] name = "bytes-utils" @@ -3390,7 +3393,6 @@ dependencies = [ "bytes", "chrono", "chrono-tz", - "ciborium", "either", "futures", "hashbrown 0.15.2", @@ -3425,10 +3427,11 @@ version = "0.45.1" dependencies = [ "ahash", "arboard", + "bincode", "bytemuck", "bytes", - "ciborium", "either", + "flate2", "itoa", "libc", "ndarray", @@ -3552,9 +3555,11 @@ name = "polars-utils" version = "0.45.1" dependencies = [ "ahash", + "bincode", "bytemuck", "bytes", "compact_str", + "flate2", "hashbrown 0.15.2", "indexmap", "libc", @@ -3567,6 +3572,7 @@ dependencies = [ "raw-cpuid", "rayon", "serde", + "serde_json", "stacker", "sysinfo", "version_check", diff --git a/Cargo.toml b/Cargo.toml index a431283c497f..b2d1b988fc4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,12 @@ atoi_simd = "0.16" atomic-waker = "1" avro-schema = { version = "0.3" } base64 = "0.22.0" +bincode = "1.3.3" bitflags = "2" bytemuck = { version = "1.11", features = ["derive", "extern_crate_alloc"] } bytes = { version = "1.7" } chrono = { version = "0.4.31", default-features = false, features = ["std"] } chrono-tz = "0.10" -ciborium = "0.2" compact_str = { version = "0.8.0", features = ["serde"] } crossbeam-channel = "0.5.8" crossbeam-deque = "0.8.5" diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 71689b4094d6..d3bc5417a9d8 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -111,10 +111,6 @@ pub enum ArrowDataType { LargeList(Box<Field>), /// A nested [`ArrowDataType`] with a given number of [`Field`]s. Struct(Vec<Field>), - /// A nested datatype that can represent slots of differing types. - /// Third argument represents mode - #[cfg_attr(feature = "serde", serde(skip))] - Union(Vec<Field>, Option<Vec<i32>>, UnionMode), /// A nested type that is represented as /// /// List<entries: Struct<key: K, value: V>> @@ -176,6 +172,10 @@ pub enum ArrowDataType { Utf8View, /// A type unknown to Arrow. Unknown, + /// A nested datatype that can represent slots of differing types. + /// Third argument represents mode + #[cfg_attr(feature = "serde", serde(skip))] + Union(Vec<Field>, Option<Vec<i32>>, UnionMode), } /// Mode of [`ArrowDataType::Union`] diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index b04633106b7c..0a6e70fdaf83 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -42,7 +42,7 @@ thiserror = { workspace = true } xxhash-rust = { workspace = true } [dev-dependencies] -bincode = { version = "1" } +bincode = { workspace = true } serde_json = { workspace = true } [build-dependencies] @@ -123,7 +123,15 @@ dtype-struct = [] bigidx = ["arrow/bigidx", "polars-utils/bigidx"] python = [] -serde = ["dep:serde", "bitflags/serde", "polars-schema/serde", "polars-utils/serde"] +serde = [ + "dep:serde", + "bitflags/serde", + "polars-schema/serde", + "polars-utils/serde", + "arrow/io_ipc", + "arrow/io_ipc_compression", + "serde_json", +] serde-lazy = ["serde", "arrow/serde", "indexmap/serde", "chrono/serde"] docs-selection = [ @@ -143,6 +151,7 @@ docs-selection = [ "row_hash", "rolling_window", "rolling_window_by", + "serde", "dtype-categorical", "dtype-decimal", "diagonal_concat", diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index e9d961ef4be0..5cc1d7f86c94 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -108,11 +108,11 @@ enum SerializableDataType { // some logical types we cannot know statically, e.g. Datetime Unknown(UnknownKind), #[cfg(feature = "dtype-categorical")] - Categorical(Option<Wrap<Utf8ViewArray>>, CategoricalOrdering), + Categorical(Option<Series>, CategoricalOrdering), #[cfg(feature = "dtype-decimal")] Decimal(Option<usize>, Option<usize>), #[cfg(feature = "dtype-categorical")] - Enum(Option<Wrap<Utf8ViewArray>>, CategoricalOrdering), + Enum(Option<Series>, CategoricalOrdering), #[cfg(feature = "object")] Object(String), } @@ -146,11 +146,23 @@ impl From<&DataType> for SerializableDataType { #[cfg(feature = "dtype-struct")] Struct(flds) => Self::Struct(flds.clone()), #[cfg(feature = "dtype-categorical")] - Categorical(_, ordering) => Self::Categorical(None, *ordering), + Categorical(Some(rev_map), ordering) => Self::Categorical( + Some( + StringChunked::with_chunk(PlSmallStr::EMPTY, rev_map.get_categories().clone()) + .into_series(), + ), + *ordering, + ), #[cfg(feature = "dtype-categorical")] - Enum(Some(rev_map), ordering) => { - Self::Enum(Some(Wrap(rev_map.get_categories().clone())), *ordering) - }, + Categorical(None, ordering) => Self::Categorical(None, *ordering), + #[cfg(feature = "dtype-categorical")] + Enum(Some(rev_map), ordering) => Self::Enum( + Some( + StringChunked::with_chunk(PlSmallStr::EMPTY, rev_map.get_categories().clone()) + .into_series(), + ), + *ordering, + ), #[cfg(feature = "dtype-categorical")] Enum(None, ordering) => Self::Enum(None, *ordering), #[cfg(feature = "dtype-decimal")] @@ -190,9 +202,26 @@ impl From<SerializableDataType> for DataType { #[cfg(feature = "dtype-struct")] Struct(flds) => Self::Struct(flds), #[cfg(feature = "dtype-categorical")] - Categorical(_, ordering) => Self::Categorical(None, ordering), + Categorical(Some(categories), ordering) => Self::Categorical( + Some(Arc::new(RevMapping::build_local( + categories.0.rechunk().chunks()[0] + .as_any() + .downcast_ref::<Utf8ViewArray>() + .unwrap() + .clone(), + ))), + ordering, + ), + #[cfg(feature = "dtype-categorical")] + Categorical(None, ordering) => Self::Categorical(None, ordering), #[cfg(feature = "dtype-categorical")] - Enum(Some(categories), _) => create_enum_dtype(categories.0), + Enum(Some(categories), _) => create_enum_dtype( + categories.rechunk().chunks()[0] + .as_any() + .downcast_ref::<Utf8ViewArray>() + .unwrap() + .clone(), + ), #[cfg(feature = "dtype-categorical")] Enum(None, ordering) => Self::Enum(None, ordering), #[cfg(feature = "dtype-decimal")] diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 2120298ce693..e83658efb2d1 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -37,8 +37,6 @@ mod series; /// 2. A [`ScalarColumn`] that repeats a single [`Scalar`] #[derive(Debug, Clone)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -#[cfg_attr(feature = "serde", serde(from = "Series"))] -#[cfg_attr(feature = "serde", serde(into = "_SerdeSeries"))] pub enum Column { Series(SeriesColumn), Partitioned(PartitionedColumn), diff --git a/crates/polars-core/src/frame/column/partitioned.rs b/crates/polars-core/src/frame/column/partitioned.rs index 93471c662d72..e6de4a5e7efb 100644 --- a/crates/polars-core/src/frame/column/partitioned.rs +++ b/crates/polars-core/src/frame/column/partitioned.rs @@ -12,12 +12,14 @@ use crate::frame::Scalar; use crate::series::IsSorted; #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct PartitionedColumn { name: PlSmallStr, values: Series, ends: Arc<[IdxSize]>, + #[cfg_attr(feature = "serde", serde(skip))] materialized: OnceLock<Series>, } diff --git a/crates/polars-core/src/frame/column/scalar.rs b/crates/polars-core/src/frame/column/scalar.rs index cc33697a42ff..1906b2c7a472 100644 --- a/crates/polars-core/src/frame/column/scalar.rs +++ b/crates/polars-core/src/frame/column/scalar.rs @@ -307,3 +307,71 @@ impl From<ScalarColumn> for Column { Self::Scalar(value) } } + +#[cfg(feature = "serde")] +mod serde_impl { + use std::sync::OnceLock; + + use polars_error::PolarsError; + use polars_utils::pl_str::PlSmallStr; + + use super::ScalarColumn; + use crate::frame::{Scalar, Series}; + + #[derive(serde::Serialize, serde::Deserialize)] + struct SerializeWrap { + name: PlSmallStr, + /// Unit-length series for dispatching to IPC serialize + unit_series: Series, + length: usize, + } + + impl From<&ScalarColumn> for SerializeWrap { + fn from(value: &ScalarColumn) -> Self { + Self { + name: value.name.clone(), + unit_series: value.scalar.clone().into_series(PlSmallStr::EMPTY), + length: value.length, + } + } + } + + impl TryFrom<SerializeWrap> for ScalarColumn { + type Error = PolarsError; + + fn try_from(value: SerializeWrap) -> Result<Self, Self::Error> { + let slf = Self { + name: value.name, + scalar: Scalar::new( + value.unit_series.dtype().clone(), + value.unit_series.get(0)?.into_static(), + ), + length: value.length, + materialized: OnceLock::new(), + }; + + Ok(slf) + } + } + + impl serde::ser::Serialize for ScalarColumn { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + SerializeWrap::from(self).serialize(serializer) + } + } + + impl<'de> serde::de::Deserialize<'de> for ScalarColumn { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + + SerializeWrap::deserialize(deserializer) + .and_then(|x| ScalarColumn::try_from(x).map_err(D::Error::custom)) + } + } +} diff --git a/crates/polars-core/src/frame/column/series.rs b/crates/polars-core/src/frame/column/series.rs index c7f79906ea0d..d7c7e1b5b773 100644 --- a/crates/polars-core/src/frame/column/series.rs +++ b/crates/polars-core/src/frame/column/series.rs @@ -7,10 +7,12 @@ use super::Series; /// At the moment this just conditionally tracks where it was created so that materialization /// problems can be tracked down. #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct SeriesColumn { inner: Series, #[cfg(debug_assertions)] + #[cfg_attr(feature = "serde", serde(skip))] materialized_at: Option<std::sync::Arc<std::backtrace::Backtrace>>, } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index c33788a95ce9..7bd94791b6a0 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -3562,4 +3562,41 @@ mod test { assert_eq!(df.get_column_names(), &["a", "b", "c"]); Ok(()) } + + #[cfg(feature = "serde")] + #[test] + fn test_deserialize_height_validation_8751() { + // Construct an invalid directly from the inner fields as the `new_unchecked_*` functions + // have debug assertions + + use polars_utils::pl_serialize; + + let df = DataFrame { + height: 2, + columns: vec![ + Int64Chunked::full("a".into(), 1, 2).into_column(), + Int64Chunked::full("b".into(), 1, 1).into_column(), + ], + cached_schema: OnceLock::new(), + }; + + // We rely on the fact that the serialization doesn't check the heights of all columns + let serialized = serde_json::to_string(&df).unwrap(); + let err = serde_json::from_str::<DataFrame>(&serialized).unwrap_err(); + + assert!(err.to_string().contains( + "successful parse invalid data: lengths don't match: could not create a new DataFrame:", + )); + + let serialized = pl_serialize::SerializeOptions::default() + .serialize_to_bytes(&df) + .unwrap(); + let err = pl_serialize::SerializeOptions::default() + .deserialize_from_reader::<DataFrame, _>(serialized.as_slice()) + .unwrap_err(); + + assert!(err.to_string().contains( + "successful parse invalid data: lengths don't match: could not create a new DataFrame:", + )); + } } diff --git a/crates/polars-core/src/serde/mod.rs b/crates/polars-core/src/serde/mod.rs index d355f959fd15..997170210388 100644 --- a/crates/polars-core/src/serde/mod.rs +++ b/crates/polars-core/src/serde/mod.rs @@ -12,14 +12,14 @@ mod test { fn test_serde() -> PolarsResult<()> { let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(2)]); - let json = serde_json::to_string(&ca).unwrap(); + let json = serde_json::to_string(&ca.clone().into_series()).unwrap(); let out = serde_json::from_str::<Series>(&json).unwrap(); assert!(ca.into_series().equals_missing(&out)); let ca = StringChunked::new("foo".into(), &[Some("foo"), None, Some("bar")]); - let json = serde_json::to_string(&ca).unwrap(); + let json = serde_json::to_string(&ca.clone().into_series()).unwrap(); let out = serde_json::from_str::<Series>(&json).unwrap(); // uses `Deserialize<'de>` assert!(ca.into_series().equals_missing(&out)); @@ -32,7 +32,7 @@ mod test { fn test_serde_owned() { let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(2)]); - let json = serde_json::to_string(&ca).unwrap(); + let json = serde_json::to_string(&ca.clone().into_series()).unwrap(); let out = serde_json::from_reader::<_, Series>(json.as_bytes()).unwrap(); // uses `DeserializeOwned` assert!(ca.into_series().equals_missing(&out)); @@ -54,7 +54,7 @@ mod test { for mut column in df.columns { column.set_sorted_flag(IsSorted::Descending); let json = serde_json::to_string(&column).unwrap(); - let out = serde_json::from_reader::<_, Series>(json.as_bytes()).unwrap(); + let out = serde_json::from_reader::<_, Column>(json.as_bytes()).unwrap(); let f = out.get_flags(); assert_ne!(f, MetadataFlags::empty()); assert_eq!(column.get_flags(), out.get_flags()); diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 0fb9d9f05f18..3b91c0048699 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -1,16 +1,17 @@ -use std::borrow::Cow; use std::fmt::Formatter; -use serde::de::{Error as DeError, MapAccess, Visitor}; -#[cfg(feature = "object")] -use serde::ser::Error as SerError; -use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use arrow::datatypes::Metadata; +use arrow::io::ipc::read::{read_stream_metadata, StreamReader, StreamState}; +use arrow::io::ipc::write::WriteOptions; +use serde::de::{Error as DeError, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[cfg(feature = "dtype-array")] -use crate::chunked_array::builder::get_fixed_size_list_builder; -use crate::chunked_array::builder::AnonymousListBuilder; use crate::chunked_array::metadata::MetadataFlags; +use crate::config; use crate::prelude::*; +use crate::utils::accumulate_dataframes_vertical; + +const FLAGS_KEY: PlSmallStr = PlSmallStr::from_static("_PL_FLAGS"); impl Serialize for Series { fn serialize<S>( @@ -20,78 +21,52 @@ impl Serialize for Series { where S: Serializer, { - match self.dtype() { - DataType::Binary => { - let ca = self.binary().unwrap(); - ca.serialize(serializer) - }, - DataType::List(_) => { - let ca = self.list().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-array")] - DataType::Array(_, _) => { - let ca = self.array().unwrap(); - ca.serialize(serializer) - }, - DataType::Boolean => { - let ca = self.bool().unwrap(); - ca.serialize(serializer) - }, - DataType::String => { - let ca = self.str().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { - let ca = self.struct_().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-date")] - DataType::Date => { - let ca = self.date().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-datetime")] - DataType::Datetime(_, _) => { - let ca = self.datetime().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - let ca = self.categorical().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-duration")] - DataType::Duration(_) => { - let ca = self.duration().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-time")] - DataType::Time => { - let ca = self.time().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "dtype-decimal")] - DataType::Decimal(_, _) => { - let ca = self.decimal().unwrap(); - ca.serialize(serializer) - }, - DataType::Null => { - let ca = self.null().unwrap(); - ca.serialize(serializer) - }, - #[cfg(feature = "object")] - DataType::Object(_, _) => Err(S::Error::custom( + use serde::ser::Error; + + if self.dtype().is_object() { + return Err(polars_err!( + ComputeError: "serializing data of type Object is not supported", - )), - dt => { - with_match_physical_numeric_polars_type!(dt, |$T| { - let ca: &ChunkedArray<$T> = self.as_ref().as_ref().as_ref(); - ca.serialize(serializer) - }) + )) + .map_err(S::Error::custom); + } + + let bytes = vec![]; + let mut bytes = std::io::Cursor::new(bytes); + let mut ipc_writer = arrow::io::ipc::write::StreamWriter::new( + &mut bytes, + WriteOptions { + // Compression should be done on an outer level + compression: Some(arrow::io::ipc::write::Compression::ZSTD), }, + ); + + let df = unsafe { + DataFrame::new_no_checks_height_from_first(vec![self.rechunk().into_column()]) + }; + + ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from([( + FLAGS_KEY, + PlSmallStr::from(self.get_flags().bits().to_string()), + )]))); + + ipc_writer + .start( + &ArrowSchema::from_iter([Field { + name: self.name().clone(), + dtype: self.dtype().clone(), + } + .to_arrow(CompatLevel::newest())]), + None, + ) + .map_err(S::Error::custom)?; + + for batch in df.iter_chunks(CompatLevel::newest(), false) { + ipc_writer.write(&batch, None).map_err(S::Error::custom)?; } + + ipc_writer.finish().map_err(S::Error::custom)?; + serializer.serialize_bytes(bytes.into_inner().as_slice()) } } @@ -100,233 +75,76 @@ impl<'de> Deserialize<'de> for Series { where D: Deserializer<'de>, { - const FIELDS: &[&str] = &["name", "datatype", "bit_settings", "length", "values"]; - struct SeriesVisitor; impl<'de> Visitor<'de> for SeriesVisitor { type Value = Series; fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { - formatter - .write_str("struct {name: <name>, datatype: <dtype>, bit_settings?: <settings>, length?: <length>, values: <values array>}") + formatter.write_str("bytes (IPC)") } - fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error> + fn visit_bytes<E>(self, mut v: &[u8]) -> Result<Self::Value, E> where - A: MapAccess<'de>, + E: DeError, { - let mut name: Option<Cow<'de, str>> = None; - let mut dtype = None; - let mut length = None; - let mut bit_settings: Option<MetadataFlags> = None; - let mut values_set = false; - while let Some(key) = map.next_key::<Cow<str>>().unwrap() { - match key.as_ref() { - "name" => { - name = match map.next_value::<Cow<str>>() { - Ok(s) => Some(s), - Err(_) => Some(Cow::Owned(map.next_value::<String>()?)), - }; - }, - "datatype" => { - dtype = Some(map.next_value()?); - }, - "bit_settings" => { - bit_settings = Some(map.next_value()?); - }, - // length is only used for struct at the moment - "length" => length = Some(map.next_value()?), - "values" => { - // we delay calling next_value until we know the dtype - values_set = true; - break; + let mut md = read_stream_metadata(&mut v).map_err(E::custom)?; + let arrow_schema = md.schema.clone(); + + let custom_metadata = md.custom_schema_metadata.take(); + + let reader = StreamReader::new(v, md, None); + let dfs = reader + .into_iter() + .map_while(|batch| match batch { + Ok(StreamState::Some(batch)) => { + Some(DataFrame::try_from((batch, &arrow_schema))) }, - fld => return Err(de::Error::unknown_field(fld, FIELDS)), - } - } - if !values_set { - return Err(de::Error::missing_field("values")); - } - let name = name.ok_or_else(|| de::Error::missing_field("name"))?; - let name = PlSmallStr::from_str(name.as_ref()); - let dtype = dtype.ok_or_else(|| de::Error::missing_field("datatype"))?; + Ok(StreamState::Waiting) => None, + Err(e) => Some(Err(e)), + }) + .collect::<PolarsResult<Vec<DataFrame>>>() + .map_err(E::custom)?; - let mut s = match dtype { - #[cfg(feature = "dtype-i8")] - DataType::Int8 => { - let values: Vec<Option<i8>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - #[cfg(feature = "dtype-u8")] - DataType::UInt8 => { - let values: Vec<Option<u8>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - #[cfg(feature = "dtype-i16")] - DataType::Int16 => { - let values: Vec<Option<i16>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - #[cfg(feature = "dtype-u16")] - DataType::UInt16 => { - let values: Vec<Option<u16>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::Int32 => { - let values: Vec<Option<i32>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::UInt32 => { - let values: Vec<Option<u32>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::Int64 => { - let values: Vec<Option<i64>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::UInt64 => { - let values: Vec<Option<u64>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - #[cfg(feature = "dtype-date")] - DataType::Date => { - let values: Vec<Option<i32>> = map.next_value()?; - Ok(Series::new(name, values).cast(&DataType::Date).unwrap()) - }, - #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { - let values: Vec<Option<i64>> = map.next_value()?; - Ok(Series::new(name, values) - .cast(&DataType::Datetime(tu, tz)) - .unwrap()) - }, - #[cfg(feature = "dtype-duration")] - DataType::Duration(tu) => { - let values: Vec<Option<i64>> = map.next_value()?; - Ok(Series::new(name, values) - .cast(&DataType::Duration(tu)) - .unwrap()) - }, - #[cfg(feature = "dtype-time")] - DataType::Time => { - let values: Vec<Option<i64>> = map.next_value()?; - Ok(Series::new(name, values).cast(&DataType::Time).unwrap()) - }, - #[cfg(feature = "dtype-decimal")] - DataType::Decimal(precision, Some(scale)) => { - let values: Vec<Option<i128>> = map.next_value()?; - Ok(ChunkedArray::from_slice_options(name, &values) - .into_decimal_unchecked(precision, scale) - .into_series()) - }, - DataType::Boolean => { - let values: Vec<Option<bool>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::Float32 => { - let values: Vec<Option<f32>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::Float64 => { - let values: Vec<Option<f64>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::String => { - let values: Vec<Option<Cow<str>>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - DataType::List(inner) => { - let values: Vec<Option<Series>> = map.next_value()?; - let mut lb = AnonymousListBuilder::new(name, values.len(), Some(*inner)); - for value in &values { - lb.append_opt_series(value.as_ref()).map_err(|e| { - de::Error::custom(format!("could not append series to list: {e}")) - })?; - } - Ok(lb.finish().into_series()) - }, - #[cfg(feature = "dtype-array")] - DataType::Array(inner, width) => { - let values: Vec<Option<Series>> = map.next_value()?; - let mut builder = - get_fixed_size_list_builder(&inner, values.len(), width, name) - .map_err(|e| { - de::Error::custom(format!( - "could not get supported list builder: {e}" - )) - })?; - for value in &values { - if let Some(s) = value { - // we only have one chunk per series as we serialize it in this way. - let arr = &s.chunks()[0]; - // SAFETY, we are within bounds - unsafe { - builder.push_unchecked(arr.as_ref(), 0); - } - } else { - // SAFETY, we are within bounds - unsafe { - builder.push_null(); - } - } - } - Ok(builder.finish().into_series()) - }, - DataType::Binary => { - let values: Vec<Option<Cow<[u8]>>> = map.next_value()?; - Ok(Series::new(name, values)) - }, - #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => { - let length = length.ok_or_else(|| de::Error::missing_field("length"))?; - let values: Vec<Series> = map.next_value()?; + let df = accumulate_dataframes_vertical(dfs).map_err(E::custom)?; - if fields.len() != values.len() { - let expected = format!("expected {} value series", fields.len()); - let expected = expected.as_str(); - return Err(de::Error::invalid_length(values.len(), &expected)); - } + if df.width() != 1 { + return Err(polars_err!( + ShapeMismatch: + "expected only 1 column when deserializing Series from IPC, got columns: {:?}", + df.schema().iter_names().collect::<Vec<_>>() + )).map_err(E::custom); + } + + let mut s = df.take_columns().swap_remove(0).take_materialized_series(); - for (f, v) in fields.iter().zip(values.iter()) { - if f.dtype() != v.dtype() { - let err = format!( - "type mismatch for struct. expected: {}, given: {}", - f.dtype(), - v.dtype() - ); - return Err(de::Error::custom(err)); + if let Some(custom_metadata) = custom_metadata { + if let Some(flags) = custom_metadata.get(&FLAGS_KEY) { + if let Ok(v) = flags.parse::<u8>() { + if let Some(flags) = MetadataFlags::from_bits(v) { + s.set_flags(flags); } + } else if config::verbose() { + eprintln!("Series::Deserialize: Failed to parse as u8: {:?}", flags) } - - let ca = StructChunked::from_series(name.clone(), length, values.iter()) - .unwrap(); - let mut s = ca.into_series(); - s.rename(name); - Ok(s) - }, - #[cfg(feature = "dtype-categorical")] - dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { - let values: Vec<Option<Cow<str>>> = map.next_value()?; - Ok(Series::new(name, values).cast(&dt).unwrap()) - }, - DataType::Null => { - let values: Vec<usize> = map.next_value()?; - let len = values.first().unwrap(); - Ok(Series::new_null(name, *len)) - }, - dt => Err(A::Error::custom(format!( - "deserializing data of type {dt} is not supported" - ))), - }?; - - if let Some(f) = bit_settings { - s.set_flags(f) + } } + Ok(s) } + + fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + // This is not ideal, but we hit here if the serialization format is JSON. + let bytes = std::iter::from_fn(|| seq.next_element::<u8>().transpose()) + .collect::<Result<Vec<_>, A::Error>>()?; + + self.visit_bytes(&bytes) + } } - deserializer.deserialize_map(SeriesVisitor) + deserializer.deserialize_bytes(SeriesVisitor) } } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 73cfeb50a730..f6547ed249a7 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -27,10 +27,9 @@ ahash = { workspace = true } arrow = { workspace = true } bitflags = { workspace = true } bytemuck = { workspace = true } -bytes = { workspace = true } +bytes = { workspace = true, features = ["serde"] } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } -ciborium = { workspace = true, optional = true } either = { workspace = true } futures = { workspace = true, optional = true } hashbrown = { workspace = true } @@ -52,8 +51,9 @@ version_check = { workspace = true } [features] # debugging utility debugging = [] -python = ["dep:pyo3", "ciborium", "polars-utils/python"] +python = ["dep:pyo3", "polars-utils/python"] serde = [ + "ir_serde", "dep:serde", "polars-core/serde-lazy", "polars-time/serde", @@ -189,7 +189,7 @@ month_end = ["polars-time/month_end"] offset_by = ["polars-time/offset_by"] bigidx = ["polars-core/bigidx", "polars-utils/bigidx"] -polars_cloud = ["serde", "ciborium"] +polars_cloud = ["serde"] ir_serde = ["serde", "polars-utils/ir_serde"] panic_on_schema = [] @@ -279,7 +279,6 @@ features = [ "replace", "dtype-u16", "regex", - "ciborium", "dtype-decimal", "arg_where", "business", diff --git a/crates/polars-plan/src/client/mod.rs b/crates/polars-plan/src/client/mod.rs index f5a5cdb0f763..c42e481cd1f6 100644 --- a/crates/polars-plan/src/client/mod.rs +++ b/crates/polars-plan/src/client/mod.rs @@ -1,7 +1,7 @@ mod check; -use arrow::legacy::error::to_compute_err; use polars_core::error::PolarsResult; +use polars_utils::pl_serialize; use crate::plans::DslPlan; @@ -12,7 +12,9 @@ pub fn prepare_cloud_plan(dsl: DslPlan) -> PolarsResult<Vec<u8>> { // Serialize the plan. let mut writer = Vec::new(); - ciborium::into_writer(&dsl, &mut writer).map_err(to_compute_err)?; + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(&mut writer, &dsl)?; Ok(writer) } diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index cb7b834627a2..08f6be59c5b9 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -51,6 +51,8 @@ impl PythonUdfExpression { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> { // Handle byte mark + + use polars_utils::pl_serialize; debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; @@ -72,7 +74,7 @@ impl PythonUdfExpression { // Load UDF metadata let mut reader = Cursor::new(buf); let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) = - ciborium::de::from_reader(&mut reader).map_err(map_err)?; + pl_serialize::deserialize_from_reader(&mut reader)?; let remainder = &buf[reader.position() as usize..]; @@ -132,6 +134,8 @@ impl ColumnsUdf for PythonUdfExpression { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> { // Write byte marks + + use polars_utils::pl_serialize; buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); Python::with_gil(|py| { @@ -160,15 +164,14 @@ impl ColumnsUdf for PythonUdfExpression { buf.extend_from_slice(&*PYTHON3_VERSION); // Write UDF metadata - ciborium::ser::into_writer( + pl_serialize::serialize_into_writer( + &mut *buf, &( self.output_type.clone(), self.is_elementwise, self.returns_scalar, ), - &mut *buf, - ) - .unwrap(); + )?; // Write UDF let dumped = dumped.extract::<PyBackedBytes>().unwrap(); @@ -191,12 +194,13 @@ impl PythonGetOutput { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> { // Skip header. + + use polars_utils::pl_serialize; debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; let mut reader = Cursor::new(buf); - let return_dtype: Option<DataType> = - ciborium::de::from_reader(&mut reader).map_err(map_err)?; + let return_dtype: Option<DataType> = pl_serialize::deserialize_from_reader(&mut reader)?; Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>) } @@ -220,9 +224,10 @@ impl FunctionOutputField for PythonGetOutput { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> { + use polars_utils::pl_serialize; + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); - ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap(); - Ok(()) + pl_serialize::serialize_into_writer(&mut *buf, &self.return_dtype) } } diff --git a/crates/polars-plan/src/plans/file_scan.rs b/crates/polars-plan/src/plans/file_scan.rs index 655907e3f666..d4b9e2b14666 100644 --- a/crates/polars-plan/src/plans/file_scan.rs +++ b/crates/polars-plan/src/plans/file_scan.rs @@ -19,6 +19,11 @@ pub enum FileScan { options: CsvReadOptions, cloud_options: Option<polars_io::cloud::CloudOptions>, }, + #[cfg(feature = "json")] + NDJson { + options: NDJsonReadOptions, + cloud_options: Option<polars_io::cloud::CloudOptions>, + }, #[cfg(feature = "parquet")] Parquet { options: ParquetOptions, @@ -33,11 +38,6 @@ pub enum FileScan { #[cfg_attr(feature = "serde", serde(skip))] metadata: Option<arrow::io::ipc::read::FileMetadata>, }, - #[cfg(feature = "json")] - NDJson { - options: NDJsonReadOptions, - cloud_options: Option<polars_io::cloud::CloudOptions>, - }, #[cfg_attr(feature = "serde", serde(skip))] Anonymous { options: Arc<AnonymousScanOptions>, diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index f1aa33a7e7dd..ec95067be4be 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -21,9 +21,10 @@ pub struct OpaquePythonUdf { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[strum(serialize_all = "SCREAMING_SNAKE_CASE")] pub enum DslFunction { - // Function that is already converted to IR. - #[cfg_attr(feature = "serde", serde(skip))] - FunctionIR(FunctionIR), + RowIndex { + name: PlSmallStr, + offset: Option<IdxSize>, + }, // This is both in DSL and IR because we want to be able to serialize it. #[cfg(feature = "python")] OpaquePython(OpaquePythonUdf), @@ -35,10 +36,6 @@ pub enum DslFunction { Unpivot { args: UnpivotArgsDSL, }, - RowIndex { - name: PlSmallStr, - offset: Option<IdxSize>, - }, Rename { existing: Arc<[PlSmallStr]>, new: Arc<[PlSmallStr]>, @@ -49,6 +46,9 @@ pub enum DslFunction { /// FillValue FillNan(Expr), Drop(DropFunction), + // Function that is already converted to IR. + #[cfg_attr(feature = "serde", serde(skip))] + FunctionIR(FunctionIR), } #[derive(Clone)] diff --git a/crates/polars-plan/src/plans/functions/mod.rs b/crates/polars-plan/src/plans/functions/mod.rs index 61cce46de9af..a357d5b269ad 100644 --- a/crates/polars-plan/src/plans/functions/mod.rs +++ b/crates/polars-plan/src/plans/functions/mod.rs @@ -31,32 +31,22 @@ use crate::prelude::*; #[derive(Clone, IntoStaticStr)] #[strum(serialize_all = "SCREAMING_SNAKE_CASE")] pub enum FunctionIR { + RowIndex { + name: PlSmallStr, + offset: Option<IdxSize>, + // Might be cached. + #[cfg_attr(feature = "ir_serde", serde(skip))] + schema: CachedSchema, + }, #[cfg(feature = "python")] OpaquePython(OpaquePythonUdf), - #[cfg_attr(feature = "ir_serde", serde(skip))] - Opaque { - function: Arc<dyn DataFrameUdf>, - schema: Option<Arc<dyn UdfSchema>>, - /// allow predicate pushdown optimizations - predicate_pd: bool, - /// allow projection pushdown optimizations - projection_pd: bool, - streamable: bool, - // used for formatting - fmt_str: PlSmallStr, - }, + FastCount { sources: ScanSources, scan_type: FileScan, alias: Option<PlSmallStr>, }, - /// Streaming engine pipeline - #[cfg_attr(feature = "ir_serde", serde(skip))] - Pipeline { - function: Arc<Mutex<dyn DataFrameUdfMut>>, - schema: SchemaRef, - original: Option<Arc<IRPlan>>, - }, + Unnest { columns: Arc<[PlSmallStr]>, }, @@ -89,12 +79,24 @@ pub enum FunctionIR { #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, }, - RowIndex { - name: PlSmallStr, - // Might be cached. - #[cfg_attr(feature = "ir_serde", serde(skip))] - schema: CachedSchema, - offset: Option<IdxSize>, + #[cfg_attr(feature = "ir_serde", serde(skip))] + Opaque { + function: Arc<dyn DataFrameUdf>, + schema: Option<Arc<dyn UdfSchema>>, + /// allow predicate pushdown optimizations + predicate_pd: bool, + /// allow projection pushdown optimizations + projection_pd: bool, + streamable: bool, + // used for formatting + fmt_str: PlSmallStr, + }, + /// Streaming engine pipeline + #[cfg_attr(feature = "ir_serde", serde(skip))] + Pipeline { + function: Arc<Mutex<dyn DataFrameUdfMut>>, + schema: SchemaRef, + original: Option<Arc<IRPlan>>, }, } diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index 800e28dbea1a..6e44abe7fe7c 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -112,10 +112,10 @@ pub enum IR { keys: Vec<ExprIR>, aggs: Vec<ExprIR>, schema: SchemaRef, - #[cfg_attr(feature = "ir_serde", serde(skip))] - apply: Option<Arc<dyn DataFrameUdf>>, maintain_order: bool, options: Arc<GroupbyOptions>, + #[cfg_attr(feature = "ir_serde", serde(skip))] + apply: Option<Arc<dyn DataFrameUdf>>, }, Join { input_left: Node, diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 76599a6e977a..efb01919a15a 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -97,10 +97,10 @@ pub enum DslPlan { input: Arc<DslPlan>, keys: Vec<Expr>, aggs: Vec<Expr>, - #[cfg_attr(feature = "serde", serde(skip))] - apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>, maintain_order: bool, options: Arc<GroupbyOptions>, + #[cfg_attr(feature = "serde", serde(skip))] + apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>, }, /// Join operation Join { @@ -162,11 +162,11 @@ pub enum DslPlan { payload: SinkType, }, IR { - #[cfg_attr(feature = "serde", serde(skip))] - node: Option<Node>, - version: u32, // Keep the original Dsl around as we need that for serialization. dsl: Arc<DslPlan>, + version: u32, + #[cfg_attr(feature = "serde", serde(skip))] + node: Option<Node>, }, } diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index 6eff59a06680..4cb3ae064b75 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -212,9 +212,6 @@ pub struct FunctionOptions { /// Collect groups to a list and apply the function over the groups. /// This can be important in aggregation context. pub collect_groups: ApplyOptions, - // used for formatting, (only for anonymous functions) - #[cfg_attr(feature = "serde", serde(skip_deserializing))] - pub fmt_str: &'static str, /// Options used when deciding how to cast the arguments of the function. pub cast_options: FunctionCastOptions, @@ -223,6 +220,10 @@ pub struct FunctionOptions { // this should always be true or we could OOB pub check_lengths: UnsafeBool, pub flags: FunctionFlags, + + // used for formatting, (only for anonymous functions) + #[cfg_attr(feature = "serde", serde(skip))] + pub fmt_str: &'static str, } impl FunctionOptions { @@ -281,11 +282,10 @@ pub struct PythonOptions { pub with_columns: Option<Arc<[PlSmallStr]>>, // Which interface is the python function. pub python_source: PythonScanSource, - /// Optional predicate the reader must apply. - #[cfg_attr(feature = "serde", serde(skip))] - pub predicate: PythonPredicate, /// A `head` call passed to the reader. pub n_rows: Option<usize>, + /// Optional predicate the reader must apply. + pub predicate: PythonPredicate, } #[derive(Clone, PartialEq, Eq, Debug, Default)] @@ -298,6 +298,7 @@ pub enum PythonScanSource { } #[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum PythonPredicate { // A pyarrow predicate python expression // can be evaluated with python.eval diff --git a/crates/polars-plan/src/plans/python/predicate.rs b/crates/polars-plan/src/plans/python/predicate.rs index 2e4a21af2749..55b55388ba28 100644 --- a/crates/polars-plan/src/plans/python/predicate.rs +++ b/crates/polars-plan/src/plans/python/predicate.rs @@ -1,5 +1,5 @@ -use polars_core::error::polars_err; use polars_core::prelude::PolarsResult; +use polars_utils::pl_serialize; use crate::prelude::*; @@ -62,8 +62,7 @@ pub fn serialize(expr: &Expr) -> PolarsResult<Option<Vec<u8>>> { return Ok(None); } let mut buf = vec![]; - ciborium::into_writer(expr, &mut buf) - .map_err(|_| polars_err!(ComputeError: "could not serialize: {}", expr))?; + pl_serialize::serialize_into_writer(&mut buf, expr)?; Ok(Some(buf)) } diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index 8a56d853b4e8..13a2164ed560 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -27,10 +27,11 @@ polars-utils = { workspace = true } ahash = { workspace = true } arboard = { workspace = true, optional = true } +bincode = { workspace = true } bytemuck = { workspace = true } bytes = { workspace = true } -ciborium = { workspace = true } either = { workspace = true } +flate2 = { workspace = true } itoa = { workspace = true } libc = { workspace = true } ndarray = { workspace = true } diff --git a/crates/polars-python/src/cloud.rs b/crates/polars-python/src/cloud.rs index 08379da8e955..372e260b1d1b 100644 --- a/crates/polars-python/src/cloud.rs +++ b/crates/polars-python/src/cloud.rs @@ -1,10 +1,9 @@ -use std::io::Cursor; - -use polars_core::error::{polars_err, to_compute_err, PolarsResult}; +use polars_core::error::{polars_err, PolarsResult}; use polars_expr::state::ExecutionState; use polars_mem_engine::create_physical_plan; use polars_plan::plans::{AExpr, IRPlan, IR}; use polars_plan::prelude::{Arena, Node}; +use polars_utils::pl_serialize; use pyo3::intern; use pyo3::prelude::{PyAnyMethods, PyModule, Python, *}; use pyo3::types::{IntoPyDict, PyBytes}; @@ -28,10 +27,8 @@ pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python<'_>) -> PyResult<Bound<'_, #[pyfunction] pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec<u8>, py: Python) -> PyResult<PyDataFrame> { // Deserialize into IRPlan. - let reader = Cursor::new(ir_plan_ser); - let mut ir_plan = ciborium::from_reader::<IRPlan, _>(reader) - .map_err(to_compute_err) - .map_err(PyPolarsErr::from)?; + let mut ir_plan: IRPlan = + pl_serialize::deserialize_from_reader(ir_plan_ser.as_slice()).map_err(PyPolarsErr::from)?; // Edit for use with GPU engine. gpu_post_opt( diff --git a/crates/polars-python/src/dataframe/serde.rs b/crates/polars-python/src/dataframe/serde.rs index c421dee342b7..48dd22fdc0d6 100644 --- a/crates/polars-python/src/dataframe/serde.rs +++ b/crates/polars-python/src/dataframe/serde.rs @@ -1,8 +1,9 @@ -use std::io::{BufReader, BufWriter, Cursor}; +use std::io::{BufReader, BufWriter}; use std::ops::Deref; use polars::prelude::*; use polars_io::mmap::ReaderBytes; +use polars_utils::pl_serialize; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; @@ -17,29 +18,26 @@ impl PyDataFrame { #[cfg(feature = "ipc_streaming")] fn __getstate__<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { // Used in pickle/pickling - let mut buf: Vec<u8> = vec![]; - IpcStreamWriter::new(&mut buf) - .with_compat_level(CompatLevel::newest()) - .finish(&mut self.df.clone()) - .expect("ipc writer"); - PyBytes::new(py, &buf) + PyBytes::new( + py, + &pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_to_bytes(&self.df) + .unwrap(), + ) } #[cfg(feature = "ipc_streaming")] fn __setstate__(&mut self, state: &Bound<PyAny>) -> PyResult<()> { // Used in pickle/pickling match state.extract::<PyBackedBytes>() { - Ok(s) => { - let c = Cursor::new(&*s); - let reader = IpcStreamReader::new(c); - - reader - .finish() - .map(|df| { - self.df = df; - }) - .map_err(|e| PyPolarsErr::from(e).into()) - }, + Ok(s) => pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(&*s) + .map(|df| { + self.df = df; + }) + .map_err(|e| PyPolarsErr::from(e).into()), Err(e) => Err(e), } } @@ -48,7 +46,9 @@ impl PyDataFrame { fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); - ciborium::into_writer(&self.df, writer) + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(writer, &self.df) .map_err(|err| ComputeError::new_err(err.to_string())) } @@ -65,8 +65,10 @@ impl PyDataFrame { #[staticmethod] fn deserialize_binary(py_f: PyObject) -> PyResult<Self> { let file = get_file_like(py_f, false)?; - let reader = BufReader::new(file); - let df = ciborium::from_reader::<DataFrame, _>(reader) + let file = BufReader::new(file); + let df: DataFrame = pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(file) .map_err(|err| ComputeError::new_err(err.to_string()))?; Ok(df.into()) } diff --git a/crates/polars-python/src/expr/serde.rs b/crates/polars-python/src/expr/serde.rs index 0900c40e0722..08685baed417 100644 --- a/crates/polars-python/src/expr/serde.rs +++ b/crates/polars-python/src/expr/serde.rs @@ -1,6 +1,7 @@ use std::io::{BufReader, BufWriter, Cursor}; use polars::lazy::prelude::Expr; +use polars_utils::pl_serialize; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; @@ -15,7 +16,9 @@ impl PyExpr { fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> { // Used in pickle/pickling let mut writer: Vec<u8> = vec![]; - ciborium::ser::into_writer(&self.inner, &mut writer) + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(&mut writer, &self.inner) .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; Ok(PyBytes::new(py, &writer)) @@ -23,10 +26,13 @@ impl PyExpr { fn __setstate__(&mut self, state: &Bound<PyAny>) -> PyResult<()> { // Used in pickle/pickling + let bytes = state.extract::<PyBackedBytes>()?; let cursor = Cursor::new(&*bytes); - self.inner = - ciborium::de::from_reader(cursor).map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; + self.inner = pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(cursor) + .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; Ok(()) } @@ -34,7 +40,9 @@ impl PyExpr { fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); - ciborium::into_writer(&self.inner, writer) + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(writer, &self.inner) .map_err(|err| ComputeError::new_err(err.to_string())) } @@ -52,7 +60,9 @@ impl PyExpr { fn deserialize_binary(py_f: PyObject) -> PyResult<PyExpr> { let file = get_file_like(py_f, false)?; let reader = BufReader::new(file); - let expr = ciborium::from_reader::<Expr, _>(reader) + let expr: Expr = pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(reader) .map_err(|err| ComputeError::new_err(err.to_string()))?; Ok(expr.into()) } diff --git a/crates/polars-python/src/file.rs b/crates/polars-python/src/file.rs index 29b1df01cecc..741c5c695152 100644 --- a/crates/polars-python/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -386,8 +386,7 @@ fn read_if_bytesio(py_f: Bound<PyAny>) -> Bound<PyAny> { py_f } -/// Create reader from PyBytes or a file-like object. To get BytesIO to have -/// better performance, use read_if_bytesio() before calling this. +/// Create reader from PyBytes or a file-like object. pub fn get_mmap_bytes_reader(py_f: &Bound<PyAny>) -> PyResult<Box<dyn MmapBytesReader>> { get_mmap_bytes_reader_and_path(py_f).map(|t| t.0) } diff --git a/crates/polars-python/src/lazyframe/serde.rs b/crates/polars-python/src/lazyframe/serde.rs index 82e2d4c87f5d..2164785ddb24 100644 --- a/crates/polars-python/src/lazyframe/serde.rs +++ b/crates/polars-python/src/lazyframe/serde.rs @@ -1,5 +1,6 @@ use std::io::{BufReader, BufWriter}; +use polars_utils::pl_serialize; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; @@ -16,7 +17,10 @@ impl PyLazyFrame { fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> { // Used in pickle/pickling let mut writer: Vec<u8> = vec![]; - ciborium::ser::into_writer(&self.ldf.logical_plan, &mut writer) + + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(&mut writer, &self.ldf.logical_plan) .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; Ok(PyBytes::new(py, &writer)) @@ -26,7 +30,9 @@ impl PyLazyFrame { // Used in pickle/pickling match state.extract::<PyBackedBytes>(py) { Ok(s) => { - let lp: DslPlan = ciborium::de::from_reader(&*s) + let lp: DslPlan = pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(&*s) .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; self.ldf = LazyFrame::from(lp); Ok(()) @@ -39,7 +45,9 @@ impl PyLazyFrame { fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); - ciborium::into_writer(&self.ldf.logical_plan, writer) + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(writer, &self.ldf.logical_plan) .map_err(|err| ComputeError::new_err(err.to_string())) } @@ -57,7 +65,9 @@ impl PyLazyFrame { fn deserialize_binary(py_f: PyObject) -> PyResult<Self> { let file = get_file_like(py_f, false)?; let reader = BufReader::new(file); - let lp = ciborium::from_reader::<DslPlan, _>(reader) + let lp: DslPlan = pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(reader) .map_err(|err| ComputeError::new_err(err.to_string()))?; Ok(LazyFrame::from(lp).into()) } diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index 523f699af7c0..a3a80505596f 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -1,9 +1,8 @@ -use std::io::Cursor; - use polars_core::chunked_array::cast::CastOptions; use polars_core::series::IsSorted; use polars_core::utils::flatten::flatten_series; use polars_row::RowEncodingOptions; +use polars_utils::pl_serialize; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -389,40 +388,30 @@ impl PySeries { Wrap(result).into_pyobject(py) } - #[cfg(feature = "ipc_streaming")] fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> { // Used in pickle/pickling let mut buf: Vec<u8> = vec![]; - // IPC only support DataFrames so we need to convert it - let mut df = self.series.clone().into_frame(); - IpcStreamWriter::new(&mut buf) - .with_compat_level(CompatLevel::newest()) - .finish(&mut df) - .expect("ipc writer"); + + pl_serialize::SerializeOptions::default() + .with_compression(true) + .serialize_into_writer(&mut buf, &self.series) + .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; + Ok(PyBytes::new(py, &buf)) } - #[cfg(feature = "ipc_streaming")] fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { // Used in pickle/pickling use pyo3::pybacked::PyBackedBytes; match state.extract::<PyBackedBytes>(py) { Ok(s) => { - let c = Cursor::new(&s); - let reader = IpcStreamReader::new(c); - let mut df = reader.finish().map_err(PyPolarsErr::from)?; - - df.pop() - .map(|s| { - self.series = s.take_materialized_series(); - }) - .ok_or_else(|| { - PyPolarsErr::from(PolarsError::NoData( - "No columns found in IPC byte stream".into(), - )) - .into() - }) + let s: Series = pl_serialize::SerializeOptions::default() + .with_compression(true) + .deserialize_from_reader(&*s) + .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; + self.series = s; + Ok(()) }, Err(e) => Err(e), } diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index b21f2a5a04e6..2fb92f49e7f7 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -12,9 +12,11 @@ description = "Private utils for the Polars DataFrame library" polars-error = { workspace = true } ahash = { workspace = true } +bincode = { workspace = true, optional = true } bytemuck = { workspace = true } bytes = { workspace = true } compact_str = { workspace = true } +flate2 = { workspace = true, default-features = true, optional = true } hashbrown = { workspace = true } indexmap = { workspace = true } libc = { workspace = true } @@ -26,6 +28,7 @@ rand = { workspace = true } raw-cpuid = { workspace = true } rayon = { workspace = true } serde = { workspace = true, optional = true } +serde_json = { workspace = true, optional = true } stacker = { workspace = true } sysinfo = { version = "0.32", default-features = false, features = ["system"], optional = true } @@ -40,5 +43,5 @@ mmap = ["memmap"] bigidx = [] nightly = [] ir_serde = ["serde"] -serde = ["dep:serde", "serde/derive"] +serde = ["dep:serde", "serde/derive", "dep:bincode", "dep:flate2", "dep:serde_json"] python = ["pyo3"] diff --git a/crates/polars-utils/src/config.rs b/crates/polars-utils/src/config.rs new file mode 100644 index 000000000000..837d1126e323 --- /dev/null +++ b/crates/polars-utils/src/config.rs @@ -0,0 +1,3 @@ +pub(crate) fn verbose() -> bool { + std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1" +} diff --git a/crates/polars-utils/src/io.rs b/crates/polars-utils/src/io.rs index cce2eafe22a7..563f8c399c27 100644 --- a/crates/polars-utils/src/io.rs +++ b/crates/polars-utils/src/io.rs @@ -4,9 +4,7 @@ use std::path::Path; use polars_error::*; -fn verbose() -> bool { - std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1" -} +use crate::config::verbose; pub fn _limit_path_len_io_err(path: &Path, err: io::Error) -> PolarsError { let path = path.to_string_lossy(); diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 75d31caba5f8..ba6da732d76c 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -14,6 +14,7 @@ pub mod cardinality_sketch; pub mod cell; pub mod chunks; pub mod clmul; +mod config; pub mod cpuid; mod error; pub mod floor_divmod; @@ -57,3 +58,6 @@ pub use io::*; #[cfg(feature = "python")] pub mod python_function; + +#[cfg(feature = "serde")] +pub mod pl_serialize; diff --git a/crates/polars-utils/src/pl_serialize.rs b/crates/polars-utils/src/pl_serialize.rs new file mode 100644 index 000000000000..d0ab267daa3a --- /dev/null +++ b/crates/polars-utils/src/pl_serialize.rs @@ -0,0 +1,140 @@ +use polars_error::{to_compute_err, PolarsResult}; + +fn serialize_impl<W, T>(writer: W, value: &T) -> PolarsResult<()> +where + W: std::io::Write, + T: serde::ser::Serialize, +{ + bincode::serialize_into(writer, value).map_err(to_compute_err) +} + +pub fn deserialize_impl<T, R>(reader: R) -> PolarsResult<T> +where + T: serde::de::DeserializeOwned, + R: std::io::Read, +{ + bincode::deserialize_from(reader).map_err(to_compute_err) +} + +/// Mainly used to enable compression when serializing the final outer value. +/// For intermediate serialization steps, the function in the module should +/// be used instead. +pub struct SerializeOptions { + compression: bool, +} + +impl SerializeOptions { + pub fn with_compression(mut self, compression: bool) -> Self { + self.compression = compression; + self + } + + pub fn serialize_into_writer<W, T>(&self, writer: W, value: &T) -> PolarsResult<()> + where + W: std::io::Write, + T: serde::ser::Serialize, + { + if self.compression { + let writer = flate2::write::ZlibEncoder::new(writer, flate2::Compression::fast()); + serialize_impl(writer, value) + } else { + serialize_impl(writer, value) + } + } + + pub fn deserialize_from_reader<T, R>(&self, reader: R) -> PolarsResult<T> + where + T: serde::de::DeserializeOwned, + R: std::io::Read, + { + if self.compression { + deserialize_impl(flate2::read::ZlibDecoder::new(reader)) + } else { + deserialize_impl(reader) + } + } + + pub fn serialize_to_bytes<T>(&self, value: &T) -> PolarsResult<Vec<u8>> + where + T: serde::ser::Serialize, + { + let mut v = vec![]; + + self.serialize_into_writer(&mut v, value)?; + + Ok(v) + } +} + +#[allow(clippy::derivable_impls)] +impl Default for SerializeOptions { + fn default() -> Self { + Self { compression: false } + } +} + +pub fn serialize_into_writer<W, T>(writer: W, value: &T) -> PolarsResult<()> +where + W: std::io::Write, + T: serde::ser::Serialize, +{ + serialize_impl(writer, value) +} + +pub fn deserialize_from_reader<T, R>(reader: R) -> PolarsResult<T> +where + T: serde::de::DeserializeOwned, + R: std::io::Read, +{ + deserialize_impl(reader) +} + +pub fn serialize_to_bytes<T>(value: &T) -> PolarsResult<Vec<u8>> +where + T: serde::ser::Serialize, +{ + let mut v = vec![]; + + serialize_into_writer(&mut v, value)?; + + Ok(v) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_serde_skip_enum() { + #[derive(Default, Debug, PartialEq)] + struct MyType(Option<usize>); + + // Note: serde(skip) must be at the end of enums + #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] + enum Enum { + A, + #[serde(skip)] + B(MyType), + } + + impl Default for Enum { + fn default() -> Self { + Self::B(MyType(None)) + } + } + + let v = Enum::A; + let b = super::serialize_to_bytes(&v).unwrap(); + let r: Enum = super::deserialize_from_reader(b.as_slice()).unwrap(); + + assert_eq!(r, v); + + let v = Enum::A; + let b = super::SerializeOptions::default() + .serialize_to_bytes(&v) + .unwrap(); + let r: Enum = super::SerializeOptions::default() + .deserialize_from_reader(b.as_slice()) + .unwrap(); + + assert_eq!(r, v); + } +} diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs index a3f20e8717db..8fc69a774f7c 100644 --- a/crates/polars-utils/src/python_function.rs +++ b/crates/polars-utils/src/python_function.rs @@ -45,11 +45,11 @@ impl serde::Serialize for PythonFunction { S: serde::Serializer, { use serde::ser::Error; - serializer.serialize_bytes( - self.try_serialize_to_bytes() - .map_err(|e| S::Error::custom(e.to_string()))? - .as_slice(), - ) + let bytes = self + .try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))?; + + Vec::<u8>::serialize(&bytes, serializer) } } @@ -61,7 +61,9 @@ impl<'a> serde::Deserialize<'a> for PythonFunction { { use serde::de::Error; let bytes = Vec::<u8>::deserialize(deserializer)?; - Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string())) + let v = Self::try_deserialize_bytes(bytes.as_slice()) + .map_err(|e| D::Error::custom(e.to_string())); + v } } diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 28bab4dc8c0f..a30ba30c82be 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -64,7 +64,7 @@ default = [ ] ndarray = ["polars-core/ndarray"] # serde support for dataframes and series -serde = ["polars-core/serde", "polars-utils/serde"] +serde = ["polars-core/serde", "polars-utils/serde", "ir_serde"] serde-lazy = [ "polars-core/serde-lazy", "polars-lazy?/serde", diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index d44201de7b35..4de12cc95aa6 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -2636,7 +2636,7 @@ def serialize( ... ) >>> bytes = df.serialize() >>> bytes # doctest: +ELLIPSIS - b'\xa1gcolumns\x82\xa4dnamecfoohdatatypeeInt64lbit_settings\x00fvalues\x83...' + b'x\x01bb@\x80\x15...' The bytes can later be deserialized back into a DataFrame. diff --git a/py-polars/polars/expr/meta.py b/py-polars/polars/expr/meta.py index 554325ab8119..6c97082cbcbb 100644 --- a/py-polars/polars/expr/meta.py +++ b/py-polars/polars/expr/meta.py @@ -332,7 +332,7 @@ def serialize( >>> expr = pl.col("foo").sum().over("bar") >>> bytes = expr.meta.serialize() >>> bytes # doctest: +ELLIPSIS - b'\xa1fWindow\xa4hfunction\xa1cAgg\xa1cSum\xa1fColumncfoolpartition_by\x81...' + b'x\x01\x02L\x80\x81...' The bytes can later be deserialized back into an `Expr` object. diff --git a/py-polars/tests/unit/dataframe/test_serde.py b/py-polars/tests/unit/dataframe/test_serde.py index 71936c9eae81..64dd42255dfe 100644 --- a/py-polars/tests/unit/dataframe/test_serde.py +++ b/py-polars/tests/unit/dataframe/test_serde.py @@ -19,31 +19,24 @@ from polars._typing import SerializationFormat -@given( - df=dataframes( - excluded_dtypes=[pl.Struct], # Outer nullability not supported - ) -) def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None: serialized = df.serialize() result = pl.DataFrame.deserialize(io.BytesIO(serialized), format="binary") assert_frame_equal(result, df, categorical_as_str=True) -@given( - df=dataframes( - excluded_dtypes=[ - pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211 - pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211 - pl.Struct, # Outer nullability not supported - ], - ) -) +@given(df=dataframes()) @example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null})) @example(df=pl.DataFrame(schema={"a": pl.List(pl.String)})) def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None: serialized = df.serialize(format="json") result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json") + + if isinstance(dt := df.to_series(0).dtype, pl.Decimal): + if dt.precision is None: + # This gets converted to precision 38 upon `to_arrow()` + pytest.skip("precision None") + assert_frame_equal(result, df, categorical_as_str=True) @@ -64,9 +57,12 @@ def test_df_serde_json_stringio(df: pl.DataFrame) -> None: def test_df_serialize_json() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a") result = df.serialize(format="json") - expected = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"SORTED_ASC","values":[1,2,3]},{"name":"b","datatype":"Int64","bit_settings":"","values":[9,5,6]}]}' - print(result) - assert result == expected + + assert isinstance(result, str) + + f = io.StringIO(result) + + assert_frame_equal(pl.DataFrame.deserialize(f, format="json"), df) @pytest.mark.parametrize( @@ -193,7 +189,6 @@ def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> Non assert_frame_equal(result, df) -@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17211") def test_df_serde_float_inf_nan() -> None: df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]}) ser = df.serialize(format="json") @@ -201,34 +196,6 @@ def test_df_serde_float_inf_nan() -> None: assert_frame_equal(result, df) -def test_df_deserialize_validation() -> None: - f = io.StringIO( - """ - { - "columns": [ - { - "name": "a", - "datatype": "Int64", - "values": [ - 1, - 2 - ] - }, - { - "name": "b", - "datatype": "Int64", - "values": [ - 1 - ] - } - ] - } - """ - ) - with pytest.raises(ComputeError, match=r"lengths don't match"): - pl.DataFrame.deserialize(f, format="json") - - def test_df_serialize_invalid_type() -> None: df = pl.DataFrame({"a": [object()]}) with pytest.raises( diff --git a/py-polars/tests/unit/io/cloud/test_credential_provider.py b/py-polars/tests/unit/io/cloud/test_credential_provider.py index b88568cad3dc..5afbc343fdec 100644 --- a/py-polars/tests/unit/io/cloud/test_credential_provider.py +++ b/py-polars/tests/unit/io/cloud/test_credential_provider.py @@ -71,11 +71,14 @@ def __call__(self) -> pl.CredentialProviderFunctionReturn: def test_scan_credential_provider_serialization_pyversion() -> None: + import zlib + lf = pl.scan_parquet( "s3://bucket/path", credential_provider=pl.CredentialProviderAWS() ) serialized = lf.serialize() + serialized = zlib.decompress(serialized) serialized = bytearray(serialized) # We can't monkeypatch sys.python_version so we just mutate the output @@ -90,6 +93,8 @@ def test_scan_credential_provider_serialization_pyversion() -> None: serialized[i] = 255 serialized[i + 1] = 254 + serialized = zlib.compress(serialized) + with pytest.raises(ComputeError, match=r"python version.*(3, 255, 254).*differs.*"): lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index 8ddcbfafd6f6..a3ffb47e01c7 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -119,7 +119,7 @@ def test_lf_serde_scan(tmp_path: Path) -> None: @pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") -def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None: +def test_lf_serde_version_specific_lambda() -> None: lf = pl.LazyFrame({"a": [1, 2, 3]}).select( pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64) ) @@ -135,9 +135,7 @@ def custom_function(x: pl.Series) -> pl.Series: @pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") -def test_lf_serde_version_specific_named_function( - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_lf_serde_version_specific_named_function() -> None: lf = pl.LazyFrame({"a": [1, 2, 3]}).select( pl.col("a").map_batches(custom_function, return_dtype=pl.Int64) ) @@ -146,3 +144,13 @@ def test_lf_serde_version_specific_named_function( result = pl.LazyFrame.deserialize(io.BytesIO(ser)) expected = pl.LazyFrame({"a": [2, 3, 4]}) assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_map_batches_on_lazyframe() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).map_batches(lambda x: x + 1) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected)