Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Dec 16, 2024
1 parent d8c4d13 commit 60a237f
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 117 deletions.
8 changes: 4 additions & 4 deletions crates/polars-arrow/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ pub enum ArrowDataType {
LargeList(Box<Field>),
/// A nested [`ArrowDataType`] with a given number of [`Field`]s.
Struct(Vec<Field>),
/// A nested datatype that can represent slots of differing types.
/// Third argument represents mode
#[cfg_attr(feature = "serde", serde(skip))]
Union(Vec<Field>, Option<Vec<i32>>, UnionMode),
/// A nested type that is represented as
///
/// List<entries: Struct<key: K, value: V>>
Expand Down Expand Up @@ -176,6 +172,10 @@ pub enum ArrowDataType {
Utf8View,
/// A type unknown to Arrow.
Unknown,
/// A nested datatype that can represent slots of differing types.
/// Third argument represents mode
#[cfg_attr(feature = "serde", serde(skip))]
Union(Vec<Field>, Option<Vec<i32>>, UnionMode),
}

/// Mode of [`ArrowDataType::Union`]
Expand Down
20 changes: 16 additions & 4 deletions crates/polars-plan/src/plans/functions/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ pub struct OpaquePythonUdf {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
pub enum DslFunction {
// First enum variant must not be serde(skip)
RowIndex {
name: PlSmallStr,
offset: Option<IdxSize>,
},
// Function that is already converted to IR.
#[cfg_attr(feature = "serde", serde(skip))]
FunctionIR(FunctionIR),
// This is both in DSL and IR because we want to be able to serialize it.
#[cfg(feature = "python")]
OpaquePython(OpaquePythonUdf),
Expand All @@ -50,6 +46,22 @@ pub enum DslFunction {
/// FillValue
FillNan(Expr),
Drop(DropFunction),
// Function that is already converted to IR.
#[cfg_attr(feature = "serde", serde(skip))]
FunctionIR(FunctionIR),
}

#[cfg(test)]
mod tests {

#[test]
fn test_serde() {
use polars_utils::pl_serialize;

use crate::plans::{DslFunction, StatsFunction};
let v = pl_serialize::serialize_to_bytes(&DslFunction::Stats(StatsFunction::Sum)).unwrap();
let r: DslFunction = pl_serialize::deserialize_from_reader(v.as_slice()).unwrap();
}
}

#[derive(Clone)]
Expand Down
42 changes: 22 additions & 20 deletions crates/polars-plan/src/plans/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,20 @@ use crate::prelude::*;
pub enum FunctionIR {
RowIndex {
name: PlSmallStr,
offset: Option<IdxSize>,
// Might be cached.
#[cfg_attr(feature = "ir_serde", serde(skip))]
schema: CachedSchema,
offset: Option<IdxSize>,
},
#[cfg(feature = "python")]
OpaquePython(OpaquePythonUdf),
#[cfg_attr(feature = "ir_serde", serde(skip))]
Opaque {
function: Arc<dyn DataFrameUdf>,
schema: Option<Arc<dyn UdfSchema>>,
/// allow predicate pushdown optimizations
predicate_pd: bool,
/// allow projection pushdown optimizations
projection_pd: bool,
streamable: bool,
// used for formatting
fmt_str: PlSmallStr,
},

FastCount {
sources: ScanSources,
scan_type: FileScan,
alias: Option<PlSmallStr>,
},
/// Streaming engine pipeline
#[cfg_attr(feature = "ir_serde", serde(skip))]
Pipeline {
function: Arc<Mutex<dyn DataFrameUdfMut>>,
schema: SchemaRef,
original: Option<Arc<IRPlan>>,
},

Unnest {
columns: Arc<[PlSmallStr]>,
},
Expand Down Expand Up @@ -96,6 +79,25 @@ pub enum FunctionIR {
#[cfg_attr(feature = "ir_serde", serde(skip))]
schema: CachedSchema,
},
#[cfg_attr(feature = "ir_serde", serde(skip))]
Opaque {
function: Arc<dyn DataFrameUdf>,
schema: Option<Arc<dyn UdfSchema>>,
/// allow predicate pushdown optimizations
predicate_pd: bool,
/// allow projection pushdown optimizations
projection_pd: bool,
streamable: bool,
// used for formatting
fmt_str: PlSmallStr,
},
/// Streaming engine pipeline
#[cfg_attr(feature = "ir_serde", serde(skip))]
Pipeline {
function: Arc<Mutex<dyn DataFrameUdfMut>>,
schema: SchemaRef,
original: Option<Arc<IRPlan>>,
},
}

impl Eq for FunctionIR {}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ pub enum IR {
keys: Vec<ExprIR>,
aggs: Vec<ExprIR>,
schema: SchemaRef,
#[cfg_attr(feature = "ir_serde", serde(skip))]
apply: Option<Arc<dyn DataFrameUdf>>,
maintain_order: bool,
options: Arc<GroupbyOptions>,
#[cfg_attr(feature = "ir_serde", serde(skip))]
apply: Option<Arc<dyn DataFrameUdf>>,
},
Join {
input_left: Node,
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ pub enum DslPlan {
input: Arc<DslPlan>,
keys: Vec<Expr>,
aggs: Vec<Expr>,
#[cfg_attr(feature = "serde", serde(skip))]
apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
maintain_order: bool,
options: Arc<GroupbyOptions>,
#[cfg_attr(feature = "serde", serde(skip))]
apply: Option<(Arc<dyn DataFrameUdf>, SchemaRef)>,
},
/// Join operation
Join {
Expand Down
13 changes: 7 additions & 6 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,6 @@ pub struct FunctionOptions {
/// Collect groups to a list and apply the function over the groups.
/// This can be important in aggregation context.
pub collect_groups: ApplyOptions,
// used for formatting, (only for anonymous functions)
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
pub fmt_str: &'static str,

/// Options used when deciding how to cast the arguments of the function.
pub cast_options: FunctionCastOptions,
Expand All @@ -223,6 +220,10 @@ pub struct FunctionOptions {
// this should always be true or we could OOB
pub check_lengths: UnsafeBool,
pub flags: FunctionFlags,

// used for formatting, (only for anonymous functions)
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
pub fmt_str: &'static str,
}

impl FunctionOptions {
Expand Down Expand Up @@ -281,11 +282,10 @@ pub struct PythonOptions {
pub with_columns: Option<Arc<[PlSmallStr]>>,
// Which interface is the python function.
pub python_source: PythonScanSource,
/// Optional predicate the reader must apply.
#[cfg_attr(feature = "serde", serde(skip))]
pub predicate: PythonPredicate,
/// A `head` call passed to the reader.
pub n_rows: Option<usize>,
/// Optional predicate the reader must apply.
pub predicate: PythonPredicate,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
Expand All @@ -298,6 +298,7 @@ pub enum PythonScanSource {
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum PythonPredicate {
// A pyarrow predicate python expression
// can be evaluated with python.eval
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-utils/src/pl_serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ mod tests {
#[derive(Default, Debug, PartialEq)]
struct MyType(Option<usize>);

// Note: The first enum variant cannot be serde(skip).
// Note: serde(skip) must be at the end of enums
#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
enum Enum {
A,
Expand Down
156 changes: 78 additions & 78 deletions py-polars/tests/unit/lazyframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,81 +68,81 @@ def test_lf_serde(lf: pl.LazyFrame) -> None:
assert_frame_equal(result, lf)


@pytest.mark.parametrize(
("format", "buf"),
[
("binary", io.BytesIO()),
("json", io.StringIO()),
("json", io.BytesIO()),
],
)
@pytest.mark.filterwarnings("ignore")
def test_lf_serde_to_from_buffer(
lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase
) -> None:
lf.serialize(buf, format=format)
buf.seek(0)
result = pl.LazyFrame.deserialize(buf, format=format)
assert_frame_equal(lf, result)


@pytest.mark.write_disk
def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

file_path = tmp_path / "small.bin"
lf.serialize(file_path)
result = pl.LazyFrame.deserialize(file_path)

assert_frame_equal(lf, result)


def test_lf_deserialize_validation() -> None:
f = io.BytesIO(b"hello world!")
with pytest.raises(ComputeError, match="expected value at line 1 column 1"):
pl.LazyFrame.deserialize(f, format="json")


@pytest.mark.write_disk
def test_lf_serde_scan(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
path = tmp_path / "dataset.parquet"

df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
df.write_parquet(path)
lf = pl.scan_parquet(path)

ser = lf.serialize()
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
assert_frame_equal(result, lf)
assert_frame_equal(result.collect(), df)


@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)
)
ser = lf.serialize()

result = pl.LazyFrame.deserialize(io.BytesIO(ser))
expected = pl.LazyFrame({"a": [2, 3, 4]})
assert_frame_equal(result, expected)


def custom_function(x: pl.Series) -> pl.Series:
return x + 1


@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
def test_lf_serde_version_specific_named_function(
monkeypatch: pytest.MonkeyPatch,
) -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
pl.col("a").map_batches(custom_function, return_dtype=pl.Int64)
)
ser = lf.serialize()

result = pl.LazyFrame.deserialize(io.BytesIO(ser))
expected = pl.LazyFrame({"a": [2, 3, 4]})
assert_frame_equal(result, expected)
# @pytest.mark.parametrize(
# ("format", "buf"),
# [
# ("binary", io.BytesIO()),
# ("json", io.StringIO()),
# ("json", io.BytesIO()),
# ],
# )
# @pytest.mark.filterwarnings("ignore")
# def test_lf_serde_to_from_buffer(
# lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase
# ) -> None:
# lf.serialize(buf, format=format)
# buf.seek(0)
# result = pl.LazyFrame.deserialize(buf, format=format)
# assert_frame_equal(lf, result)


# @pytest.mark.write_disk
# def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None:
# tmp_path.mkdir(exist_ok=True)

# file_path = tmp_path / "small.bin"
# lf.serialize(file_path)
# result = pl.LazyFrame.deserialize(file_path)

# assert_frame_equal(lf, result)


# def test_lf_deserialize_validation() -> None:
# f = io.BytesIO(b"hello world!")
# with pytest.raises(ComputeError, match="expected value at line 1 column 1"):
# pl.LazyFrame.deserialize(f, format="json")


# @pytest.mark.write_disk
# def test_lf_serde_scan(tmp_path: Path) -> None:
# tmp_path.mkdir(exist_ok=True)
# path = tmp_path / "dataset.parquet"

# df = pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
# df.write_parquet(path)
# lf = pl.scan_parquet(path)

# ser = lf.serialize()
# result = pl.LazyFrame.deserialize(io.BytesIO(ser))
# assert_frame_equal(result, lf)
# assert_frame_equal(result.collect(), df)


# @pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
# def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None:
# lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
# pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)
# )
# ser = lf.serialize()

# result = pl.LazyFrame.deserialize(io.BytesIO(ser))
# expected = pl.LazyFrame({"a": [2, 3, 4]})
# assert_frame_equal(result, expected)


# def custom_function(x: pl.Series) -> pl.Series:
# return x + 1


# @pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
# def test_lf_serde_version_specific_named_function(
# monkeypatch: pytest.MonkeyPatch,
# ) -> None:
# lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
# pl.col("a").map_batches(custom_function, return_dtype=pl.Int64)
# )
# ser = lf.serialize()

# result = pl.LazyFrame.deserialize(io.BytesIO(ser))
# expected = pl.LazyFrame({"a": [2, 3, 4]})
# assert_frame_equal(result, expected)

0 comments on commit 60a237f

Please sign in to comment.