Skip to content

Commit

Permalink
Add support for show_only_intersecting
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Brusic <[email protected]>
  • Loading branch information
brusic committed Jan 3, 2024
1 parent 8440468 commit 68d1746
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,35 @@ setup:

- match: { aggregations.conns.buckets.3.doc_count: 1 }
- match: { aggregations.conns.buckets.3.key: "4" }


---
"Show only intersections":

- do:
search:
index: test
rest_total_hits_as_int: true
body:
size: 0
aggs:
conns:
adjacency_matrix:
show_only_intersecting: true
filters:
1:
term:
num: 1
2:
term:
num: 2
4:
term:
num: 4

- match: { hits.total: 3 }

- length: { aggregations.conns.buckets: 1 }

- match: { aggregations.conns.buckets.0.doc_count: 1 }
- match: { aggregations.conns.buckets.0.key: "1&2" }
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ public class AdjacencyMatrixAggregationBuilder extends AbstractAggregationBuilde

private static final ParseField SEPARATOR_FIELD = new ParseField("separator");
private static final ParseField FILTERS_FIELD = new ParseField("filters");

public static final ParseField SHOW_ONLY_INTERSECTING = new ParseField("show_only_intersecting");

private List<KeyedFilter> filters;
private boolean showOnlyIntersecting = false;
private String separator = DEFAULT_SEPARATOR;

private static final ObjectParser<AdjacencyMatrixAggregationBuilder, String> PARSER = ObjectParser.fromBuilder(
Expand All @@ -81,6 +85,10 @@ public class AdjacencyMatrixAggregationBuilder extends AbstractAggregationBuilde
static {
PARSER.declareString(AdjacencyMatrixAggregationBuilder::separator, SEPARATOR_FIELD);
PARSER.declareNamedObjects(AdjacencyMatrixAggregationBuilder::setFiltersAsList, KeyedFilter.PARSER, FILTERS_FIELD);
PARSER.declareBoolean(
AdjacencyMatrixAggregationBuilder::showOnlyIntersecting,
AdjacencyMatrixAggregationBuilder.SHOW_ONLY_INTERSECTING
);
}

public static AggregationBuilder parse(XContentParser parser, String name) throws IOException {
Expand Down Expand Up @@ -115,6 +123,7 @@ protected AdjacencyMatrixAggregationBuilder(
super(clone, factoriesBuilder, metadata);
this.filters = new ArrayList<>(clone.filters);
this.separator = clone.separator;
this.showOnlyIntersecting = clone.showOnlyIntersecting;
}

@Override
Expand All @@ -138,13 +147,36 @@ public AdjacencyMatrixAggregationBuilder(String name, String separator, Map<Stri
setFiltersAsMap(filters);
}

/**
* @param name
* the name of this aggregation
* @param separator
* the string used to separate keys in intersections buckets e.g.
* &amp; character for keyed filters A and B would return an
* intersection bucket named A&amp;B
* @param filters
* the filters and their key to use with this aggregation.
* @param showOnlyIntersecting
* show only the buckets that intersection multiple documents
*/
public AdjacencyMatrixAggregationBuilder(
String name,
String separator,
Map<String, QueryBuilder> filters,
boolean showOnlyIntersecting
) {
this(name, separator, filters);
this.showOnlyIntersecting = showOnlyIntersecting;
}

/**
* Read from a stream.
*/
public AdjacencyMatrixAggregationBuilder(StreamInput in) throws IOException {
super(in);
int filtersSize = in.readVInt();
separator = in.readString();
showOnlyIntersecting = in.readBoolean();
filters = new ArrayList<>(filtersSize);
for (int i = 0; i < filtersSize; i++) {
filters.add(new KeyedFilter(in));
Expand All @@ -155,6 +187,7 @@ public AdjacencyMatrixAggregationBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeVInt(filters.size());
out.writeString(separator);
out.writeBoolean(showOnlyIntersecting);
for (KeyedFilter keyedFilter : filters) {
keyedFilter.writeTo(out);
}
Expand Down Expand Up @@ -185,6 +218,11 @@ private AdjacencyMatrixAggregationBuilder setFiltersAsList(List<KeyedFilter> fil
return this;
}

public AdjacencyMatrixAggregationBuilder showOnlyIntersecting(boolean showOnlyIntersecting) {
this.showOnlyIntersecting = showOnlyIntersecting;
return this;
}

/**
* Set the separator used to join pairs of bucket keys
*/
Expand Down Expand Up @@ -214,6 +252,10 @@ public Map<String, QueryBuilder> filters() {
return result;
}

public boolean isShowOnlyIntersecting() {
return showOnlyIntersecting;
}

@Override
protected AdjacencyMatrixAggregationBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException {
boolean modified = false;
Expand All @@ -224,7 +266,9 @@ protected AdjacencyMatrixAggregationBuilder doRewrite(QueryRewriteContext queryS
rewrittenFilters.add(new KeyedFilter(kf.key(), rewritten));
}
if (modified) {
return new AdjacencyMatrixAggregationBuilder(name).separator(separator).setFiltersAsList(rewrittenFilters);
return new AdjacencyMatrixAggregationBuilder(name).separator(separator)
.setFiltersAsList(rewrittenFilters)
.showOnlyIntersecting(showOnlyIntersecting);
}
return this;
}
Expand All @@ -245,7 +289,16 @@ protected AggregatorFactory doBuild(QueryShardContext queryShardContext, Aggrega
+ "] index level setting."
);
}
return new AdjacencyMatrixAggregatorFactory(name, filters, separator, queryShardContext, parent, subFactoriesBuilder, metadata);
return new AdjacencyMatrixAggregatorFactory(
name,
filters,
showOnlyIntersecting,
separator,
queryShardContext,
parent,
subFactoriesBuilder,
metadata
);
}

@Override
Expand All @@ -257,7 +310,8 @@ public BucketCardinality bucketCardinality() {
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SEPARATOR_FIELD.getPreferredName(), separator);
builder.startObject(AdjacencyMatrixAggregator.FILTERS_FIELD.getPreferredName());
builder.field(SHOW_ONLY_INTERSECTING.getPreferredName(), showOnlyIntersecting);
builder.startObject(FILTERS_FIELD.getPreferredName());
for (KeyedFilter keyedFilter : filters) {
builder.field(keyedFilter.key(), keyedFilter.filter());
}
Expand All @@ -268,7 +322,7 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), filters, separator);
return Objects.hash(super.hashCode(), filters, showOnlyIntersecting, separator);
}

@Override
Expand All @@ -277,7 +331,9 @@ public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;
AdjacencyMatrixAggregationBuilder other = (AdjacencyMatrixAggregationBuilder) obj;
return Objects.equals(filters, other.filters) && Objects.equals(separator, other.separator);
return Objects.equals(filters, other.filters)
&& Objects.equals(separator, other.separator)
&& Objects.equals(showOnlyIntersecting, other.showOnlyIntersecting);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -70,8 +69,6 @@
*/
public class AdjacencyMatrixAggregator extends BucketsAggregator {

public static final ParseField FILTERS_FIELD = new ParseField("filters");

/**
* A keyed filter
*
Expand Down Expand Up @@ -145,6 +142,8 @@ public boolean equals(Object obj) {

private final String[] keys;
private final Weight[] filters;

private final boolean showOnlyIntersecting;
private final int totalNumKeys;
private final int totalNumIntersections;
private final String separator;
Expand All @@ -155,6 +154,7 @@ public AdjacencyMatrixAggregator(
String separator,
String[] keys,
Weight[] filters,
boolean showOnlyIntersecting,
SearchContext context,
Aggregator parent,
Map<String, Object> metadata
Expand All @@ -163,6 +163,7 @@ public AdjacencyMatrixAggregator(
this.separator = separator;
this.keys = keys;
this.filters = filters;
this.showOnlyIntersecting = showOnlyIntersecting;
this.totalNumIntersections = ((keys.length * keys.length) - keys.length) / 2;
this.totalNumKeys = keys.length + totalNumIntersections;
}
Expand All @@ -177,10 +178,12 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
return new LeafBucketCollectorBase(sub, null) {
@Override
public void collect(int doc, long bucket) throws IOException {
// Check each of the provided filters
for (int i = 0; i < bits.length; i++) {
if (bits[i].get(doc)) {
collectBucket(sub, doc, bucketOrd(bucket, i));
if (!showOnlyIntersecting) {
// Check each of the provided filters
for (int i = 0; i < bits.length; i++) {
if (bits[i].get(doc)) {
collectBucket(sub, doc, bucketOrd(bucket, i));
}
}
}
// Check all the possible intersections of the provided filters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,14 @@ public class AdjacencyMatrixAggregatorFactory extends AggregatorFactory {

private final String[] keys;
private final Weight[] weights;

private final boolean showOnlyIntersecting;
private final String separator;

public AdjacencyMatrixAggregatorFactory(
String name,
List<KeyedFilter> filters,
boolean showOnlyIntersecting,
String separator,
QueryShardContext queryShardContext,
AggregatorFactory parent,
Expand All @@ -79,6 +82,7 @@ public AdjacencyMatrixAggregatorFactory(
Query filter = keyedFilter.filter().toQuery(queryShardContext);
weights[i] = contextSearcher.createWeight(contextSearcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1f);
}
this.showOnlyIntersecting = showOnlyIntersecting;
}

@Override
Expand All @@ -88,7 +92,17 @@ public Aggregator createInternal(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {
return new AdjacencyMatrixAggregator(name, factories, separator, keys, weights, searchContext, parent, metadata);
return new AdjacencyMatrixAggregator(
name,
factories,
separator,
keys,
weights,
showOnlyIntersecting,
searchContext,
parent,
metadata
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,13 @@ public void testFiltersSameMap() {
assertEquals(original, builder.filters());
assert original != builder.filters();
}

public void testShowOnlyIntersecting() {
Map<String, QueryBuilder> original = new HashMap<>();
original.put("bbb", new MatchNoneQueryBuilder());
original.put("aaa", new MatchNoneQueryBuilder());
AdjacencyMatrixAggregationBuilder builder;
builder = new AdjacencyMatrixAggregationBuilder("my-agg", "&", original, true);
assertEquals(true, builder.isShowOnlyIntersecting());
}
}

0 comments on commit 68d1746

Please sign in to comment.