Skip to content

Commit

Permalink
JNI for HISTOGRAM and MERGE_HISTOGRAM aggregations (#14154)
Browse files Browse the repository at this point in the history
This implements JNI for  `HISTOGRAM` and `MERGE_HISTOGRAM` aggregations in both groupby and reduction.

Depends on:
 * #14045

Contributes to:
 * #13885.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Jason Lowe (https://github.com/jlowe)

URL: #14154
  • Loading branch information
ttnghia authored Sep 27, 2023
1 parent bff0fcd commit 66ac962
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 5 deletions.
26 changes: 24 additions & 2 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,9 @@ enum Kind {
DENSE_RANK(29),
PERCENT_RANK(30),
TDIGEST(31), // This can take a delta argument for accuracy level
MERGE_TDIGEST(32); // This can take a delta argument for accuracy level
MERGE_TDIGEST(32), // This can take a delta argument for accuracy level
HISTOGRAM(33),
MERGE_HISTOGRAM(34);

final int nativeId;

Expand Down Expand Up @@ -918,6 +920,26 @@ static TDigestAggregation mergeTDigest(int delta) {
return new TDigestAggregation(Kind.MERGE_TDIGEST, delta);
}

static final class HistogramAggregation extends NoParamAggregation {
private HistogramAggregation() {
super(Kind.HISTOGRAM);
}
}

static final class MergeHistogramAggregation extends NoParamAggregation {
private MergeHistogramAggregation() {
super(Kind.MERGE_HISTOGRAM);
}
}

static HistogramAggregation histogram() {
return new HistogramAggregation();
}

static MergeHistogramAggregation mergeHistogram() {
return new MergeHistogramAggregation();
}

/**
* Create one of the aggregations that only needs a kind, no other parameters. This does not
* work for all types and for code safety reasons each kind is added separately.
Expand Down
24 changes: 23 additions & 1 deletion java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -315,4 +315,26 @@ public static GroupByAggregation createTDigest(int delta) {
public static GroupByAggregation mergeTDigest(int delta) {
return new GroupByAggregation(Aggregation.mergeTDigest(delta));
}

/**
* Histogram aggregation, computing the frequencies for each unique row.
*
* A histogram is given as a lists column, in which the first child stores unique rows from
* the input values and the second child stores their corresponding frequencies.
*
* @return A lists of structs column in which each list contains a histogram corresponding to
* an input key.
*/
public static GroupByAggregation histogram() {
return new GroupByAggregation(Aggregation.histogram());
}

/**
* MergeHistogram aggregation, to merge multiple histograms.
*
* @return A new histogram in which the frequencies of the unique rows are sum up.
*/
public static GroupByAggregation mergeHistogram() {
return new GroupByAggregation(Aggregation.mergeHistogram());
}
}
20 changes: 19 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ReductionAggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -286,4 +286,22 @@ public static ReductionAggregation mergeSets(NullEquality nullEquality, NaNEqual
return new ReductionAggregation(Aggregation.mergeSets(nullEquality, nanEquality));
}

/**
* Create HistogramAggregation, computing the frequencies for each unique row.
*
* @return A structs column in which the first child stores unique rows from the input and the
* second child stores their corresponding frequencies.
*/
public static ReductionAggregation histogram() {
return new ReductionAggregation(Aggregation.histogram());
}

/**
* Create MergeHistogramAggregation, to merge multiple histograms.
*
* @return A new histogram in which the frequencies of the unique rows are sum up.
*/
public static ReductionAggregation mergeHistogram() {
return new ReductionAggregation(Aggregation.mergeHistogram());
}
}
7 changes: 6 additions & 1 deletion java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -90,6 +90,11 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv
case 30: // ANSI SQL PERCENT_RANK
return cudf::make_rank_aggregation(cudf::rank_method::MIN, {}, cudf::null_policy::INCLUDE,
{}, cudf::rank_percentage::ONE_NORMALIZED);
case 33: // HISTOGRAM
return cudf::make_histogram_aggregation();
case 34: // MERGE_HISTOGRAM
return cudf::make_merge_histogram_aggregation();

default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
}
}();
Expand Down
109 changes: 109 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4129,6 +4129,115 @@ void testMergeTDigestReduction() {
}
}

@Test
void testGroupbyHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies
ListType histogramList = new ListType(false, histogramStruct);

// key = 0: values = [2, 2, -3, -2, 2]
// key = 1: values = [2, 0, 5, 2, 1]
// key = 2: values = [-3, 1, 1, 2, 2]
try (Table input = new Table.TestBuilder()
.column(2, 0, 2, 1, 1, 1, 0, 0, 0, 1, 2, 2, 1, 0, 2)
.column(-3, 2, 1, 2, 0, 5, 2, -3, -2, 2, 1, 2, 1, 2, 2)
.build();
Table result = input.groupBy(0)
.aggregate(GroupByAggregation.histogram().onColumn(1));
Table sortedResult = result.orderBy(OrderByArg.asc(0));
ColumnVector sortedOutHistograms = sortedResult.getColumn(1).listSortRows(false, false);

ColumnVector expectedKeys = ColumnVector.fromInts(0, 1, 2);
ColumnVector expectedHistograms = ColumnVector.fromLists(histogramList,
Arrays.asList(new StructData(-3, 1L), new StructData(-2, 1L), new StructData(2, 3L)),
Arrays.asList(new StructData(0, 1L), new StructData(1, 1L), new StructData(2, 2L),
new StructData(5, 1L)),
Arrays.asList(new StructData(-3, 1L), new StructData(1, 2L), new StructData(2, 2L)))
) {
assertColumnsAreEqual(expectedKeys, sortedResult.getColumn(0));
assertColumnsAreEqual(expectedHistograms, sortedOutHistograms);
}
}

@Test
void testGroupbyMergeHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies
ListType histogramList = new ListType(false, histogramStruct);

// key = 0: histograms = [[<-3, 1>, <-2, 1>, <2, 3>], [<0, 1>, <1, 1>], [<-3, 3>, <0, 1>, <1, 2>]]
// key = 1: histograms = [[<-2, 1>, <1, 3>, <2, 2>], [<0, 2>, <1, 1>, <2, 2>]]
try (Table input = new Table.TestBuilder()
.column(0, 1, 0, 1, 0)
.column(histogramStruct,
new StructData[]{new StructData(-3, 1L), new StructData(-2, 1L), new StructData(2, 3L)},
new StructData[]{new StructData(-2, 1L), new StructData(1, 3L), new StructData(2, 2L)},
new StructData[]{new StructData(0, 1L), new StructData(1, 1L)},
new StructData[]{new StructData(0, 2L), new StructData(1, 1L), new StructData(2, 2L)},
new StructData[]{new StructData(-3, 3L), new StructData(0, 1L), new StructData(1, 2L)})
.build();
Table result = input.groupBy(0)
.aggregate(GroupByAggregation.mergeHistogram().onColumn(1));
Table sortedResult = result.orderBy(OrderByArg.asc(0));
ColumnVector sortedOutHistograms = sortedResult.getColumn(1).listSortRows(false, false);

ColumnVector expectedKeys = ColumnVector.fromInts(0, 1);
ColumnVector expectedHistograms = ColumnVector.fromLists(histogramList,
Arrays.asList(new StructData(-3, 4L), new StructData(-2, 1L), new StructData(0, 2L),
new StructData(1, 3L), new StructData(2, 3L)),
Arrays.asList(new StructData(-2, 1L), new StructData(0, 2L), new StructData(1, 4L),
new StructData(2, 4L)))
) {
assertColumnsAreEqual(expectedKeys, sortedResult.getColumn(0));
assertColumnsAreEqual(expectedHistograms, sortedOutHistograms);
}
}

@Test
void testReductionHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies

try (ColumnVector input = ColumnVector.fromInts(-3, 2, 1, 2, 0, 5, 2, -3, -2, 2, 1);
Scalar result = input.reduce(ReductionAggregation.histogram(), DType.LIST);
ColumnVector resultCV = result.getListAsColumnView().copyToColumnVector();
Table resultTable = new Table(resultCV);
Table sortedResult = resultTable.orderBy(OrderByArg.asc(0));

ColumnVector expectedHistograms = ColumnVector.fromStructs(histogramStruct,
new StructData(-3, 2L), new StructData(-2, 1L), new StructData(0, 1L),
new StructData(1, 2L), new StructData(2, 4L), new StructData(5, 1L))
) {
assertColumnsAreEqual(expectedHistograms, sortedResult.getColumn(0));
}
}

@Test
void testReductionMergeHistogram() {
StructType histogramStruct = new StructType(false,
new BasicType(false, DType.INT32), // values
new BasicType(false, DType.INT64)); // frequencies

try (ColumnVector input = ColumnVector.fromStructs(histogramStruct,
new StructData(-3, 2L), new StructData(2, 1L), new StructData(1, 1L),
new StructData(2, 2L), new StructData(0, 4L), new StructData(5, 1L),
new StructData(2, 2L), new StructData(-3, 3L), new StructData(-2, 5L),
new StructData(2, 3L), new StructData(1, 4L));
Scalar result = input.reduce(ReductionAggregation.mergeHistogram(), DType.LIST);
ColumnVector resultCV = result.getListAsColumnView().copyToColumnVector();
Table resultTable = new Table(resultCV);
Table sortedResult = resultTable.orderBy(OrderByArg.asc(0));

ColumnVector expectedHistograms = ColumnVector.fromStructs(histogramStruct,
new StructData(-3, 5L), new StructData(-2, 5L), new StructData(0, 4L),
new StructData(1, 5L), new StructData(2, 8L), new StructData(5, 1L))
) {
assertColumnsAreEqual(expectedHistograms, sortedResult.getColumn(0));
}
}
@Test
void testGroupByMinMaxDecimal() {
try (Table t1 = new Table.TestBuilder()
Expand Down

0 comments on commit 66ac962

Please sign in to comment.