Skip to content

Commit

Permalink
[BugFix] fix no matching function error in array_contains/array_posit…
Browse files Browse the repository at this point in the history
…ion (#51835)

Signed-off-by: silverbullet233 <[email protected]>
(cherry picked from commit 8cc2ae2)

# Conflicts:
#	test/sql/test_array_fn/R/test_array_contains
#	test/sql/test_array_fn/T/test_array_contains
  • Loading branch information
silverbullet233 authored and mergify[bot] committed Oct 14, 2024
1 parent 4099470 commit 1af0e6e
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ public static boolean argumentTypeContainDecimalV3(String fnName, Type[] argumen
return true;
}

if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) ||
FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) {
return argumentTypes[0].isArrayType() &&
(((ArrayType) argumentTypes[0]).getItemType().isDecimalV3() || argumentTypes[1].isDecimalV3());
}


if (Arrays.stream(argumentTypes).anyMatch(Type::isDecimalV3)) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ private static Function resolveByDeducingReturnType(Function fn, Type[] inputArg
return null;
}

private static Function resolvePolymorphicArrayFunction(Function fn, Type[] inputArgTypes) {
// for some special array function, they have ANY_ARRAY/ANY_ELEMENT in arguments, should align type
String fnName = fn.getFunctionName().getFunction();
if (FunctionSet.ARRAY_CONTAINS.equalsIgnoreCase(fnName) ||
FunctionSet.ARRAY_POSITION.equalsIgnoreCase(fnName)) {
Type elementType = ((ArrayType) inputArgTypes[0]).getItemType();
Type commonType = TypeManager.getCommonSuperType(elementType, inputArgTypes[1]);
if (commonType == null) {
return null;
}
return newScalarFunction((ScalarFunction) fn,
Arrays.asList(new ArrayType(commonType), commonType), fn.getReturnType());
}
return null;
}

/**
* Inspired by <a href="https://github.com/postgres/postgres/blob/master/src/backend/parser/parse_coerce.c#L1934">...</a>
* <p>
Expand Down Expand Up @@ -292,6 +308,11 @@ public static Function generatePolymorphicFunction(Function fn, Type[] paramType
return resolvedFunction;
}

resolvedFunction = resolvePolymorphicArrayFunction(fn, paramTypes);
if (resolvedFunction != null) {
return resolvedFunction;
}

// common deduce
ArrayType typeArray;
Type typeElement;
Expand Down
249 changes: 249 additions & 0 deletions test/sql/test_array_fn/R/test_array_contains
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
-- name: test_array_contains_with_const
CREATE TABLE t (
pk bigint not null ,
str string,
arr_bigint array<bigint>,
arr_str array<string>,
arr_decimal array<decimal(38,5)>
) ENGINE=OLAP
DUPLICATE KEY(`pk`)
DISTRIBUTED BY HASH(`pk`) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
-- result:
-- !result
insert into t select generate_series, md5sum(generate_series), array_repeat(generate_series, 1000),array_repeat(md5sum(generate_series), 100), array_repeat(generate_series, 1000) from table(generate_series(0, 9999));
-- result:
-- !result
insert into t values (10000, md5sum(10000), array_append(array_generate(1000), null), array_append(array_repeat(md5sum(10000),100), null),array_append(array_generate(1000),null));
-- result:
-- !result
select array_contains([1,2,3,4], 1) from t order by pk limit 10;
-- result:
1
1
1
1
1
1
1
1
1
1
-- !result
select array_position([1,2,3,4], 1) from t order by pk limit 10;
-- result:
1
1
1
1
1
1
1
1
1
1
-- !result
select array_contains([1,2,3,4], null) from t order by pk limit 10;
-- result:
0
0
0
0
0
0
0
0
0
0
-- !result
select array_position([1,2,3,4], null) from t order by pk limit 10;
-- result:
0
0
0
0
0
0
0
0
0
0
-- !result
select array_contains([1,2,3,null], null) from t order by pk limit 10;
-- result:
1
1
1
1
1
1
1
1
1
1
-- !result
select array_position([1,2,3,null], null) from t order by pk limit 10;
-- result:
4
4
4
4
4
4
4
4
4
4
-- !result
select array_contains(null, null) from t order by pk limit 10;
-- result:
None
None
None
None
None
None
None
None
None
None
-- !result
select array_position(null, null) from t order by pk limit 10;
-- result:
None
None
None
None
None
None
None
None
None
None
-- !result
set @arr = array_generate(10000);
-- result:
-- !result
select sum(array_contains(@arr, pk)) from t;
-- result:
10000
-- !result
select sum(array_contains(@arr, 100)) from t;
-- result:
10001
-- !result
select sum(array_position(@arr, pk)) from t;
-- result:
50005000
-- !result
select sum(array_position(@arr, 100)) from t;
-- result:
1000100
-- !result
select sum(array_contains(array_append(@arr, null), pk)) from t;
-- result:
10000
-- !result
select sum(array_contains(array_append(@arr, null), null)) from t;
-- result:
10001
-- !result
select sum(array_contains(arr_bigint, 100)) from t;
-- result:
2
-- !result
select sum(array_position(arr_bigint, 100)) from t;
-- result:
101
-- !result
select sum(array_contains(arr_str, md5sum(100))) from t;
-- result:
1
-- !result
select sum(array_position(arr_str, md5sum(100))) from t;
-- result:
1
-- !result
select sum(array_contains(arr_decimal, pk)) from t;
-- result:
10000
-- !result
select sum(array_position(arr_decimal, pk)) from t;
-- result:
10000
-- !result
select sum(array_contains(arr_decimal, 100)) from t;
-- result:
2
-- !result
select sum(array_position(arr_decimal, 100)) from t;
-- result:
101
-- !result
set @arr = array_repeat("abcdefg", 1000000);
-- result:
-- !result
select sum(array_contains(@arr, "abcdefg")) from t;
-- result:
10001
-- !result
select sum(array_contains(@arr, str)) from t;
-- result:
0
-- !result
-- name: test_array_contains_with_decimal
create table t (
k bigint,
v1 array<decimal(38,5)>,
v2 array<array<decimal(38,5)>>,
v3 array<array<array<decimal(38,5)>>>
) duplicate key (`k`)
distributed by random buckets 1
properties('replication_num'='1');
-- result:
-- !result
insert into t values (1,[1.1], [[1.1]],[[[1.1]]]);
-- result:
-- !result
select array_contains(v1, 1.1) from t;
-- result:
1
-- !result
select array_contains(v2, [1.1]) from t;
-- result:
1
-- !result
select array_contains(v3, [[1.1]]) from t;
-- result:
1
-- !result
select array_contains(v2, v1) from t;
-- result:
1
-- !result
select array_contains(v3, v2) from t;
-- result:
1
-- !result
select array_position(v1, 1.1) from t;
-- result:
1
-- !result
select array_position(v2, [1.1]) from t;
-- result:
1
-- !result
select array_position(v3, [[1.1]]) from t;
-- result:
1
-- !result
select array_position(v2, v1) from t;
-- result:
1
-- !result
select array_position(v3, v2) from t;
-- result:
1
-- !result
70 changes: 70 additions & 0 deletions test/sql/test_array_fn/T/test_array_contains
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
-- name: test_array_contains_with_const
CREATE TABLE t (
pk bigint not null ,
str string,
arr_bigint array<bigint>,
arr_str array<string>,
arr_decimal array<decimal(38,5)>
) ENGINE=OLAP
DUPLICATE KEY(`pk`)
DISTRIBUTED BY HASH(`pk`) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);

insert into t select generate_series, md5sum(generate_series), array_repeat(generate_series, 1000),array_repeat(md5sum(generate_series), 100), array_repeat(generate_series, 1000) from table(generate_series(0, 9999));
insert into t values (10000, md5sum(10000), array_append(array_generate(1000), null), array_append(array_repeat(md5sum(10000),100), null),array_append(array_generate(1000),null));

select array_contains([1,2,3,4], 1) from t order by pk limit 10;
select array_position([1,2,3,4], 1) from t order by pk limit 10;
select array_contains([1,2,3,4], null) from t order by pk limit 10;
select array_position([1,2,3,4], null) from t order by pk limit 10;
select array_contains([1,2,3,null], null) from t order by pk limit 10;
select array_position([1,2,3,null], null) from t order by pk limit 10;
select array_contains(null, null) from t order by pk limit 10;
select array_position(null, null) from t order by pk limit 10;

set @arr = array_generate(10000);
select sum(array_contains(@arr, pk)) from t;
select sum(array_contains(@arr, 100)) from t;
select sum(array_position(@arr, pk)) from t;
select sum(array_position(@arr, 100)) from t;
select sum(array_contains(array_append(@arr, null), pk)) from t;
select sum(array_contains(array_append(@arr, null), null)) from t;
select sum(array_contains(arr_bigint, 100)) from t;
select sum(array_position(arr_bigint, 100)) from t;
select sum(array_contains(arr_str, md5sum(100))) from t;
select sum(array_position(arr_str, md5sum(100))) from t;
select sum(array_contains(arr_decimal, pk)) from t;
select sum(array_position(arr_decimal, pk)) from t;
select sum(array_contains(arr_decimal, 100)) from t;
select sum(array_position(arr_decimal, 100)) from t;


set @arr = array_repeat("abcdefg", 1000000);
select sum(array_contains(@arr, "abcdefg")) from t;
select sum(array_contains(@arr, str)) from t;

-- name: test_array_contains_with_decimal
create table t (
k bigint,
v1 array<decimal(38,5)>,
v2 array<array<decimal(38,5)>>,
v3 array<array<array<decimal(38,5)>>>
) duplicate key (`k`)
distributed by random buckets 1
properties('replication_num'='1');

insert into t values (1,[1.1], [[1.1]],[[[1.1]]]);

select array_contains(v1, 1.1) from t;
select array_contains(v2, [1.1]) from t;
select array_contains(v3, [[1.1]]) from t;
select array_contains(v2, v1) from t;
select array_contains(v3, v2) from t;

select array_position(v1, 1.1) from t;
select array_position(v2, [1.1]) from t;
select array_position(v3, [[1.1]]) from t;
select array_position(v2, v1) from t;
select array_position(v3, v2) from t;

0 comments on commit 1af0e6e

Please sign in to comment.