Skip to content

Commit

Permalink
Change MaxLogicalType to Work Better (#3316)
Browse files Browse the repository at this point in the history
* remove force maxlogicaltype

* touch ups

* fix broken tests

* Run clang-format

---------

Co-authored-by: CI Bot <[email protected]>
  • Loading branch information
mxwli and mxwli authored Apr 18, 2024
1 parent 4496f4b commit 7984feb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 9 deletions.
65 changes: 56 additions & 9 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,17 @@ static bool tryCombineArrayTypes(const LogicalType& left, const LogicalType& rig
return true;
}

static bool tryCombineListArrayTypes(const LogicalType& left, const LogicalType& right,
LogicalType& result) {
LogicalType childType;
if (!LogicalTypeUtils::tryGetMaxLogicalType(*ListType::getChildType(&left),
*ArrayType::getChildType(&right), childType)) {
return false;
}
result = *LogicalType::LIST(childType);
return true;
}

// If we can match child labels and combine their types, then we can combine
// the struct
static bool tryCombineStructTypes(const LogicalType& left, const LogicalType& right,
Expand Down Expand Up @@ -1184,6 +1195,7 @@ static LogicalTypeID joinDifferentSignIntegrals(const LogicalTypeID& signedType,
}
}

/*
static uint32_t internalTypeOrder(const LogicalTypeID& type) {
switch (type) {
case LogicalTypeID::ANY:
Expand Down Expand Up @@ -1255,6 +1267,26 @@ static uint32_t internalTypeOrder(const LogicalTypeID& type) {
// LCOV_EXCL_END
}
}
*/

static uint32_t internalTimeOrder(const LogicalTypeID& type) {
switch (type) {
case LogicalTypeID::DATE:
return 50;
case LogicalTypeID::TIMESTAMP_SEC:
return 51;
case LogicalTypeID::TIMESTAMP_MS:
return 52;
case LogicalTypeID::TIMESTAMP:
return 53;
case LogicalTypeID::TIMESTAMP_TZ:
return 54;
case LogicalTypeID::TIMESTAMP_NS:
return 55;
default:
return 0; // return 0 if not timestamp
}
}

bool canAlwaysCast(const LogicalTypeID& typeID) {
switch (typeID) {
Expand Down Expand Up @@ -1296,14 +1328,26 @@ bool LogicalTypeUtils::tryGetMaxLogicalTypeID(const LogicalTypeID& left, const L
return true;
}
}
auto leftPlacement = internalTypeOrder(left);
auto rightPlacement = internalTypeOrder(right);
if (leftPlacement > rightPlacement) {
result = left;

// check timestamp combination
// note: this will become obsolete if implicit casting
// between timestamps is allowed
auto leftOrder = internalTimeOrder(left);
auto rightOrder = internalTimeOrder(right);
if (leftOrder && rightOrder) {
if (leftOrder > rightOrder) {
result = left;
} else {
result = right;
}
return true;
}
result = right;
return true;

return false;
}

static inline bool isSemanticallyNested(LogicalTypeID ID) {
return LogicalTypeUtils::isNested(ID) && ID != LogicalTypeID::RDF_VARIANT;
}

bool LogicalTypeUtils::tryGetMaxLogicalType(const LogicalType& left, const LogicalType& right,
Expand All @@ -1316,9 +1360,12 @@ bool LogicalTypeUtils::tryGetMaxLogicalType(const LogicalType& left, const Logic
result = left;
return true;
}
if ((isNested(left) && left.typeID != LogicalTypeID::RDF_VARIANT) &&
(isNested(right) && right.typeID != LogicalTypeID::RDF_VARIANT)) {
if (left.typeID != right.typeID) {
if (isSemanticallyNested(left.typeID) || isSemanticallyNested(right.typeID)) {
if (left.typeID == LogicalTypeID::LIST && right.typeID == LogicalTypeID::ARRAY) {
return tryCombineListArrayTypes(left, right, result);
} else if (left.typeID == LogicalTypeID::ARRAY && right.typeID == LogicalTypeID::LIST) {
return tryCombineListArrayTypes(right, left, result);
} else if (left.typeID != right.typeID) {
return false;
}
switch (left.typeID) {
Expand Down
10 changes: 10 additions & 0 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy
case LogicalTypeID::TIMESTAMP_NS:
case LogicalTypeID::TIMESTAMP_TZ:
// currently don't allow timestamp to other timestamp types
// When we implement this in the future, revise tryGetMaxLogicalTypeID
return castTimestamp(targetTypeID);
default:
return UNDEFINED_CAST_COST;
Expand All @@ -156,6 +157,7 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy

uint32_t BuiltInFunctionsUtils::getTargetTypeCost(LogicalTypeID typeID) {
switch (typeID) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
return 101;
case LogicalTypeID::INT32:
Expand Down Expand Up @@ -187,13 +189,16 @@ uint32_t BuiltInFunctionsUtils::castInt64(LogicalTypeID targetTypeID) {
case LogicalTypeID::FLOAT:
case LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
case LogicalTypeID::SERIAL:
return 0;
default:
return UNDEFINED_CAST_COST;
}
}

uint32_t BuiltInFunctionsUtils::castInt32(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
case LogicalTypeID::INT128:
case LogicalTypeID::FLOAT:
Expand All @@ -206,6 +211,7 @@ uint32_t BuiltInFunctionsUtils::castInt32(LogicalTypeID targetTypeID) {

uint32_t BuiltInFunctionsUtils::castInt16(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT32:
case LogicalTypeID::INT64:
case LogicalTypeID::INT128:
Expand All @@ -219,6 +225,7 @@ uint32_t BuiltInFunctionsUtils::castInt16(LogicalTypeID targetTypeID) {

uint32_t BuiltInFunctionsUtils::castInt8(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT16:
case LogicalTypeID::INT32:
case LogicalTypeID::INT64:
Expand All @@ -244,6 +251,7 @@ uint32_t BuiltInFunctionsUtils::castUInt64(LogicalTypeID targetTypeID) {

uint32_t BuiltInFunctionsUtils::castUInt32(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
case LogicalTypeID::INT128:
case LogicalTypeID::UINT64:
Expand All @@ -258,6 +266,7 @@ uint32_t BuiltInFunctionsUtils::castUInt32(LogicalTypeID targetTypeID) {
uint32_t BuiltInFunctionsUtils::castUInt16(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::INT32:
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
case LogicalTypeID::INT128:
case LogicalTypeID::UINT32:
Expand All @@ -274,6 +283,7 @@ uint32_t BuiltInFunctionsUtils::castUInt8(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::INT16:
case LogicalTypeID::INT32:
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
case LogicalTypeID::INT128:
case LogicalTypeID::UINT16:
Expand Down

0 comments on commit 7984feb

Please sign in to comment.