Skip to content

Commit

Permalink
fix no matching function error in array_contains/array_position
Browse files Browse the repository at this point in the history
Signed-off-by: silverbullet233 <[email protected]>
  • Loading branch information
silverbullet233 committed Oct 13, 2024
1 parent b550d07 commit f6b7f9e
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ public static boolean argumentTypeContainDecimalV3(String fnName, Type[] argumen
return true;
}

if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) ||
FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) {
return argumentTypes[0].isArrayType() &&
(((ArrayType) argumentTypes[0]).getItemType().isDecimalV3() || argumentTypes[1].isDecimalV3());
}


if (Arrays.stream(argumentTypes).anyMatch(Type::isDecimalV3)) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,22 @@ private static Function resolveByDeducingReturnType(Function fn, Type[] inputArg
return null;
}

private static Function resolvePolymorphicArrayFunction(Function fn, Type[] inputArgTypes) {
// for some special array function, they have ANY_ARRAY/ANY_ELEMENT in arguments, should align type
String fnName = fn.getFunctionName().getFunction();
if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) ||
FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) {
Type elementType = ((ArrayType) inputArgTypes[0]).getItemType();
Type commonType = TypeManager.getCommonSuperType(elementType, inputArgTypes[1]);
if (commonType == null) {
return null;
}
return newScalarFunction((ScalarFunction) fn,
Arrays.asList(new ArrayType(commonType), commonType), fn.getReturnType());
}
return null;
}

/**
* Inspired by <a href="https://github.com/postgres/postgres/blob/master/src/backend/parser/parse_coerce.c#L1934">...</a>
* <p>
Expand Down Expand Up @@ -302,6 +318,11 @@ public static Function generatePolymorphicFunction(Function fn, Type[] paramType
return resolvedFunction;
}

resolvedFunction = resolvePolymorphicArrayFunction(fn, paramTypes);
if (resolvedFunction != null) {
return resolvedFunction;
}

// common deduce
ArrayType typeArray;
Type typeElement;
Expand Down
54 changes: 54 additions & 0 deletions test/sql/test_array_fn/R/test_array_contains
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,58 @@ select sum(array_contains(@arr, "abcdefg")) from t;
select sum(array_contains(@arr, str)) from t;
-- result:
0
-- !result
-- name: test_array_contains_with_decimal
create table t (
k bigint,
v1 array<decimal(38,5)>,
v2 array<array<decimal(38,5)>>,
v3 array<array<array<decimal(38,5)>>>
) duplicate key (`k`)
distributed by random buckets 1
properties('replication_num'='1');
-- result:
-- !result
insert into t values (1,[1.1], [[1.1]],[[[1.1]]]);
-- result:
-- !result
select array_contains(v1, 1.1) from t;
-- result:
1
-- !result
select array_contains(v2, [1.1]) from t;
-- result:
1
-- !result
select array_contains(v3, [[1.1]]) from t;
-- result:
1
-- !result
select array_contains(v2, v1) from t;
-- result:
1
-- !result
select array_contains(v3, v2) from t;
-- result:
1
-- !result
select array_position(v1, 1.1) from t;
-- result:
1
-- !result
select array_position(v2, [1.1]) from t;
-- result:
1
-- !result
select array_position(v3, [[1.1]]) from t;
-- result:
1
-- !result
select array_position(v2, v1) from t;
-- result:
1
-- !result
select array_position(v3, v2) from t;
-- result:
1
-- !result
23 changes: 23 additions & 0 deletions test/sql/test_array_fn/T/test_array_contains
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,26 @@ set @arr = array_repeat("abcdefg", 1000000);
select sum(array_contains(@arr, "abcdefg")) from t;
select sum(array_contains(@arr, str)) from t;

-- name: test_array_contains_with_decimal
create table t (
k bigint,
v1 array<decimal(38,5)>,
v2 array<array<decimal(38,5)>>,
v3 array<array<array<decimal(38,5)>>>
) duplicate key (`k`)
distributed by random buckets 1
properties('replication_num'='1');

insert into t values (1,[1.1], [[1.1]],[[[1.1]]]);

select array_contains(v1, 1.1) from t;
select array_contains(v2, [1.1]) from t;
select array_contains(v3, [[1.1]]) from t;
select array_contains(v2, v1) from t;
select array_contains(v3, v2) from t;

select array_position(v1, 1.1) from t;
select array_position(v2, [1.1]) from t;
select array_position(v3, [[1.1]]) from t;
select array_position(v2, v1) from t;
select array_position(v3, v2) from t;

0 comments on commit f6b7f9e

Please sign in to comment.