Skip to content

Commit

Permalink
Allow semanticQuery to receive InferenceResults
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Dec 1, 2023
1 parent a1b174c commit 1f7a60b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
Expand Down Expand Up @@ -110,10 +111,18 @@ public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
}

public Query textExpansionQuery(TextExpansionResults expansionResults, SearchExecutionContext context) {
public Query semanticQuery(InferenceResults inferenceResults, SearchExecutionContext context) {

if (inferenceResults instanceof TextExpansionResults == false) {
throw new IllegalArgumentException(
"field [" + name() + "] does not use a model that outputs sparse vector inference results"
);
}

TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults;
String fieldName = name() + "." + "inference";
BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder();
for (var weightedToken : expansionResults.getWeightedTokens()) {
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
queryBuilder.add(
new BooleanClause(
FeatureField.newLinearQuery(fieldName, indexedValueForSearch(weightedToken.token()), weightedToken.weight()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
Expand Down Expand Up @@ -125,16 +124,10 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
}

InferenceResults inferenceResults = inferenceResultsList.get(0);
if (inferenceResults instanceof TextExpansionResults expansionResults) {
SemanticTextFieldMapper.SemanticTextFieldType mapper = (SemanticTextFieldMapper.SemanticTextFieldType) context.getFieldType(
fieldName
);
return mapper.textExpansionQuery(expansionResults, context);
}

throw new IllegalArgumentException(
"field [" + fieldName + "] does not use a model that outputs sparse vector inference results"
SemanticTextFieldMapper.SemanticTextFieldType mapper = (SemanticTextFieldMapper.SemanticTextFieldType) context.getFieldType(
fieldName
);
return mapper.semanticQuery(inferenceResults, context);
}

@Override
Expand Down

0 comments on commit 1f7a60b

Please sign in to comment.