Skip to content

Commit

Permalink
temp temp temp
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <[email protected]>
  • Loading branch information
sandeshkr419 committed Sep 25, 2024
1 parent afbc128 commit 30b95d4
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.codec.composite.CompositeIndexReader;
Expand All @@ -28,14 +30,15 @@
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.metrics.CompensatedSum;
import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.StarTreeFilter;
import org.opensearch.search.startree.StarTreeQueryContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -50,8 +53,6 @@
*/
public class StarTreeQueryHelper {

private static Map<LeafReaderContext, StarTreeValues> starTreeValuesMap = new HashMap<>();

/**
* Checks if the search context can be supported by star-tree
*/
Expand All @@ -69,10 +70,6 @@ public static boolean isStarTreeSupported(SearchContext context, boolean trackTo
return canUseStarTree;
}

/**
* Gets a parsed OriginalOrStarTreeQuery from the search context and source builder.
* Returns null if the query cannot be supported.
*/

/**
* Gets a parsed OriginalOrStarTreeQuery from the search context and source builder.
Expand All @@ -98,10 +95,14 @@ public static StarTreeQueryContext getStarTreeQueryContext(SearchContext context
return null;
}

boolean needCaching = context.aggregations().factories().getFactories().length > 1;
// List<MetricInfo> metricInfos = new ArrayList<>();
for (AggregatorFactory aggregatorFactory : context.aggregations().factories().getFactories()) {
if (validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory) == false) {
MetricStat metricStat = validateStarTreeMetricSuport(compositeMappedFieldType, aggregatorFactory);
if (metricStat == null) {
return null;
}
// metricInfos.add(new )
}

return starTreeQueryContext;
Expand Down Expand Up @@ -150,10 +151,11 @@ private static Map<String, Long> getStarTreePredicates(QueryBuilder queryBuilder
return predicateMap;
}

private static boolean validateStarTreeMetricSuport(
private static MetricStat validateStarTreeMetricSuport(
CompositeDataCubeFieldType compositeIndexFieldInfo,
AggregatorFactory aggregatorFactory
) {
// List<MetricStat> metricStats = new ArrayList<>();
if (aggregatorFactory instanceof MetricAggregatorFactory && aggregatorFactory.getSubFactories().getFactories().length == 0) {
String field;
Map<String, List<MetricStat>> supportedMetrics = compositeIndexFieldInfo.getMetrics()
Expand All @@ -162,10 +164,12 @@ private static boolean validateStarTreeMetricSuport(

MetricStat metricStat = ((MetricAggregatorFactory) aggregatorFactory).getMetricStat();
field = ((MetricAggregatorFactory) aggregatorFactory).getField();
return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(metricStat);
} else {
return false;

if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(metricStat)) {
return metricStat;
}
}
return null;
}

public static CompositeIndexFieldInfo getSupportedStarTree(SearchContext context) {
Expand Down Expand Up @@ -200,24 +204,164 @@ public static LeafBucketCollector getStarTreeLeafCollector(
SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues.getMetricValuesIterator(
metricName
);
StarTreeValuesIterator result = context.getStarTreeFilteredValues(ctx, starTreeValues);
// Obtain a FixedBitSet of matched document IDs
FixedBitSet matchedDocIds = context.getStarTreeFilteredValues(ctx, starTreeValues); // Assuming this method gives a FixedBitSet

// Safety check: make sure the FixedBitSet is non-null and valid
if (matchedDocIds == null) {
throw new IllegalStateException("FixedBitSet is null");
}

int numBits = matchedDocIds.length(); // Get the length of the FixedBitSet

int entryId;
while ((entryId = result.nextEntry()) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
if (valuesIterator.advance(entryId) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
// Iterate over the FixedBitSet
for (int bit = matchedDocIds.nextSetBit(0); bit != -1; bit = bit + 1 < numBits ? matchedDocIds.nextSetBit(bit + 1) : -1) {
// Advance to the bit (entryId) in the valuesIterator
if (valuesIterator.advance(bit) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
int count = valuesIterator.valuesCount();
for (int i = 0; i < count; i++) {
long value = valuesIterator.nextValue();
valueConsumer.accept(value); // Apply the operation (max, sum, etc.)
valueConsumer.accept(value); // Apply the consumer operation (e.g., max, sum)
}
}
}

// Call the final consumer after processing all entries
finalConsumer.run();

// Return a LeafBucketCollector that terminates collection
return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
@Override
public void collect(int doc, long bucket) {
throw new CollectionTerminatedException();
}
};
}

// public static LeafBucketCollector getStarTreeLeafCollectorNew(
// SearchContext context,
// ValuesSource.Numeric valuesSource,
// LeafReaderContext ctx,
// LeafBucketCollector sub,
// CompositeIndexFieldInfo starTree,
// String metric,
// Consumer<Long> valueConsumer,
// Runnable finalConsumer
// ) throws IOException {
// // Check in contextCache if the star-tree values are already computed
// Map<LeafReaderContext, Map<String, StarTreeQueryHelper.MetricInfo>> cache = context.getStarTreeQueryContext().getLeafResultsCache();
// if(cache != null) {
//
// if (cache.containsKey(ctx)) {
// MetricInfo metricInfoMap = cache.get(ctx).get(metric);
// finalConsumer.run();
// }
// }
// else if (!cache.containsKey(ctx)) {
// // TODO: fetch from cache
//
// } else {
// // TODO: compute cache first
// }
//
// StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
// String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName();
// String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, metric);
//
// assert starTreeValues != null;
// List<SortedNumericStarTreeValuesIterator> valuesIterators;
// SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues.getMetricValuesIterator(
// metricName
// );
// StarTreeValuesIterator result = context.getStarTreeFilteredValues(ctx, starTreeValues);
//
// int entryId;
// while ((entryId = result.nextEntry()) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
// for
// if (valuesIterator.advance(entryId) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
// int count = valuesIterator.valuesCount();
// for (int i = 0; i < count; i++) {
// long value = valuesIterator.nextValue();
// valueConsumer.accept(value); // Apply the operation (max, sum, etc.)
// }
// }
// }
// finalConsumer.run();
// return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
// @Override
// public void collect(int doc, long bucket) {
// throw new CollectionTerminatedException();
// }
// };
// }
//
// public abstract class MetricInfo {
// String metric;
// MetricStat metricStat;
//
//
//
// MetricInfo (String metric, MetricStat metricStat) {
// if (metricStat == MetricStat.SUM) {
// return new SumMetricInfo(metric);
// }
// return null;
// }
//
//
// public abstract void valueConsumer(long value);
//
// public abstract <T extends Number> T getMetricValue();
// }
//
// public class SumMetricInfo extends MetricInfo {
// CompensatedSum compensatedSum;
//
// public SumMetricInfo(String metric) {
// super(metric, MetricStat.SUM);
// compensatedSum = new CompensatedSum(0,0);
// }
//
// public void valueConsumer(long value) {
// compensatedSum.add(NumericUtils.sortableLongToDouble(value));
// }
//
// public Double getMetricValue() {
// return compensatedSum.value();
// }
// }
//
// public static void computeLeafResultsCache(SearchContext context,
// LeafReaderContext ctx,
// CompositeIndexFieldInfo starTree,
// List<MetricInfo> metricInfos) throws IOException {
// Map<String, MetricInfo> leafCache = new HashMap<>();
// StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree);
// assert starTreeValues != null;
// StarTreeValuesIterator result = context.getStarTreeFilteredValues(ctx, starTreeValues);
//
// List<Integer> entryIdCache = new ArrayList<>();
// int entryId;
// while ((entryId = result.nextEntry()) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
// entryIdCache.add(entryId);
// }
//
// for (MetricInfo metricInfo : metricInfos) {
// SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues.getMetricValuesIterator(
// metricInfo.metric
// );
//
// for (int cachedEntryId : entryIdCache) {
// if (valuesIterator.advance(cachedEntryId) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
// int count = valuesIterator.valuesCount();
// for (int i = 0; i < count; i++) {
// long value = valuesIterator.nextValue();
// metricInfo.valueConsumer(value);
// }
// }
// }
// leafCache.put(metricInfo.metric, metricInfo);
// }
// context.getStarTreeQueryContext().getLeafResultsCache().put(ctx, leafCache);
// }
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.Version;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.action.search.SearchType;
Expand Down Expand Up @@ -1170,13 +1171,12 @@ public StarTreeQueryContext getStarTreeQueryContext() {
}

@Override
public StarTreeValuesIterator getStarTreeFilteredValues(LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
public FixedBitSet getStarTreeFilteredValues(LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
if (this.starTreeValuesMap.containsKey(ctx)) {

return starTreeValuesMap.get(ctx);
}
StarTreeFilter filter = new StarTreeFilter(starTreeValues, this.getStarTreeQueryContext().getQueryMap());
StarTreeValuesIterator result = filter.getStarTreeResult();
FixedBitSet result = filter.getStarTreeResult();

starTreeValuesMap.put(ctx, result);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.BigArrays;
Expand Down Expand Up @@ -268,18 +269,29 @@ public LeafBucketCollector getStarTreeLeafCollector(
SortedNumericStarTreeValuesIterator countValueIterator = (SortedNumericStarTreeValuesIterator) starTreeValues.getMetricValuesIterator(
countMetricName
);
StarTreeValuesIterator result = context.getStarTreeFilteredValues(ctx, starTreeValues);
FixedBitSet matchedDocIds = context.getStarTreeFilteredValues(ctx, starTreeValues);

int entryId;
while ((entryId = result.nextEntry()) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
if (sumValuesIterator.advance(entryId) != StarTreeValuesIterator.NO_MORE_ENTRIES) {

// Safety check: make sure the FixedBitSet is non-null and valid
if (matchedDocIds == null) {
throw new IllegalStateException("FixedBitSet is null");
}

int numBits = matchedDocIds.length(); // Get the length of the FixedBitSet

// Iterate over the FixedBitSet
for (int bit = matchedDocIds.nextSetBit(0); bit != -1; bit = bit + 1 < numBits ? matchedDocIds.nextSetBit(bit + 1) : -1) {
// Advance to the bit (entryId) in the valuesIterator
if (sumValuesIterator.advance(bit) != StarTreeValuesIterator.NO_MORE_ENTRIES &&
countValueIterator.advance(bit) != StarTreeValuesIterator.NO_MORE_ENTRIES) {
int count = sumValuesIterator.valuesCount();
for (int i = 0; i < count; i++) {
kahanSummation.add(NumericUtils.sortableLongToDouble(sumValuesIterator.nextValue()));
counts.increment(0, countValueIterator.nextValue());
counts.increment(0, countValueIterator.nextValue()); // Apply the consumer operation (e.g., max, sum)
}
}
}

sums.set(0, kahanSummation.value());
return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public LeafBucketCollector getStarTreeCollector(LeafReaderContext ctx, LeafBucke
ctx,
sub,
starTree,
MetricStat.SUM.getTypeName(),
MetricStat.MAX.getTypeName(),
value -> {
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public LeafBucketCollector getStarTreeCollector(LeafReaderContext ctx, LeafBucke
ctx,
sub,
starTree,
MetricStat.SUM.getTypeName(),
MetricStat.MIN.getTypeName(),
value -> {
min.set(Math.min(min.get(), (NumericUtils.sortableLongToDouble(value))));
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
if (supportedStarTree != null) {
return getStarTreeCollector(ctx, sub, supportedStarTree);
}
System.out.println("nopes nopes");
return getDefaultLeafCollector(ctx, sub);
}

Expand Down Expand Up @@ -135,7 +136,7 @@ public void collect(int doc, long bucket) throws IOException {

public LeafBucketCollector getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
final CompensatedSum kahanSummation = new CompensatedSum(sums.get(0), 0);
return StarTreeQueryHelper.getStarTreeLeafCollector(
context,
valuesSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.action.search.SearchType;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -130,7 +131,7 @@ public List<InternalAggregation> toInternalAggregations(Collection<Collector> co
private final List<Releasable> releasables = new CopyOnWriteArrayList<>();
private final AtomicBoolean closed = new AtomicBoolean(false);
private InnerHitsContext innerHitsContext;
protected volatile Map<LeafReaderContext, StarTreeValuesIterator> starTreeValuesMap;
protected volatile Map<LeafReaderContext, FixedBitSet> starTreeValuesMap;
private volatile boolean searchTimedOut;

protected SearchContext() {}
Expand Down Expand Up @@ -545,7 +546,7 @@ public SearchContext starTreeQueryContext(StarTreeQueryContext starTreeQueryCont
return this;
}

public StarTreeValuesIterator getStarTreeFilteredValues(LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
public FixedBitSet getStarTreeFilteredValues(LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
return null;
}
}
Loading

0 comments on commit 30b95d4

Please sign in to comment.