Skip to content

Commit

Permalink
Adding test cases and max,min aggregator support
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <[email protected]>
  • Loading branch information
sandeshkr419 committed Aug 23, 2024
1 parent e16e01f commit e4a270a
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
import org.opensearch.script.ScriptFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.metrics.MaxAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MinAggregatorFactory;
import org.opensearch.search.aggregations.metrics.SumAggregatorFactory;
import org.opensearch.search.aggregations.support.AggregationUsageService;
import org.opensearch.search.aggregations.support.ValuesSourceRegistry;
Expand Down Expand Up @@ -585,7 +587,7 @@ private Map<String, List<Predicate<Long>>> getStarTreePredicates(QueryBuilder qu
}

public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) {
String field = null;
String field;
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
.stream()
.collect(Collectors.toMap(Metric::getField, Metric::getMetrics));
Expand All @@ -595,14 +597,16 @@ public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType composite
return false;
}

// TODO: increment supported aggregation type
if (aggregatorFactory instanceof SumAggregatorFactory) {
field = ((SumAggregatorFactory) aggregatorFactory).getField();
if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM)) {
return true;
}
return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM);
} else if (aggregatorFactory instanceof MaxAggregatorFactory) {
field = ((MaxAggregatorFactory) aggregatorFactory).getField();
return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MAX);
} else if (aggregatorFactory instanceof MinAggregatorFactory) {
field = ((MinAggregatorFactory) aggregatorFactory).getField();
return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MIN);
}

return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.DoubleArray;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils;
import org.opensearch.index.fielddata.NumericDoubleValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.search.DocValueFormat;
Expand All @@ -51,6 +55,8 @@
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
import org.opensearch.search.startree.StarTreeQuery;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -120,6 +126,16 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
throw new CollectionTerminatedException();
}
}

if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) {
StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery();
return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree());
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
final NumericDoubleValues values = MultiValueMode.MAX.select(allValues);
Expand All @@ -143,6 +159,34 @@ public void collect(int doc, long bucket) throws IOException {
};
}

private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, "max");
SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocValuesIteratorMap().get(metricName);

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
return new LeafBucketCollectorBase(sub, allValues) {

@Override
public void collect(int doc, long bucket) throws IOException {
if (bucket >= maxes.size()) {
long from = maxes.size();
maxes = bigArrays.grow(maxes, bucket + 1);
maxes.fill(from, maxes.size(), Double.NEGATIVE_INFINITY);
}
if (values.advanceExact(doc)) {
final double value = Double.longBitsToDouble(values.nextValue());
double max = maxes.get(bucket);
max = Math.max(max, value);
maxes.set(bucket, max);
}
}
};
}

@Override
public double metric(long owningBucketOrd) {
if (valuesSource == null || owningBucketOrd >= maxes.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
*
* @opensearch.internal
*/
class MaxAggregatorFactory extends ValuesSourceAggregatorFactory {
public class MaxAggregatorFactory extends ValuesSourceAggregatorFactory {

static void registerAggregators(ValuesSourceRegistry.Builder builder) {
builder.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.DoubleArray;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils;
import org.opensearch.index.fielddata.NumericDoubleValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.search.DocValueFormat;
Expand All @@ -51,6 +55,8 @@
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.OriginalOrStarTreeQuery;
import org.opensearch.search.startree.StarTreeQuery;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -119,6 +125,16 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
throw new CollectionTerminatedException();
}
}

if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) {
StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery();
return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree());
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub)
throws IOException {
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
final NumericDoubleValues values = MultiValueMode.MIN.select(allValues);
Expand All @@ -138,10 +154,38 @@ public void collect(int doc, long bucket) throws IOException {
mins.set(bucket, min);
}
}
};
}

private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, "min");
SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocValuesIteratorMap().get(metricName);

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
return new LeafBucketCollectorBase(sub, allValues) {

@Override
public void collect(int doc, long bucket) throws IOException {
if (bucket >= mins.size()) {
long from = mins.size();
mins = bigArrays.grow(mins, bucket + 1);
mins.fill(from, mins.size(), Double.POSITIVE_INFINITY);
}
if (values.advanceExact(doc)) {
final double value = Double.longBitsToDouble(values.nextValue());
double min = mins.get(bucket);
min = Math.min(min, value);
mins.set(bucket, min);
}
}
};
}


@Override
public double metric(long owningBucketOrd) {
if (valuesSource == null || owningBucketOrd >= mins.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
*
* @opensearch.internal
*/
class MinAggregatorFactory extends ValuesSourceAggregatorFactory {
public class MinAggregatorFactory extends ValuesSourceAggregatorFactory {

static void registerAggregators(ValuesSourceRegistry.Builder builder) {
builder.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ public ScoreMode scoreMode() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) {
StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery();
return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree());
Expand All @@ -100,9 +104,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
Expand Down Expand Up @@ -134,15 +135,14 @@ public void collect(int doc, long bucket) throws IOException {

private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
final BigArrays bigArrays = context.bigArrays();
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, "sum");

SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocValuesIteratorMap().get(metricName);

final BigArrays bigArrays = context.bigArrays();
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ public void visit(QueryVisitor queryVisitor) {

@Override
public boolean equals(Object o) {
return false;
return true;
}

@Override
public int hashCode() {
return 0;
return originalQuery.hashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@
*/
@LuceneTestCase.SuppressSysoutChecks(bugUrl = "we log a lot on purpose")
public class StarTreeDocValuesFormatTests extends BaseDocValuesFormatTestCase {
MapperService mapperService = null;

StarTreeFieldConfiguration.StarTreeBuildMode buildMode;
MapperService mapperService;

public StarTreeDocValuesFormatTests(StarTreeFieldConfiguration.StarTreeBuildMode buildMode) {
this.buildMode = buildMode;
Expand Down Expand Up @@ -105,13 +106,14 @@ public void teardown() throws IOException {
@Override
protected Codec getCodec() {
final Logger testLogger = LogManager.getLogger(StarTreeDocValuesFormatTests.class);

Codec codec;
try {
createMapperService(getExpandedMapping());
mapperService = createMapperService(getExpandedMapping());
codec = new Composite99Codec(Lucene99Codec.Mode.BEST_SPEED, mapperService, testLogger);
} catch (IOException e) {
throw new RuntimeException(e);
}
Codec codec = new Composite99Codec(Lucene99Codec.Mode.BEST_SPEED, mapperService, testLogger);

return codec;
}

Expand Down Expand Up @@ -195,7 +197,7 @@ public void testStarTreeDocValues() throws IOException {
directory.close();
}

private XContentBuilder getExpandedMapping() throws IOException {
public static XContentBuilder getExpandedMapping() throws IOException {
return topMapping(b -> {
b.startObject("composite");
b.startObject("startree");
Expand All @@ -215,6 +217,8 @@ private XContentBuilder getExpandedMapping() throws IOException {
b.field("name", "field");
b.startArray("stats");
b.value("sum");
b.value("max");
b.value("min");
b.value("count");
b.endArray();
b.endObject();
Expand All @@ -236,13 +240,14 @@ private XContentBuilder getExpandedMapping() throws IOException {
});
}

private XContentBuilder topMapping(CheckedConsumer<XContentBuilder, IOException> buildFields) throws IOException {
private static XContentBuilder topMapping(CheckedConsumer<XContentBuilder, IOException> buildFields) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("_doc");
buildFields.accept(builder);
return builder.endObject().endObject();
}

private void createMapperService(XContentBuilder builder) throws IOException {
public static MapperService createMapperService(XContentBuilder builder) throws IOException {
MapperService mapperService = null;
IndexMetadata indexMetadata = IndexMetadata.builder("test")
.settings(
Settings.builder()
Expand All @@ -261,5 +266,6 @@ private void createMapperService(XContentBuilder builder) throws IOException {
"test"
);
mapperService.merge(indexMetadata, MapperService.MergeReason.INDEX_TEMPLATE);
return mapperService;
}
}
Loading

0 comments on commit e4a270a

Please sign in to comment.