Skip to content

Commit

Permalink
Query logic is done in the field type; SemanticQueryBuilder just uses it
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 30, 2023
1 parent 66a1d35 commit b8fce51
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

package org.elasticsearch.xpack.ml.mapper;

import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
Expand All @@ -18,6 +24,8 @@
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.Map;
Expand All @@ -27,7 +35,6 @@ public class SemanticTextFieldMapper extends FieldMapper {

public static final String CONTENT_TYPE = "semantic_text";


private static SemanticTextFieldMapper toType(FieldMapper in) {
return (SemanticTextFieldMapper) in;
}
Expand Down Expand Up @@ -57,10 +64,6 @@ protected Parameter<?>[] getParameters() {
return new Parameter<?>[] { modelId, meta };
}

private SemanticTextFieldType buildFieldType(MapperBuilderContext context) {
return new SemanticTextFieldType(context.buildFullName(name), modelId.getValue(), meta.getValue());
}

@Override
public SemanticTextFieldMapper build(MapperBuilderContext context) {
return new SemanticTextFieldMapper(
Expand Down Expand Up @@ -107,6 +110,29 @@ public Query termQuery(Object value, SearchExecutionContext context) {
return sparseVectorFieldType.termQuery(value, context);
}

public Query textExpansionQuery(TextExpansionResults expansionResults, SearchExecutionContext context) {
String fieldName = name() + "." + "inference";
BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder();
for (var weightedToken : expansionResults.getWeightedTokens()) {
queryBuilder.add(
new BooleanClause(
FeatureField.newLinearQuery(fieldName, indexedValueForSearch(weightedToken.token()), weightedToken.weight()),
BooleanClause.Occur.SHOULD
)
);
}
queryBuilder.setMinimumNumberShouldMatch(1);
var parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated()));
return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, name());
}

private static String indexedValueForSearch(Object value) {
if (value instanceof BytesRef) {
return ((BytesRef) value).utf8ToString();
}
return value.toString();
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.identity(name(), context, format);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,29 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;
package org.elasticsearch.xpack.ml.queries;

import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.action.InferenceAction;
import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper;

import java.io.IOException;
import java.util.List;
Expand All @@ -46,7 +44,6 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil

private final String fieldName;
private final String query;
private QueryBuilder innerQueryBuilder;

private static final ParseField QUERY_FIELD = new ParseField("query");

Expand All @@ -67,14 +64,12 @@ public SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce<List<? extends I
this.fieldName = other.fieldName;
this.query = other.query;
this.inferenceResultsSupplier = inferenceResultsSupplier;
this.innerQueryBuilder = other.innerQueryBuilder;
}

public SemanticQueryBuilder(SemanticQueryBuilder other, QueryBuilder innerQueryBuilder) {
this.fieldName = other.fieldName;
this.query = other.query;
this.inferenceResultsSupplier = other.inferenceResultsSupplier;
this.innerQueryBuilder = innerQueryBuilder;
}

public SemanticQueryBuilder(StreamInput in) throws IOException {
Expand All @@ -86,13 +81,6 @@ public SemanticQueryBuilder(StreamInput in) throws IOException {
@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
if (inferenceResultsSupplier != null) {
if (inferenceResultsSupplier.get() == null) {
// Inference still not returned
return this;
}
if (innerQueryBuilder == null) {
return inferenceResultsToQuery(fieldName, inferenceResultsSupplier.get());
}
return this;
}

Expand Down Expand Up @@ -152,8 +140,25 @@ private QueryBuilder inferenceResultsToQuery(String fieldName, List<? extends In
}

protected Query doToQuery(SearchExecutionContext context) throws IOException {
var parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated()));
return new ESToParentBlockJoinQuery(innerQueryBuilder.toQuery(context), parentFilter, ScoreMode.Total, fieldName);
List<? extends InferenceResults> inferenceResultsList = inferenceResultsSupplier.get();
if (inferenceResultsList == null) {
throw new IllegalArgumentException("No inference retrieved for field " + fieldName);
}
if (inferenceResultsList.size() != 1) {
throw new IllegalArgumentException("received multiple inference results for field " + fieldName);
}

InferenceResults inferenceResults = inferenceResultsList.get(0);
if (inferenceResults instanceof TextExpansionResults expansionResults) {
SemanticTextFieldMapper.SemanticTextFieldType mapper = (SemanticTextFieldMapper.SemanticTextFieldType) context.getFieldType(
fieldName
);
return mapper.textExpansionQuery(expansionResults, context);
}

throw new IllegalArgumentException(
"field [" + fieldName + "] does not use a model that outputs sparse vector inference results"
);
}

@Override
Expand Down

0 comments on commit b8fce51

Please sign in to comment.