Skip to content

Commit

Permalink
Add mean_horizontal for temporals
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Jan 2, 2025
1 parent 1ebd039 commit e8621c0
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 41 deletions.
180 changes: 147 additions & 33 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;

use arrow::temporal_conversions::MILLISECONDS_IN_DAY;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::series::arithmetic::coerce_lhs_rhs;
Expand Down Expand Up @@ -190,10 +191,44 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult<Option<Column>> {
}
}

fn null_with_supertype(
columns: Vec<&Series>,
date_to_datetime: bool,
) -> PolarsResult<Option<Column>> {
// We must first determine the correct return dtype.
let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?;
if return_dtype == DataType::Boolean {
return_dtype = IDX_DTYPE;
} else if date_to_datetime && return_dtype == DataType::Date {
return_dtype = DataType::Datetime(TimeUnit::Milliseconds, None);
}
Ok(Some(Column::full_null(
columns[0].name().clone(),
columns[0].len(),
&return_dtype,
)))
}

fn null_with_supertype_from_series(
columns: &[Column],
date_to_datetime: bool,
) -> PolarsResult<Option<Column>> {
null_with_supertype(
columns
.iter()
.map(|c| c.as_materialized_series())
.collect::<Vec<_>>(),
date_to_datetime,
)
}

pub fn sum_horizontal(
columns: &[Column],
null_strategy: NullStrategy,
) -> PolarsResult<Option<Column>> {
if columns.is_empty() {
return Ok(None);
}
validate_column_lengths(columns)?;
let ignore_nulls = null_strategy == NullStrategy::Ignore;

Expand All @@ -220,17 +255,12 @@ pub fn sum_horizontal(
.collect::<Vec<_>>();

// If we have any null columns and null strategy is not `Ignore`, we can return immediately.
if !ignore_nulls && non_null_cols.len() < columns.len() {
// We must first determine the correct return dtype.
let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? {
DataType::Boolean => DataType::UInt32,
dt => dt,
};
return Ok(Some(Column::full_null(
columns[0].name().clone(),
columns[0].len(),
&return_dtype,
)));
let num_cols = non_null_cols.len();
let name = columns[0].name();
if num_cols == 0 {
return Ok(Some(columns[0].clone().with_name(name.clone())));
} else if !ignore_nulls && non_null_cols.len() < columns.len() {
return null_with_supertype(non_null_cols, false);
}

match non_null_cols.len() {
Expand All @@ -239,15 +269,16 @@ pub fn sum_horizontal(
Ok(None)
} else {
// all columns are null dtype, so result is null dtype
Ok(Some(columns[0].clone()))
Ok(Some(columns[0].clone().with_name(name.clone())))
}
},
1 => Ok(Some(
apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean {
non_null_cols[0].cast(&DataType::UInt32)?
non_null_cols[0].cast(&IDX_DTYPE)?
} else {
non_null_cols[0].clone()
})?
.with_name(name.clone())
.into(),
)),
2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone())
Expand All @@ -274,33 +305,98 @@ pub fn mean_horizontal(
columns: &[Column],
null_strategy: NullStrategy,
) -> PolarsResult<Option<Column>> {
if columns.is_empty() {
return Ok(None);
}
let ignore_nulls = null_strategy == NullStrategy::Ignore;

validate_column_lengths(columns)?;
let name = columns[0].name().clone();

let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| {
let dtype = s.dtype();
dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
});
let Some(first_non_null_idx) = columns.iter().position(|c| !c.dtype().is_null()) else {
// All columns are null; return f64 nulls.
return Ok(Some(
Float64Chunked::full_null(name, columns[0].len()).into_column(),
));
};

if !non_numeric_columns.is_empty() {
let col = non_numeric_columns.first().cloned();
if first_non_null_idx > 0 && !ignore_nulls {
return null_with_supertype_from_series(columns, true);
}

// Ensure column dtypes are all valid
let first_dtype = columns[first_non_null_idx].dtype();
let is_temporal = first_dtype.is_temporal();
let columns = if is_temporal {
// All remaining dtypes must be the same temporal dtype (or null).
for col in &columns[first_non_null_idx + 1..] {
let dtype = col.dtype();
if !ignore_nulls && dtype == &DataType::Null {
// A null column guarantees null output.
return null_with_supertype_from_series(columns, true);
} else if dtype != first_dtype && dtype != &DataType::Null {
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
columns[first_non_null_idx].name(),
first_dtype,
dtype,
col.name(),
);
};
}
columns[first_non_null_idx..]
.iter()
.map(|c| c.cast(&DataType::Int64).unwrap())
.collect::<Vec<_>>()
} else if first_dtype.is_numeric()
|| first_dtype.is_decimal()
|| first_dtype.is_bool()
|| first_dtype.is_null()
|| first_dtype.is_temporal()
{
// All remaining must be numeric (or null).
for col in &columns[first_non_null_idx + 1..] {
let dtype = col.dtype();
if !(dtype.is_numeric()
|| dtype.is_decimal()
|| dtype.is_bool()
|| dtype.is_temporal()
|| dtype.is_null())
{
polars_bail!(
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})",
columns[first_non_null_idx].name(),
first_dtype,
col.name(),
dtype,
);
}
}
columns[first_non_null_idx..].to_vec()
} else {
polars_bail!(
InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
col.unwrap().name(),
col.unwrap().dtype(),
InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={})",
columns[first_non_null_idx].name(),
first_dtype,
);
}
let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
};

let num_rows = columns.len();
match num_rows {
0 => Ok(None),
1 => Ok(Some(match columns[0].dtype() {
dt if dt != &DataType::Float32 && !dt.is_decimal() => {
columns[0].cast(&DataType::Float64)?
1 => Ok(Some(match first_dtype {
dt if dt != &DataType::Float32 && !is_temporal && !dt.is_decimal() => {
columns[0].cast(&DataType::Float64)?.with_name(name)
},
_ => match first_dtype {
DataType::Date => (&columns[0] * MILLISECONDS_IN_DAY)
.with_name(name)
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
dt if is_temporal => columns[0].cast(dt)?.with_name(name),
_ => columns[0].clone().with_name(name),
},
_ => columns[0].clone(),
})),
_ => {
let sum = || sum_horizontal(columns.as_slice(), null_strategy);
let sum = || sum_horizontal(&columns, null_strategy);
let null_count = || {
columns
.par_iter()
Expand All @@ -321,7 +417,7 @@ pub fn mean_horizontal(
};

let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count));
let sum = sum?;
let sum = sum?.map(|c| c.with_name(name));
let null_count = null_count?;

// value lengths: len - null_count
Expand Down Expand Up @@ -349,8 +445,26 @@ pub fn mean_horizontal(
.into_column()
.cast(dt)?;

sum.map(|sum| std::ops::Div::div(&sum, &value_length))
.transpose()
let out = sum.map(|sum| std::ops::Div::div(&sum, &value_length));

let x = out.map(|opt| {
opt.and_then(|value| {
if is_temporal {
if first_dtype == &DataType::Date {
// Cast to DateTime(us)
(value * MILLISECONDS_IN_DAY)
.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))
} else {
// Cast to original
value.cast(first_dtype)
}
} else {
Ok(value)
}
})
});

x.transpose()
},
}
}
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,17 +331,19 @@ impl FunctionExpr {
MinHorizontal => mapper.map_to_supertype(),
SumHorizontal { .. } => {
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
// Booleans sum to UInt32.
DataType::Boolean => { f.dtype = DataType::UInt32; f},
_ => f,
if f.dtype == DataType::Boolean {
f.dtype = IDX_DTYPE;
}
f
})
},
MeanHorizontal { .. } => {
mapper.map_to_supertype().map(|mut f| {
match f.dtype {
dt @ DataType::Float32 => { f.dtype = dt; },
DataType::Boolean => { f.dtype = DataType::Float64; },
DataType::Float32 => { f.dtype = DataType::Float32; },
DataType::Date => { f.dtype = DataType::Datetime(TimeUnit::Milliseconds, None); }
dt if dt.is_temporal() => { f.dtype = dt; }
_ => { f.dtype = DataType::Float64; },
};
f
Expand Down
Loading

0 comments on commit e8621c0

Please sign in to comment.