Skip to content

Commit

Permalink
Implement array_repeat, remove array_fill (#7199)
Browse files Browse the repository at this point in the history
* feat: implement array_repeat, remove array_fill

* fix: proto
  • Loading branch information
izveigor authored Aug 6, 2023
1 parent 8e7a09b commit fa1b21c
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 190 deletions.
99 changes: 81 additions & 18 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ CREATE TABLE values(
(8, 15, 16, 8.8, NULL, '')
;

statement ok
CREATE TABLE values_without_nulls
AS VALUES
(1, 1, 2, 1.1, 'Lorem', 'A'),
(2, 3, 4, 2.2, 'ipsum', ''),
(3, 5, 6, 3.3, 'dolor', 'BB'),
(4, 7, 8, 4.4, 'sit', NULL),
(5, 9, 10, 5.5, 'amet', 'CCC'),
(6, 11, 12, 6.6, ',', 'DD'),
(7, 13, 14, 7.7, 'consectetur', 'E'),
(8, 15, 16, 8.8, 'adipiscing', 'F'),
(9, 17, 18, 9.9, 'elit', '')
;

statement ok
CREATE TABLE arrays
AS VALUES
Expand Down Expand Up @@ -996,25 +1010,71 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma
[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]]
[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]]

## array_fill
## array_repeat (aliases: `list_repeat`)

# array_fill scalar function #1
# array_repeat scalar function #1
query ???
select array_fill(11, make_array(1, 2, 3)), array_fill(3, make_array(2, 3)), array_fill(2, make_array(2));
select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4);
----
[[[11, 11, 11], [11, 11, 11]]] [[3, 3, 3], [3, 3, 3]] [2, 2]
[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l]

# array_fill scalar function #2
query ??
select array_fill(1, make_array(1, 1, 1)), array_fill(2, make_array(2, 2, 2, 2, 2));
# array_repeat scalar function #2 (element as list)
query ???
select array_repeat([1], 5), array_repeat([1.1, 2.2, 3.3], 3), array_repeat([[1, 2], [3, 4]], 2);
----
[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]

# list_repeat scalar function #3 (function alias: `array_repeat`)
query ???
select list_repeat(1, 5), list_repeat(3.14, 3), list_repeat('l', 4);
----
[[[1]]] [[[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]], [[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]]]
[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l]

# array_fill scalar function #3
# array_repeat with columns #1
query ?
select array_repeat(column4, column1) from values_without_nulls;
----
[1.1]
[2.2, 2.2]
[3.3, 3.3, 3.3]
[4.4, 4.4, 4.4, 4.4]
[5.5, 5.5, 5.5, 5.5, 5.5]
[6.6, 6.6, 6.6, 6.6, 6.6, 6.6]
[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7]
[8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8]
[9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9]

# array_repeat with columns #2 (element as list)
query ?
select array_fill(1, make_array())
select array_repeat(column1, column3) from arrays_values_without_nulls;
----
[]
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]
[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]
[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]]

# array_repeat with columns and scalars #1
query ??
select array_repeat(1, column1), array_repeat(column4, 3) from values_without_nulls;
----
[1] [1.1, 1.1, 1.1]
[1, 1] [2.2, 2.2, 2.2]
[1, 1, 1] [3.3, 3.3, 3.3]
[1, 1, 1, 1] [4.4, 4.4, 4.4]
[1, 1, 1, 1, 1] [5.5, 5.5, 5.5]
[1, 1, 1, 1, 1, 1] [6.6, 6.6, 6.6]
[1, 1, 1, 1, 1, 1, 1] [7.7, 7.7, 7.7]
[1, 1, 1, 1, 1, 1, 1, 1] [8.8, 8.8, 8.8]
[1, 1, 1, 1, 1, 1, 1, 1, 1] [9.9, 9.9, 9.9]

# array_repeat with columns and scalars #2 (element as list)
query ??
select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_without_nulls;
----
[[1]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
[[1], [1]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]
[[1], [1], [1]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]
[[1], [1], [1], [1]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]]

## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)

Expand Down Expand Up @@ -1570,7 +1630,7 @@ h,e,l,l,o 1-2-3-4-5 1|2|3

# array_to_string scalar function #2
query TTT
select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_fill(3, [3, 2, 2]), '/\');
select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_repeat(array_repeat(array_repeat(3, 2), 2), 3), '/\');
----
11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3

Expand Down Expand Up @@ -1670,7 +1730,7 @@ select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinali

# cardinality scalar function #2
query II
select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_fill(3, array[3, 2, 3]));
select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3));
----
6 18

Expand Down Expand Up @@ -1883,10 +1943,10 @@ select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2,
NULL NULL 2

# array_length scalar function #4
query IIII
select array_length(array_fill(3, [3, 2, 5]), 1), array_length(array_fill(3, [3, 2, 5]), 2), array_length(array_fill(3, [3, 2, 5]), 3), array_length(array_fill(3, [3, 2, 5]), 4);
query II
select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2);
----
3 2 5 NULL
3 2

# array_length scalar function #5
query III
Expand Down Expand Up @@ -1936,7 +1996,7 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])),

# array_dims scalar function #2
query ??
select array_dims(array_fill(2, [1, 2, 3])), array_dims(array_fill(3, [2, 5, 4]));
select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2));
----
[1, 2, 3] [2, 5, 4]

Expand Down Expand Up @@ -1974,7 +2034,7 @@ select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4]))

# array_ndims scalar function #2
query II
select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]);
select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]);
----
3 21

Expand Down Expand Up @@ -2264,6 +2324,9 @@ select array_concat(column1, [7]) from arrays_values_v2;
statement ok
drop table values;

statement ok
drop table values_without_nulls;

statement ok
drop table nested_arrays;

Expand Down
20 changes: 10 additions & 10 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ pub enum BuiltinScalarFunction {
ArrayDims,
/// array_element
ArrayElement,
/// array_fill
ArrayFill,
/// array_length
ArrayLength,
/// array_ndims
Expand All @@ -149,6 +147,8 @@ pub enum BuiltinScalarFunction {
ArrayRemoveN,
/// array_remove_all
ArrayRemoveAll,
/// array_repeat
ArrayRepeat,
/// array_replace
ArrayReplace,
/// array_replace_n
Expand Down Expand Up @@ -354,12 +354,12 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayFill => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
BuiltinScalarFunction::ArrayNdims => Volatility::Immutable,
BuiltinScalarFunction::ArrayPosition => Volatility::Immutable,
BuiltinScalarFunction::ArrayPositions => Volatility::Immutable,
BuiltinScalarFunction::ArrayPrepend => Volatility::Immutable,
BuiltinScalarFunction::ArrayRepeat => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemove => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemoveN => Volatility::Immutable,
BuiltinScalarFunction::ArrayRemoveAll => Volatility::Immutable,
Expand Down Expand Up @@ -536,18 +536,18 @@ impl BuiltinScalarFunction {
"The {self} function can only accept list as the first argument"
))),
},
BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new(
"item",
input_expr_types[1].clone(),
true,
)))),
BuiltinScalarFunction::ArrayLength => Ok(UInt64),
BuiltinScalarFunction::ArrayNdims => Ok(UInt64),
BuiltinScalarFunction::ArrayPosition => Ok(UInt64),
BuiltinScalarFunction::ArrayPositions => {
Ok(List(Arc::new(Field::new("item", UInt64, true))))
}
BuiltinScalarFunction::ArrayPrepend => Ok(input_expr_types[1].clone()),
BuiltinScalarFunction::ArrayRepeat => Ok(List(Arc::new(Field::new(
"item",
input_expr_types[0].clone(),
true,
)))),
BuiltinScalarFunction::ArrayRemove => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayRemoveN => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayRemoveAll => Ok(input_expr_types[0].clone()),
Expand Down Expand Up @@ -822,7 +822,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
Expand All @@ -832,6 +831,7 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()),
Expand Down Expand Up @@ -1310,7 +1310,6 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::ArrayHas => {
&["array_has", "list_has", "array_contains", "list_contains"]
}
BuiltinScalarFunction::ArrayFill => &["array_fill"],
BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"],
BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"],
BuiltinScalarFunction::ArrayPosition => &[
Expand All @@ -1326,6 +1325,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
"array_push_front",
"list_push_front",
],
BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"],
BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"],
BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"],
BuiltinScalarFunction::ArrayRemoveAll => &["array_remove_all", "list_remove_all"],
Expand Down
14 changes: 7 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,12 +576,6 @@ scalar_expr!(
array element,
"extracts the element with the index n from the array."
);
scalar_expr!(
ArrayFill,
array_fill,
element array,
"returns an array filled with copies of the given value."
);
scalar_expr!(
ArrayLength,
array_length,
Expand Down Expand Up @@ -612,6 +606,12 @@ scalar_expr!(
array element,
"prepends an element to the beginning of an array."
);
scalar_expr!(
ArrayRepeat,
array_repeat,
element count,
"returns an array containing element `count` times."
);
scalar_expr!(
ArrayRemove,
array_remove,
Expand Down Expand Up @@ -1062,12 +1062,12 @@ mod test {

test_scalar_expr!(ArrayAppend, array_append, array, element);
test_unary_scalar_expr!(ArrayDims, array_dims);
test_scalar_expr!(ArrayFill, array_fill, element, array);
test_scalar_expr!(ArrayLength, array_length, array, dimension);
test_unary_scalar_expr!(ArrayNdims, array_ndims);
test_scalar_expr!(ArrayPosition, array_position, array, element, index);
test_scalar_expr!(ArrayPositions, array_positions, array, element);
test_scalar_expr!(ArrayPrepend, array_prepend, array, element);
test_scalar_expr!(ArrayRepeat, array_repeat, element, count);
test_scalar_expr!(ArrayRemove, array_remove, array, element);
test_scalar_expr!(ArrayRemoveN, array_remove_n, array, element, max);
test_scalar_expr!(ArrayRemoveAll, array_remove_all, array, element);
Expand Down
Loading

0 comments on commit fa1b21c

Please sign in to comment.