diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java index 83c81ebd0e0cf..241caf33fb5bb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; @@ -110,10 +111,18 @@ public Query termQuery(Object value, SearchExecutionContext context) { return sparseVectorFieldType.termQuery(value, context); } - public Query textExpansionQuery(TextExpansionResults expansionResults, SearchExecutionContext context) { + public Query semanticQuery(InferenceResults inferenceResults, SearchExecutionContext context) { + + if (inferenceResults instanceof TextExpansionResults == false) { + throw new IllegalArgumentException( + "field [" + name() + "] does not use a model that outputs sparse vector inference results" + ); + } + + TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults; String fieldName = name() + "." + "inference"; BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); - for (var weightedToken : expansionResults.getWeightedTokens()) { + for (var weightedToken : textExpansionResults.getWeightedTokens()) { queryBuilder.add( new BooleanClause( FeatureField.newLinearQuery(fieldName, indexedValueForSearch(weightedToken.token()), weightedToken.weight()), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java index 1f3c67d08d817..1acf46e795daf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.inference.InferenceResults; @@ -125,16 +124,10 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } InferenceResults inferenceResults = inferenceResultsList.get(0); - if (inferenceResults instanceof TextExpansionResults expansionResults) { - SemanticTextFieldMapper.SemanticTextFieldType mapper = (SemanticTextFieldMapper.SemanticTextFieldType) context.getFieldType( - fieldName - ); - return mapper.textExpansionQuery(expansionResults, context); - } - - throw new IllegalArgumentException( - "field [" + fieldName + "] does not use a model that outputs sparse vector inference results" + SemanticTextFieldMapper.SemanticTextFieldType mapper = (SemanticTextFieldMapper.SemanticTextFieldType) context.getFieldType( + fieldName ); + return mapper.semanticQuery(inferenceResults, context); } @Override