Skip to content

Commit

Permalink
Do the query emulating what nested query does
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 28, 2023
1 parent e911424 commit 6bd9d36
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -45,6 +46,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil

private final String fieldName;
private final String query;
private QueryBuilder innerQueryBuilder;

private static final ParseField QUERY_FIELD = new ParseField("query");

Expand All @@ -65,6 +67,14 @@ public SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce<List<? extends I
this.fieldName = other.fieldName;
this.query = other.query;
this.inferenceResultsSupplier = inferenceResultsSupplier;
this.innerQueryBuilder = other.innerQueryBuilder;
}

public SemanticQueryBuilder(SemanticQueryBuilder other, QueryBuilder innerQueryBuilder) {
this.fieldName = other.fieldName;
this.query = other.query;
this.inferenceResultsSupplier = other.inferenceResultsSupplier;
this.innerQueryBuilder = innerQueryBuilder;
}

public SemanticQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -80,7 +90,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
// Inference still not returned
return this;
}
return inferenceResultsToQuery(fieldName, inferenceResultsSupplier.get());
if (innerQueryBuilder == null) {
return inferenceResultsToQuery(fieldName, inferenceResultsSupplier.get());
}
return this;
}

Set<String> modelsForField = queryRewriteContext.getModelIdsForField(fieldName);
Expand Down Expand Up @@ -114,7 +127,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return new SemanticQueryBuilder(this, inferenceResultsSupplier);
}

private static QueryBuilder inferenceResultsToQuery(String fieldName, List<? extends InferenceResults> inferenceResultsList) {
private QueryBuilder inferenceResultsToQuery(String fieldName, List<? extends InferenceResults> inferenceResultsList) {
if (inferenceResultsList.size() != 1) {
throw new IllegalArgumentException("received multiple inference results for field " + fieldName);
}
Expand All @@ -130,15 +143,19 @@ private static QueryBuilder inferenceResultsToQuery(String fieldName, List<? ext
);
}
boolQuery.minimumShouldMatch(1);
var nestedQuery = QueryBuilders.nestedQuery("_semantic_text_inference." + fieldName, boolQuery, ScoreMode.Total);
return nestedQuery;
return new SemanticQueryBuilder(this, boolQuery);
} else {
throw new IllegalArgumentException(
"field [" + fieldName + "] does not use a model that outputs sparse vector inference results"
);
}
}

protected Query doToQuery(SearchExecutionContext context) throws IOException {
var parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated()));
return new ESToParentBlockJoinQuery(innerQueryBuilder.toQuery(context), parentFilter, ScoreMode.Total, fieldName);
}

@Override
public String getWriteableName() {
return NAME;
Expand Down Expand Up @@ -166,11 +183,6 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
builder.endObject();
}

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
throw new IllegalStateException("semantic_query should have been rewritten to another query type");
}

@Override
protected boolean doEquals(SemanticQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
public static class SemanticTextFieldType extends SimpleMappedFieldType {

private final SparseVectorFieldMapper.SparseVectorFieldType sparseVectorFieldType;
private final SparseVectorFieldMapper.SparseVectorFieldType sparseVectorFieldTypeForSearch;

private final String modelId;

Expand All @@ -86,6 +87,10 @@ public SemanticTextFieldType(String name, String modelId, Map<String, String> me
SemanticTextInferenceResultFieldMapper.NAME + "." + name + "." + "inference",
meta
);
this.sparseVectorFieldTypeForSearch = new SparseVectorFieldMapper.SparseVectorFieldType(
name + "." + "inference",
meta
);
this.modelId = modelId;
}

Expand All @@ -108,7 +113,7 @@ public String getInferenceModel() {

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
return sparseVectorFieldTypeForSearch.termQuery(value, context);
}

@Override
Expand Down

0 comments on commit 6bd9d36

Please sign in to comment.