From f6b7f9e4479c8d96d3e18cebe1a7748c7865665c Mon Sep 17 00:00:00 2001 From: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> Date: Sat, 12 Oct 2024 14:34:27 +0800 Subject: [PATCH] fix no matching function error in array_contains/array_position Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> --- .../analyzer/DecimalV3FunctionAnalyzer.java | 7 +++ .../analyzer/PolymorphicFunctionAnalyzer.java | 21 ++++++++ test/sql/test_array_fn/R/test_array_contains | 54 +++++++++++++++++++ test/sql/test_array_fn/T/test_array_contains | 23 ++++++++ 4 files changed, 105 insertions(+) diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java index 15dc610c26933..ec36b5a1c97bd 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java @@ -319,6 +319,13 @@ public static boolean argumentTypeContainDecimalV3(String fnName, Type[] argumen return true; } + if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) || + FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) { + return argumentTypes[0].isArrayType() && + (((ArrayType) argumentTypes[0]).getItemType().isDecimalV3() || argumentTypes[1].isDecimalV3()); + } + + if (Arrays.stream(argumentTypes).anyMatch(Type::isDecimalV3)) { return true; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java index f1ecfad1c5dc6..1bee662560c99 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java @@ -239,6 +239,22 @@ private static Function resolveByDeducingReturnType(Function fn, Type[] inputArg return null; } + private static Function resolvePolymorphicArrayFunction(Function fn, Type[] inputArgTypes) { + // for some special array function, they have ANY_ARRAY/ANY_ELEMENT in arguments, should align type + String fnName = fn.getFunctionName().getFunction(); + if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) || + FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) { + Type elementType = ((ArrayType) inputArgTypes[0]).getItemType(); + Type commonType = TypeManager.getCommonSuperType(elementType, inputArgTypes[1]); + if (commonType == null) { + return null; + } + return newScalarFunction((ScalarFunction) fn, + Arrays.asList(new ArrayType(commonType), commonType), fn.getReturnType()); + } + return null; + } + /** * Inspired by ... *
@@ -302,6 +318,11 @@ public static Function generatePolymorphicFunction(Function fn, Type[] paramType
return resolvedFunction;
}
+ resolvedFunction = resolvePolymorphicArrayFunction(fn, paramTypes);
+ if (resolvedFunction != null) {
+ return resolvedFunction;
+ }
+
// common deduce
ArrayType typeArray;
Type typeElement;
diff --git a/test/sql/test_array_fn/R/test_array_contains b/test/sql/test_array_fn/R/test_array_contains
index d680eb9d52900..34623ba442bac 100644
--- a/test/sql/test_array_fn/R/test_array_contains
+++ b/test/sql/test_array_fn/R/test_array_contains
@@ -192,4 +192,58 @@ select sum(array_contains(@arr, "abcdefg")) from t;
select sum(array_contains(@arr, str)) from t;
-- result:
0
+-- !result
+-- name: test_array_contains_with_decimal
+create table t (
+ k bigint,
+ v1 array