Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Eq, PartialEq, Hash for dyn PhysicalExpr #13005

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions datafusion/core/tests/sql/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ use bytes::Bytes;
use chrono::{TimeZone, Utc};
use datafusion_expr::{col, lit, Expr, Operator};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::{self, BoxStream};
use object_store::{
path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta,
Expand Down Expand Up @@ -97,7 +96,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> {
assert!(pred.as_any().is::<BinaryExpr>());
let pred = pred.as_any().downcast_ref::<BinaryExpr>().unwrap();

assert_eq!(pred, expected.as_any());
assert_eq!(pred, expected.as_ref());

Ok(())
}
Expand Down
81 changes: 34 additions & 47 deletions datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use datafusion_expr_common::sort_properties::ExprProperties;
/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash {
/// Returns the physical expression as [`Any`] so that it can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
Expand Down Expand Up @@ -141,38 +141,6 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
Ok(Some(vec![]))
}

/// Update the hash `state` with this expression requirements from
/// [`Hash`].
///
/// This method is required to support hashing [`PhysicalExpr`]s. To
/// implement it, typically the type implementing
/// [`PhysicalExpr`] implements [`Hash`] and
/// then the following boiler plate is used:
///
/// # Example:
/// ```
/// // User defined expression that derives Hash
/// #[derive(Hash, Debug, PartialEq, Eq)]
/// struct MyExpr {
/// val: u64
/// }
///
/// // impl PhysicalExpr {
/// // ...
/// # impl MyExpr {
/// // Boiler plate to call the derived Hash impl
/// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
/// use std::hash::Hash;
/// let mut s = state;
/// self.hash(&mut s);
/// }
/// // }
/// # }
/// ```
/// Note: [`PhysicalExpr`] is not constrained by [`Hash`]
/// directly because it must remain object safe.
fn dyn_hash(&self, _state: &mut dyn Hasher);

/// Calculates the properties of this [`PhysicalExpr`] based on its
/// children's properties (i.e. order and range), recursively aggregating
/// the information from its children. In cases where the [`PhysicalExpr`]
Expand All @@ -183,6 +151,39 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
}
}

pub trait DynEq {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add some documentation here explaining why this is needed and what it is used for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in f1917fa.

fn dyn_eq(&self, other: &dyn Any) -> bool;
}

impl<T: Eq + Any> DynEq for T {
fn dyn_eq(&self, other: &dyn Any) -> bool {
other
.downcast_ref::<Self>()
.map_or(false, |other| other == self)
}
}

impl PartialEq for dyn PhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other.as_any())
}
}

impl Eq for dyn PhysicalExpr {}

/// Note: [`PhysicalExpr`] is not constrained by [`Hash`] directly because it must remain
/// object safe.
pub trait DynHash {
fn dyn_hash(&self, _state: &mut dyn Hasher);
}

impl<T: Hash + Any> DynHash for T {
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
self.type_id().hash(&mut state);
self.hash(&mut state)
}
}

impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
Expand Down Expand Up @@ -210,20 +211,6 @@ pub fn with_new_children_if_necessary(
}
}

pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn PhysicalExpr>>() {
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn PhysicalExpr>>() {
any.downcast_ref::<Box<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else {
any
}
}

/// Returns [`Display`] able a list of [`PhysicalExpr`]
///
/// Example output: `[a + 1, b]`
Expand Down
4 changes: 0 additions & 4 deletions datafusion/physical-expr-common/src/sort_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,10 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {todo!() }
/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {todo!()}
/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> {todo!()}
/// # fn dyn_hash(&self, _state: &mut dyn Hasher) {todo!()}
/// # }
/// # impl Display for MyPhysicalExpr {
/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") }
/// # }
/// # impl PartialEq<dyn Any> for MyPhysicalExpr {
/// # fn eq(&self, _other: &dyn Any) -> bool { true }
/// # }
/// # fn col(name: &str) -> Arc<dyn PhysicalExpr> { Arc::new(MyPhysicalExpr) }
/// // Sort by a ASC
/// let options = SortOptions::default();
Expand Down
7 changes: 3 additions & 4 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ pub struct ConstExpr {

impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions
&& self.expr.eq(other.expr.as_any())
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
}
}

Expand Down Expand Up @@ -121,7 +120,7 @@ impl ConstExpr {

/// Returns true if this constant expression is equal to the given expression
pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
self.expr.eq(other.as_ref().as_any())
self.expr.as_ref() == other.as_ref()
}

/// Returns a [`Display`]able list of `ConstExpr`.
Expand Down Expand Up @@ -557,7 +556,7 @@ impl EquivalenceGroup {
new_classes.push((source, vec![Arc::clone(target)]));
}
if let Some((_, values)) =
new_classes.iter_mut().find(|(key, _)| key.eq(source))
new_classes.iter_mut().find(|(key, _)| *key == source)
{
if !physical_exprs_contains(values, target) {
values.push(Arc::clone(target));
Expand Down
30 changes: 5 additions & 25 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

mod kernels;

use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
Expand All @@ -48,11 +47,11 @@ use kernels::{
};

/// Binary expression
#[derive(Debug, Hash, Clone)]
pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
#[derive(Debug, Hash, Clone, Eq, PartialEq)]
pub struct BinaryExpr<DynPhysicalExpr: ?Sized = dyn PhysicalExpr> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the boiler plate in this PR is great ❤️

I feel like this additional trait / workaround is non ideal as it makes understanding the BinaryExpr very unobvious (there is now a level of indirection that is irrelevant for most of users of BinaryExpr).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am playing around with it to see if I can avoid it somehow

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is one way that seems to work: peter-toth#5

More verbose, but I think it keeps the structs simpler to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we don't need to keep the generic parameter workaround. I agree that those params look a bit weird.
peter-toth#5 looks good to me, if you extend it to other expressions I can merge it tomorrow and amend my PR description.

Copy link
Contributor

@alamb alamb Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github copilot and I banged out the code. However, I think I accidentally push the commits to your branch 😬 -- hopefully that is ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. 😁 I've just updated the PR description.

left: Arc<DynPhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
right: Arc<DynPhysicalExpr>,
/// Specifies whether an error is returned on overflow or not
fail_on_overflow: bool,
}
Expand Down Expand Up @@ -477,11 +476,6 @@ impl PhysicalExpr for BinaryExpr {
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}

/// For each operator, [`BinaryExpr`] has distinct rules.
/// TODO: There may be rules specific to some data types and expression ranges.
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
Expand Down Expand Up @@ -525,20 +519,6 @@ impl PhysicalExpr for BinaryExpr {
}
}

impl PartialEq<dyn Any> for BinaryExpr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this is great

fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.left.eq(&x.left)
&& self.op == x.op
&& self.right.eq(&x.right)
&& self.fail_on_overflow.eq(&x.fail_on_overflow)
})
.unwrap_or(false)
}
}

/// Casts dictionary array to result type for binary numerical operators. Such operators
/// between array and scalar produce a dictionary array other than primitive array of the
/// same operators between array and array. This leads to inconsistent result types causing
Expand Down
40 changes: 3 additions & 37 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
// under the License.

use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::expressions::try_cast;
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
Expand All @@ -37,7 +36,7 @@ use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);

#[derive(Debug, Hash)]
#[derive(Debug, Hash, PartialEq, Eq)]
enum EvalMethod {
/// CASE WHEN condition THEN result
/// [WHEN ...]
Expand Down Expand Up @@ -80,7 +79,7 @@ enum EvalMethod {
/// [WHEN ...]
/// [ELSE result]
/// END
#[derive(Debug, Hash)]
#[derive(Debug, Hash, PartialEq, Eq)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can it permit auto derivation of PartialEq for CaseExpr 🤔 -- we need manual impl's for others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is probably the strangest part of the bug. If you have double wrappers around dyn Trait then there is no issue: rust-lang/rust#78808 (comment)
CaseExpr doesn't have any dyn Trait fields with a single wrapper so the derive macro has no problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a ticket to remove those impl's once the bug is resolved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about #13196?

pub struct CaseExpr {
/// Optional base expression that can be compared to literal values in the "when" expressions
expr: Option<Arc<dyn PhysicalExpr>>,
Expand Down Expand Up @@ -506,39 +505,6 @@ impl PhysicalExpr for CaseExpr {
)?))
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
}

impl PartialEq<dyn Any> for CaseExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
let expr_eq = match (&self.expr, &x.expr) {
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
(None, None) => true,
_ => false,
};
let else_expr_eq = match (&self.else_expr, &x.else_expr) {
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
(None, None) => true,
_ => false,
};
expr_eq
&& else_expr_eq
&& self.when_then_expr.len() == x.when_then_expr.len()
&& self.when_then_expr.iter().zip(x.when_then_expr.iter()).all(
|((when1, then1), (when2, then2))| {
when1.eq(when2) && then1.eq(then2)
},
)
})
.unwrap_or(false)
}
}

/// Create a CASE expression
Expand Down
30 changes: 5 additions & 25 deletions datafusion/physical-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

use std::any::Any;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::sync::Arc;

use crate::physical_expr::{down_cast_any_ref, PhysicalExpr};
use crate::physical_expr::PhysicalExpr;

use arrow::compute::{can_cast_types, CastOptions};
use arrow::datatypes::{DataType, DataType::*, Schema};
Expand All @@ -42,10 +42,10 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
};

/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
#[derive(Debug, Clone)]
pub struct CastExpr {
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct CastExpr<DynPhysicalExpr: ?Sized = dyn PhysicalExpr> {
/// The expression to cast
pub expr: Arc<dyn PhysicalExpr>,
pub expr: Arc<DynPhysicalExpr>,
/// The data type to cast to
cast_type: DataType,
/// Cast options
Expand Down Expand Up @@ -160,13 +160,6 @@ impl PhysicalExpr for CastExpr {
]))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.expr.hash(&mut s);
self.cast_type.hash(&mut s);
self.cast_options.hash(&mut s);
}

/// A [`CastExpr`] preserves the ordering of its child if the cast is done
/// under the same datatype family.
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
Expand All @@ -186,19 +179,6 @@ impl PhysicalExpr for CastExpr {
}
}

impl PartialEq<dyn Any> for CastExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.expr.eq(&x.expr)
&& self.cast_type == x.cast_type
&& self.cast_options == x.cast_options
})
.unwrap_or(false)
}
}

/// Return a PhysicalExpression representing `expr` casted to
/// `cast_type`, if any casting is needed.
///
Expand Down
Loading