Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute multiple float aggregations in one go #12547

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ New Features
* GITHUB#12479: Add new Maximum Inner Product vector similarity function for non-normalized dot-product
vector search. (Jack Mazanec, Ben Trent)

* GITHUB#12546 Enable doing multiple aggregations at a time with FloatTaxonomyFacets. (Stefan Vodita)

Improvements
---------------------
* GITHUB#12374: Add CachingLeafSlicesSupplier to compute the LeafSlices for concurrent segment search (Sorabh Hamirwasia)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,43 @@ abstract class FloatTaxonomyFacets extends TaxonomyFacets {

// TODO: also use native hash map for sparse collection, like IntTaxonomyFacets

/** Aggregation function used for combining values. */
final AssociationAggregationFunction aggregationFunction;
/** Aggregation functions used for combining values. */
final List<AssociationAggregationFunction> aggregationFunctions;

/** Per-ordinal value. */
float[] values;
float[][] values;

@Override
boolean hasValues() {
return values != null;
}

void initializeValueCounters() {
if (values == null) {
values = new float[aggregationFunctions.size()][taxoReader.getSize()];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a scary amount of RAM we could end up requiring. Are we sure that all the labels in taxoReader will have nonzero values? I wonder if we ought to switch to a sparse approach?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great point. IntTaxonomyFacets has the ability to choose sparse values if the taxonomy is large and there aren't a lot of hits. We can have the same functionality in FloatTaxonomyFacets. This was also mentioned recently in another issue, which puts into question the way we decide between sparse and dense values.
Fundamentally, I think the user of this feature will have to decide if they can make the space for time tradeoff for computing multiple aggregations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, you mean FloatTaxonomyFacets today is never sparse in its aggregation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct. Compare initializeValueCounters for IntTaxonomyFacets and FloatTaxonomyFacets. I don't think there's a good reason for Int/FloatTaxonomyFacets to differ here. Maybe sparse values just never got implemented for FloatTaxonomyFacets.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe open a spinoff to implement sparse values for FloatTaxonomyFacets? But let's not block this otherwise great change?

}
}

/** Sole constructor. */
FloatTaxonomyFacets(
String indexFieldName,
TaxonomyReader taxoReader,
AssociationAggregationFunction aggregationFunction,
FacetsConfig config,
FacetsCollector fc)
throws IOException {
super(indexFieldName, taxoReader, config, fc);
this.aggregationFunction = aggregationFunction;
this(indexFieldName, taxoReader, List.of(aggregationFunction), config, fc);
}

@Override
boolean hasValues() {
return values != null;
}

void initializeValueCounters() {
if (values == null) {
values = new float[taxoReader.getSize()];
}
FloatTaxonomyFacets(
String indexFieldName,
TaxonomyReader taxoReader,
List<AssociationAggregationFunction> aggregationFunctions,
FacetsConfig config,
FacetsCollector fc)
throws IOException {
super(indexFieldName, taxoReader, config, fc);
this.aggregationFunctions = aggregationFunctions;
values = new float[aggregationFunctions.size()][taxoReader.getSize()];
}

/** Rolls up any single-valued hierarchical dimensions. */
Expand All @@ -80,19 +90,30 @@ void rollup() throws IOException {
if (ft.hierarchical && ft.multiValued == false) {
int dimRootOrd = taxoReader.getOrdinal(new FacetLabel(dim));
assert dimRootOrd > 0;
float newValue =
aggregationFunction.aggregate(values[dimRootOrd], rollup(children[dimRootOrd]));
values[dimRootOrd] = newValue;
for (int aggregationIdx = 0;
aggregationIdx < aggregationFunctions.size();
aggregationIdx++) {
AssociationAggregationFunction aggregationFunction =
aggregationFunctions.get(aggregationIdx);
float newValue =
aggregationFunction.aggregate(
values[aggregationIdx][dimRootOrd],
rollup(aggregationFunction, values[aggregationIdx], children[dimRootOrd]));
values[aggregationIdx][dimRootOrd] = newValue;
}
}
}
}

private float rollup(int ord) throws IOException {
private float rollup(AssociationAggregationFunction aggregationFunction, float[] values, int ord)
throws IOException {
int[] children = getChildren();
int[] siblings = getSiblings();
float aggregationValue = 0f;
while (ord != TaxonomyReader.INVALID_ORDINAL) {
float childValue = aggregationFunction.aggregate(values[ord], rollup(children[ord]));
float childValue =
aggregationFunction.aggregate(
values[ord], rollup(aggregationFunction, values, children[ord]));
values[ord] = childValue;
aggregationValue = aggregationFunction.aggregate(aggregationValue, childValue);
ord = siblings[ord];
Expand All @@ -102,6 +123,11 @@ private float rollup(int ord) throws IOException {

@Override
public Number getSpecificValue(String dim, String... path) throws IOException {
return getSpecificValue(0, dim, path);
}

public Number getSpecificValue(int aggregationIdx, String dim, String... path)
throws IOException {
DimConfig dimConfig = verifyDim(dim);
if (path.length == 0) {
if (dimConfig.hierarchical && dimConfig.multiValued == false) {
Expand All @@ -117,11 +143,16 @@ public Number getSpecificValue(String dim, String... path) throws IOException {
if (ord < 0) {
return -1;
}
return values == null ? 0 : values[ord];
return values == null ? 0 : values[aggregationIdx][ord];
}

@Override
public FacetResult getAllChildren(String dim, String... path) throws IOException {
return getAllChildren(0, dim, path);
}

public FacetResult getAllChildren(int aggregationIdx, String dim, String... path)
throws IOException {
DimConfig dimConfig = verifyDim(dim);
FacetLabel cp = new FacetLabel(dim, path);
int dimOrd = taxoReader.getOrdinal(cp);
Expand All @@ -143,10 +174,13 @@ public FacetResult getAllChildren(String dim, String... path) throws IOException
FloatArrayList ordValues = new FloatArrayList();

while (ord != TaxonomyReader.INVALID_ORDINAL) {
if (values[ord] > 0) {
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, values[ord]);
if (values[aggregationIdx][ord] > 0) {
aggregatedValue =
aggregationFunctions
.get(aggregationIdx)
.aggregate(aggregatedValue, values[aggregationIdx][ord]);
ordinals.add(ord);
ordValues.add(values[ord]);
ordValues.add(values[aggregationIdx][ord]);
}
ord = siblings[ord];
}
Expand All @@ -157,7 +191,7 @@ public FacetResult getAllChildren(String dim, String... path) throws IOException

if (dimConfig.multiValued) {
if (dimConfig.requireDimCount) {
aggregatedValue = values[dimOrd];
aggregatedValue = values[aggregationIdx][dimOrd];
} else {
// Our sum'd count is not correct, in general:
aggregatedValue = -1;
Expand All @@ -179,6 +213,11 @@ public FacetResult getAllChildren(String dim, String... path) throws IOException

@Override
public FacetResult getTopChildren(int topN, String dim, String... path) throws IOException {
return getTopChildren(0, topN, dim, path);
}

public FacetResult getTopChildren(int aggregationIdx, int topN, String dim, String... path)
throws IOException {
validateTopN(topN);
DimConfig dimConfig = verifyDim(dim);
FacetLabel cp = new FacetLabel(dim, path);
Expand All @@ -191,7 +230,8 @@ public FacetResult getTopChildren(int topN, String dim, String... path) throws I
return null;
}

TopChildrenForPath topChildrenForPath = getTopChildrenForPath(dimConfig, dimOrd, topN);
TopChildrenForPath topChildrenForPath =
getTopChildrenForPath(aggregationIdx, dimConfig, dimOrd, topN);
return createFacetResult(topChildrenForPath, dim, path);
}

Expand All @@ -201,6 +241,11 @@ public FacetResult getTopChildren(int topN, String dim, String... path) throws I
*/
private TopChildrenForPath getTopChildrenForPath(DimConfig dimConfig, int pathOrd, int topN)
throws IOException {
return getTopChildrenForPath(0, dimConfig, pathOrd, topN);
}

private TopChildrenForPath getTopChildrenForPath(
int aggregationIdx, DimConfig dimConfig, int pathOrd, int topN) throws IOException {

TopOrdAndFloatQueue q = new TopOrdAndFloatQueue(Math.min(taxoReader.getSize(), topN));
float bottomValue = 0;
Expand All @@ -215,9 +260,10 @@ private TopChildrenForPath getTopChildrenForPath(DimConfig dimConfig, int pathOr

TopOrdAndFloatQueue.OrdAndValue reuse = null;
while (ord != TaxonomyReader.INVALID_ORDINAL) {
float value = values[ord];
float value = values[aggregationIdx][ord];
if (value > 0) {
aggregatedValue = aggregationFunction.aggregate(aggregatedValue, value);
aggregatedValue =
aggregationFunctions.get(aggregationIdx).aggregate(aggregatedValue, value);
childCount++;
if (value > bottomValue || (value == bottomValue && ord < bottomOrd)) {
if (reuse == null) {
Expand All @@ -238,7 +284,7 @@ private TopChildrenForPath getTopChildrenForPath(DimConfig dimConfig, int pathOr

if (dimConfig.multiValued) {
if (dimConfig.requireDimCount) {
aggregatedValue = values[pathOrd];
aggregatedValue = values[aggregationIdx][pathOrd];
} else {
// Our sum'd count is not correct, in general:
aggregatedValue = -1;
Expand Down Expand Up @@ -286,6 +332,11 @@ FacetResult createFacetResult(TopChildrenForPath topChildrenForPath, String dim,

@Override
public List<FacetResult> getTopDims(int topNDims, int topNChildren) throws IOException {
return getTopDims(0, topNDims, topNChildren);
}

public List<FacetResult> getTopDims(int aggregationIdx, int topNDims, int topNChildren)
throws IOException {
validateTopN(topNDims);
validateTopN(topNChildren);

Expand Down Expand Up @@ -330,7 +381,7 @@ protected boolean lessThan(DimValue a, DimValue b) {
if (dimConfig.requireDimCount) {
// If the dim is configured as multi-valued and requires dim counts, we can access
// an accurate count for the dim computed at indexing time:
dimValue = values[dimOrd];
dimValue = values[aggregationIdx][dimOrd];
} else {
// If the dim is configured as multi-valued but not requiring dim counts, we cannot
// compute an accurate dim count, and use -1 as a place-holder:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,19 @@ public TaxonomyFacetFloatAssociations(
FacetsCollector fc,
AssociationAggregationFunction aggregationFunction)
throws IOException {
super(indexFieldName, taxoReader, aggregationFunction, config, fc);
aggregateValues(aggregationFunction, fc.getMatchingDocs());
this(indexFieldName, taxoReader, config, fc, List.of(aggregationFunction));
}

/** Create {@code TaxonomyFacetFloatAssociations} with multiple aggregations. */
public TaxonomyFacetFloatAssociations(
String indexFieldName,
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
List<AssociationAggregationFunction> aggregationFunctions)
throws IOException {
super(indexFieldName, taxoReader, aggregationFunctions, config, fc);
aggregateValues(aggregationFunctions, fc.getMatchingDocs());
}

/**
Expand All @@ -104,8 +115,26 @@ public TaxonomyFacetFloatAssociations(
AssociationAggregationFunction aggregationFunction,
DoubleValuesSource valuesSource)
throws IOException {
super(indexFieldName, taxoReader, aggregationFunction, config, fc);
aggregateValues(aggregationFunction, fc.getMatchingDocs(), fc.getKeepScores(), valuesSource);
this(
indexFieldName,
taxoReader,
config,
fc,
List.of(aggregationFunction),
List.of((valuesSource)));
}

/** Create {@code TaxonomyFacetFloatAssociations} with multiple aggregations. */
public TaxonomyFacetFloatAssociations(
String indexFieldName,
TaxonomyReader taxoReader,
FacetsConfig config,
FacetsCollector fc,
List<AssociationAggregationFunction> aggregationFunctions,
List<DoubleValuesSource> valuesSources)
throws IOException {
super(indexFieldName, taxoReader, aggregationFunctions, config, fc);
aggregateValues(aggregationFunctions, fc.getMatchingDocs(), fc.getKeepScores(), valuesSources);
}

private static DoubleValues scores(MatchingDocs hits) {
Expand All @@ -128,32 +157,57 @@ public boolean advanceExact(int doc) throws IOException {

/** Aggregate using the provided {@code DoubleValuesSource}. */
private void aggregateValues(
AssociationAggregationFunction aggregationFunction,
List<AssociationAggregationFunction> aggregationFunctions,
List<MatchingDocs> matchingDocs,
boolean keepScores,
DoubleValuesSource valueSource)
List<DoubleValuesSource> valuesSources)
throws IOException {
for (MatchingDocs hits : matchingDocs) {
if (hits.totalHits == 0) {
continue;
}
initializeValueCounters();

int numAggregations = aggregationFunctions.size();
assert numAggregations == valuesSources.size();

SortedNumericDocValues ordinalValues =
DocValues.getSortedNumeric(hits.context.reader(), indexFieldName);
DoubleValues scores = keepScores ? scores(hits) : null;
DoubleValues functionValues = valueSource.getValues(hits.context, scores);
DoubleValues[] functionValues = new DoubleValues[numAggregations];
for (int aggregationIdx = 0; aggregationIdx < numAggregations; aggregationIdx++) {
functionValues[aggregationIdx] =
valuesSources.get(aggregationIdx).getValues(hits.context, scores);
}
DocIdSetIterator it =
ConjunctionUtils.intersectIterators(List.of(hits.bits.iterator(), ordinalValues));

int[] advanced = new int[numAggregations];
float[] incomingValues = new float[numAggregations];
for (int doc = it.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = it.nextDoc()) {
if (functionValues.advanceExact(doc)) {
float value = (float) functionValues.doubleValue();
int ordinalCount = ordinalValues.docValueCount();
for (int i = 0; i < ordinalCount; i++) {
int ord = (int) ordinalValues.nextValue();
float newValue = aggregationFunction.aggregate(values[ord], value);
values[ord] = newValue;
int numAdvanced = 0;
for (int aggregationIdx = 0; aggregationIdx < numAggregations; aggregationIdx++) {
if (functionValues[aggregationIdx].advanceExact(doc)) {
advanced[numAdvanced] = aggregationIdx;
incomingValues[numAdvanced] = (float) functionValues[aggregationIdx].doubleValue();
numAdvanced++;
}
}
if (numAdvanced == 0) {
continue;
}

int ordinalCount = ordinalValues.docValueCount();
for (int i = 0; i < ordinalCount; i++) {
int ord = (int) ordinalValues.nextValue();
for (int advancedIdx = 0; advancedIdx < numAdvanced; advancedIdx++) {
int aggregationIdx = advanced[advancedIdx];
float value = incomingValues[advancedIdx];
float newValue =
aggregationFunctions
.get(aggregationIdx)
.aggregate(values[aggregationIdx][ord], value);
values[aggregationIdx][ord] = newValue;
}
}
}
Expand All @@ -165,7 +219,7 @@ private void aggregateValues(

/** Aggregate from indexed association values. */
private void aggregateValues(
AssociationAggregationFunction aggregationFunction, List<MatchingDocs> matchingDocs)
List<AssociationAggregationFunction> aggregationFunctions, List<MatchingDocs> matchingDocs)
throws IOException {

for (MatchingDocs hits : matchingDocs) {
Expand All @@ -188,8 +242,15 @@ private void aggregateValues(
offset += 4;
float value = (float) BitUtil.VH_BE_FLOAT.get(bytes, offset);
offset += 4;
float newValue = aggregationFunction.aggregate(values[ord], value);
values[ord] = newValue;
for (int aggregationIdx = 0;
aggregationIdx < aggregationFunctions.size();
aggregationIdx++) {
float newValue =
aggregationFunctions
.get(aggregationIdx)
.aggregate(values[aggregationIdx][ord], value);
values[aggregationIdx][ord] = newValue;
}
}
}
}
Expand Down
Loading