diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 4fdd8039f8938..6833bcb780363 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -16,12 +16,13 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.index.query.AbstractQueryBuilder; -import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ParseField; @@ -45,6 +46,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder modelsForField = queryRewriteContext.getModelIdsForField(fieldName); @@ -114,7 +127,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return new SemanticQueryBuilder(this, inferenceResultsSupplier); } - private static QueryBuilder inferenceResultsToQuery(String fieldName, List inferenceResultsList) { + private QueryBuilder inferenceResultsToQuery(String fieldName, List inferenceResultsList) { if (inferenceResultsList.size() != 1) { throw new IllegalArgumentException("received multiple inference results for field " + fieldName); } @@ -130,8 +143,7 @@ private static QueryBuilder inferenceResultsToQuery(String fieldName, List me SemanticTextInferenceResultFieldMapper.NAME + "." + name + "." + "inference", meta ); + this.sparseVectorFieldTypeForSearch = new SparseVectorFieldMapper.SparseVectorFieldType( + name + "." + "inference", + meta + ); this.modelId = modelId; } @@ -108,7 +113,7 @@ public String getInferenceModel() { @Override public Query termQuery(Object value, SearchExecutionContext context) { - return sparseVectorFieldType.termQuery(value, context); + return sparseVectorFieldTypeForSearch.termQuery(value, context); } @Override