From 60a237f2a759b2a0b84ed34f149374c560de1a51 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Mon, 16 Dec 2024 23:37:12 +1100 Subject: [PATCH] c --- crates/polars-arrow/src/datatypes/mod.rs | 8 +- crates/polars-plan/src/plans/functions/dsl.rs | 20 ++- crates/polars-plan/src/plans/functions/mod.rs | 42 ++--- crates/polars-plan/src/plans/ir/mod.rs | 4 +- crates/polars-plan/src/plans/mod.rs | 4 +- crates/polars-plan/src/plans/options.rs | 13 +- crates/polars-utils/src/pl_serialize.rs | 2 +- py-polars/tests/unit/lazyframe/test_serde.py | 156 +++++++++--------- 8 files changed, 132 insertions(+), 117 deletions(-) diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 71689b4094d6..d3bc5417a9d8 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -111,10 +111,6 @@ pub enum ArrowDataType { LargeList(Box), /// A nested [`ArrowDataType`] with a given number of [`Field`]s. Struct(Vec), - /// A nested datatype that can represent slots of differing types. - /// Third argument represents mode - #[cfg_attr(feature = "serde", serde(skip))] - Union(Vec, Option>, UnionMode), /// A nested type that is represented as /// /// List> @@ -176,6 +172,10 @@ pub enum ArrowDataType { Utf8View, /// A type unknown to Arrow. Unknown, + /// A nested datatype that can represent slots of differing types. + /// Third argument represents mode + #[cfg_attr(feature = "serde", serde(skip))] + Union(Vec, Option>, UnionMode), } /// Mode of [`ArrowDataType::Union`] diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index 28744998126e..9846bbc5fd15 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -21,14 +21,10 @@ pub struct OpaquePythonUdf { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[strum(serialize_all = "SCREAMING_SNAKE_CASE")] pub enum DslFunction { - // First enum variant must not be serde(skip) RowIndex { name: PlSmallStr, offset: Option, }, - // Function that is already converted to IR. - #[cfg_attr(feature = "serde", serde(skip))] - FunctionIR(FunctionIR), // This is both in DSL and IR because we want to be able to serialize it. #[cfg(feature = "python")] OpaquePython(OpaquePythonUdf), @@ -50,6 +46,22 @@ pub enum DslFunction { /// FillValue FillNan(Expr), Drop(DropFunction), + // Function that is already converted to IR. + #[cfg_attr(feature = "serde", serde(skip))] + FunctionIR(FunctionIR), +} + +#[cfg(test)] +mod tests { + + #[test] + fn test_serde() { + use polars_utils::pl_serialize; + + use crate::plans::{DslFunction, StatsFunction}; + let v = pl_serialize::serialize_to_bytes(&DslFunction::Stats(StatsFunction::Sum)).unwrap(); + let r: DslFunction = pl_serialize::deserialize_from_reader(v.as_slice()).unwrap(); + } } #[derive(Clone)] diff --git a/crates/polars-plan/src/plans/functions/mod.rs b/crates/polars-plan/src/plans/functions/mod.rs index 60a1fb16fe74..a357d5b269ad 100644 --- a/crates/polars-plan/src/plans/functions/mod.rs +++ b/crates/polars-plan/src/plans/functions/mod.rs @@ -33,37 +33,20 @@ use crate::prelude::*; pub enum FunctionIR { RowIndex { name: PlSmallStr, + offset: Option, // Might be cached. #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, - offset: Option, }, #[cfg(feature = "python")] OpaquePython(OpaquePythonUdf), - #[cfg_attr(feature = "ir_serde", serde(skip))] - Opaque { - function: Arc, - schema: Option>, - /// allow predicate pushdown optimizations - predicate_pd: bool, - /// allow projection pushdown optimizations - projection_pd: bool, - streamable: bool, - // used for formatting - fmt_str: PlSmallStr, - }, + FastCount { sources: ScanSources, scan_type: FileScan, alias: Option, }, - /// Streaming engine pipeline - #[cfg_attr(feature = "ir_serde", serde(skip))] - Pipeline { - function: Arc>, - schema: SchemaRef, - original: Option>, - }, + Unnest { columns: Arc<[PlSmallStr]>, }, @@ -96,6 +79,25 @@ pub enum FunctionIR { #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, }, + #[cfg_attr(feature = "ir_serde", serde(skip))] + Opaque { + function: Arc, + schema: Option>, + /// allow predicate pushdown optimizations + predicate_pd: bool, + /// allow projection pushdown optimizations + projection_pd: bool, + streamable: bool, + // used for formatting + fmt_str: PlSmallStr, + }, + /// Streaming engine pipeline + #[cfg_attr(feature = "ir_serde", serde(skip))] + Pipeline { + function: Arc>, + schema: SchemaRef, + original: Option>, + }, } impl Eq for FunctionIR {} diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index 800e28dbea1a..6e44abe7fe7c 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -112,10 +112,10 @@ pub enum IR { keys: Vec, aggs: Vec, schema: SchemaRef, - #[cfg_attr(feature = "ir_serde", serde(skip))] - apply: Option>, maintain_order: bool, options: Arc, + #[cfg_attr(feature = "ir_serde", serde(skip))] + apply: Option>, }, Join { input_left: Node, diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 0b276cf50a3c..efb01919a15a 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -97,10 +97,10 @@ pub enum DslPlan { input: Arc, keys: Vec, aggs: Vec, - #[cfg_attr(feature = "serde", serde(skip))] - apply: Option<(Arc, SchemaRef)>, maintain_order: bool, options: Arc, + #[cfg_attr(feature = "serde", serde(skip))] + apply: Option<(Arc, SchemaRef)>, }, /// Join operation Join { diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index 6eff59a06680..40eac804d299 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -212,9 +212,6 @@ pub struct FunctionOptions { /// Collect groups to a list and apply the function over the groups. /// This can be important in aggregation context. pub collect_groups: ApplyOptions, - // used for formatting, (only for anonymous functions) - #[cfg_attr(feature = "serde", serde(skip_deserializing))] - pub fmt_str: &'static str, /// Options used when deciding how to cast the arguments of the function. pub cast_options: FunctionCastOptions, @@ -223,6 +220,10 @@ pub struct FunctionOptions { // this should always be true or we could OOB pub check_lengths: UnsafeBool, pub flags: FunctionFlags, + + // used for formatting, (only for anonymous functions) + #[cfg_attr(feature = "serde", serde(skip_deserializing))] + pub fmt_str: &'static str, } impl FunctionOptions { @@ -281,11 +282,10 @@ pub struct PythonOptions { pub with_columns: Option>, // Which interface is the python function. pub python_source: PythonScanSource, - /// Optional predicate the reader must apply. - #[cfg_attr(feature = "serde", serde(skip))] - pub predicate: PythonPredicate, /// A `head` call passed to the reader. pub n_rows: Option, + /// Optional predicate the reader must apply. + pub predicate: PythonPredicate, } #[derive(Clone, PartialEq, Eq, Debug, Default)] @@ -298,6 +298,7 @@ pub enum PythonScanSource { } #[derive(Clone, PartialEq, Eq, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum PythonPredicate { // A pyarrow predicate python expression // can be evaluated with python.eval diff --git a/crates/polars-utils/src/pl_serialize.rs b/crates/polars-utils/src/pl_serialize.rs index 748baa407f18..4400ac87d47d 100644 --- a/crates/polars-utils/src/pl_serialize.rs +++ b/crates/polars-utils/src/pl_serialize.rs @@ -99,7 +99,7 @@ mod tests { #[derive(Default, Debug, PartialEq)] struct MyType(Option); - // Note: The first enum variant cannot be serde(skip). + // Note: serde(skip) must be at the end of enums #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] enum Enum { A, diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index 8ddcbfafd6f6..0f1eaf7e774d 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -68,81 +68,81 @@ def test_lf_serde(lf: pl.LazyFrame) -> None: assert_frame_equal(result, lf) -@pytest.mark.parametrize( - ("format", "buf"), - [ - ("binary", io.BytesIO()), - ("json", io.StringIO()), - ("json", io.BytesIO()), - ], -) -@pytest.mark.filterwarnings("ignore") -def test_lf_serde_to_from_buffer( - lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase -) -> None: - lf.serialize(buf, format=format) - buf.seek(0) - result = pl.LazyFrame.deserialize(buf, format=format) - assert_frame_equal(lf, result) - - -@pytest.mark.write_disk -def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: - tmp_path.mkdir(exist_ok=True) - - file_path = tmp_path / "small.bin" - lf.serialize(file_path) - result = pl.LazyFrame.deserialize(file_path) - - assert_frame_equal(lf, result) - - -def test_lf_deserialize_validation() -> None: - f = io.BytesIO(b"hello world!") - with pytest.raises(ComputeError, match="expected value at line 1 column 1"): - pl.LazyFrame.deserialize(f, format="json") - - -@pytest.mark.write_disk -def test_lf_serde_scan(tmp_path: Path) -> None: - tmp_path.mkdir(exist_ok=True) - path = tmp_path / "dataset.parquet" - - df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) - df.write_parquet(path) - lf = pl.scan_parquet(path) - - ser = lf.serialize() - result = pl.LazyFrame.deserialize(io.BytesIO(ser)) - assert_frame_equal(result, lf) - assert_frame_equal(result.collect(), df) - - -@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") -def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None: - lf = pl.LazyFrame({"a": [1, 2, 3]}).select( - pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64) - ) - ser = lf.serialize() - - result = pl.LazyFrame.deserialize(io.BytesIO(ser)) - expected = pl.LazyFrame({"a": [2, 3, 4]}) - assert_frame_equal(result, expected) - - -def custom_function(x: pl.Series) -> pl.Series: - return x + 1 - - -@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") -def test_lf_serde_version_specific_named_function( - monkeypatch: pytest.MonkeyPatch, -) -> None: - lf = pl.LazyFrame({"a": [1, 2, 3]}).select( - pl.col("a").map_batches(custom_function, return_dtype=pl.Int64) - ) - ser = lf.serialize() - - result = pl.LazyFrame.deserialize(io.BytesIO(ser)) - expected = pl.LazyFrame({"a": [2, 3, 4]}) - assert_frame_equal(result, expected) +# @pytest.mark.parametrize( +# ("format", "buf"), +# [ +# ("binary", io.BytesIO()), +# ("json", io.StringIO()), +# ("json", io.BytesIO()), +# ], +# ) +# @pytest.mark.filterwarnings("ignore") +# def test_lf_serde_to_from_buffer( +# lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase +# ) -> None: +# lf.serialize(buf, format=format) +# buf.seek(0) +# result = pl.LazyFrame.deserialize(buf, format=format) +# assert_frame_equal(lf, result) + + +# @pytest.mark.write_disk +# def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: +# tmp_path.mkdir(exist_ok=True) + +# file_path = tmp_path / "small.bin" +# lf.serialize(file_path) +# result = pl.LazyFrame.deserialize(file_path) + +# assert_frame_equal(lf, result) + + +# def test_lf_deserialize_validation() -> None: +# f = io.BytesIO(b"hello world!") +# with pytest.raises(ComputeError, match="expected value at line 1 column 1"): +# pl.LazyFrame.deserialize(f, format="json") + + +# @pytest.mark.write_disk +# def test_lf_serde_scan(tmp_path: Path) -> None: +# tmp_path.mkdir(exist_ok=True) +# path = tmp_path / "dataset.parquet" + +# df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) +# df.write_parquet(path) +# lf = pl.scan_parquet(path) + +# ser = lf.serialize() +# result = pl.LazyFrame.deserialize(io.BytesIO(ser)) +# assert_frame_equal(result, lf) +# assert_frame_equal(result.collect(), df) + + +# @pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +# def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None: +# lf = pl.LazyFrame({"a": [1, 2, 3]}).select( +# pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64) +# ) +# ser = lf.serialize() + +# result = pl.LazyFrame.deserialize(io.BytesIO(ser)) +# expected = pl.LazyFrame({"a": [2, 3, 4]}) +# assert_frame_equal(result, expected) + + +# def custom_function(x: pl.Series) -> pl.Series: +# return x + 1 + + +# @pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +# def test_lf_serde_version_specific_named_function( +# monkeypatch: pytest.MonkeyPatch, +# ) -> None: +# lf = pl.LazyFrame({"a": [1, 2, 3]}).select( +# pl.col("a").map_batches(custom_function, return_dtype=pl.Int64) +# ) +# ser = lf.serialize() + +# result = pl.LazyFrame.deserialize(io.BytesIO(ser)) +# expected = pl.LazyFrame({"a": [2, 3, 4]}) +# assert_frame_equal(result, expected)