From 450aee8b7d13e92d4de084c225011fd712393a86 Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 2 Jan 2025 09:03:26 +0100 Subject: [PATCH] fix: Fix union --- crates/polars-core/src/datatypes/mod.rs | 2 ++ crates/polars-core/src/datatypes/schema.rs | 21 +++++++++++++++++++ crates/polars-core/src/lib.rs | 1 + .../src/plans/conversion/dsl_to_ir.rs | 4 +++- .../tests/unit/operations/test_concat.py | 10 +++++++++ 5 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 crates/polars-core/src/datatypes/schema.rs diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index ed2f810f05bd..f9e4b71e5602 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -22,6 +22,7 @@ use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign}; +mod schema; pub use aliases::*; pub use any_value::*; pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray}; @@ -42,6 +43,7 @@ use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; +pub use schema::SchemaExtPl; #[cfg(feature = "serde")] use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] diff --git a/crates/polars-core/src/datatypes/schema.rs b/crates/polars-core/src/datatypes/schema.rs new file mode 100644 index 000000000000..05a0fc3a3657 --- /dev/null +++ b/crates/polars-core/src/datatypes/schema.rs @@ -0,0 +1,21 @@ +use super::*; + +pub trait SchemaExtPl { + // Answers if this schema matches the given schema. + // + // Allows (nested) Null types in this schema to match any type in the schema, + // but not vice versa. In such a case Ok(true) is returned, because a cast + // is necessary. If no cast is necessary Ok(false) is returned, and an + // error is returned if the types are incompatible. + fn matches_schema(&self, other: &Schema) -> PolarsResult; +} + +impl SchemaExtPl for Schema { + fn matches_schema(&self, other: &Schema) -> PolarsResult { + let mut cast = false; + for (a, b) in self.iter_values().zip(other.iter_values()) { + cast |= a.matches_schema_type(b)?; + } + Ok(cast) + } +} diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index b81a65674eaa..25377fbfe62e 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -31,6 +31,7 @@ mod tests; use std::sync::Mutex; use std::time::{SystemTime, UNIX_EPOCH}; +pub use datatypes::SchemaExtPl; pub use hashing::IdBuildHasher; use once_cell::sync::Lazy; use rayon::{ThreadPool, ThreadPoolBuilder}; diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 60512a9b0703..3474c8079079 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -384,8 +384,10 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult let schema = ctxt.lp_arena.get(first).schema(ctxt.lp_arena); for n in &inputs[1..] { let schema_i = ctxt.lp_arena.get(*n).schema(ctxt.lp_arena); - polars_ensure!(schema == schema_i, InvalidOperation: "'union'/'concat' inputs should all have the same schema,\ + // The first argument + schema_i.matches_schema(schema.as_ref()).map_err(|_| polars_err!(InvalidOperation: "'union'/'concat' inputs should all have the same schema,\ got\n{:?} and \n{:?}", schema, schema_i) + )?; } let options = args.into(); diff --git a/py-polars/tests/unit/operations/test_concat.py b/py-polars/tests/unit/operations/test_concat.py index 6c964764c181..a2664df1b000 100644 --- a/py-polars/tests/unit/operations/test_concat.py +++ b/py-polars/tests/unit/operations/test_concat.py @@ -97,3 +97,13 @@ def test_concat_series() -> None: assert pl.concat([s, s]).len() == 6 # check if s remains unchanged assert s.len() == 3 + + +def test_concat_null_20501() -> None: + a = pl.DataFrame({"id": [1], "value": ["foo"]}) + b = pl.DataFrame({"id": [2], "value": [None]}) + + assert pl.concat([a.lazy(), b.lazy()]).collect().to_dict(as_series=False) == { + "id": [1, 2], + "value": ["foo", None], + }