From f6b7f9e4479c8d96d3e18cebe1a7748c7865665c Mon Sep 17 00:00:00 2001 From: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> Date: Sat, 12 Oct 2024 14:34:27 +0800 Subject: [PATCH] fix no matching function error in array_contains/array_position Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> --- .../analyzer/DecimalV3FunctionAnalyzer.java | 7 +++ .../analyzer/PolymorphicFunctionAnalyzer.java | 21 ++++++++ test/sql/test_array_fn/R/test_array_contains | 54 +++++++++++++++++++ test/sql/test_array_fn/T/test_array_contains | 23 ++++++++ 4 files changed, 105 insertions(+) diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java index 15dc610c26933..ec36b5a1c97bd 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/DecimalV3FunctionAnalyzer.java @@ -319,6 +319,13 @@ public static boolean argumentTypeContainDecimalV3(String fnName, Type[] argumen return true; } + if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) || + FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) { + return argumentTypes[0].isArrayType() && + (((ArrayType) argumentTypes[0]).getItemType().isDecimalV3() || argumentTypes[1].isDecimalV3()); + } + + if (Arrays.stream(argumentTypes).anyMatch(Type::isDecimalV3)) { return true; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java index f1ecfad1c5dc6..1bee662560c99 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/PolymorphicFunctionAnalyzer.java @@ -239,6 +239,22 @@ private static Function resolveByDeducingReturnType(Function fn, Type[] inputArg return null; } + private static Function resolvePolymorphicArrayFunction(Function fn, Type[] inputArgTypes) { + // for some special array function, they have ANY_ARRAY/ANY_ELEMENT in arguments, should align type + String fnName = fn.getFunctionName().getFunction(); + if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) || + FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) { + Type elementType = ((ArrayType) inputArgTypes[0]).getItemType(); + Type commonType = TypeManager.getCommonSuperType(elementType, inputArgTypes[1]); + if (commonType == null) { + return null; + } + return newScalarFunction((ScalarFunction) fn, + Arrays.asList(new ArrayType(commonType), commonType), fn.getReturnType()); + } + return null; + } + /** * Inspired by ... *

@@ -302,6 +318,11 @@ public static Function generatePolymorphicFunction(Function fn, Type[] paramType return resolvedFunction; } + resolvedFunction = resolvePolymorphicArrayFunction(fn, paramTypes); + if (resolvedFunction != null) { + return resolvedFunction; + } + // common deduce ArrayType typeArray; Type typeElement; diff --git a/test/sql/test_array_fn/R/test_array_contains b/test/sql/test_array_fn/R/test_array_contains index d680eb9d52900..34623ba442bac 100644 --- a/test/sql/test_array_fn/R/test_array_contains +++ b/test/sql/test_array_fn/R/test_array_contains @@ -192,4 +192,58 @@ select sum(array_contains(@arr, "abcdefg")) from t; select sum(array_contains(@arr, str)) from t; -- result: 0 +-- !result +-- name: test_array_contains_with_decimal +create table t ( + k bigint, + v1 array, + v2 array>, + v3 array>> +) duplicate key (`k`) +distributed by random buckets 1 +properties('replication_num'='1'); +-- result: +-- !result +insert into t values (1,[1.1], [[1.1]],[[[1.1]]]); +-- result: +-- !result +select array_contains(v1, 1.1) from t; +-- result: +1 +-- !result +select array_contains(v2, [1.1]) from t; +-- result: +1 +-- !result +select array_contains(v3, [[1.1]]) from t; +-- result: +1 +-- !result +select array_contains(v2, v1) from t; +-- result: +1 +-- !result +select array_contains(v3, v2) from t; +-- result: +1 +-- !result +select array_position(v1, 1.1) from t; +-- result: +1 +-- !result +select array_position(v2, [1.1]) from t; +-- result: +1 +-- !result +select array_position(v3, [[1.1]]) from t; +-- result: +1 +-- !result +select array_position(v2, v1) from t; +-- result: +1 +-- !result +select array_position(v3, v2) from t; +-- result: +1 -- !result \ No newline at end of file diff --git a/test/sql/test_array_fn/T/test_array_contains b/test/sql/test_array_fn/T/test_array_contains index 2053bf975790f..4b8799df1bcea 100644 --- a/test/sql/test_array_fn/T/test_array_contains +++ b/test/sql/test_array_fn/T/test_array_contains @@ -45,3 +45,26 @@ set @arr = array_repeat("abcdefg", 1000000); select sum(array_contains(@arr, "abcdefg")) from t; select sum(array_contains(@arr, str)) from t; +-- name: test_array_contains_with_decimal +create table t ( + k bigint, + v1 array, + v2 array>, + v3 array>> +) duplicate key (`k`) +distributed by random buckets 1 +properties('replication_num'='1'); + +insert into t values (1,[1.1], [[1.1]],[[[1.1]]]); + +select array_contains(v1, 1.1) from t; +select array_contains(v2, [1.1]) from t; +select array_contains(v3, [[1.1]]) from t; +select array_contains(v2, v1) from t; +select array_contains(v3, v2) from t; + +select array_position(v1, 1.1) from t; +select array_position(v2, [1.1]) from t; +select array_position(v3, [[1.1]]) from t; +select array_position(v2, v1) from t; +select array_position(v3, v2) from t; \ No newline at end of file