Skip to content

Commit

Permalink
Handle aggs and start to adapt some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Jun 10, 2024
1 parent 5b36174 commit 87142ad
Show file tree
Hide file tree
Showing 18 changed files with 240 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RANK_FEATURE_PHASE_ADDED = def(8_678_00_0);
public static final TransportVersion RANK_DOC_IN_SHARD_FETCH_REQUEST = def(8_679_00_0);
public static final TransportVersion SECURITY_SETTINGS_REQUEST_TIMEOUTS = def(8_680_00_0);
public static final TransportVersion RANK_DOCS_RETRIEVER = def(8_681_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.lucene.grouping.TopFieldGroups;
import org.elasticsearch.search.retriever.RankDocsSortField;
import org.elasticsearch.search.sort.ShardDocSortField;

import java.io.IOException;
Expand Down Expand Up @@ -548,6 +549,8 @@ private static SortField rewriteMergeSortField(SortField sortField) {
return newSortField;
} else if (sortField.getClass() == ShardDocSortField.class) {
return new SortField(sortField.getField(), SortField.Type.LONG, sortField.getReverse());
} else if (sortField.getClass() == RankDocsSortField.class) {
return new SortField(sortField.getField(), SortField.Type.INT, sortField.getReverse());
} else {
return sortField;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ private void registerSorts() {
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScoreSortBuilder.NAME, ScoreSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, ScriptSortBuilder.NAME, ScriptSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, FieldSortBuilder.NAME, FieldSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, RankDocsSortBuilder.NAME, FieldSortBuilder::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SortBuilder.class, RankDocsSortBuilder.NAME, RankDocsSortBuilder::new));
}

private static <T> void registerFromPlugin(List<SearchPlugin> plugins, Function<SearchPlugin, List<T>> producer, Consumer<T> consumer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
Expand All @@ -24,7 +25,6 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -94,12 +94,6 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
}

private final KnnSearchBuilder knnSearchBuilder;
private final String field;
private final float[] queryVector;
private final QueryVectorBuilder queryVectorBuilder;
private final int k;
private final int numCands;
private final Float similarity;

public KnnRetrieverBuilder(
String field,
Expand All @@ -109,13 +103,20 @@ public KnnRetrieverBuilder(
int numCands,
Float similarity
) {
this.knnSearchBuilder = new KnnSearchBuilder(field, VectorData.fromFloats(queryVector), queryVectorBuilder, k, numCands, similarity);
this.field = field;
this.queryVector = queryVector;
this.queryVectorBuilder = queryVectorBuilder;
this.k = k;
this.numCands = numCands;
this.similarity = similarity;
this.knnSearchBuilder = new KnnSearchBuilder(
field,
VectorData.fromFloats(queryVector),
queryVectorBuilder,
k,
numCands,
similarity
);
}

private KnnRetrieverBuilder(KnnRetrieverBuilder clone, KnnSearchBuilder knnSearchBuilder, List<QueryBuilder> preFilterQueryBuilders) {
super(clone);
this.knnSearchBuilder = knnSearchBuilder;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

// ---- FOR TESTING XCONTENT PARSING ----
Expand All @@ -126,13 +127,22 @@ public String getName() {
}

@Override
public QueryBuilder originalQuery() {
// TODO nested + inner_hits
ExactKnnQueryBuilder knn = new ExactKnnQueryBuilder(knnSearchBuilder.getQueryVector(), knnSearchBuilder.getField());
if (preFilterQueryBuilders.isEmpty()) {
return knn;
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
var rewritten = knnSearchBuilder.rewrite(ctx);
boolean hasChanged = rewritten != knnSearchBuilder;
var rewrittenFilters = rewritePreFilters(ctx);
hasChanged |= rewrittenFilters != preFilterQueryBuilders;
if (hasChanged) {
return new KnnRetrieverBuilder(this, rewritten, rewrittenFilters);
}
var ret = new BoolQueryBuilder().should(knn);
return this;
}

@Override
public QueryBuilder originalQuery(QueryBuilder leadQuery) {
// TODO nested + inner_hits
BoolQueryBuilder ret = new BoolQueryBuilder().must(leadQuery)
.should(new ExactKnnQueryBuilder(knnSearchBuilder.getQueryVector(), knnSearchBuilder.getField()));
preFilterQueryBuilders.stream().forEach(ret::filter);
return ret;
}
Expand All @@ -152,39 +162,32 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder

@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(FIELD_FIELD.getPreferredName(), field);
builder.field(K_FIELD.getPreferredName(), k);
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
builder.field(FIELD_FIELD.getPreferredName(), knnSearchBuilder.getField());
builder.field(K_FIELD.getPreferredName(), knnSearchBuilder.k());
builder.field(NUM_CANDS_FIELD.getPreferredName(), knnSearchBuilder.k());

if (queryVector != null) {
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
if (knnSearchBuilder.getQueryVector() != null) {
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), knnSearchBuilder.getQueryVector());
}

if (queryVectorBuilder != null) {
builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), queryVectorBuilder);
if (knnSearchBuilder.getQueryVectorBuilder() != null) {
builder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), knnSearchBuilder.getQueryVectorBuilder());
}

if (similarity != null) {
builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity);
if (knnSearchBuilder.getSimilarity() != null) {
builder.field(VECTOR_SIMILARITY.getPreferredName(), knnSearchBuilder.getSimilarity());
}
}

@Override
public boolean doEquals(Object o) {
KnnRetrieverBuilder that = (KnnRetrieverBuilder) o;
return k == that.k
&& numCands == that.numCands
&& Objects.equals(field, that.field)
&& Arrays.equals(queryVector, that.queryVector)
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
&& Objects.equals(similarity, that.similarity);
return Objects.equals(knnSearchBuilder, that.knnSearchBuilder);
}

@Override
public int doHashCode() {
int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity);
result = 31 * result + Arrays.hashCode(queryVector);
return result;
return Objects.hash(knnSearchBuilder);
}

// ---- END TESTING ----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
Expand All @@ -23,7 +24,7 @@
import java.util.Comparator;

public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuilder> {
public static final String NAME = "rank_docs";
public static final String NAME = "rank";

private final RankDoc[] rankDocs;

Expand Down Expand Up @@ -92,7 +93,6 @@ protected int doHashCode() {

@Override
public TransportVersion getMinimalSupportedVersion() {
// TODO
return TransportVersion.current();
return TransportVersions.RANK_DOCS_RETRIEVER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.DisMaxQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand All @@ -25,6 +25,9 @@
import java.util.Objects;
import java.util.function.Supplier;

/**
* An {@link RetrieverBuilder} that is used to
*/
public class RankDocsRetrieverBuilder extends RetrieverBuilder {
private static final Logger logger = LogManager.getLogger(RankDocsRetrieverBuilder.class);

Expand All @@ -33,22 +36,54 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
private final List<RetrieverBuilder> sources;
private final Supplier<RankDoc[]> rankDocs;

public RankDocsRetrieverBuilder(int windowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
public RankDocsRetrieverBuilder(
int windowSize,
List<RetrieverBuilder> rewritten,
Supplier<RankDoc[]> rankDocs,
List<QueryBuilder> preFilterQueryBuilders
) {
this.windowSize = windowSize;
this.rankDocs = rankDocs;
this.sources = sources;
this.sources = rewritten;
this.preFilterQueryBuilders = preFilterQueryBuilders;
}

@Override
public String getName() {
return NAME;
}

private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException {
for (var source : sources) {
var newSource = source.rewrite(ctx);
if (newSource != source) {
return true;
}
}
return false;
}

@Override
public QueryBuilder originalQuery() {
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
assert sourceShouldRewrite(ctx) == false : "Retriever sources should be rewritten first";
var rewrittenFilters = rewritePreFilters(ctx);
if (rewrittenFilters != preFilterQueryBuilders) {
return new RankDocsRetrieverBuilder(windowSize, sources, rankDocs, rewrittenFilters);
}
return this;
}

@Override
public QueryBuilder originalQuery(QueryBuilder leadQuery) {
DisMaxQueryBuilder disMax = new DisMaxQueryBuilder().tieBreaker(0f);
for (var source : sources) {
disMax.add(source.originalQuery());
var query = source.originalQuery(leadQuery);
if (query != null) {
if (source.retrieverName != null) {
query.queryName(source.retrieverName);
}
disMax.add(query);
}
}
return disMax;
}
Expand All @@ -70,7 +105,10 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
for (var preFilterQueryBuilder : preFilterQueryBuilders) {
bq.filter(preFilterQueryBuilder);
}
bq.should(originalQuery());
QueryBuilder originalQuery = originalQuery(rankQuery);
if (originalQuery != null) {
bq.should(originalQuery);
}
searchSourceBuilder.query(bq);
}

Expand All @@ -82,7 +120,7 @@ protected boolean doEquals(Object o) {

@Override
protected int doHashCode() {
return Objects.hash(super.hashCode(), windowSize);
return Objects.hash(super.hashCode(), Arrays.hashCode(rankDocs.get()), windowSize);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.elasticsearch.search.retriever;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.index.query.QueryRewriteContext;
Expand All @@ -24,24 +26,28 @@
import java.util.Arrays;

public class RankDocsSortBuilder extends SortBuilder<RankDocsSortBuilder> {
public static final String NAME = "rank_docs";
public static final String NAME = "rank_sort";

private final RankDoc[] rankDocs;

public RankDocsSortBuilder(RankDoc[] rankDocs) {
this.rankDocs = rankDocs;
}

@Override
public String getWriteableName() {
return NAME;
public RankDocsSortBuilder(StreamInput in) throws IOException {
this.rankDocs = in.readArray(c -> c.readNamedWriteable(RankDoc.class), RankDoc[]::new);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeArray(StreamOutput::writeNamedWriteable, rankDocs);
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public SortBuilder<?> rewrite(QueryRewriteContext ctx) throws IOException {
return this;
Expand All @@ -57,8 +63,7 @@ protected SortFieldAndFormat build(SearchExecutionContext context) throws IOExce

@Override
public TransportVersion getMinimalSupportedVersion() {
// TODO
return TransportVersion.current();
return TransportVersions.RANK_DOCS_RETRIEVER;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ public static RetrieverBuilder parseTopLevelRetrieverBuilder(XContentParser pars

protected String retrieverName;

public RetrieverBuilder() {}

protected RetrieverBuilder(RetrieverBuilder clone) {
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
this.retrieverName = clone.retrieverName;
}

protected final List<QueryBuilder> rewritePreFilters(QueryRewriteContext ctx) throws IOException {
List<QueryBuilder> newFilters = new ArrayList<>(preFilterQueryBuilders.size());
boolean changed = false;
for (var filter : preFilterQueryBuilders) {
var newFilter = filter.rewrite(ctx);
changed |= filter != newFilter;
newFilters.add(newFilter);
}
if (changed) {
return newFilters;
}
return preFilterQueryBuilders;
}

/**
* Gets the filters for this retriever.
*/
Expand All @@ -173,12 +194,13 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {

/**
* Returns the original {@link QueryBuilder} used to compute the top documents.
* @param leadQuery
*/
public abstract QueryBuilder originalQuery();
public abstract QueryBuilder originalQuery(QueryBuilder leadQuery);

/**
* This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}.
* Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}.
* This method is called at the end of rewrite on the final retriever.
* Elements of the search request are expected to be "extracted" into the {@link SearchSourceBuilder}.
*/
public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed);

Expand Down
Loading

0 comments on commit 87142ad

Please sign in to comment.