diff --git a/modules/reindex/src/main/java/org/elasticsearch/reindex/RestUpdateByQueryAction.java b/modules/reindex/src/main/java/org/elasticsearch/reindex/RestUpdateByQueryAction.java index 50a2b7de6db39..ecddfef7e971f 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/reindex/RestUpdateByQueryAction.java +++ b/modules/reindex/src/main/java/org/elasticsearch/reindex/RestUpdateByQueryAction.java @@ -71,6 +71,8 @@ protected UpdateByQueryRequest buildRequest(RestRequest request, NamedWriteableR consumers.put("script", o -> internal.setScript(Script.parse(o))); consumers.put("max_docs", s -> setMaxDocsValidateIdentical(internal, ((Number) s).intValue())); + // TODO There surely must be a better way of doing this + request.params().put("_source_includes", "*"); parseInternalRequest(internal, request, namedWriteableRegistry, consumers); internal.setPipeline(request.param("pipeline")); diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index cef6bf92cc5e6..b5b0194d68c33 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -165,7 +165,7 @@ private void runCoordinatorRewritePhase() { continue; } boolean canMatch = true; - CoordinatorRewriteContext coordinatorRewriteContext = coordinatorRewriteContextProvider.getCoordinatorRewriteContext( + CoordinatorRewriteContext coordinatorRewriteContext = coordinatorRewriteContextProvider.getCoordinatorRewriteContextForIndex( request.shardId().getIndex() ); if (coordinatorRewriteContext != null) { diff --git a/server/src/main/java/org/elasticsearch/action/search/CoordinatorQueryRewriteSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CoordinatorQueryRewriteSearchPhase.java new file mode 100644 index 0000000000000..6e97008dbd7a9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/search/CoordinatorQueryRewriteSearchPhase.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.search; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.routing.GroupShardsIterator; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.query.CoordinatorRewriteContext; +import org.elasticsearch.index.query.CoordinatorRewriteContextProvider; +import org.elasticsearch.index.query.Rewriteable; +import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; + +import static org.elasticsearch.core.Strings.format; + +/** + * This search phase can be used as an initial search phase to pre-filter search shards based on query rewriting. + * The queries are rewritten against the shards and based on the rewrite result shards might be able to be excluded + * from the search. The extra round trip to the search shards is very cheap and is not subject to rejections + * which allows to fan out to more shards at the same time without running into rejections even if we are hitting a + * large portion of the clusters indices. + * This phase can also be used to pre-sort shards based on min/max values in each shard of the provided primary sort. + * When the query primary sort is perform on a field, this phase extracts the min/max value in each shard and + * sort them according to the provided order. This can be useful for instance to ensure that shards that contain recent + * data are executed first when sorting by descending timestamp. + */ +final class CoordinatorQueryRewriteSearchPhase extends SearchPhase { + + private final Logger logger; + private final SearchRequest request; + private final GroupShardsIterator shardsIts; + private final ActionListener listener; + + private final CoordinatorRewriteContextProvider coordinatorRewriteContextProvider; + + private final IndicesService indicesService; + + private final Executor executor; + + CoordinatorQueryRewriteSearchPhase( + Logger logger, + SearchRequest request, + GroupShardsIterator shardsIts, + Executor executor, + IndicesService indicesService, + CoordinatorRewriteContextProvider coordinatorRewriteContextProvider, + ActionListener listener + ) { + super("coordinator_rewrite"); + + this.logger = logger; + this.request = request; + this.executor = executor; + this.listener = listener; + this.shardsIts = shardsIts; + this.coordinatorRewriteContextProvider = coordinatorRewriteContextProvider; + this.indicesService = indicesService; + } + + private static boolean assertSearchCoordinationThread() { + return ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION); + } + + @Override + public void run() throws IOException { + assert assertSearchCoordinationThread(); + runCoordinatorRewritePhase(); + } + + // tries to pre-filter shards based on information that's available to the coordinator + // without having to reach out to the actual shards + private void runCoordinatorRewritePhase() { + // TODO: the index filter (i.e, `_index:patten`) should be prefiltered on the coordinator + assert assertSearchCoordinationThread(); + final List matchedShardLevelRequests = new ArrayList<>(); + Map> fieldModelIds = new HashMap<>(); + for (SearchShardIterator searchShardIterator : shardsIts) { + Index index = searchShardIterator.shardId().getIndex(); + MappingLookup mappingLookup = indicesService.indexService(index).mapperService().mappingLookup(); + mappingLookup.modelsForFields().forEach((k, v) -> { + Set modelIds = fieldModelIds.computeIfAbsent(k, value -> new HashSet()); + modelIds.add(v); + }); + } + + CoordinatorRewriteContext coordinatorRewriteContext = coordinatorRewriteContextProvider.getCoordinatorRewriteContextForModels( + fieldModelIds + ); + Rewriteable.rewriteAndFetch(request, coordinatorRewriteContext, listener); + } + + @Override + public void start() { + // Note that the search is failed when this task is rejected by the executor + executor.execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + if (logger.isDebugEnabled()) { + logger.debug(() -> format("Failed to execute [%s] while running [%s] phase", request, getName()), e); + } + listener.onFailure(new SearchPhaseExecutionException(getName(), e.getMessage(), e.getCause(), ShardSearchFailure.EMPTY_ARRAY)); + } + + @Override + protected void doRun() throws IOException { + CoordinatorQueryRewriteSearchPhase.this.run(); + } + }); + } +} diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index aeb71a3b03d8f..2e01dbb229c4c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -123,6 +123,7 @@ public SearchPhase newSearchPhase( Map aliasFilter, Map concreteIndexBoosts, boolean preFilter, + boolean runCoordinatorPhase, ThreadPool threadPool, SearchResponse.Clusters clusters ) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index a2739e2c2a85e..8ef767c2cb478 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -1118,6 +1118,7 @@ private void executeSearch( Collections.unmodifiableMap(aliasFilter), concreteIndexBoosts, preFilterSearchShards, + SearchService.canRewriteInCoordinator(searchRequest.source()), threadPool, clusters ).start(); @@ -1210,6 +1211,7 @@ SearchPhase newSearchPhase( Map aliasFilter, Map concreteIndexBoosts, boolean preFilter, + boolean runCoordinatorPhase, ThreadPool threadPool, SearchResponse.Clusters clusters ); @@ -1234,10 +1236,38 @@ public SearchPhase newSearchPhase( Map aliasFilter, Map concreteIndexBoosts, boolean preFilter, + boolean runCoordinatorPhase, ThreadPool threadPool, SearchResponse.Clusters clusters ) { - if (preFilter) { + if (runCoordinatorPhase) { + return new CoordinatorQueryRewriteSearchPhase( + logger, + searchRequest, + shardIterators, + threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), + searchService.getIndicesService(), + searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis), + listener.delegateFailureAndWrap((l, newSearchRequest) -> { + SearchPhase action = newSearchPhase( + task, + newSearchRequest, + executor, + shardIterators, + timeProvider, + connectionLookup, + clusterState, + aliasFilter, + concreteIndexBoosts, + preFilter, + false, + threadPool, + clusters + ); + action.start(); + } + )); + } else if (preFilter) { return new CanMatchPreFilterSearchPhase( logger, searchTransportService, @@ -1263,6 +1293,7 @@ public SearchPhase newSearchPhase( aliasFilter, concreteIndexBoosts, false, + false, threadPool, clusters ); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index 796d10d5c893b..9336721ecfe0f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -120,6 +120,10 @@ String modelForField(String fieldName) { return this.fieldToInferenceModels.get(fieldName); } + Map modelsForFields() { + return this.fieldToInferenceModels; + } + /** * Returns the mapped field type for the given field name. */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 14c3c4371c030..cf8d39905e2e1 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -495,4 +495,8 @@ public void validateDoesNotShadow(String name) { public String modelForField(String fieldName) { return fieldTypeLookup.modelForField(fieldName); } + + public Map modelsForFields() { + return fieldTypeLookup.modelsForFields(); + } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java index 4bc79628268ff..bb56dc1b40236 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java @@ -80,12 +80,13 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { public static class SemanticTextFieldType extends SimpleMappedFieldType { - private SparseVectorFieldType sparseVectorFieldType; + private final SparseVectorFieldType sparseVectorFieldType; private final String modelId; public SemanticTextFieldType(String name, String modelId, Map meta) { super(name, true, false, false, TextSearchInfo.NONE, meta); + this.sparseVectorFieldType = new SparseVectorFieldType(name + "." + SPARSE_VECTOR_SUBFIELD_NAME, meta); this.modelId = modelId; } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java index 2a1062f8876d2..eb87cc91cbb4f 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContext.java @@ -17,6 +17,8 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import java.util.Collections; +import java.util.Map; +import java.util.Set; import java.util.function.LongSupplier; /** @@ -27,8 +29,11 @@ * don't hold queried data. See IndexMetadata#getTimestampRange() for more details */ public class CoordinatorRewriteContext extends QueryRewriteContext { + @Nullable private final IndexLongFieldRange indexLongFieldRange; + @Nullable private final DateFieldMapper.DateFieldType timestampFieldType; + private final Map> fieldNamesToInferenceModel; public CoordinatorRewriteContext( XContentParserConfiguration parserConfig, @@ -55,6 +60,34 @@ public CoordinatorRewriteContext( ); this.indexLongFieldRange = indexLongFieldRange; this.timestampFieldType = timestampFieldType; + this.fieldNamesToInferenceModel = Map.of(); + } + + public CoordinatorRewriteContext( + XContentParserConfiguration parserConfig, + Client client, + LongSupplier nowInMillis, + Map> fieldNamesToInferenceModel + ) { + super( + parserConfig, + client, + nowInMillis, + null, + MappingLookup.EMPTY, + Collections.emptyMap(), + null, + null, + null, + null, + null, + null, + null, + null + ); + this.indexLongFieldRange = null; + this.timestampFieldType = null; + this.fieldNamesToInferenceModel = fieldNamesToInferenceModel; } long getMinTimestamp() { @@ -82,4 +115,9 @@ public MappedFieldType getFieldType(String fieldName) { public CoordinatorRewriteContext convertToCoordinatorRewriteContext() { return this; } + + @Nullable + public Set inferenceModelsForFieldName(String fieldName) { + return fieldNamesToInferenceModel.get(fieldName); + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContextProvider.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContextProvider.java index e44861b4afe8a..962c0b2bf9900 100644 --- a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContextProvider.java +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteContextProvider.java @@ -13,9 +13,12 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.Index; import org.elasticsearch.index.mapper.DateFieldMapper; +import org.elasticsearch.index.mapper.StringFieldType; import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.xcontent.XContentParserConfiguration; +import java.util.Map; +import java.util.Set; import java.util.function.Function; import java.util.function.LongSupplier; import java.util.function.Supplier; @@ -42,7 +45,7 @@ public CoordinatorRewriteContextProvider( } @Nullable - public CoordinatorRewriteContext getCoordinatorRewriteContext(Index index) { + public CoordinatorRewriteContext getCoordinatorRewriteContextForIndex(Index index) { var clusterState = clusterStateSupplier.get(); var indexMetadata = clusterState.metadata().index(index); @@ -63,4 +66,9 @@ public CoordinatorRewriteContext getCoordinatorRewriteContext(Index index) { return new CoordinatorRewriteContext(parserConfig, client, nowInMillis, timestampRange, dateFieldType); } + + @Nullable + public CoordinatorRewriteContext getCoordinatorRewriteContextForModels(Map> fieldToModelIds) { + return new CoordinatorRewriteContext(parserConfig, client, nowInMillis, fieldToModelIds); + } } diff --git a/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteableQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteableQueryBuilder.java new file mode 100644 index 0000000000000..f65ccd8316def --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/CoordinatorRewriteableQueryBuilder.java @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.query; + +// Marker interface for queries that can be +public interface CoordinatorRewriteableQueryBuilder { +} diff --git a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java index 21a99365255d9..adbb51f682fc3 100644 --- a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java +++ b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java @@ -8,6 +8,8 @@ package org.elasticsearch.ingest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.index.IndexRequest; @@ -34,6 +36,7 @@ public class FieldInferenceBulkRequestPreprocessor extends AbstractBulkRequestPreprocessor { + private static final Logger logger = LogManager.getLogger(FieldInferenceBulkRequestPreprocessor.class); public static final String SEMANTIC_TEXT_ORIGIN = "semantic_text"; private final IndicesService indicesService; @@ -99,11 +102,6 @@ public boolean shouldExecuteOnIngestNode() { } private boolean fieldNeedsInference(IndexRequest indexRequest, String fieldName, Object fieldValue) { - - if (fieldValue instanceof String == false) { - return false; - } - return getModelForField(indexRequest, fieldName) != null; } @@ -147,13 +145,17 @@ private void runInferenceForFields( String fieldName = fieldNames.get(0); List nextFieldNames = fieldNames.subList(1, fieldNames.size()); final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class); - if (fieldValue == null) { + Object existingInference = ingestDocument.getFieldValue(fieldName + "." + SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, Object.class, true); + if (fieldValue == null || existingInference != null) { // Run inference for next field + logger.info("Skipping inference for field [" + fieldName + "]"); runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure); + return; } String modelForField = getModelForField(indexRequest, fieldName); assert modelForField != null : "Field " + fieldName + " has no model associated in mappings"; + logger.info("Calculating inference for field [" + fieldName + "]"); // TODO Hardcoding task type, how to get that from model ID? InferenceAction.Request inferenceRequest = new InferenceAction.Request( diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 6919cfdbc00b4..0c009c560f01a 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -51,6 +51,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.query.CoordinatorRewriteContextProvider; +import org.elasticsearch.index.query.CoordinatorRewriteableQueryBuilder; import org.elasticsearch.index.query.InnerHitContextBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.MatchNoneQueryBuilder; @@ -1736,6 +1737,13 @@ public static boolean canRewriteToMatchNone(SearchSourceBuilder source) { return aggregations == null || aggregations.mustVisitAllDocs() == false; } + public static boolean canRewriteInCoordinator(SearchSourceBuilder source) { + if (source == null) { + return false; + } + return source.subSearches().stream().anyMatch(sqwb -> sqwb.getQueryBuilder() instanceof CoordinatorRewriteableQueryBuilder); + } + @SuppressWarnings({ "rawtypes", "unchecked" }) private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest request, ActionListener listener) { ActionListener actionListener = listener.delegateFailureAndWrap((l, r) -> { diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchSourceContext.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchSourceContext.java index bba614dce78a5..9d21cbe929fcd 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchSourceContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchSourceContext.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.SemanticTextFieldMapper; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.search.lookup.SourceFilter; import org.elasticsearch.xcontent.ParseField; @@ -124,7 +125,8 @@ public static FetchSourceContext parseFromRestRequest(RestRequest request) { if (fetchSource != null || sourceIncludes != null || sourceExcludes != null) { return FetchSourceContext.of(fetchSource == null || fetchSource, sourceIncludes, sourceExcludes); } - return null; + + return FetchSourceContext.of(true, null, new String[]{"*." + SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME}); } public static FetchSourceContext fromXContent(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index f4bce4906c0b0..736d2a178a001 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -366,6 +366,7 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeStorageProvider; +import org.elasticsearch.xpack.ml.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder; import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; import org.elasticsearch.xpack.ml.rest.RestMlInfoAction; @@ -1688,6 +1689,11 @@ public List> getQueries() { TextExpansionQueryBuilder.NAME, TextExpansionQueryBuilder::new, TextExpansionQueryBuilder::fromXContent + ), + new QuerySpec( + SemanticQueryBuilder.NAME, + SemanticQueryBuilder::new, + SemanticQueryBuilder::fromXContent ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java new file mode 100644 index 0000000000000..4a3abe72693bf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/SemanticQueryBuilder.java @@ -0,0 +1,224 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.xpack.ml.queries; + +import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.InferenceAction; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.SemanticTextFieldMapper; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.CoordinatorRewriteContext; +import org.elasticsearch.index.query.CoordinatorRewriteableQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_FIELD_ADDED; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class SemanticQueryBuilder extends AbstractQueryBuilder implements CoordinatorRewriteableQueryBuilder { + + public static final String NAME = "semantic_query"; + + private final String fieldName; + private final String query; + + private static final ParseField QUERY_FIELD = new ParseField("query"); + + private SetOnce inferenceResultsSupplier; + + public SemanticQueryBuilder(String fieldName, String query) { + if (fieldName == null) { + throw new IllegalArgumentException("[" + NAME + "] requires a fieldName"); + } + if (query == null) { + throw new IllegalArgumentException("[" + NAME + "] requires a query"); + } + this.fieldName = fieldName; + this.query = query; + } + + public SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce inferenceResultsSupplier) { + this.fieldName = other.fieldName; + this.query = other.query; + this.inferenceResultsSupplier = inferenceResultsSupplier; + } + + public SemanticQueryBuilder(StreamInput in) throws IOException { + super(in); + this.fieldName = in.readString(); + this.query = in.readString(); + } + + @Override + protected QueryBuilder doCoordinatorRewrite(CoordinatorRewriteContext coordinatorRewriteContext) { + if (inferenceResultsSupplier != null) { + if (inferenceResultsSupplier.get() == null) { + // Inference still not returned + return this; + } + return inferenceResultsToQuery(fieldName, inferenceResultsSupplier.get()); + } + + Set modelNames = coordinatorRewriteContext.inferenceModelsForFieldName(fieldName); + if (modelNames == null) { + throw new IllegalArgumentException( + "field [" + fieldName + "] is not a " + SemanticTextFieldMapper.CONTENT_TYPE + " field type" + ); + } + + if (modelNames.size() > 1) { + throw new IllegalArgumentException("field [" + fieldName + "] has multiple models associated to it on different indices. " + + "A single model needs to be associated to the field in all the indices that contain it"); + } + + // TODO Hardcoding task type + String modelId = modelNames.iterator().next(); + InferenceAction.Request inferenceRequest = new InferenceAction.Request(TaskType.SPARSE_EMBEDDING, modelId, query, Map.of()); + + SetOnce inferenceResultsSupplier = new SetOnce<>(); + coordinatorRewriteContext.registerAsyncAction((client, listener) -> { + executeAsyncWithOrigin(client, ML_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, ActionListener.wrap(inferenceResponse -> { + inferenceResultsSupplier.set(inferenceResponse.getResult()); + listener.onResponse(null); + }, listener::onFailure)); + }); + + return new SemanticQueryBuilder(this, inferenceResultsSupplier); + } + + private static QueryBuilder inferenceResultsToQuery(String fieldName, InferenceResults inferenceResults) { + if (inferenceResults instanceof TextExpansionResults expansionResults) { + var boolQuery = QueryBuilders.boolQuery(); + for (var weightedToken : expansionResults.getWeightedTokens()) { + boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight())); + } + boolQuery.minimumShouldMatch(1); + return boolQuery; + } else { + throw new IllegalArgumentException( + "field [" + fieldName + "] does not use a model that outputs sparse vector inference results" + ); + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return SEMANTIC_TEXT_FIELD_ADDED; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + if (inferenceResultsSupplier != null) { + throw new IllegalStateException("inference supplier must be null, can't serialize suppliers, missing a rewriteAndFetch?"); + } + out.writeString(fieldName); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(NAME); + builder.startObject(fieldName); + boostAndQueryNameToXContent(builder); + builder.endObject(); + builder.endObject(); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + throw new IllegalStateException("semantic_query should have been rewritten to another query type"); + } + + @Override + protected boolean doEquals(SemanticQueryBuilder other) { + return Objects.equals(fieldName, other.fieldName) + && Objects.equals(query, other.query) + && Objects.equals(inferenceResultsSupplier, other.inferenceResultsSupplier); + } + + @Override + protected int doHashCode() { + return Objects.hash(fieldName, query, inferenceResultsSupplier); + } + + public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IOException { + String fieldName = null; + String query = null; + float boost = AbstractQueryBuilder.DEFAULT_BOOST; + String queryName = null; + String currentFieldName = null; + XContentParser.Token token; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + throwParsingExceptionOnMultipleFields(NAME, parser.getTokenLocation(), fieldName, currentFieldName); + fieldName = currentFieldName; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + query = parser.text(); + } else if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + boost = parser.floatValue(); + } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + queryName = parser.text(); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" + ); + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); + } + } + } + } + + if (fieldName == null) { + throw new ParsingException(parser.getTokenLocation(), "No field name specified for semantic query"); + } + + if (query == null) { + throw new ParsingException(parser.getTokenLocation(), "No query specified for semantic query"); + } + + SemanticQueryBuilder queryBuilder = new SemanticQueryBuilder(fieldName, query); + queryBuilder.queryName(queryName); + queryBuilder.boost(boost); + return queryBuilder; + } +}