Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Always resolve dynamic types in schema #20406

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,30 +253,39 @@ impl DataType {

/// Materialize this datatype if it is unknown. All other datatypes
/// are left unchanged.
pub fn materialize_unknown(&self) -> PolarsResult<DataType> {
pub fn materialize_unknown(self, allow_unknown: bool) -> PolarsResult<DataType> {
match self {
DataType::Unknown(u) => u
.materialize()
.ok_or_else(|| polars_err!(SchemaMismatch: "failed to materialize unknown type")),
DataType::List(inner) => Ok(DataType::List(Box::new(inner.materialize_unknown()?))),
DataType::Unknown(u) => match u.materialize() {
Some(known) => Ok(known),
None => {
if allow_unknown {
Ok(DataType::Unknown(u))
} else {
polars_bail!(SchemaMismatch: "failed to materialize unknown type")
}
},
},
DataType::List(inner) => Ok(DataType::List(Box::new(
inner.materialize_unknown(allow_unknown)?,
))),
#[cfg(feature = "dtype-array")]
DataType::Array(inner, size) => Ok(DataType::Array(
Box::new(inner.materialize_unknown()?),
*size,
Box::new(inner.materialize_unknown(allow_unknown)?),
size,
)),
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => Ok(DataType::Struct(
fields
.iter()
.into_iter()
.map(|f| {
PolarsResult::Ok(Field::new(
f.name().clone(),
f.dtype().materialize_unknown()?,
f.name,
f.dtype.materialize_unknown(allow_unknown)?,
))
})
.try_collect_vec()?,
)),
_ => Ok(self.clone()),
_ => Ok(self),
}
}

Expand Down
9 changes: 0 additions & 9 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ pub trait SchemaExt {
fn iter_fields(&self) -> impl ExactSizeIterator<Item = Field> + '_;

fn to_supertype(&mut self, other: &Schema) -> PolarsResult<bool>;

fn materialize_unknown_dtypes(&self) -> PolarsResult<Schema>;
}

impl SchemaExt for Schema {
Expand Down Expand Up @@ -90,13 +88,6 @@ impl SchemaExt for Schema {
}
Ok(changed)
}

/// Materialize all unknown dtypes in this schema.
fn materialize_unknown_dtypes(&self) -> PolarsResult<Schema> {
self.iter()
.map(|(name, dtype)| Ok((name.clone(), dtype.materialize_unknown()?)))
.collect()
}
}

pub trait SchemaNamesAndDtypes {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/reduce/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn into_reduction(
expr_arena
.get(node)
.to_dtype(schema, Context::Default, expr_arena)?
.materialize_unknown()
.materialize_unknown(false)
};
let out = match expr_arena.get(node) {
AExpr::Agg(agg) => match agg {
Expand Down
13 changes: 12 additions & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,17 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
convert_utils::convert_st_union(&mut inputs, ctxt.lp_arena, ctxt.expr_arena)
.map_err(|e| e.context(failed_here!(vertical concat)))?;
}

let first = *inputs.first().ok_or_else(
|| polars_err!(InvalidOperation: "expected at least one input in 'union'/'concat'"),
)?;
let schema = ctxt.lp_arena.get(first).schema(ctxt.lp_arena);
for n in &inputs[1..] {
let schema_i = ctxt.lp_arena.get(*n).schema(ctxt.lp_arena);
polars_ensure!(schema == schema_i, InvalidOperation: "'union'/'concat' inputs should all have the same schema,\
got\n{:?} and \n{:?}", schema, schema_i)
}

let options = args.into();
IR::Union { inputs, options }
},
Expand Down Expand Up @@ -976,7 +987,7 @@ fn resolve_with_columns(
);
polars_bail!(ComputeError: msg)
}
new_schema.with_column(field.name().clone(), field.dtype().clone());
new_schema.with_column(field.name, field.dtype.materialize_unknown(true)?);
arena.clear();
}

Expand Down
8 changes: 7 additions & 1 deletion crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ pub fn expressions_to_schema(
) -> PolarsResult<Schema> {
let mut expr_arena = Arena::with_capacity(4 * expr.len());
expr.iter()
.map(|expr| expr.to_field_amortized(schema, ctxt, &mut expr_arena))
.map(|expr| {
let mut field = expr.to_field_amortized(schema, ctxt, &mut expr_arena)?;

field.dtype = field.dtype.materialize_unknown(true)?;
Ok(field)
})
.collect()
}

Expand Down Expand Up @@ -336,6 +341,7 @@ pub(crate) fn expr_irs_to_schema<I: IntoIterator<Item = K>, K: AsRef<ExprIR>>(
if let Some(name) = e.get_alias() {
field.name = name.clone()
}
field.dtype = field.dtype.materialize_unknown(true).unwrap();
field
})
.collect()
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,11 @@ pub fn compute_output_schema(
.iter()
.map(|e| {
let name = e.output_name().clone();
let dtype = e.dtype(input_schema, Context::Default, expr_arena)?.clone();
let dtype = e
.dtype(input_schema, Context::Default, expr_arena)?
.clone()
.materialize_unknown(true)
.unwrap();
PolarsResult::Ok(Field::new(name, dtype))
})
.try_collect()?;
Expand Down
17 changes: 7 additions & 10 deletions crates/polars-stream/src/physical_plan/to_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use parking_lot::Mutex;
use polars_core::prelude::PlRandomState;
use polars_core::schema::{Schema, SchemaExt};
use polars_core::schema::Schema;
use polars_error::PolarsResult;
use polars_expr::groups::new_hash_grouper;
use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState};
Expand Down Expand Up @@ -416,9 +416,8 @@ fn to_graph_rec<'a>(
let input_key = to_graph_rec(*input, ctx)?;

let input_schema = &ctx.phys_sm[*input].output_schema;
let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
let grouper = new_hash_grouper(Arc::new(key_schema));
let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)?;
let grouper = new_hash_grouper(key_schema);

let key_selectors = key
.iter()
Expand Down Expand Up @@ -521,11 +520,9 @@ fn to_graph_rec<'a>(
let right_input_schema = ctx.phys_sm[*input_right].output_schema.clone();

let left_key_schema =
compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
compute_output_schema(&left_input_schema, left_on, ctx.expr_arena)?;
let right_key_schema =
compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)?
.materialize_unknown_dtypes()?;
compute_output_schema(&right_input_schema, right_on, ctx.expr_arena)?;

let left_key_selectors = left_on
.iter()
Expand All @@ -540,8 +537,8 @@ fn to_graph_rec<'a>(
nodes::joins::equi_join::EquiJoinNode::new(
left_input_schema,
right_input_schema,
Arc::new(left_key_schema),
Arc::new(right_key_schema),
left_key_schema,
right_key_schema,
left_key_selectors,
right_key_selectors,
args,
Expand Down
34 changes: 0 additions & 34 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,34 +706,6 @@ def test_multiple_columns_drop() -> None:
assert out.columns == ["a"]


def test_concat() -> None:
df1 = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
df2 = pl.concat([df1, df1], rechunk=True)

assert df2.shape == (6, 3)
assert df2.n_chunks() == 1
assert df2.rows() == df1.rows() + df1.rows()
assert pl.concat([df1, df1], rechunk=False).n_chunks() == 2

# concat from generator of frames
df3 = pl.concat(items=(df1 for _ in range(2)))
assert_frame_equal(df2, df3)

# check that df4 is not modified following concat of itself
df4 = pl.from_records(((1, 2), (1, 2)))
_ = pl.concat([df4, df4, df4])

assert df4.shape == (2, 2)
assert df4.rows() == [(1, 1), (2, 2)]

# misc error conditions
with pytest.raises(ValueError):
_ = pl.concat([])

with pytest.raises(ValueError):
pl.concat([df1, df1], how="rubbish") # type: ignore[arg-type]


def test_arg_where() -> None:
s = pl.Series([True, False, True, False])
assert_series_equal(
Expand Down Expand Up @@ -2262,12 +2234,6 @@ def test_list_of_list_of_struct() -> None:
assert df.to_dicts() == [] # type: ignore[union-attr]


def test_concat_to_empty() -> None:
assert pl.concat([pl.DataFrame([]), pl.DataFrame({"a": [1]})]).to_dict(
as_series=False
) == {"a": [1]}


def test_fill_null_limits() -> None:
assert pl.DataFrame(
{
Expand Down
45 changes: 1 addition & 44 deletions py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, time, timezone
from decimal import Decimal
from itertools import chain
from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

import fsspec
import numpy as np
Expand Down Expand Up @@ -1896,49 +1896,6 @@ def test_row_index_projection_pushdown_18463(
)


def test_concat_multiple_inmem() -> None:
f = io.BytesIO()
g = io.BytesIO()

df1 = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["xyz", "abc", "wow"],
}
)
df2 = pl.DataFrame(
{
"a": [5, 6, 7],
"b": ["a", "few", "entries"],
}
)

dfs = pl.concat([df1, df2])

df1.write_parquet(f)
df2.write_parquet(g)

f.seek(0)
g.seek(0)

items: list[IO[bytes]] = [f, g]
assert_frame_equal(pl.read_parquet(items), dfs)

f.seek(0)
g.seek(0)

assert_frame_equal(pl.read_parquet(items, use_pyarrow=True), dfs)

f.seek(0)
g.seek(0)

fb = f.read()
gb = g.read()

assert_frame_equal(pl.read_parquet([fb, gb]), dfs)
assert_frame_equal(pl.read_parquet([fb, gb], use_pyarrow=True), dfs)


@pytest.mark.write_disk
def test_write_binary_open_file(tmp_path: Path) -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
Expand Down
99 changes: 99 additions & 0 deletions py-polars/tests/unit/operations/test_concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import io
from typing import IO

import pytest

import polars as pl
from polars.testing import assert_frame_equal


def test_concat_invalid_schema_err_20355() -> None:
lf1 = pl.LazyFrame({"x": [1], "y": [None]})
lf2 = pl.LazyFrame({"y": [1]})
with pytest.raises(pl.exceptions.InvalidOperationError):
pl.concat([lf1, lf2]).collect(streaming=True)


def test_concat_df() -> None:
df1 = pl.DataFrame({"a": [2, 1, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
df2 = pl.concat([df1, df1], rechunk=True)

assert df2.shape == (6, 3)
assert df2.n_chunks() == 1
assert df2.rows() == df1.rows() + df1.rows()
assert pl.concat([df1, df1], rechunk=False).n_chunks() == 2

# concat from generator of frames
df3 = pl.concat(items=(df1 for _ in range(2)))
assert_frame_equal(df2, df3)

# check that df4 is not modified following concat of itself
df4 = pl.from_records(((1, 2), (1, 2)))
_ = pl.concat([df4, df4, df4])

assert df4.shape == (2, 2)
assert df4.rows() == [(1, 1), (2, 2)]

# misc error conditions
with pytest.raises(ValueError):
_ = pl.concat([])

with pytest.raises(ValueError):
pl.concat([df1, df1], how="rubbish") # type: ignore[arg-type]


def test_concat_to_empty() -> None:
assert pl.concat([pl.DataFrame([]), pl.DataFrame({"a": [1]})]).to_dict(
as_series=False
) == {"a": [1]}


def test_concat_multiple_parquet_inmem() -> None:
f = io.BytesIO()
g = io.BytesIO()

df1 = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["xyz", "abc", "wow"],
}
)
df2 = pl.DataFrame(
{
"a": [5, 6, 7],
"b": ["a", "few", "entries"],
}
)

dfs = pl.concat([df1, df2])

df1.write_parquet(f)
df2.write_parquet(g)

f.seek(0)
g.seek(0)

items: list[IO[bytes]] = [f, g]
assert_frame_equal(pl.read_parquet(items), dfs)

f.seek(0)
g.seek(0)

assert_frame_equal(pl.read_parquet(items, use_pyarrow=True), dfs)

f.seek(0)
g.seek(0)

fb = f.read()
gb = g.read()

assert_frame_equal(pl.read_parquet([fb, gb]), dfs)
assert_frame_equal(pl.read_parquet([fb, gb], use_pyarrow=True), dfs)


def test_concat_series() -> None:
s = pl.Series("a", [2, 1, 3])

assert pl.concat([s, s]).len() == 6
# check if s remains unchanged
assert s.len() == 3
Loading
Loading