diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index 43cf8b9621f..bcc8c2d43ef 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -1071,6 +1071,17 @@ static bool tryCombineArrayTypes(const LogicalType& left, const LogicalType& rig return true; } +static bool tryCombineListArrayTypes(const LogicalType& left, const LogicalType& right, + LogicalType& result) { + LogicalType childType; + if (!LogicalTypeUtils::tryGetMaxLogicalType(*ListType::getChildType(&left), + *ArrayType::getChildType(&right), childType)) { + return false; + } + result = *LogicalType::LIST(childType); + return true; +} + // If we can match child labels and combine their types, then we can combine // the struct static bool tryCombineStructTypes(const LogicalType& left, const LogicalType& right, @@ -1184,6 +1195,7 @@ static LogicalTypeID joinDifferentSignIntegrals(const LogicalTypeID& signedType, } } +/* static uint32_t internalTypeOrder(const LogicalTypeID& type) { switch (type) { case LogicalTypeID::ANY: @@ -1255,6 +1267,26 @@ static uint32_t internalTypeOrder(const LogicalTypeID& type) { // LCOV_EXCL_END } } +*/ + +static uint32_t internalTimeOrder(const LogicalTypeID& type) { + switch (type) { + case LogicalTypeID::DATE: + return 50; + case LogicalTypeID::TIMESTAMP_SEC: + return 51; + case LogicalTypeID::TIMESTAMP_MS: + return 52; + case LogicalTypeID::TIMESTAMP: + return 53; + case LogicalTypeID::TIMESTAMP_TZ: + return 54; + case LogicalTypeID::TIMESTAMP_NS: + return 55; + default: + return 0; // return 0 if not timestamp + } +} bool canAlwaysCast(const LogicalTypeID& typeID) { switch (typeID) { @@ -1296,14 +1328,26 @@ bool LogicalTypeUtils::tryGetMaxLogicalTypeID(const LogicalTypeID& left, const L return true; } } - auto leftPlacement = internalTypeOrder(left); - auto rightPlacement = internalTypeOrder(right); - if (leftPlacement > rightPlacement) { - result = left; + + // check timestamp combination + // note: this will become obsolete if implicit casting + // between timestamps is allowed + auto leftOrder = internalTimeOrder(left); + auto rightOrder = internalTimeOrder(right); + if (leftOrder && rightOrder) { + if (leftOrder > rightOrder) { + result = left; + } else { + result = right; + } return true; } - result = right; - return true; + + return false; +} + +static inline bool isSemanticallyNested(LogicalTypeID ID) { + return LogicalTypeUtils::isNested(ID) && ID != LogicalTypeID::RDF_VARIANT; } bool LogicalTypeUtils::tryGetMaxLogicalType(const LogicalType& left, const LogicalType& right, @@ -1316,9 +1360,12 @@ bool LogicalTypeUtils::tryGetMaxLogicalType(const LogicalType& left, const Logic result = left; return true; } - if ((isNested(left) && left.typeID != LogicalTypeID::RDF_VARIANT) && - (isNested(right) && right.typeID != LogicalTypeID::RDF_VARIANT)) { - if (left.typeID != right.typeID) { + if (isSemanticallyNested(left.typeID) || isSemanticallyNested(right.typeID)) { + if (left.typeID == LogicalTypeID::LIST && right.typeID == LogicalTypeID::ARRAY) { + return tryCombineListArrayTypes(left, right, result); + } else if (left.typeID == LogicalTypeID::ARRAY && right.typeID == LogicalTypeID::LIST) { + return tryCombineListArrayTypes(right, left, result); + } else if (left.typeID != right.typeID) { return false; } switch (left.typeID) { diff --git a/src/function/built_in_function_utils.cpp b/src/function/built_in_function_utils.cpp index 3c860b3d2de..1473e1ad685 100644 --- a/src/function/built_in_function_utils.cpp +++ b/src/function/built_in_function_utils.cpp @@ -148,6 +148,7 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy case LogicalTypeID::TIMESTAMP_NS: case LogicalTypeID::TIMESTAMP_TZ: // currently don't allow timestamp to other timestamp types + // When we implement this in the future, revise tryGetMaxLogicalTypeID return castTimestamp(targetTypeID); default: return UNDEFINED_CAST_COST; @@ -156,6 +157,7 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy uint32_t BuiltInFunctionsUtils::getTargetTypeCost(LogicalTypeID typeID) { switch (typeID) { + case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: return 101; case LogicalTypeID::INT32: @@ -187,6 +189,8 @@ uint32_t BuiltInFunctionsUtils::castInt64(LogicalTypeID targetTypeID) { case LogicalTypeID::FLOAT: case LogicalTypeID::DOUBLE: return getTargetTypeCost(targetTypeID); + case LogicalTypeID::SERIAL: + return 0; default: return UNDEFINED_CAST_COST; } @@ -194,6 +198,7 @@ uint32_t BuiltInFunctionsUtils::castInt64(LogicalTypeID targetTypeID) { uint32_t BuiltInFunctionsUtils::castInt32(LogicalTypeID targetTypeID) { switch (targetTypeID) { + case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: case LogicalTypeID::INT128: case LogicalTypeID::FLOAT: @@ -206,6 +211,7 @@ uint32_t BuiltInFunctionsUtils::castInt32(LogicalTypeID targetTypeID) { uint32_t BuiltInFunctionsUtils::castInt16(LogicalTypeID targetTypeID) { switch (targetTypeID) { + case LogicalTypeID::SERIAL: case LogicalTypeID::INT32: case LogicalTypeID::INT64: case LogicalTypeID::INT128: @@ -219,6 +225,7 @@ uint32_t BuiltInFunctionsUtils::castInt16(LogicalTypeID targetTypeID) { uint32_t BuiltInFunctionsUtils::castInt8(LogicalTypeID targetTypeID) { switch (targetTypeID) { + case LogicalTypeID::SERIAL: case LogicalTypeID::INT16: case LogicalTypeID::INT32: case LogicalTypeID::INT64: @@ -244,6 +251,7 @@ uint32_t BuiltInFunctionsUtils::castUInt64(LogicalTypeID targetTypeID) { uint32_t BuiltInFunctionsUtils::castUInt32(LogicalTypeID targetTypeID) { switch (targetTypeID) { + case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: case LogicalTypeID::INT128: case LogicalTypeID::UINT64: @@ -258,6 +266,7 @@ uint32_t BuiltInFunctionsUtils::castUInt32(LogicalTypeID targetTypeID) { uint32_t BuiltInFunctionsUtils::castUInt16(LogicalTypeID targetTypeID) { switch (targetTypeID) { case LogicalTypeID::INT32: + case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: case LogicalTypeID::INT128: case LogicalTypeID::UINT32: @@ -274,6 +283,7 @@ uint32_t BuiltInFunctionsUtils::castUInt8(LogicalTypeID targetTypeID) { switch (targetTypeID) { case LogicalTypeID::INT16: case LogicalTypeID::INT32: + case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: case LogicalTypeID::INT128: case LogicalTypeID::UINT16: