Skip to content

Commit

Permalink
Add support to parse sub-aggregations from filter/nested aggregations
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Nath <[email protected]>
  • Loading branch information
abhinav-nath committed Oct 7, 2022
1 parent 94eb007 commit caad7a3
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,35 @@
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import jakarta.json.stream.JsonGenerator;
import org.opensearch.client.util.ObjectBuilder;

import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

// typedef: _types.aggregations.SingleBucketAggregateBase



public abstract class SingleBucketAggregateBase extends AggregateBase {
private final Map<String, Aggregate> aggregations;
private final long docCount;

// ---------------------------------------------------------------------------------------------

protected SingleBucketAggregateBase(AbstractBuilder<?> builder) {
super(builder);
this.aggregations = ApiTypeHelper.unmodifiable(builder.aggregations);

this.docCount = ApiTypeHelper.requireNonNull(builder.docCount, this, "docCount");

}

public final Map<String, Aggregate> aggregations() {
return this.aggregations;
}

/**
* Required - API name: {@code doc_count}
*/
Expand All @@ -76,11 +88,24 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
protected abstract static class AbstractBuilder<BuilderT extends AbstractBuilder<BuilderT>>
extends
AggregateBase.AbstractBuilder<BuilderT> {
@Nullable
protected Map<String, Aggregate> aggregations = new HashMap<>();
private Long docCount;

/**
* Required - API name: {@code doc_count}
*/
public final BuilderT aggregations(Map<String, Aggregate> aggregateMap) {
this.aggregations = _mapPutAll(this.aggregations, aggregateMap);
return self();
}

public final BuilderT aggregations(String key, Aggregate value) {
this.aggregations = _mapPut(this.aggregations, key, value);
return self();
}

public final BuilderT aggregations(String key, Function<Aggregate.Builder, ObjectBuilder<Aggregate>> function) {
return aggregations(key, function.apply(new Aggregate.Builder()).build());
}

public final BuilderT docCount(long value) {
this.docCount = value;
return self();
Expand All @@ -94,6 +119,12 @@ protected static <BuilderT extends AbstractBuilder<BuilderT>> void setupSingleBu
AggregateBase.setupAggregateBaseDeserializer(op);
op.add(AbstractBuilder::docCount, JsonpDeserializer.longDeserializer(), "doc_count");

op.setUnknownFieldHandler((builder, name, parser, mapper) -> {
if (builder.aggregations == null) {
builder.aggregations = new HashMap<>();
}
Aggregate._TYPED_KEYS_DESERIALIZER.deserializeEntry(name, parser, mapper, builder.aggregations);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.client.opensearch.OpenSearchAsyncClient;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.OpenSearchException;
import org.opensearch.client.opensearch._types.Refresh;
import org.opensearch.client.opensearch._types.aggregations.Aggregate;
import org.opensearch.client.opensearch._types.aggregations.HistogramAggregate;
import org.opensearch.client.opensearch._types.aggregations.TermsAggregation;
import org.opensearch.client.opensearch._types.mapping.Property;
import org.opensearch.client.opensearch._types.query_dsl.BoolQuery;
import org.opensearch.client.opensearch._types.query_dsl.TermsQuery;
import org.opensearch.client.opensearch.cat.NodesResponse;
import org.opensearch.client.opensearch.core.BulkResponse;
import org.opensearch.client.opensearch.core.ClearScrollResponse;
Expand All @@ -48,6 +53,7 @@
import org.opensearch.client.opensearch.core.InfoResponse;
import org.opensearch.client.opensearch.core.MsearchResponse;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.core.bulk.OperationType;
import org.opensearch.client.opensearch.core.msearch.RequestItem;
import org.opensearch.client.opensearch.indices.CreateIndexResponse;
Expand All @@ -58,9 +64,9 @@
import org.opensearch.client.opensearch.model.ModelTestCase;
import org.opensearch.client.transport.endpoints.BooleanResponse;


import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -330,6 +336,44 @@ public void testSearchAggregation() throws IOException {

}

@Test
public void testSubAggregation() throws IOException {

highLevelClient().create(_1 -> _1.index("products").id("A").document(new Product(5, "Blue")).refresh(Refresh.True));
highLevelClient().create(_1 -> _1.index("products").id("B").document(new Product(10, "Blue")).refresh(Refresh.True));
highLevelClient().create(_1 -> _1.index("products").id("C").document(new Product(15, "Black")).refresh(Refresh.True));

List<FieldValue> fieldValues = List.of(FieldValue.of("Blue"));

SearchRequest searchRequest = SearchRequest.of(_1 -> _1
.index("products")
.size(0)
.aggregations(
"price", _3 -> _3
.aggregations(Map.of("price", TermsAggregation.of(_4 -> _4
.field("price"))
._toAggregation()))
.filter(BoolQuery.of(_5 -> _5
.filter(List.of(TermsQuery.of(_6 -> _6
.field("color.keyword")
.terms(_7 -> _7
.value(fieldValues)))
._toQuery())))
._toQuery()
)
));
SearchResponse<Product> searchResponse = highLevelClient().search(searchRequest, Product.class);

Aggregate prices = searchResponse.aggregations().get("price")._get()._toAggregate();
assertEquals(2, searchResponse.aggregations().get("price").filter().docCount());
assertEquals(1, prices.filter().aggregations().get("price").dterms().buckets().array().get(0).docCount());
assertEquals(1, prices.filter().aggregations().get("price").dterms().buckets().array().get(1).docCount());

// We've set "size" to zero
assertEquals(0, searchResponse.hits().hits().size());

}

@Test
public void testGetMapping() throws Exception {
// See also VariantsTest.testNestedTaggedUnionWithDefaultTag()
Expand Down Expand Up @@ -405,12 +449,18 @@ public void setMsg(String msg) {

public static class Product {
public double price;
public String color;

public Product() {}
public Product(double price) {
this.price = price;
}

public Product(double price, String color) {
this.price = price;
this.color = color;
}

public double getPrice() {
return this.price;
}
Expand Down

0 comments on commit caad7a3

Please sign in to comment.