Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogicalType for TypeSignature Numeric and String, Coercible #13240

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions datafusion/common/src/types/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl fmt::Debug for dyn LogicalType {
}
}

impl std::fmt::Display for dyn LogicalType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make a distinction between Debug and Display. Currently, we print something like

        let int: Box<dyn LogicalType> = Box::new(Int32);
        println!(format!("{}", int));

------- output ------
LogicalType(Native(Int32), Int32))

I imagine we can print Int32 toInt32, print JSON to JSON ... 🤔
However, I think it may not be the point of this PR. We can do it in another PR.

}
}

impl PartialEq for dyn LogicalType {
fn eq(&self, other: &Self) -> bool {
self.signature().eq(&other.signature())
Expand Down
37 changes: 34 additions & 3 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ impl LogicalType for NativeType {
// mapping solutions to provide backwards compatibility while transitioning from
// the purely physical system to a logical / physical system.

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
}
}

impl From<DataType> for NativeType {
fn from(value: DataType) -> Self {
use NativeType::*;
Expand Down Expand Up @@ -392,8 +398,33 @@ impl From<DataType> for NativeType {
}
}

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
impl NativeType {
#[inline]
pub fn is_numeric(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal(_, _)
)
}

#[inline]
pub fn is_integer(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
)
}
}
14 changes: 10 additions & 4 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! and return types of functions in DataFusion.

use arrow::datatypes::DataType;
use datafusion_common::types::LogicalTypeRef;

/// Constant that is used as a placeholder for any valid timezone.
/// This is used where a function can accept a timestamp type with any
Expand Down Expand Up @@ -109,7 +110,7 @@ pub enum TypeSignature {
/// For example, `Coercible(vec![DataType::Float64])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
Coercible(Vec<DataType>),
Coercible(Vec<LogicalTypeRef>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
Expand All @@ -123,7 +124,9 @@ pub enum TypeSignature {
/// Specifies Signatures for array functions
ArraySignature(ArrayFunctionSignature),
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
/// See [`NativeType::is_numeric`] to know which type is considered numeric
///
/// [`NativeType::is_numeric`]: datafusion_common
Numeric(usize),
/// Fixed number of arguments of all the same string types.
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
Expand Down Expand Up @@ -201,7 +204,10 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({num})")]
}
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
Expand Down Expand Up @@ -322,7 +328,7 @@ impl Signature {
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
Expand Down
149 changes: 94 additions & 55 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::{LogicalType, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
Expand Down Expand Up @@ -395,40 +396,56 @@ fn get_valid_types(
}
}

fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
if length < 1 {
return plan_err!(
"The signature expected at least one argument but received {expected_length}"
);
}

if length != expected_length {
return plan_err!(
"The signature expected {length} arguments but received {expected_length}"
);
}

Ok(())
}

let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
function_length_check(current_types.len(), *number)?;

let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();
if logical_data_type == NativeType::String {
new_types.push(data_type.to_owned());
} else if logical_data_type == NativeType::Null {
// TODO: Switch to Utf8View if all the string functions supports Utf8View
new_types.push(DataType::Utf8);
} else {
return plan_err!(
"The signature expected NativeType::String but received {data_type}"
);
}
}

fn coercion_rule(
// Find the common string type for the given types
fn find_common_type(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
(DataType::Null, data_type) | (data_type, DataType::Null) => {
coercion_rule(data_type, &DataType::Utf8)
}
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
coercion_rule(lhs, rhs)
find_common_type(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
| (other, DataType::Dictionary(_, v)) => find_common_type(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
Expand All @@ -444,15 +461,13 @@ fn get_valid_types(
}

// Length checked above, safe to unwrap
let mut coerced_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
coerced_type = coercion_rule(&coerced_type, t)?;
let mut coerced_type = new_types.first().unwrap().to_owned();
for t in new_types.iter().skip(1) {
coerced_type = find_common_type(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if data_type.is_null() {
DataType::Utf8
} else if let DataType::Dictionary(_, v) = data_type {
if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
Expand All @@ -462,22 +477,22 @@ fn get_valid_types(
vec![vec![base_type_or_default_type(&coerced_type); *number]]
}
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}
function_length_check(current_types.len(), *number)?;

let mut valid_type = current_types.first().unwrap().clone();
// Find common numeric type amongs given types except string
let mut valid_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
let logical_data_type: NativeType = t.into();
if logical_data_type == NativeType::Null {
continue;
}

if !logical_data_type.is_numeric() {
return plan_err!(
"The signature expected NativeType::Numeric but received {t}"
);
}

if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
valid_type = coerced_type;
} else {
Expand All @@ -489,31 +504,55 @@ fn get_valid_types(
}
}

let logical_data_type: NativeType = valid_type.clone().into();
// Fallback to default type if we don't know which type to coerced to
// f64 is chosen since most of the math functions utilize Signature::numeric,
// and their default type is double precision
if logical_data_type == NativeType::Null {
valid_type = DataType::Float64;
}

vec![vec![valid_type; *number]]
}
TypeSignature::Coercible(target_types) => {
if target_types.is_empty() {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if target_types.len() != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
target_types.len(),
current_types.len()
);
function_length_check(current_types.len(), target_types.len())?;

// Aim to keep this logic as SIMPLE as possible!
// Make sure the corresponding test is covered
// If this function becomes COMPLEX, create another new signature!
fn can_coerce_to(
logical_type: &NativeType,
target_type: &NativeType,
) -> bool {
if logical_type == target_type {
return true;
}

if logical_type == &NativeType::Null {
return true;
}

if target_type.is_integer() && logical_type.is_integer() {
return true;
}

false
}

for (data_type, target_type) in current_types.iter().zip(target_types.iter())
let mut new_types = Vec::with_capacity(current_types.len());
for (current_type, target_type) in
current_types.iter().zip(target_types.iter())
{
if !can_cast_types(data_type, target_type) {
return plan_err!("{data_type} is not coercible to {target_type}");
let logical_type: NativeType = current_type.into();
let target_logical_type = target_type.native();
if can_coerce_to(&logical_type, target_logical_type) {
let target_type =
target_logical_type.default_cast_for(current_type)?;
new_types.push(target_type);
}
}

vec![target_types.to_owned()]
vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
Expand Down
24 changes: 4 additions & 20 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr,
ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility,
Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature,
SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::utils::get_sort_options;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef};
Expand Down Expand Up @@ -79,15 +79,7 @@ impl Default for FirstValue {
impl FirstValue {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Numeric(1),
TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
signature: Signature::any(1, Volatility::Immutable),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally first/last value get the first/last item in the column, so it should be the type of the column.

I forgot why we don't use Any.
Without the change, we need to add another type signature for boolean, so I change it to Any in this PR

requirement_satisfied: false,
}
}
Expand Down Expand Up @@ -406,15 +398,7 @@ impl Default for LastValue {
impl LastValue {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Numeric(1),
TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
signature: Signature::any(1, Volatility::Immutable),
requirement_satisfied: false,
}
}
Expand Down
5 changes: 1 addition & 4 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ impl Stddev {
/// Create a new STDDEV aggregate function
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
signature: Signature::numeric(1, Volatility::Immutable),
alias: vec!["stddev_samp".to_string()],
}
}
Expand Down
5 changes: 1 addition & 4 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ impl VarianceSample {
pub fn new() -> Self {
Self {
aliases: vec![String::from("var_sample"), String::from("var_samp")],
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
signature: Signature::numeric(1, Volatility::Immutable),
}
}
}
Expand Down
Loading