Skip to content

Commit

Permalink
adding tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <[email protected]>
  • Loading branch information
sandeshkr419 committed Sep 30, 2024
1 parent 2aaa5f1 commit a18249d
Show file tree
Hide file tree
Showing 11 changed files with 418 additions and 381 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.codec.composite.CompositeIndexReader;
Expand All @@ -30,15 +29,14 @@
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.metrics.CompensatedSum;
import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.StarTreeFilter;
import org.opensearch.search.startree.StarTreeQueryContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -95,14 +93,16 @@ public static StarTreeQueryContext getStarTreeQueryContext(SearchContext context
return null;
}

boolean needCaching = context.aggregations().factories().getFactories().length > 1;
// List<MetricInfo> metricInfos = new ArrayList<>();

for (AggregatorFactory aggregatorFactory : context.aggregations().factories().getFactories()) {
MetricStat metricStat = validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory);
MetricStat metricStat = validateStarTreeMetricSupport(compositeMappedFieldType, aggregatorFactory);
if (metricStat == null) {
return null;
}
// metricInfos.add(new )
}

if (context.aggregations().factories().getFactories().length > 1) {
context.initializeStarTreeValuesMap();
}

return starTreeQueryContext;
Expand Down Expand Up @@ -134,7 +134,7 @@ private static StarTreeQueryContext toStarTreeQueryContext(

/**
* Parse query body to star-tree predicates
* @param queryBuilder
* @param queryBuilder to match supported query shape
* @return predicates to match
*/
private static Map<String, Long> getStarTreePredicates(QueryBuilder queryBuilder, List<String> supportedDimensions) {
Expand All @@ -151,11 +151,10 @@ private static Map<String, Long> getStarTreePredicates(QueryBuilder queryBuilder
return predicateMap;
}

private static MetricStat validateStarTreeMetricSuport(
private static MetricStat validateStarTreeMetricSupport(
CompositeDataCubeFieldType compositeIndexFieldInfo,
AggregatorFactory aggregatorFactory
) {
// List<MetricStat> metricStats = new ArrayList<>();
if (aggregatorFactory instanceof MetricAggregatorFactory && aggregatorFactory.getSubFactories().getFactories().length == 0) {
String field;
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
Expand Down Expand Up @@ -197,6 +196,7 @@ public static LeafBucketCollector getStarTreeLeafCollector(
Runnable finalConsumer
) throws IOException {
StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
assert starTreeValues != null;
String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, metric);

Expand All @@ -205,15 +205,10 @@ public static LeafBucketCollector getStarTreeLeafCollector(
metricName
);
// Obtain a FixedBitSet of matched document IDs
FixedBitSet matchedDocIds = context.getStarTreeFilteredValues(ctx, starTreeValues); // Assuming this method gives a FixedBitSet

// Safety check: make sure the FixedBitSet is non-null and valid
if (matchedDocIds == null) {
throw new IllegalStateException("FixedBitSet is null");
}
FixedBitSet matchedDocIds = getStarTreeFilteredValues(context, ctx, starTreeValues); // Assuming this method gives a FixedBitSet
assert matchedDocIds != null;

int numBits = matchedDocIds.length(); // Get the length of the FixedBitSet

if (numBits > 0) {
// Iterate over the FixedBitSet
for (int bit = matchedDocIds.nextSetBit(0); bit != -1; bit = (bit + 1 < numBits) ? matchedDocIds.nextSetBit(bit + 1) : -1) {
Expand All @@ -230,6 +225,7 @@ public static LeafBucketCollector getStarTreeLeafCollector(
}
}


// Call the final consumer after processing all entries
finalConsumer.run();

Expand All @@ -241,4 +237,16 @@ public void collect(int doc, long bucket) {
}
};
}

public static FixedBitSet getStarTreeFilteredValues(SearchContext context, LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
if (context.getStarTreeValuesMap() != null && context.getStarTreeValuesMap().containsKey(ctx)) {
return context.getStarTreeValuesMap().get(ctx);
}

StarTreeFilter filter = new StarTreeFilter(starTreeValues, context.getStarTreeQueryContext().getQueryMap());
FixedBitSet result = filter.getStarTreeResult();

context.getStarTreeValuesMap().put(ctx, result);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ protected void parseCreateField(ParseContext context) {
* @opensearch.experimental
*/
@ExperimentalApi
public static final class StarTreeFieldType extends CompositeDataCubeFieldType {
public static class StarTreeFieldType extends CompositeDataCubeFieldType {

private final StarTreeFieldConfiguration starTreeConfig;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.Version;
import org.opensearch.action.search.SearchShardTask;
Expand All @@ -59,10 +58,7 @@
import org.opensearch.index.IndexService;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.cache.bitset.BitsetFilterCache;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper;
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.StarTreeValuesIterator;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -185,7 +181,6 @@ final class DefaultSearchContext extends SearchContext {
private SliceBuilder sliceBuilder;
private SearchShardTask task;
private final Version minNodeVersion;
private StarTreeQueryContext starTreeQueryContext;

/**
* The original query as sent by the user without the types and aliases
Expand Down Expand Up @@ -1158,27 +1153,4 @@ public boolean evaluateKeywordIndexOrDocValuesEnabled() {
}
return false;
}

@Override
public SearchContext starTreeQueryContext(StarTreeQueryContext starTreeQueryContext) {
this.starTreeQueryContext = starTreeQueryContext;
return this;
}

@Override
public StarTreeQueryContext getStarTreeQueryContext() {
return this.starTreeQueryContext;
}

@Override
public FixedBitSet getStarTreeFilteredValues(LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
if (this.starTreeValuesMap.containsKey(ctx)) {
return starTreeValuesMap.get(ctx);
}
StarTreeFilter filter = new StarTreeFilter(starTreeValues, this.getStarTreeQueryContext().getQueryMap());
FixedBitSet result = filter.getStarTreeResult();

starTreeValuesMap.put(ctx, result);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.StarTreeFilter;

import java.io.IOException;
import java.util.Map;
import java.util.function.Consumer;

import static org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper.getStarTreeFilteredValues;
import static org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper.getSupportedStarTree;

/**
Expand Down Expand Up @@ -109,13 +108,12 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
}
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context);
if (supportedStarTree != null) {
// return getStarTreeLeafCollector(ctx, sub, supportedStarTree);
return getStarTreeLeafCollector(ctx, sub, supportedStarTree);
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
Expand Down Expand Up @@ -149,105 +147,11 @@ public void collect(int doc, long bucket) throws IOException {
};
}

public LeafBucketCollector getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return StarTreeQueryHelper.getStarTreeLeafCollector(
context,
valuesSource,
ctx,
sub,
starTree,
MetricStat.SUM.getTypeName(),
value -> kahanSummation.add(NumericUtils.sortableLongToDouble(value)),
() -> sums.set(0, kahanSummation.value())
);
}

// private LeafBucketCollector getStarTreeLeafCollector1(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
// throws IOException {
// final BigArrays bigArrays = context.bigArrays();
// final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
//
// StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
// String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
// String sumMetricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(
// starTree.getField(),
// fieldName,
// MetricStat.SUM.getTypeName()
// );
// assert starTreeValues != null;
// SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocIdSetIterator(sumMetricName);
//
// String countMetricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(
// starTree.getField(),
// fieldName,
// MetricStat.VALUE_COUNT.getTypeName()
// );
// SortedNumericDocValues countValues = (SortedNumericDocValues) starTreeValues.getMetricDocIdSetIterator(countMetricName);
//
// return new LeafBucketCollectorBase(sub, values) {
// @Override
// public void collect(int doc, long bucket) throws IOException {
// counts = bigArrays.grow(counts, bucket + 1);
// sums = bigArrays.grow(sums, bucket + 1);
// compensations = bigArrays.grow(compensations, bucket + 1);
//
// if (values.advanceExact(doc) && countValues.advanceExact(doc)) {
// final long valueCount = values.docValueCount();
// counts.increment(bucket, countValues.nextValue());
// // Compute the sum of double values with Kahan summation algorithm which is more
// // accurate than naive summation.
// double sum = sums.get(bucket);
// double compensation = compensations.get(bucket);
//
// kahanSummation.reset(sum, compensation);
//
// for (int i = 0; i < valueCount; i++) {
// double value = NumericUtils.sortableLongToDouble(values.nextValue());
// kahanSummation.add(value);
// }
//
// sums.set(bucket, kahanSummation.value());
// compensations.set(bucket, kahanSummation.delta());
// }
// }
// };
// }


public LeafBucketCollector getStarTreeCollector2(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
StarTreeQueryHelper.getStarTreeLeafCollector(
context,
valuesSource,
ctx,
sub,
starTree,
MetricStat.SUM.getTypeName(),
value -> kahanSummation.add(NumericUtils.sortableLongToDouble(value)),
() -> sums.set(0, kahanSummation.value())
);
StarTreeQueryHelper.getStarTreeLeafCollector(
context,
valuesSource,
ctx,
sub,
starTree,
MetricStat.VALUE_COUNT.getTypeName(),
value -> counts.increment(0, value),
() -> {}
);
return LeafBucketCollector.NO_OP_COLLECTOR;
}

public LeafBucketCollector getStarTreeLeafCollector(
LeafReaderContext ctx,
LeafBucketCollector sub,
CompositeIndexFieldInfo starTree
) throws IOException {
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
assert starTreeValues != null;

String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
String sumMetricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(
starTree.getField(),
Expand All @@ -260,35 +164,31 @@ public LeafBucketCollector getStarTreeLeafCollector(
MetricStat.VALUE_COUNT.getTypeName()
);

assert starTreeValues != null;

final CompensatedSum kahanSummation = new CompensatedSum(sums.get(0), 0);
SortedNumericStarTreeValuesIterator sumValuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues.getMetricValuesIterator(
sumMetricName
);
SortedNumericStarTreeValuesIterator countValueIterator = (SortedNumericStarTreeValuesIterator) starTreeValues.getMetricValuesIterator(
countMetricName
);
FixedBitSet matchedDocIds = context.getStarTreeFilteredValues(ctx, starTreeValues);


// Safety check: make sure the FixedBitSet is non-null and valid
if (matchedDocIds == null) {
throw new IllegalStateException("FixedBitSet is null");
}
SortedNumericStarTreeValuesIterator sumValuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues
.getMetricValuesIterator(sumMetricName);
SortedNumericStarTreeValuesIterator countValueIterator = (SortedNumericStarTreeValuesIterator) starTreeValues
.getMetricValuesIterator(countMetricName);
FixedBitSet matchedDocIds = getStarTreeFilteredValues(context, ctx, starTreeValues);
assert matchedDocIds != null;

int numBits = matchedDocIds.length(); // Get the length of the FixedBitSet
if (numBits > 0) {
// Iterate over the FixedBitSet
for (int bit = matchedDocIds.nextSetBit(0); bit != -1; bit = bit + 1 < numBits ? matchedDocIds.nextSetBit(bit + 1) : -1) {
// Advance to the bit (entryId) in the valuesIterator
if (sumValuesIterator.advance(bit) == StarTreeValuesIterator.NO_MORE_ENTRIES
|| countValueIterator.advance(bit) == StarTreeValuesIterator.NO_MORE_ENTRIES) {
continue; // Skip if no more entries
}

// Iterate over the FixedBitSet
for (int bit = matchedDocIds.nextSetBit(0); bit != -1; bit = bit + 1 < numBits ? matchedDocIds.nextSetBit(bit + 1) : -1) {
// Advance to the bit (entryId) in the valuesIterator
if (sumValuesIterator.advance(bit) != StarTreeValuesIterator.NO_MORE_ENTRIES &&
countValueIterator.advance(bit) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
int count = sumValuesIterator.valuesCount();
for (int i = 0; i < count; i++) {
// Iterate over the values for the current entryId
for (int i = 0; i < sumValuesIterator.valuesCount(); i++) {
kahanSummation.add(NumericUtils.sortableLongToDouble(sumValuesIterator.nextValue()));
counts.increment(0, countValueIterator.nextValue()); // Apply the consumer operation (e.g., max, sum)
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
if (supportedStarTree != null) {
return getStarTreeCollector(ctx, sub, supportedStarTree);
}
System.out.println("nopes nopes");
return getDefaultLeafCollector(ctx, sub);
}

Expand Down
Loading

0 comments on commit a18249d

Please sign in to comment.