From 1d1e81915929f0302e2fa02d7f63db85c0b152ee Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Sat, 30 Nov 2024 17:04:04 +0000 Subject: [PATCH] iter --- .../action/update/TransportUpdateAction.java | 21 +- .../index/mapper/DocumentParserContext.java | 21 +- .../index/mapper/InferenceFieldMapper.java | 9 - .../mapper/InferenceMetadataFieldsMapper.java | 3 +- .../index/mapper/SourceFieldMapper.java | 2 +- .../index/mapper/MapperTestCase.java | 2 + x-pack/plugin/core/build.gradle | 1 + .../core/ml/search}/SparseVectorQuery.java | 2 +- .../ml/search}/SparseVectorQueryBuilder.java | 3 +- .../ml/search}/TextExpansionQueryBuilder.java | 4 +- .../core/ml/search}/TokenPruningConfig.java | 4 +- .../search}/WeightedTokensQueryBuilder.java | 3 +- .../core/ml/search}/WeightedTokensUtils.java | 3 +- .../search/SparseVectorQueryBuilderTests.java | 331 +++++ .../TextExpansionQueryBuilderTests.java | 292 ++++ .../ml/search/TokenPruningConfigTests.java | 41 + .../WeightedTokensQueryBuilderTests.java | 466 +++++++ x-pack/plugin/inference/build.gradle | 1 - .../xpack/inference/InferenceFeatures.java | 16 +- .../xpack/inference/InferencePlugin.java | 6 +- .../ShardBulkInferenceActionFilter.java | 23 +- .../highlight/SemanticTextHighlighter.java | 11 +- .../mapper/AbstractSemanticTextFieldType.java | 30 - .../mapper/LegacySemanticTextField.java | 324 ----- .../mapper/LegacySemanticTextFieldMapper.java | 817 ----------- .../mapper/OffsetSourceFieldMapper.java | 41 +- .../inference/mapper/SemanticTextField.java | 152 +- .../mapper/SemanticTextFieldMapper.java | 223 +-- .../inference/mapper/SemanticTextUtils.java | 152 ++ .../queries/SemanticQueryBuilder.java | 4 +- .../LegacySemanticTextFieldMapperTests.java | 1227 ----------------- .../mapper/LegacySemanticTextFieldTests.java | 292 ---- .../mapper/SemanticTextFieldMapperTests.java | 128 +- .../mapper/SemanticTextFieldTests.java | 59 +- .../mapper/SemanticTextUtilsTests.java | 351 +++++ .../queries/SemanticQueryBuilderTests.java | 1 + .../SparseVectorQueryBuilderTests.java | 5 +- .../TextExpansionQueryBuilderTests.java | 3 + .../queries/TokenPruningConfigTests.java | 1 + .../WeightedTokensQueryBuilderTests.java | 5 +- 40 files changed, 2026 insertions(+), 3054 deletions(-) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/queries => core/src/main/java/org/elasticsearch/xpack/core/ml/search}/SparseVectorQuery.java (97%) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/queries => core/src/main/java/org/elasticsearch/xpack/core/ml/search}/SparseVectorQueryBuilder.java (99%) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/queries => core/src/main/java/org/elasticsearch/xpack/core/ml/search}/TextExpansionQueryBuilder.java (98%) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/queries => core/src/main/java/org/elasticsearch/xpack/core/ml/search}/TokenPruningConfig.java (98%) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/queries => core/src/main/java/org/elasticsearch/xpack/core/ml/search}/WeightedTokensQueryBuilder.java (98%) rename x-pack/plugin/{inference/src/main/java/org/elasticsearch/xpack/inference/queries => core/src/main/java/org/elasticsearch/xpack/core/ml/search}/WeightedTokensUtils.java (97%) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/AbstractSemanticTextFieldType.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextField.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapper.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtils.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapperTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtilsTests.java diff --git a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java index 0749512635f83..6662eba8189fb 100644 --- a/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java +++ b/server/src/main/java/org/elasticsearch/action/update/TransportUpdateAction.java @@ -412,7 +412,7 @@ private static UpdateHelper.Result deleteInferenceResults( // This has two important side effects: // - The inference field value will remain parsable by its mapper // - The inference results will be removed, forcing them to be re-generated downstream - updatedSource.put(inferenceFieldName, inferenceFieldMapper.getOriginalValue(updatedSource)); + updatedSource.put(inferenceFieldName, getOriginalValueLegacy(inferenceFieldName, updatedSource)); updatedSourceModified = true; break; } @@ -435,4 +435,23 @@ private static UpdateHelper.Result deleteInferenceResults( return returnedResult; } + + /** + * Get the field's original value (i.e. the value the user specified) from the provided source. + * + * @param sourceAsMap The source as a map + * @return The field's original value, or {@code null} if none was provided + */ + private static Object getOriginalValueLegacy(String fullPath, Map sourceAsMap) { + Object fieldValue = sourceAsMap.get(fullPath); + if (fieldValue == null) { + return null; + } else if (fieldValue instanceof Map == false) { + // Don't try to further validate the non-map value, that will be handled when the source is fully parsed + return fieldValue; + } + + Map fieldValueMap = XContentMapValues.nodeMapValue(fieldValue, "Field [" + fullPath + "]"); + return XContentMapValues.extractValue("text", fieldValueMap); + } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java index d8ff772ce4dcf..9f4dbabf4e700 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java @@ -66,8 +66,8 @@ public boolean isWithinInferenceMetadata() { } @Override - public void markInferenceMetadata() { - in.markInferenceMetadata(); + public void markInferenceMetadataField() { + in.markInferenceMetadataField(); } @Override @@ -155,7 +155,7 @@ private enum Scope { // Indicates if the source for this context has been marked to be recorded. Applies to synthetic source only. private boolean recordedSource; - private boolean inferenceMetadata; + private boolean hasInferenceMetadata; private DocumentParserContext( MappingLookup mappingLookup, @@ -349,12 +349,19 @@ public final DocumentParserContext addIgnoredFieldFromContext(IgnoredSourceField return this; } - public void markInferenceMetadata() { - this.inferenceMetadata = true; + /** + * Called by {@link InferenceMetadataFieldsMapper} to indicate whether the metadata field is present + * in _source. + */ + public void markInferenceMetadataField() { + this.hasInferenceMetadata = true; } - public final boolean hasInferenceMetadata() { - return false;// TODO: inferenceMetadata; + /** + * Returns whether the _source contains an inference metadata field. + */ + public final boolean hasInferenceMetadataField() { + return hasInferenceMetadata; } /** diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java index 249ef5004e59c..f7c6eef7dfd49 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceFieldMapper.java @@ -12,7 +12,6 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.inference.InferenceService; -import java.util.Map; import java.util.Set; /** @@ -26,12 +25,4 @@ public interface InferenceFieldMapper { * @param sourcePaths The source path that populates the input for the field (before inference) */ InferenceFieldMetadata getMetadata(Set sourcePaths); - - /** - * Get the field's original value (i.e. the value the user specified) from the provided source. - * - * @param sourceAsMap The source as a map - * @return The field's original value, or {@code null} if none was provided - */ - Object getOriginalValue(Map sourceAsMap); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java index 1550f6298135b..b713ecada5237 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/InferenceMetadataFieldsMapper.java @@ -71,10 +71,11 @@ protected boolean supportsParsingObject() { protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - context.markInferenceMetadata(); + context.markInferenceMetadataField(); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.currentToken(), parser); String fieldName = parser.currentName(); + // TODO: Find the leaf field under objects Mapper mapper = context.mappingLookup().getMapper(fieldName); if (mapper != null && mapper instanceof InferenceFieldMapper && mapper instanceof FieldMapper fieldMapper) { fieldMapper.parseCreateField(new DocumentParserContext.Wrapper(context.parent(), context) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index 709cfd562af71..b0bac9eaaefaf 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -428,7 +428,7 @@ public BytesReference applyFilters( return null; } var modSourceFilter = sourceFilter; - if (context != null && context.hasInferenceMetadata()) { + if (context != null && context.hasInferenceMetadataField()) { String[] modExcludes = new String[excludes != null ? excludes.length + 1 : 1]; if (excludes != null) { System.arraycopy(excludes, 0, modExcludes, 0, excludes.length); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java index 29bb3b15a9f86..02aa483c4e214 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperTestCase.java @@ -1166,8 +1166,10 @@ public void testSupportsParsingObject() throws IOException { Object sampleValueForDocument = getSampleObjectForDocument(); assertThat(sampleValueForDocument, instanceOf(Map.class)); SourceToParse source = source(builder -> { + builder.startObject(InferenceMetadataFieldsMapper.NAME); builder.field("field"); builder.value(sampleValueForDocument); + builder.endObject(); }); ParsedDocument doc = mapper.parse(source); assertNotNull(doc); diff --git a/x-pack/plugin/core/build.gradle b/x-pack/plugin/core/build.gradle index f0eea64f9be6e..5d2c8c9811c0a 100644 --- a/x-pack/plugin/core/build.gradle +++ b/x-pack/plugin/core/build.gradle @@ -68,6 +68,7 @@ dependencies { testImplementation project(path: ':modules:analysis-common') testImplementation project(path: ':modules:rest-root') testImplementation project(path: ':modules:health-shards-availability') + testImplementation project(path: ':modules:mapper-extras') // Needed for Fips140ProviderVerificationTests testCompileOnly('org.bouncycastle:bc-fips:1.0.2.5') diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SparseVectorQuery.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQuery.java similarity index 97% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SparseVectorQuery.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQuery.java index 32e4623454f17..2cb4d6777dcb9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SparseVectorQuery.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQuery.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.IndexSearcher; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilder.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java index 752009b7b910a..e9e4e90421adc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -33,7 +33,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilder.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java index be435bd18b55c..81758ec5f9342 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilder.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.search.Query; import org.apache.lucene.util.SetOnce; @@ -39,7 +39,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.inference.queries.WeightedTokensQueryBuilder.PRUNING_CONFIG; +import static org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder.PRUNING_CONFIG; /** * @deprecated Replaced by sparse_vector query diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfig.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfig.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfig.java index 6f5c2995af8b8..13358839830ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfig.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.core.ml.search; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.io.stream.StreamInput; @@ -22,7 +22,7 @@ import java.util.Objects; import java.util.Set; -import static org.elasticsearch.xpack.inference.queries.WeightedTokensQueryBuilder.PRUNING_CONFIG; +import static org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder.PRUNING_CONFIG; public class TokenPruningConfig implements Writeable, ToXContentObject { public static final ParseField TOKENS_FREQ_RATIO_THRESHOLD = new ParseField("tokens_freq_ratio_threshold"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilder.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java index 8246b8a399310..f41fcd77ce627 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilder.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -24,7 +24,6 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/WeightedTokensUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java similarity index 97% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/WeightedTokensUtils.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java index 0d3e628cea07a..0fcd07ed8ce08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/WeightedTokensUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensUtils.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.queries; +package org.elasticsearch.xpack.core.ml.search; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; @@ -16,7 +16,6 @@ import org.apache.lucene.search.Query; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java new file mode 100644 index 0000000000000..f88e7467c29b2 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java @@ -0,0 +1,331 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.search; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.test.index.IndexVersionUtils; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.either; +import static org.hamcrest.Matchers.hasSize; + +public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase { + + private static final String SPARSE_VECTOR_FIELD = "mySparseVectorField"; + private static final List WEIGHTED_TOKENS = List.of(new WeightedToken("foo", .42f)); + private static final int NUM_TOKENS = WEIGHTED_TOKENS.size(); + + @Override + protected SparseVectorQueryBuilder doCreateTestQueryBuilder() { + TokenPruningConfig tokenPruningConfig = randomBoolean() + ? new TokenPruningConfig(randomIntBetween(1, 100), randomFloat(), randomBoolean()) + : null; + return createTestQueryBuilder(tokenPruningConfig); + } + + private SparseVectorQueryBuilder createTestQueryBuilder(TokenPruningConfig tokenPruningConfig) { + SparseVectorQueryBuilder builder; + if (randomBoolean()) { + builder = new SparseVectorQueryBuilder( + SPARSE_VECTOR_FIELD, + null, + randomAlphaOfLength(10), + randomAlphaOfLengthBetween(10, 25), + tokenPruningConfig != null, + tokenPruningConfig + ); + } else { + builder = new SparseVectorQueryBuilder( + SPARSE_VECTOR_FIELD, + WEIGHTED_TOKENS, + null, + null, + tokenPruningConfig != null, + tokenPruningConfig + ); + } + + if (randomBoolean()) { + builder.boost((float) randomDoubleBetween(0.1, 10.0, true)); + } + if (randomBoolean()) { + builder.queryName(randomAlphaOfLength(4)); + } + return builder; + } + + @Override + protected Collection> getPlugins() { + return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class); + } + + @Override + protected Settings createTestIndexSettings() { + // The sparse_vector field is not supported on versions 8.0 to 8.10. Because of this we'll only allow + // index versions after its reintroduction. + final IndexVersion indexVersionCreated = randomBoolean() + ? IndexVersion.current() + : IndexVersionUtils.randomVersionBetween(random(), IndexVersions.NEW_SPARSE_VECTOR, IndexVersion.current()); + return Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, indexVersionCreated).build(); + } + + @Override + protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException { + return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class)) + && (args[0] instanceof CoordinatedInferenceAction); + } + + @Override + protected Object simulateMethod(Method method, Object[] args) { + CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1]; + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout()); + assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType()); + assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType()); + + // Randomisation cannot be used here as {@code #doAssertLuceneQuery} + // asserts that 2 rewritten queries are the same + var tokens = new ArrayList(); + for (int i = 0; i < NUM_TOKENS; i++) { + tokens.add(new WeightedToken(Integer.toString(i), (i + 1) * 1.0f)); + } + + var response = InferModelAction.Response.builder() + .setId(request.getModelId()) + .addInferenceResults(List.of(new TextExpansionResults("foo", tokens, randomBoolean()))) + .build(); + @SuppressWarnings("unchecked") // We matched the method above. + ActionListener listener = (ActionListener) args[2]; + listener.onResponse(response); + return null; + } + + @Override + protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + mapperService.merge( + "_doc", + new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(SPARSE_VECTOR_FIELD, "type=sparse_vector"))), + MapperService.MergeReason.MAPPING_UPDATE + ); + } + + @Override + protected void doAssertLuceneQuery(SparseVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) { + assertThat(query, instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) query; + assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1); + assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS)); + + Class featureQueryClass = FeatureField.newLinearQuery("", "", 0.5f).getClass(); + // if the weight is 1.0f a BoostQuery is returned + Class boostQueryClass = FeatureField.newLinearQuery("", "", 1.0f).getClass(); + + for (var clause : booleanQuery.clauses()) { + assertEquals(BooleanClause.Occur.SHOULD, clause.occur()); + assertThat(clause.query(), either(instanceOf(featureQueryClass)).or(instanceOf(boostQueryClass))); + } + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader} + */ + @Override + public void testCacheability() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(SPARSE_VECTOR_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(); + QueryBuilder rewriteQuery = rewriteQuery(queryBuilder, new SearchExecutionContext(context)); + + assertNotNull(rewriteQuery.toQuery(context)); + assertTrue("query should be cacheable: " + queryBuilder.toString(), context.isCacheable()); + } + } + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader}; this query should always be rewritten + */ + @Override + public void testMustRewrite() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(SPARSE_VECTOR_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(); + queryBuilder.toQuery(context); + } + } + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader}; this query should always be rewritten + */ + @Override + public void testToQuery() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(SPARSE_VECTOR_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(); + if (queryBuilder.getQueryVectors() == null) { + QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, context); + assertTrue(rewrittenQueryBuilder instanceof SparseVectorQueryBuilder); + testDoToQuery((SparseVectorQueryBuilder) rewrittenQueryBuilder, context); + } else { + testDoToQuery(queryBuilder, context); + } + } + } + } + + private void testDoToQuery(SparseVectorQueryBuilder queryBuilder, SearchExecutionContext context) throws IOException { + Query query = queryBuilder.doToQuery(context); + if (queryBuilder.shouldPruneTokens()) { + // It's possible that all documents were pruned for aggressive pruning configurations + assertTrue(query instanceof BooleanQuery || query instanceof MatchNoDocsQuery); + } else { + assertTrue(query instanceof SparseVectorQuery); + } + } + + public void testIllegalValues() { + { + // This will be caught and returned in the API as an IllegalArgumentException + NullPointerException e = expectThrows( + NullPointerException.class, + () -> new SparseVectorQueryBuilder(null, "model text", "model id") + ); + assertEquals("[sparse_vector] requires a [field]", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new SparseVectorQueryBuilder("field name", null, "model id") + ); + assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new SparseVectorQueryBuilder("field name", "model text", null) + ); + assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage()); + } + } + + public void testToXContent() throws IOException { + QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", "baz"); + checkGeneratedJson(""" + { + "sparse_vector": { + "field": "foo", + "inference_id": "bar", + "query": "baz", + "prune": false + } + }""", query); + } + + public void testToXContentWithThresholds() throws IOException { + QueryBuilder query = new SparseVectorQueryBuilder("foo", null, "bar", "baz", true, new TokenPruningConfig(4, 0.3f, false)); + checkGeneratedJson(""" + { + "sparse_vector": { + "field": "foo", + "inference_id": "bar", + "query": "baz", + "prune": true, + "pruning_config": { + "tokens_freq_ratio_threshold": 4.0, + "tokens_weight_threshold": 0.3 + } + } + }""", query); + } + + public void testToXContentWithThresholdsAndOnlyScorePrunedTokens() throws IOException { + QueryBuilder query = new SparseVectorQueryBuilder("foo", null, "bar", "baz", true, new TokenPruningConfig(4, 0.3f, true)); + + checkGeneratedJson(""" + { + "sparse_vector": { + "field": "foo", + "inference_id": "bar", + "query": "baz", + "prune": true, + "pruning_config": { + "tokens_freq_ratio_threshold": 4.0, + "tokens_weight_threshold": 0.3, + "only_score_pruned_tokens": true + } + } + }""", query); + } + + @Override + protected String[] shuffleProtectedFields() { + return new String[] { QUERY_VECTOR_FIELD.getPreferredName() }; + } + + public void testThatWeCorrectlyRewriteQueryIntoVectors() { + SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); + + TokenPruningConfig TokenPruningConfig = randomBoolean() ? new TokenPruningConfig(2, 0.3f, false) : null; + + SparseVectorQueryBuilder queryBuilder = createTestQueryBuilder(TokenPruningConfig); + QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); + assertTrue(rewrittenQueryBuilder instanceof SparseVectorQueryBuilder); + assertEquals(queryBuilder.shouldPruneTokens(), ((SparseVectorQueryBuilder) rewrittenQueryBuilder).shouldPruneTokens()); + assertNotNull(((SparseVectorQueryBuilder) rewrittenQueryBuilder).getQueryVectors()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java new file mode 100644 index 0000000000000..b99158bba788b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TextExpansionQueryBuilderTests.java @@ -0,0 +1,292 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.search; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.either; +import static org.hamcrest.Matchers.hasSize; + +public class TextExpansionQueryBuilderTests extends AbstractQueryTestCase { + + private static final String RANK_FEATURES_FIELD = "rank"; + private static final int NUM_TOKENS = 10; + + @Override + protected TextExpansionQueryBuilder doCreateTestQueryBuilder() { + TokenPruningConfig tokenPruningConfig = randomBoolean() + ? new TokenPruningConfig(randomIntBetween(1, 100), randomFloat(), randomBoolean()) + : null; + var builder = new TextExpansionQueryBuilder( + RANK_FEATURES_FIELD, + randomAlphaOfLength(4), + randomAlphaOfLength(4), + tokenPruningConfig + ); + if (randomBoolean()) { + builder.boost((float) randomDoubleBetween(0.1, 10.0, true)); + } + if (randomBoolean()) { + builder.queryName(randomAlphaOfLength(4)); + } + return builder; + } + + @Override + protected Collection> getPlugins() { + return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class); + } + + @Override + public void testMustRewrite() { + SearchExecutionContext context = createSearchExecutionContext(); + TextExpansionQueryBuilder builder = new TextExpansionQueryBuilder("foo", "bar", "baz"); + IllegalStateException e = expectThrows(IllegalStateException.class, () -> builder.toQuery(context)); + assertEquals("text_expansion should have been rewritten to another query type", e.getMessage()); + } + + @Override + protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException { + return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class)) + && (args[0] instanceof CoordinatedInferenceAction); + } + + @Override + protected Object simulateMethod(Method method, Object[] args) { + CoordinatedInferenceAction.Request request = (CoordinatedInferenceAction.Request) args[1]; + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout()); + assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType()); + assertEquals(CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL, request.getRequestModelType()); + + // Randomisation cannot be used here as {@code #doAssertLuceneQuery} + // asserts that 2 rewritten queries are the same + var tokens = new ArrayList(); + for (int i = 0; i < NUM_TOKENS; i++) { + tokens.add(new WeightedToken(Integer.toString(i), (i + 1) * 1.0f)); + } + + var response = InferModelAction.Response.builder() + .setId(request.getModelId()) + .addInferenceResults(List.of(new TextExpansionResults("foo", tokens, randomBoolean()))) + .build(); + @SuppressWarnings("unchecked") // We matched the method above. + ActionListener listener = (ActionListener) args[2]; + listener.onResponse(response); + return null; + } + + @Override + protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + mapperService.merge( + "_doc", + new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(RANK_FEATURES_FIELD, "type=rank_features"))), + MapperService.MergeReason.MAPPING_UPDATE + ); + } + + @Override + protected void doAssertLuceneQuery(TextExpansionQueryBuilder queryBuilder, Query query, SearchExecutionContext context) { + assertThat(query, instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) query; + assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1); + assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS)); + + Class featureQueryClass = FeatureField.newLinearQuery("", "", 0.5f).getClass(); + // if the weight is 1.0f a BoostQuery is returned + Class boostQueryClass = FeatureField.newLinearQuery("", "", 1.0f).getClass(); + + for (var clause : booleanQuery.clauses()) { + assertEquals(BooleanClause.Occur.SHOULD, clause.occur()); + assertThat(clause.query(), either(instanceOf(featureQueryClass)).or(instanceOf(boostQueryClass))); + } + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader} + */ + @Override + public void testCacheability() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(RANK_FEATURES_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); + QueryBuilder rewriteQuery = rewriteQuery(queryBuilder, new SearchExecutionContext(context)); + + assertNotNull(rewriteQuery.toQuery(context)); + assertTrue("query should be cacheable: " + queryBuilder.toString(), context.isCacheable()); + } + } + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader}; this query should always be rewritten + */ + @Override + public void testToQuery() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(RANK_FEATURES_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); + IllegalStateException e = expectThrows(IllegalStateException.class, () -> queryBuilder.toQuery(context)); + assertEquals("text_expansion should have been rewritten to another query type", e.getMessage()); + } + } + } + + @Override + public void testFromXContent() throws IOException { + super.testFromXContent(); + assertCriticalWarnings(TextExpansionQueryBuilder.TEXT_EXPANSION_DEPRECATION_MESSAGE); + } + + @Override + public void testUnknownField() throws IOException { + super.testUnknownField(); + assertCriticalWarnings(TextExpansionQueryBuilder.TEXT_EXPANSION_DEPRECATION_MESSAGE); + } + + @Override + public void testUnknownObjectException() throws IOException { + super.testUnknownObjectException(); + assertCriticalWarnings(TextExpansionQueryBuilder.TEXT_EXPANSION_DEPRECATION_MESSAGE); + } + + @Override + public void testValidOutput() throws IOException { + super.testValidOutput(); + assertCriticalWarnings(TextExpansionQueryBuilder.TEXT_EXPANSION_DEPRECATION_MESSAGE); + } + + public void testIllegalValues() { + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new TextExpansionQueryBuilder(null, "model text", "model id") + ); + assertEquals("[text_expansion] requires a fieldName", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new TextExpansionQueryBuilder("field name", null, "model id") + ); + assertEquals("[text_expansion] requires a model_text value", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new TextExpansionQueryBuilder("field name", "model text", null) + ); + assertEquals("[text_expansion] requires a model_id value", e.getMessage()); + } + } + + public void testToXContent() throws IOException { + QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz"); + checkGeneratedJson(""" + { + "text_expansion": { + "foo": { + "model_text": "bar", + "model_id": "baz" + } + } + }""", query); + } + + public void testToXContentWithThresholds() throws IOException { + QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new TokenPruningConfig(4, 0.3f, false)); + checkGeneratedJson(""" + { + "text_expansion": { + "foo": { + "model_text": "bar", + "model_id": "baz", + "pruning_config": { + "tokens_freq_ratio_threshold": 4.0, + "tokens_weight_threshold": 0.3 + } + } + } + }""", query); + } + + public void testToXContentWithThresholdsAndOnlyScorePrunedTokens() throws IOException { + QueryBuilder query = new TextExpansionQueryBuilder("foo", "bar", "baz", new TokenPruningConfig(4, 0.3f, true)); + checkGeneratedJson(""" + { + "text_expansion": { + "foo": { + "model_text": "bar", + "model_id": "baz", + "pruning_config": { + "tokens_freq_ratio_threshold": 4.0, + "tokens_weight_threshold": 0.3, + "only_score_pruned_tokens": true + } + } + } + }""", query); + } + + @Override + protected String[] shuffleProtectedFields() { + return new String[] { WeightedTokensQueryBuilder.TOKENS_FIELD.getPreferredName() }; + } + + public void testThatTokensAreCorrectlyPruned() { + SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); + TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); + QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); + if (queryBuilder.getTokenPruningConfig() == null) { + assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder); + } else { + assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfigTests.java new file mode 100644 index 0000000000000..8cdf44ae51dd4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/TokenPruningConfigTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.search; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractXContentSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class TokenPruningConfigTests extends AbstractXContentSerializingTestCase { + + public static TokenPruningConfig testInstance() { + return new TokenPruningConfig(randomIntBetween(1, 100), randomFloat(), randomBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return TokenPruningConfig::new; + } + + @Override + protected TokenPruningConfig createTestInstance() { + return testInstance(); + } + + @Override + protected TokenPruningConfig mutateInstance(TokenPruningConfig instance) throws IOException { + return null; + } + + @Override + protected TokenPruningConfig doParseInstance(XContentParser parser) throws IOException { + return TokenPruningConfig.fromXContent(parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java new file mode 100644 index 0000000000000..7372def52355e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/WeightedTokensQueryBuilderTests.java @@ -0,0 +1,466 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.search; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.AbstractQueryTestCase; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder.TOKENS_FIELD; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.either; +import static org.hamcrest.Matchers.hasSize; + +public class WeightedTokensQueryBuilderTests extends AbstractQueryTestCase { + + private static final String RANK_FEATURES_FIELD = "rank"; + private static final List WEIGHTED_TOKENS = List.of(new WeightedToken("foo", .42f)); + private static final int NUM_TOKENS = WEIGHTED_TOKENS.size(); + + @Override + protected WeightedTokensQueryBuilder doCreateTestQueryBuilder() { + return createTestQueryBuilder(randomBoolean()); + } + + private WeightedTokensQueryBuilder createTestQueryBuilder(boolean onlyScorePrunedTokens) { + TokenPruningConfig tokenPruningConfig = randomBoolean() + ? new TokenPruningConfig(randomIntBetween(1, 100), randomFloat(), onlyScorePrunedTokens) + : null; + + var builder = new WeightedTokensQueryBuilder(RANK_FEATURES_FIELD, WEIGHTED_TOKENS, tokenPruningConfig); + if (randomBoolean()) { + builder.boost((float) randomDoubleBetween(0.1, 10.0, true)); + } + if (randomBoolean()) { + builder.queryName(randomAlphaOfLength(4)); + } + return builder; + } + + @Override + protected Collection> getPlugins() { + return List.of(MapperExtrasPlugin.class, XPackClientPlugin.class); + } + + @Override + protected boolean canSimulateMethod(Method method, Object[] args) throws NoSuchMethodException { + return method.equals(Client.class.getMethod("execute", ActionType.class, ActionRequest.class, ActionListener.class)) + && (args[0] instanceof InferModelAction); + } + + @Override + protected Object simulateMethod(Method method, Object[] args) { + InferModelAction.Request request = (InferModelAction.Request) args[1]; + assertEquals(InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, request.getInferenceTimeout()); + assertEquals(TrainedModelPrefixStrings.PrefixType.SEARCH, request.getPrefixType()); + + // Randomisation of tokens cannot be used here as {@code #doAssertLuceneQuery} + // asserts that 2 rewritten queries are the same + var response = InferModelAction.Response.builder() + .setId(request.getId()) + .addInferenceResults(List.of(new TextExpansionResults("foo", WEIGHTED_TOKENS.stream().toList(), randomBoolean()))) + .build(); + @SuppressWarnings("unchecked") // We matched the method above. + ActionListener listener = (ActionListener) args[2]; + listener.onResponse(response); + return null; + } + + @Override + protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + mapperService.merge( + "_doc", + new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(RANK_FEATURES_FIELD, "type=rank_features"))), + MapperService.MergeReason.MAPPING_UPDATE + ); + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader} + */ + @Override + public void testToQuery() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + // Index at least one document so we have a freq > 0 + Document document = new Document(); + document.add(new FeatureField(RANK_FEATURES_FIELD, "foo", 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + // We need to force token pruning config here, to get repeatable lucene queries for comparison + WeightedTokensQueryBuilder firstQuery = createTestQueryBuilder(false); + WeightedTokensQueryBuilder controlQuery = copyQuery(firstQuery); + QueryBuilder rewritten = rewriteQuery(firstQuery, context); + Query firstLuceneQuery = rewritten.toQuery(context); + assertNotNull("toQuery should not return null", firstLuceneQuery); + assertLuceneQuery(firstQuery, firstLuceneQuery, context); + assertEquals( + "query is not equal to its copy after calling toQuery, firstQuery: " + firstQuery + ", secondQuery: " + controlQuery, + firstQuery, + controlQuery + ); + assertEquals( + "equals is not symmetric after calling toQuery, firstQuery: " + firstQuery + ", secondQuery: " + controlQuery, + controlQuery, + firstQuery + ); + assertThat( + "query copy's hashcode is different from original hashcode after calling toQuery, firstQuery: " + + firstQuery + + ", secondQuery: " + + controlQuery, + controlQuery.hashCode(), + equalTo(firstQuery.hashCode()) + ); + WeightedTokensQueryBuilder secondQuery = copyQuery(firstQuery); + + // query _name never should affect the result of toQuery, we randomly set it to make sure + if (randomBoolean()) { + secondQuery.queryName( + secondQuery.queryName() == null + ? randomAlphaOfLengthBetween(1, 30) + : secondQuery.queryName() + randomAlphaOfLengthBetween(1, 10) + ); + } + context = new SearchExecutionContext(context); + Query secondLuceneQuery = rewriteQuery(secondQuery, context).toQuery(context); + assertNotNull("toQuery should not return null", secondLuceneQuery); + assertLuceneQuery(secondQuery, secondLuceneQuery, context); + + if (builderGeneratesCacheableQueries()) { + assertEquals( + "two equivalent query builders lead to different lucene queries hashcode", + secondLuceneQuery.hashCode(), + firstLuceneQuery.hashCode() + ); + assertEquals( + "two equivalent query builders lead to different lucene queries", + rewrite(secondLuceneQuery), + rewrite(firstLuceneQuery) + ); + } + + if (supportsBoost() && firstLuceneQuery instanceof MatchNoDocsQuery == false) { + secondQuery.boost(firstQuery.boost() + 1f + randomFloat()); + Query thirdLuceneQuery = rewriteQuery(secondQuery, context).toQuery(context); + assertNotEquals( + "modifying the boost doesn't affect the corresponding lucene query", + rewrite(firstLuceneQuery), + rewrite(thirdLuceneQuery) + ); + } + + } + } + } + + @Override + public void testFromXContent() throws IOException { + super.testFromXContent(); + assertCriticalWarnings(WeightedTokensQueryBuilder.WEIGHTED_TOKENS_DEPRECATION_MESSAGE); + } + + @Override + public void testUnknownField() throws IOException { + super.testUnknownField(); + assertCriticalWarnings(WeightedTokensQueryBuilder.WEIGHTED_TOKENS_DEPRECATION_MESSAGE); + } + + @Override + public void testUnknownObjectException() throws IOException { + super.testUnknownObjectException(); + assertCriticalWarnings(WeightedTokensQueryBuilder.WEIGHTED_TOKENS_DEPRECATION_MESSAGE); + } + + @Override + public void testValidOutput() throws IOException { + super.testValidOutput(); + assertCriticalWarnings(WeightedTokensQueryBuilder.WEIGHTED_TOKENS_DEPRECATION_MESSAGE); + } + + public void testPruningIsAppliedCorrectly() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + List documents = List.of( + createDocument( + List.of("the", "quick", "brown", "fox", "jumped", "over", "lazy", "dog", "me"), + List.of(.2f, 1.8f, 1.75f, 5.9f, 1.6f, 1.4f, .4f, 4.8f, 2.1f) + ), + createDocument( + List.of("the", "rains", "in", "spain", "fall", "mainly", "on", "plain", "me"), + List.of(.1f, 3.6f, .1f, 4.8f, .6f, .3f, .1f, 2.6f, 2.1f) + ), + createDocument( + List.of("betty", "bought", "butter", "but", "the", "was", "bitter", "me"), + List.of(6.8f, 1.4f, .5f, 3.2f, .1f, 3.2f, .6f, 2.1f) + ), + createDocument( + List.of("she", "sells", "seashells", "by", "the", "seashore", "me"), + List.of(.2f, 1.4f, 5.9f, .1f, .1f, 3.6f, 2.1f) + ) + ); + iw.addDocuments(documents); + + List inputTokens = List.of( + new WeightedToken("the", .1f), // Will be pruned - score too low, freq too high + new WeightedToken("black", 5.3f), // Will be pruned - does not exist in index + new WeightedToken("dog", 7.5f), // Will be kept - high score and low freq + new WeightedToken("jumped", 4.5f), // Will be kept - high score and low freq + new WeightedToken("on", .1f), // Will be kept - low score but also low freq + new WeightedToken("me", 3.8f) // Will be kept - high freq but also high score + ); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + + WeightedTokensQueryBuilder noPruningQuery = new WeightedTokensQueryBuilder(RANK_FEATURES_FIELD, inputTokens, null); + Query query = noPruningQuery.doToQuery(context); + assertCorrectLuceneQuery("noPruningQuery", query, List.of("the", "black", "dog", "jumped", "on", "me")); + + WeightedTokensQueryBuilder queryThatShouldBePruned = new WeightedTokensQueryBuilder( + RANK_FEATURES_FIELD, + inputTokens, + new TokenPruningConfig(2, 0.5f, false) + ); + query = queryThatShouldBePruned.doToQuery(context); + assertCorrectLuceneQuery("queryThatShouldBePruned", query, List.of("dog", "jumped", "on", "me")); + + WeightedTokensQueryBuilder onlyScorePrunedTokensQuery = new WeightedTokensQueryBuilder( + RANK_FEATURES_FIELD, + inputTokens, + new TokenPruningConfig(2, 0.5f, true) + ); + query = onlyScorePrunedTokensQuery.doToQuery(context); + assertCorrectLuceneQuery("onlyScorePrunedTokensQuery", query, List.of("the", "black")); + } + } + } + + private void assertCorrectLuceneQuery(String name, Query query, List expectedFeatureFields) { + assertTrue(query instanceof SparseVectorQuery); + Query termsQuery = ((SparseVectorQuery) query).getTermsQuery(); + assertTrue(termsQuery instanceof BooleanQuery); + List booleanClauses = ((BooleanQuery) termsQuery).clauses(); + assertEquals( + name + " had " + booleanClauses.size() + " clauses, expected " + expectedFeatureFields.size(), + expectedFeatureFields.size(), + booleanClauses.size() + ); + for (int i = 0; i < booleanClauses.size(); i++) { + Query clauseQuery = booleanClauses.get(i).query(); + assertTrue(name + " query " + query + " expected to be a BoostQuery", clauseQuery instanceof BoostQuery); + // FeatureQuery is not visible so we check the String representation + assertTrue(name + " query " + query + " expected to be a FeatureQuery", clauseQuery.toString().contains("FeatureQuery")); + assertTrue( + name + " query " + query + " expected to have field " + expectedFeatureFields.get(i), + clauseQuery.toString().contains("feature=" + expectedFeatureFields.get(i)) + ); + } + } + + private Document createDocument(List tokens, List weights) { + if (tokens.size() != weights.size()) { + throw new IllegalArgumentException( + "tokens and weights must have the same size. Got " + tokens.size() + " and " + weights.size() + "." + ); + } + Document document = new Document(); + for (int i = 0; i < tokens.size(); i++) { + document.add(new FeatureField(RANK_FEATURES_FIELD, tokens.get(i), weights.get(i))); + } + return document; + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader} + */ + @Override + public void testCacheability() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(RANK_FEATURES_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + WeightedTokensQueryBuilder queryBuilder = createTestQueryBuilder(); + QueryBuilder rewriteQuery = rewriteQuery(queryBuilder, new SearchExecutionContext(context)); + + assertNotNull(rewriteQuery.toQuery(context)); + assertTrue("query should be cacheable: " + queryBuilder.toString(), context.isCacheable()); + } + } + } + + /** + * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader} + */ + @Override + public void testMustRewrite() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + Document document = new Document(); + document.add(new FloatDocValuesField(RANK_FEATURES_FIELD, 1.0f)); + iw.addDocument(document); + try (IndexReader reader = iw.getReader()) { + SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); + context.setAllowUnmappedFields(true); + WeightedTokensQueryBuilder queryBuilder = createTestQueryBuilder(); + queryBuilder.toQuery(context); + } + } + } + + @Override + protected void doAssertLuceneQuery(WeightedTokensQueryBuilder queryBuilder, Query query, SearchExecutionContext context) { + assertThat(query, instanceOf(SparseVectorQuery.class)); + Query termsQuery = ((SparseVectorQuery) query).getTermsQuery(); + assertThat(termsQuery, instanceOf(BooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) termsQuery; + assertEquals(booleanQuery.getMinimumNumberShouldMatch(), 1); + assertThat(booleanQuery.clauses(), hasSize(NUM_TOKENS)); + + Class featureQueryClass = FeatureField.newLinearQuery("", "", 0.5f).getClass(); + // if the weight is 1.0f a BoostQuery is returned + Class boostQueryClass = FeatureField.newLinearQuery("", "", 1.0f).getClass(); + + for (var clause : booleanQuery.clauses()) { + assertEquals(BooleanClause.Occur.SHOULD, clause.occur()); + assertThat(clause.query(), either(instanceOf(featureQueryClass)).or(instanceOf(boostQueryClass))); + } + } + + public void testIllegalValues() { + List weightedTokens = List.of(new WeightedToken("foo", 1.0f)); + { + NullPointerException e = expectThrows( + NullPointerException.class, + () -> new WeightedTokensQueryBuilder(null, weightedTokens, null) + ); + assertEquals("[weighted_tokens] requires a fieldName", e.getMessage()); + } + { + NullPointerException e = expectThrows( + NullPointerException.class, + () -> new WeightedTokensQueryBuilder("field name", null, null) + ); + assertEquals("[weighted_tokens] requires tokens", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new WeightedTokensQueryBuilder("field name", List.of(), null) + ); + assertEquals("[weighted_tokens] requires at least one token", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new WeightedTokensQueryBuilder("field name", weightedTokens, new TokenPruningConfig(-1, 0.0f, false)) + ); + assertEquals("[tokens_freq_ratio_threshold] must be between [1] and [100], got -1.0", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new WeightedTokensQueryBuilder("field name", weightedTokens, new TokenPruningConfig(101, 0.0f, false)) + ); + assertEquals("[tokens_freq_ratio_threshold] must be between [1] and [100], got 101.0", e.getMessage()); + } + { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new WeightedTokensQueryBuilder("field name", weightedTokens, new TokenPruningConfig(5, 5f, false)) + ); + assertEquals("[tokens_weight_threshold] must be between 0 and 1", e.getMessage()); + } + } + + public void testToXContent() throws Exception { + QueryBuilder query = new WeightedTokensQueryBuilder("foo", WEIGHTED_TOKENS, null); + checkGeneratedJson(""" + { + "weighted_tokens": { + "foo": { + "tokens": { + "foo": 0.42 + } + } + } + }""", query); + } + + public void testToXContentWithThresholds() throws Exception { + QueryBuilder query = new WeightedTokensQueryBuilder("foo", WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.4f, false)); + checkGeneratedJson(""" + { + "weighted_tokens": { + "foo": { + "tokens": { + "foo": 0.42 + }, + "pruning_config": { + "tokens_freq_ratio_threshold": 4.0, + "tokens_weight_threshold": 0.4 + } + } + } + }""", query); + } + + public void testToXContentWithThresholdsAndOnlyScorePrunedTokens() throws Exception { + QueryBuilder query = new WeightedTokensQueryBuilder("foo", WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.4f, true)); + checkGeneratedJson(""" + { + "weighted_tokens": { + "foo": { + "tokens": { + "foo": 0.42 + }, + "pruning_config": { + "tokens_freq_ratio_threshold": 4.0, + "tokens_weight_threshold": 0.4, + "only_score_pruned_tokens": true + } + } + } + }""", query); + } + + @Override + protected String[] shuffleProtectedFields() { + return new String[] { TOKENS_FIELD.getPreferredName() }; + } +} diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 3e0ff7633267f..3c19e11a450b4 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -38,7 +38,6 @@ dependencies { testImplementation(testArtifact(project(':server'))) testImplementation(project(':x-pack:plugin:inference:qa:test-service-plugin')) testImplementation project(':modules:reindex') - testImplementation project(':modules:mapper-extras') clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') api "com.ibm.icu:icu4j:${versions.icu4j}" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 22980467a44ae..c82f287792a7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -9,7 +9,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.xpack.inference.mapper.LegacySemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; @@ -26,9 +26,9 @@ public Set getFeatures() { return Set.of( TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED, RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED, - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID, + SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID, SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS, - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2, + SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2, TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED ); } @@ -36,11 +36,11 @@ public Set getFeatures() { @Override public Set getTestFeatures() { return Set.of( - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX, - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX, - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, - LegacySemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX + SemanticTextFieldMapper.SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, + SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 98509b78f6452..8f1d201bba0e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -77,9 +77,9 @@ import org.elasticsearch.xpack.inference.mapper.OffsetSourceMetaFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; -import org.elasticsearch.xpack.inference.queries.SparseVectorQueryBuilder; -import org.elasticsearch.xpack.inference.queries.TextExpansionQueryBuilder; -import org.elasticsearch.xpack.inference.queries.WeightedTokensQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.TextExpansionQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 029a91ca208a7..5bafac8f08457 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -42,10 +42,9 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField; -import org.elasticsearch.xpack.inference.mapper.LegacySemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; @@ -394,33 +393,35 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons var model = responses.get(0).model(); // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); - List inputs = responses.stream() - .filter(r -> r.field().equals(fieldName)) - .map(r -> r.input) - .collect(Collectors.toList()); - assert inputs.size() == 1; List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); if (addMetadataField) { + List inputs = responses.stream() + .filter(r -> r.field().equals(fieldName)) + .map(r -> r.input) + .collect(Collectors.toList()); + assert inputs.size() == 1; var result = new SemanticTextField( fieldName, model.getInferenceEntityId(), new SemanticTextField.ModelSettings(model), - SemanticTextField.toSemanticTextFieldChunks(fieldName, inputs.get(0), results, indexRequest.getContentType()), + SemanticTextField.toSemanticTextFieldChunks(indexCreatedVersion, inputs.get(0), results, indexRequest.getContentType()), indexRequest.getContentType() ); inferenceFieldsMap.put(fieldName, result); } else { + List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); + assert inputs.size() == 1; var result = new LegacySemanticTextField( fieldName, inputs, new LegacySemanticTextField.InferenceResult( model.getInferenceEntityId(), - new LegacySemanticTextField.ModelSettings(model), + new SemanticTextField.ModelSettings(model), LegacySemanticTextField.toSemanticTextFieldChunks(results, indexRequest.getContentType()) ), indexRequest.getContentType() ); - LegacySemanticTextFieldMapper.insertValue(fieldName, newDocMap, result); + SemanticTextUtils.insertValue(fieldName, newDocMap, result); } } if (addMetadataField) { @@ -500,7 +501,7 @@ private Map> createFieldInferenceRequests(Bu ensureResponseAccumulatorSlot(itemIndex); final String value; try { - value = SemanticTextField.nodeStringValues(field, valueObj); + value = SemanticTextUtils.nodeStringValues(field, valueObj); } catch (Exception exc) { addInferenceResponseFailure(item.id(), exc); break; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java index 01723c5573ec2..09723d76bb358 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java @@ -33,9 +33,9 @@ import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; import org.elasticsearch.xpack.inference.mapper.OffsetSourceMetaFieldMapper; -import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.queries.SparseVectorQuery; +import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQuery; import java.io.IOException; import java.util.ArrayList; @@ -96,6 +96,9 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc fieldContext.hitContext.docId(), queries ); + if (chunks.size() == 0) { + return null; + } Map inputs = extractContentFields(fieldContext.hitContext, mappingLookup, inferenceMetadata.getSourceFields()); chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed()); @@ -122,7 +125,7 @@ private Map extractContentFields( } Object sourceValue = hitContext.source().extractValue(sourceFieldType.name(), null); if (sourceValue != null) { - inputs.put(sourceField, SemanticTextField.nodeStringValues(sourceFieldType.name(), sourceValue)); + inputs.put(sourceField, SemanticTextUtils.nodeStringValues(sourceFieldType.name(), sourceValue)); } } return inputs; @@ -144,7 +147,7 @@ private List extractOffsetAndScores( Scorer scorer = weight.scorer(reader.getContext()); var terms = reader.terms(OffsetSourceMetaFieldMapper.NAME); if (terms == null) { - // TODO: Empty terms + // The field is empty return List.of(); } var offsetReader = new OffsetSourceFieldMapper.OffsetsReader(terms, fieldType.getOffsetsField().fullPath()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/AbstractSemanticTextFieldType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/AbstractSemanticTextFieldType.java deleted file mode 100644 index ec601d6162e78..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/AbstractSemanticTextFieldType.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.elasticsearch.index.mapper.SimpleMappedFieldType; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.inference.InferenceResults; - -import java.util.Map; - -public abstract class AbstractSemanticTextFieldType extends SimpleMappedFieldType { - protected AbstractSemanticTextFieldType( - String name, - boolean isIndexed, - boolean isStored, - boolean hasDocValues, - TextSearchInfo textSearchInfo, - Map meta - ) { - super(name, isIndexed, isStored, hasDocValues, textSearchInfo, meta); - } - - public abstract QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer requestSize, float boost, String queryName); -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextField.java deleted file mode 100644 index 80fc21d68ff85..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextField.java +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.DeprecationHandler; -import org.elasticsearch.xcontent.NamedXContentRegistry; -import org.elasticsearch.xcontent.ObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.support.MapXContentParser; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; - -/** - * A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs. - * The resulting object preserves the original input under the {@link LegacySemanticTextField#TEXT_FIELD} and exposes - * the inference results under the {@link LegacySemanticTextField#INFERENCE_FIELD}. - * - * @param fieldName The original field name. - * @param originalValues The original values associated with the field name. - * @param inference The inference result. - * @param contentType The {@link XContentType} used to store the embeddings chunks. - */ -public record LegacySemanticTextField(String fieldName, List originalValues, InferenceResult inference, XContentType contentType) - implements - ToXContentObject { - - static final String TEXT_FIELD = "text"; - static final String INFERENCE_FIELD = "inference"; - static final String INFERENCE_ID_FIELD = "inference_id"; - static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; - static final String CHUNKS_FIELD = "chunks"; - static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; - static final String CHUNKED_TEXT_FIELD = "text"; - static final String MODEL_SETTINGS_FIELD = "model_settings"; - static final String TASK_TYPE_FIELD = "task_type"; - static final String DIMENSIONS_FIELD = "dimensions"; - static final String SIMILARITY_FIELD = "similarity"; - static final String ELEMENT_TYPE_FIELD = "element_type"; - - public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} - - public record Chunk(String text, BytesReference rawEmbeddings) {} - - public record ModelSettings( - TaskType taskType, - Integer dimensions, - SimilarityMeasure similarity, - DenseVectorFieldMapper.ElementType elementType - ) implements ToXContentObject { - public ModelSettings(Model model) { - this( - model.getTaskType(), - model.getServiceSettings().dimensions(), - model.getServiceSettings().similarity(), - model.getServiceSettings().elementType() - ); - } - - public ModelSettings( - TaskType taskType, - Integer dimensions, - SimilarityMeasure similarity, - DenseVectorFieldMapper.ElementType elementType - ) { - this.taskType = Objects.requireNonNull(taskType, "task type must not be null"); - this.dimensions = dimensions; - this.similarity = similarity; - this.elementType = elementType; - validate(); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TASK_TYPE_FIELD, taskType.toString()); - if (dimensions != null) { - builder.field(DIMENSIONS_FIELD, dimensions); - } - if (similarity != null) { - builder.field(SIMILARITY_FIELD, similarity); - } - if (elementType != null) { - builder.field(ELEMENT_TYPE_FIELD, elementType); - } - return builder.endObject(); - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder(); - sb.append("task_type=").append(taskType); - if (dimensions != null) { - sb.append(", dimensions=").append(dimensions); - } - if (similarity != null) { - sb.append(", similarity=").append(similarity); - } - if (elementType != null) { - sb.append(", element_type=").append(elementType); - } - return sb.toString(); - } - - private void validate() { - switch (taskType) { - case TEXT_EMBEDDING: - validateFieldPresent(DIMENSIONS_FIELD, dimensions); - validateFieldPresent(SIMILARITY_FIELD, similarity); - validateFieldPresent(ELEMENT_TYPE_FIELD, elementType); - break; - case SPARSE_EMBEDDING: - validateFieldNotPresent(DIMENSIONS_FIELD, dimensions); - validateFieldNotPresent(SIMILARITY_FIELD, similarity); - validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType); - break; - - default: - throw new IllegalArgumentException( - "Wrong [" - + TASK_TYPE_FIELD - + "], expected " - + TEXT_EMBEDDING - + " or " - + SPARSE_EMBEDDING - + ", got " - + taskType.name() - ); - } - } - - private void validateFieldPresent(String field, Object fieldValue) { - if (fieldValue == null) { - throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]"); - } - } - - private void validateFieldNotPresent(String field, Object fieldValue) { - if (fieldValue != null) { - throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]"); - } - } - } - - public static String getOriginalTextFieldName(String fieldName) { - return fieldName + "." + TEXT_FIELD; - } - - public static String getInferenceFieldName(String fieldName) { - return fieldName + "." + INFERENCE_FIELD; - } - - public static String getChunksFieldName(String fieldName) { - return getInferenceFieldName(fieldName) + "." + CHUNKS_FIELD; - } - - public static String getEmbeddingsFieldName(String fieldName) { - return getChunksFieldName(fieldName) + "." + CHUNKED_EMBEDDINGS_FIELD; - } - - static LegacySemanticTextField parse(XContentParser parser, Tuple context) throws IOException { - return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context); - } - - static ModelSettings parseModelSettings(XContentParser parser) throws IOException { - return MODEL_SETTINGS_PARSER.parse(parser, null); - } - - static ModelSettings parseModelSettingsFromMap(Object node) { - if (node == null) { - return null; - } - try { - Map map = XContentMapValues.nodeMapValue(node, MODEL_SETTINGS_FIELD); - XContentParser parser = new MapXContentParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.IGNORE_DEPRECATIONS, - map, - XContentType.JSON - ); - return parseModelSettings(parser); - } catch (Exception exc) { - throw new ElasticsearchException(exc); - } - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (originalValues.isEmpty() == false) { - builder.field(TEXT_FIELD, originalValues.size() == 1 ? originalValues.get(0) : originalValues); - } - builder.startObject(INFERENCE_FIELD); - builder.field(INFERENCE_ID_FIELD, inference.inferenceId); - builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); - builder.startArray(CHUNKS_FIELD); - for (var chunk : inference.chunks) { - builder.startObject(); - builder.field(CHUNKED_TEXT_FIELD, chunk.text); - XContentParser parser = XContentHelper.createParserNotCompressed( - XContentParserConfiguration.EMPTY, - chunk.rawEmbeddings, - contentType - ); - builder.field(CHUNKED_EMBEDDINGS_FIELD).copyCurrentStructure(parser); - builder.endObject(); - } - builder.endArray(); - builder.endObject(); - builder.endObject(); - return builder; - } - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser> SEMANTIC_TEXT_FIELD_PARSER = - new ConstructingObjectParser<>( - SemanticTextFieldMapper.CONTENT_TYPE, - true, - (args, context) -> new LegacySemanticTextField( - context.v1(), - (List) (args[0] == null ? List.of() : args[0]), - (InferenceResult) args[1], - context.v2() - ) - ); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( - INFERENCE_FIELD, - true, - args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List) args[2]) - ); - - private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( - CHUNKS_FIELD, - true, - args -> new Chunk((String) args[0], (BytesReference) args[1]) - ); - - private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( - MODEL_SETTINGS_FIELD, - true, - args -> { - TaskType taskType = TaskType.fromString((String) args[0]); - Integer dimensions = (Integer) args[1]; - SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]); - DenseVectorFieldMapper.ElementType elementType = args[3] == null - ? null - : DenseVectorFieldMapper.ElementType.fromString((String) args[3]); - return new ModelSettings(taskType, dimensions, similarity, elementType); - } - ); - - static { - SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD)); - SEMANTIC_TEXT_FIELD_PARSER.declareObject( - constructorArg(), - (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), - new ParseField(INFERENCE_FIELD) - ); - - INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); - INFERENCE_RESULT_PARSER.declareObject(constructorArg(), MODEL_SETTINGS_PARSER, new ParseField(MODEL_SETTINGS_FIELD)); - INFERENCE_RESULT_PARSER.declareObjectArray(constructorArg(), CHUNKS_PARSER, new ParseField(CHUNKS_FIELD)); - - CHUNKS_PARSER.declareString(constructorArg(), new ParseField(CHUNKED_TEXT_FIELD)); - CHUNKS_PARSER.declareField(constructorArg(), (p, c) -> { - XContentBuilder b = XContentBuilder.builder(p.contentType().xContent()); - b.copyCurrentStructure(p); - return BytesReference.bytes(b); - }, new ParseField(CHUNKED_EMBEDDINGS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY); - - MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD)); - MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD)); - MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD)); - MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ELEMENT_TYPE_FIELD)); - } - - /** - * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. - */ - public static List toSemanticTextFieldChunks(List results, XContentType contentType) { - List chunks = new ArrayList<>(); - for (var result : results) { - for (Iterator it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it - .hasNext();) { - var chunkAsByteReference = it.next(); - chunks.add(new Chunk(chunkAsByteReference.matchedText(), chunkAsByteReference.bytesReference())); - } - } - return chunks; - } - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapper.java deleted file mode 100644 index addb616d7638f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapper.java +++ /dev/null @@ -1,817 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.ScoreMode; -import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.support.XContentMapValues; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.features.NodeFeature; -import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.fielddata.FieldDataContext; -import org.elasticsearch.index.fielddata.IndexFieldData; -import org.elasticsearch.index.mapper.BlockLoader; -import org.elasticsearch.index.mapper.BlockSourceReader; -import org.elasticsearch.index.mapper.DocumentParserContext; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.InferenceFieldMapper; -import org.elasticsearch.index.mapper.KeywordFieldMapper; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperBuilderContext; -import org.elasticsearch.index.mapper.MapperMergeContext; -import org.elasticsearch.index.mapper.MappingLookup; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ObjectMapper; -import org.elasticsearch.index.mapper.SourceValueFetcher; -import org.elasticsearch.index.mapper.TextSearchInfo; -import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.MatchNoneQueryBuilder; -import org.elasticsearch.index.query.NestedQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentLocation; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; - -import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; - -/** - * A {@link FieldMapper} for semantic text fields. - */ -public class LegacySemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { - public static final NodeFeature SEMANTIC_TEXT_SEARCH_INFERENCE_ID = new NodeFeature("semantic_text.search_inference_id"); - public static final NodeFeature SEMANTIC_TEXT_DEFAULT_ELSER_2 = new NodeFeature("semantic_text.default_elser_2"); - public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix"); - public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix"); - public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix"); - public static final NodeFeature SEMANTIC_TEXT_ZERO_SIZE_FIX = new NodeFeature("semantic_text.zero_size_fix"); - public static final NodeFeature SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX = new NodeFeature( - "semantic_text.always_emit_inference_id_fix" - ); - - public static final String CONTENT_TYPE = "semantic_text"; - private final IndexSettings indexSettings; - - public static class Builder extends FieldMapper.Builder { - private final IndexVersion indexVersionCreated; - private final IndexSettings indexSettings; - - private final Parameter inferenceId = Parameter.stringParam( - LegacySemanticTextField.INFERENCE_ID_FIELD, - false, - mapper -> ((LegacySemanticTextFieldType) mapper.fieldType()).inferenceId, - SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID - ).addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException( - "[" - + LegacySemanticTextField.INFERENCE_ID_FIELD - + "] on mapper [" - + leafName() - + "] of type [" - + CONTENT_TYPE - + "] must not be empty" - ); - } - }).alwaysSerialize(); - - private final Parameter searchInferenceId = Parameter.stringParam( - LegacySemanticTextField.SEARCH_INFERENCE_ID_FIELD, - true, - mapper -> ((LegacySemanticTextFieldType) mapper.fieldType()).searchInferenceId, - null - ).acceptsNull().addValidator(v -> { - if (v != null && Strings.isEmpty(v)) { - throw new IllegalArgumentException( - "[" - + LegacySemanticTextField.SEARCH_INFERENCE_ID_FIELD - + "] on mapper [" - + leafName() - + "] of type [" - + CONTENT_TYPE - + "] must not be empty" - ); - } - }); - - private final Parameter modelSettings = new Parameter<>( - LegacySemanticTextField.MODEL_SETTINGS_FIELD, - true, - () -> null, - (n, c, o) -> LegacySemanticTextField.parseModelSettingsFromMap(o), - mapper -> ((LegacySemanticTextFieldType) mapper.fieldType()).modelSettings, - XContentBuilder::field, - Objects::toString - ).acceptsNull().setMergeValidator(LegacySemanticTextFieldMapper::canMergeModelSettings); - - private final Parameter> meta = Parameter.metaParam(); - - private Function inferenceFieldBuilder; - - public static Builder from(LegacySemanticTextFieldMapper mapper) { - Builder builder = new Builder( - mapper.leafName(), - mapper.fieldType().indexVersionCreated, - mapper.fieldType().getChunksField().bitsetProducer(), - mapper.indexSettings - ); - builder.init(mapper); - return builder; - } - - public Builder( - String name, - IndexVersion indexVersionCreated, - Function bitSetProducer, - IndexSettings indexSettings - ) { - super(name); - this.indexVersionCreated = indexVersionCreated; - this.indexSettings = indexSettings; - this.inferenceFieldBuilder = c -> createInferenceField( - c, - indexVersionCreated, - modelSettings.get(), - bitSetProducer, - indexSettings - ); - } - - public Builder setInferenceId(String id) { - this.inferenceId.setValue(id); - return this; - } - - public Builder setSearchInferenceId(String id) { - this.searchInferenceId.setValue(id); - return this; - } - - public Builder setModelSettings(LegacySemanticTextField.ModelSettings value) { - this.modelSettings.setValue(value); - return this; - } - - @Override - protected Parameter[] getParameters() { - return new Parameter[] { inferenceId, searchInferenceId, modelSettings, meta }; - } - - @Override - protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeContext mapperMergeContext) { - LegacySemanticTextFieldMapper semanticMergeWith = (LegacySemanticTextFieldMapper) mergeWith; - semanticMergeWith = copySettings(semanticMergeWith, mapperMergeContext); - - super.merge(semanticMergeWith, conflicts, mapperMergeContext); - conflicts.check(); - var context = mapperMergeContext.createChildContext(semanticMergeWith.leafName(), ObjectMapper.Dynamic.FALSE); - var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); - var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), context); - inferenceFieldBuilder = c -> mergedInferenceField; - } - - @Override - public LegacySemanticTextFieldMapper build(MapperBuilderContext context) { - if (copyTo.copyToFields().isEmpty() == false) { - throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support [copy_to]"); - } - if (multiFieldsBuilder.hasMultiFields()) { - throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields"); - } - final String fullName = context.buildFullName(leafName()); - - if (context.isInNestedContext()) { - throw new IllegalArgumentException(CONTENT_TYPE + " field [" + fullName + "] cannot be nested"); - } - var childContext = context.createChildContext(leafName(), ObjectMapper.Dynamic.FALSE); - final ObjectMapper inferenceField = inferenceFieldBuilder.apply(childContext); - - return new LegacySemanticTextFieldMapper( - leafName(), - new LegacySemanticTextFieldType( - fullName, - inferenceId.getValue(), - searchInferenceId.getValue(), - modelSettings.getValue(), - inferenceField, - indexVersionCreated, - meta.getValue() - ), - builderParams(this, context), - indexSettings - ); - } - - /** - * As necessary, copy settings from this builder to the passed-in mapper. - * Used to preserve {@link SemanticTextField.ModelSettings} when updating a semantic text mapping to one where the model settings - * are not specified. - * - * @param mapper The mapper - * @return A mapper with the copied settings applied - */ - private LegacySemanticTextFieldMapper copySettings(LegacySemanticTextFieldMapper mapper, MapperMergeContext mapperMergeContext) { - LegacySemanticTextFieldMapper returnedMapper = mapper; - if (mapper.fieldType().getModelSettings() == null) { - Builder builder = from(mapper); - builder.setModelSettings(modelSettings.getValue()); - returnedMapper = builder.build(mapperMergeContext.getMapperBuilderContext()); - } - - return returnedMapper; - } - } - - private LegacySemanticTextFieldMapper( - String simpleName, - MappedFieldType mappedFieldType, - BuilderParams builderParams, - IndexSettings indexSettings - ) { - super(simpleName, mappedFieldType, builderParams); - this.indexSettings = indexSettings; - } - - @Override - public Iterator iterator() { - List subIterators = new ArrayList<>(); - subIterators.add(fieldType().getInferenceField()); - return subIterators.iterator(); - } - - @Override - public FieldMapper.Builder getMergeBuilder() { - return Builder.from(this); - } - - @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { - return; - } - - XContentLocation xContentLocation = parser.getTokenLocation(); - final LegacySemanticTextField field; - boolean isWithinLeaf = context.path().isWithinLeafObject(); - try { - context.path().setWithinLeafObject(true); - field = LegacySemanticTextField.parse(parser, new Tuple<>(fullPath(), context.parser().contentType())); - } finally { - context.path().setWithinLeafObject(isWithinLeaf); - } - - final String fullFieldName = fieldType().name(); - if (field.inference().inferenceId().equals(fieldType().getInferenceId()) == false) { - throw new DocumentParsingException( - xContentLocation, - Strings.format( - "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", - LegacySemanticTextField.INFERENCE_ID_FIELD, - field.inference().inferenceId(), - fullFieldName, - LegacySemanticTextField.INFERENCE_ID_FIELD, - fieldType().getInferenceId() - ) - ); - } - - final LegacySemanticTextFieldMapper mapper; - if (fieldType().getModelSettings() == null) { - context.path().remove(); - Builder builder = (Builder) new Builder( - leafName(), - fieldType().indexVersionCreated, - fieldType().getChunksField().bitsetProducer(), - indexSettings - ).init(this); - try { - mapper = builder.setModelSettings(field.inference().modelSettings()) - .setInferenceId(field.inference().inferenceId()) - .build(context.createDynamicMapperBuilderContext()); - context.addDynamicMapper(mapper); - } finally { - context.path().add(leafName()); - } - } else { - Conflicts conflicts = new Conflicts(fullFieldName); - canMergeModelSettings(fieldType().getModelSettings(), field.inference().modelSettings(), conflicts); - try { - conflicts.check(); - } catch (Exception exc) { - throw new DocumentParsingException( - xContentLocation, - "Incompatible model settings for field [" - + fullPath() - + "]. Check that the " - + LegacySemanticTextField.INFERENCE_ID_FIELD - + " is not using different model settings", - exc - ); - } - mapper = this; - } - - var chunksField = mapper.fieldType().getChunksField(); - var embeddingsField = mapper.fieldType().getEmbeddingsField(); - for (var chunk : field.inference().chunks()) { - try ( - XContentParser subParser = XContentHelper.createParserNotCompressed( - XContentParserConfiguration.EMPTY, - chunk.rawEmbeddings(), - context.parser().contentType() - ) - ) { - DocumentParserContext subContext = context.createNestedContext(chunksField).switchParser(subParser); - subParser.nextToken(); - embeddingsField.parse(subContext); - } - } - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - - @Override - public LegacySemanticTextFieldType fieldType() { - return (LegacySemanticTextFieldType) super.fieldType(); - } - - @Override - public InferenceFieldMetadata getMetadata(Set sourcePaths) { - String[] copyFields = sourcePaths.toArray(String[]::new); - // ensure consistent order - Arrays.sort(copyFields); - return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields); - } - - @Override - public Object getOriginalValue(Map sourceAsMap) { - Object fieldValue = sourceAsMap.get(fullPath()); - if (fieldValue == null) { - return null; - } else if (fieldValue instanceof Map == false) { - // Don't try to further validate the non-map value, that will be handled when the source is fully parsed - return fieldValue; - } - - Map fieldValueMap = XContentMapValues.nodeMapValue(fieldValue, "Field [" + fullPath() + "]"); - return XContentMapValues.extractValue(LegacySemanticTextField.TEXT_FIELD, fieldValueMap); - } - - @Override - protected void doValidate(MappingLookup mappers) { - int parentPathIndex = fullPath().lastIndexOf(leafName()); - if (parentPathIndex > 0) { - // Check that the parent object field allows subobjects. - // Subtract one from the parent path index to omit the trailing dot delimiter. - ObjectMapper parentMapper = mappers.objectMappers().get(fullPath().substring(0, parentPathIndex - 1)); - if (parentMapper == null) { - throw new IllegalStateException(CONTENT_TYPE + " field [" + fullPath() + "] does not have a parent object mapper"); - } - - if (parentMapper.subobjects() == ObjectMapper.Subobjects.DISABLED) { - throw new IllegalArgumentException( - CONTENT_TYPE + " field [" + fullPath() + "] cannot be in an object field with subobjects disabled" - ); - } - } - } - - public static class LegacySemanticTextFieldType extends AbstractSemanticTextFieldType { - private final String inferenceId; - private final String searchInferenceId; - private final LegacySemanticTextField.ModelSettings modelSettings; - private final ObjectMapper inferenceField; - private final IndexVersion indexVersionCreated; - - public LegacySemanticTextFieldType( - String name, - String inferenceId, - String searchInferenceId, - LegacySemanticTextField.ModelSettings modelSettings, - ObjectMapper inferenceField, - IndexVersion indexVersionCreated, - Map meta - ) { - super(name, true, false, false, TextSearchInfo.NONE, meta); - this.inferenceId = inferenceId; - this.searchInferenceId = searchInferenceId; - this.modelSettings = modelSettings; - this.inferenceField = inferenceField; - this.indexVersionCreated = indexVersionCreated; - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - public String getInferenceId() { - return inferenceId; - } - - public String getSearchInferenceId() { - return searchInferenceId == null ? inferenceId : searchInferenceId; - } - - public LegacySemanticTextField.ModelSettings getModelSettings() { - return modelSettings; - } - - public ObjectMapper getInferenceField() { - return inferenceField; - } - - public NestedObjectMapper getChunksField() { - return (NestedObjectMapper) inferenceField.getMapper(LegacySemanticTextField.CHUNKS_FIELD); - } - - public FieldMapper getEmbeddingsField() { - return (FieldMapper) getChunksField().getMapper(LegacySemanticTextField.CHUNKED_EMBEDDINGS_FIELD); - } - - @Override - public Query termQuery(Object value, SearchExecutionContext context) { - throw new IllegalArgumentException(CONTENT_TYPE + " fields do not support term query"); - } - - @Override - public Query existsQuery(SearchExecutionContext context) { - if (getEmbeddingsField() == null) { - return new MatchNoDocsQuery(); - } - - return NestedQueryBuilder.toQuery( - (c -> getEmbeddingsField().fieldType().existsQuery(c)), - LegacySemanticTextField.getChunksFieldName(name()), - ScoreMode.None, - false, - context - ); - } - - @Override - public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - // Redirect the fetcher to load the original values of the field - return SourceValueFetcher.toString(LegacySemanticTextField.getOriginalTextFieldName(name()), context, format); - } - - @Override - public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { - throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); - } - - @Override - public boolean fieldHasValue(FieldInfos fieldInfos) { - return fieldInfos.fieldInfo(LegacySemanticTextField.getEmbeddingsFieldName(name())) != null; - } - - @Override - public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer requestSize, float boost, String queryName) { - String nestedFieldPath = LegacySemanticTextField.getChunksFieldName(name()); - String inferenceResultsFieldName = LegacySemanticTextField.getEmbeddingsFieldName(name()); - QueryBuilder childQueryBuilder; - - if (modelSettings == null) { - // No inference results have been indexed yet - childQueryBuilder = new MatchNoneQueryBuilder(); - } else { - childQueryBuilder = switch (modelSettings.taskType()) { - case SPARSE_EMBEDDING -> { - if (inferenceResults instanceof TextExpansionResults == false) { - throw new IllegalArgumentException( - generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, TextExpansionResults.NAME) - ); - } - - // TODO: Use WeightedTokensQueryBuilder - TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults; - var boolQuery = QueryBuilders.boolQuery(); - for (var weightedToken : textExpansionResults.getWeightedTokens()) { - boolQuery.should( - QueryBuilders.termQuery(inferenceResultsFieldName, weightedToken.token()).boost(weightedToken.weight()) - ); - } - boolQuery.minimumShouldMatch(1); - - yield boolQuery; - } - case TEXT_EMBEDDING -> { - if (inferenceResults instanceof MlTextEmbeddingResults == false) { - throw new IllegalArgumentException( - generateQueryInferenceResultsTypeMismatchMessage(inferenceResults, MlTextEmbeddingResults.NAME) - ); - } - - MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults; - float[] inference = textEmbeddingResults.getInferenceAsFloat(); - if (inference.length != modelSettings.dimensions()) { - throw new IllegalArgumentException( - generateDimensionCountMismatchMessage(inference.length, modelSettings.dimensions()) - ); - } - - Integer k = requestSize; - if (k != null) { - // Ensure that k is at least the default size so that aggregations work when size is set to 0 in the request - k = Math.max(k, DEFAULT_SIZE); - } - - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null); - } - default -> throw new IllegalStateException( - "Field [" - + name() - + "] is configured to use an inference endpoint with an unsupported task type [" - + modelSettings.taskType() - + "]" - ); - }; - } - - return new NestedQueryBuilder(nestedFieldPath, childQueryBuilder, ScoreMode.Max).boost(boost).queryName(queryName); - } - - private String generateQueryInferenceResultsTypeMismatchMessage(InferenceResults inferenceResults, String expectedResultsType) { - StringBuilder sb = new StringBuilder( - "Field [" - + name() - + "] expected query inference results to be of type [" - + expectedResultsType - + "]," - + " got [" - + inferenceResults.getWriteableName() - + "]." - ); - - return generateInvalidQueryInferenceResultsMessage(sb); - } - - private String generateDimensionCountMismatchMessage(int inferenceDimCount, int expectedDimCount) { - StringBuilder sb = new StringBuilder( - "Field [" - + name() - + "] expected query inference results with " - + expectedDimCount - + " dimensions, got " - + inferenceDimCount - + " dimensions." - ); - - return generateInvalidQueryInferenceResultsMessage(sb); - } - - private String generateInvalidQueryInferenceResultsMessage(StringBuilder baseMessageBuilder) { - if (searchInferenceId != null && searchInferenceId.equals(inferenceId) == false) { - baseMessageBuilder.append( - " Is the search inference endpoint [" - + searchInferenceId - + "] compatible with the inference endpoint [" - + inferenceId - + "]?" - ); - } else { - baseMessageBuilder.append(" Has the configuration for inference endpoint [" + inferenceId + "] changed?"); - } - - return baseMessageBuilder.toString(); - } - - @Override - public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { - SourceValueFetcher fetcher = SourceValueFetcher.toString(blContext.sourcePaths(name().concat(".text"))); - return new BlockSourceReader.BytesRefsBlockLoader(fetcher, BlockSourceReader.lookupMatchingAll()); - } - } - - /** - *

- * Insert or replace the path's value in the map with the provided new value. The map will be modified in-place. - * If the complete path does not exist in the map, it will be added to the deepest (sub-)map possible. - *

- *

- * For example, given the map: - *

- *
-     * {
-     *   "path1": {
-     *     "path2": {
-     *       "key1": "value1"
-     *     }
-     *   }
-     * }
-     * 
- *

- * And the caller wanted to insert {@code "path1.path2.path3.key2": "value2"}, the method would emit the modified map: - *

- *
-     * {
-     *   "path1": {
-     *     "path2": {
-     *       "key1": "value1",
-     *       "path3.key2": "value2"
-     *     }
-     *   }
-     * }
-     * 
- * - * @param path the value's path in the map. - * @param map the map to search and modify in-place. - * @param newValue the new value to assign to the path. - * - * @throws IllegalArgumentException If either the path cannot be fully traversed or there is ambiguity about where to insert the new - * value. - */ - public static void insertValue(String path, Map map, Object newValue) { - String[] pathElements = path.split("\\."); - if (pathElements.length == 0) { - return; - } - - List suffixMaps = extractSuffixMaps(pathElements, 0, map); - if (suffixMaps.isEmpty()) { - // This should never happen. Throw in case it does for some reason. - throw new IllegalStateException("extractSuffixMaps returned an empty suffix map list"); - } else if (suffixMaps.size() == 1) { - SuffixMap suffixMap = suffixMaps.getFirst(); - suffixMap.map().put(suffixMap.suffix(), newValue); - } else { - throw new IllegalArgumentException( - "Path [" + path + "] could be inserted in " + suffixMaps.size() + " distinct ways, it is ambiguous which one to use" - ); - } - } - - private record SuffixMap(String suffix, Map map) {} - - private static List extractSuffixMaps(String[] pathElements, int index, Object currentValue) { - if (currentValue instanceof List valueList) { - List suffixMaps = new ArrayList<>(valueList.size()); - for (Object o : valueList) { - suffixMaps.addAll(extractSuffixMaps(pathElements, index, o)); - } - - return suffixMaps; - } else if (currentValue instanceof Map) { - @SuppressWarnings("unchecked") - Map map = (Map) currentValue; - List suffixMaps = new ArrayList<>(map.size()); - - String key = pathElements[index]; - while (index < pathElements.length) { - if (map.containsKey(key)) { - if (index + 1 == pathElements.length) { - // We found the complete path - suffixMaps.add(new SuffixMap(key, map)); - } else { - // We've matched that path partially, keep traversing to try to match it fully - suffixMaps.addAll(extractSuffixMaps(pathElements, index + 1, map.get(key))); - } - } - - if (++index < pathElements.length) { - key += "." + pathElements[index]; - } - } - - if (suffixMaps.isEmpty()) { - // We checked for all remaining elements in the path, and they do not exist. This means we found a leaf map that we should - // add the value to. - suffixMaps.add(new SuffixMap(key, map)); - } - - return suffixMaps; - } else { - throw new IllegalArgumentException( - "Path [" - + String.join(".", Arrays.copyOfRange(pathElements, 0, index)) - + "] has value [" - + currentValue - + "] of type [" - + currentValue.getClass().getSimpleName() - + "], which cannot be traversed into further" - ); - } - } - - private static ObjectMapper createInferenceField( - MapperBuilderContext context, - IndexVersion indexVersionCreated, - @Nullable LegacySemanticTextField.ModelSettings modelSettings, - Function bitSetProducer, - IndexSettings indexSettings - ) { - return new ObjectMapper.Builder(LegacySemanticTextField.INFERENCE_FIELD, Optional.of(ObjectMapper.Subobjects.ENABLED)).dynamic( - ObjectMapper.Dynamic.FALSE - ).add(createChunksField(indexVersionCreated, modelSettings, bitSetProducer, indexSettings)).build(context); - } - - private static NestedObjectMapper.Builder createChunksField( - IndexVersion indexVersionCreated, - @Nullable LegacySemanticTextField.ModelSettings modelSettings, - Function bitSetProducer, - IndexSettings indexSettings - ) { - NestedObjectMapper.Builder chunksField = new NestedObjectMapper.Builder( - LegacySemanticTextField.CHUNKS_FIELD, - indexVersionCreated, - bitSetProducer, - indexSettings - ); - chunksField.dynamic(ObjectMapper.Dynamic.FALSE); - KeywordFieldMapper.Builder chunkTextField = new KeywordFieldMapper.Builder( - LegacySemanticTextField.CHUNKED_TEXT_FIELD, - indexVersionCreated - ).indexed(false).docValues(false); - if (modelSettings != null) { - chunksField.add(createEmbeddingsField(indexVersionCreated, modelSettings)); - } - chunksField.add(chunkTextField); - return chunksField; - } - - private static Mapper.Builder createEmbeddingsField( - IndexVersion indexVersionCreated, - LegacySemanticTextField.ModelSettings modelSettings - ) { - return switch (modelSettings.taskType()) { - case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(LegacySemanticTextField.CHUNKED_EMBEDDINGS_FIELD); - case TEXT_EMBEDDING -> { - DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - LegacySemanticTextField.CHUNKED_EMBEDDINGS_FIELD, - indexVersionCreated - ); - - SimilarityMeasure similarity = modelSettings.similarity(); - if (similarity != null) { - switch (similarity) { - case COSINE -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.COSINE); - case DOT_PRODUCT -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT); - case L2_NORM -> denseVectorMapperBuilder.similarity(DenseVectorFieldMapper.VectorSimilarity.L2_NORM); - default -> throw new IllegalArgumentException( - "Unknown similarity measure in model_settings [" + similarity.name() + "]" - ); - } - } - denseVectorMapperBuilder.dimensions(modelSettings.dimensions()); - denseVectorMapperBuilder.elementType(modelSettings.elementType()); - - yield denseVectorMapperBuilder; - } - default -> throw new IllegalArgumentException("Invalid task_type in model_settings [" + modelSettings.taskType().name() + "]"); - }; - } - - private static boolean canMergeModelSettings( - LegacySemanticTextField.ModelSettings previous, - LegacySemanticTextField.ModelSettings current, - Conflicts conflicts - ) { - if (Objects.equals(previous, current)) { - return true; - } - if (previous == null) { - return true; - } - conflicts.addConflict("model_settings", ""); - return false; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java index fd309233e6aec..bd5be953b69a8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java @@ -129,25 +129,32 @@ protected boolean supportsParsingObject() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { XContentParser parser = context.parser(); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - String fieldName = null; - String sourceFieldName = null; - int startOffset = -1; - int endOffset = -1; - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - fieldName = parser.currentName(); - } else if (SOURCE_NAME_FIELD.equals(fieldName)) { - sourceFieldName = parser.text(); - } else if (START_OFFSET_FIELD.equals(fieldName)) { - startOffset = parser.intValue(); - } else if (END_OFFSET_FIELD.equals(fieldName)) { - endOffset = parser.intValue(); - } else { - throw new IllegalArgumentException("Unkown field name [" + fieldName + "]"); + boolean isWithinLeafObject = context.path().isWithinLeafObject(); + // make sure that we don't expand dots in field names while parsing + context.path().setWithinLeafObject(true); + try { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + String fieldName = null; + String sourceFieldName = null; + int startOffset = -1; + int endOffset = -1; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } else if (SOURCE_NAME_FIELD.equals(fieldName)) { + sourceFieldName = parser.text(); + } else if (START_OFFSET_FIELD.equals(fieldName)) { + startOffset = parser.intValue(); + } else if (END_OFFSET_FIELD.equals(fieldName)) { + endOffset = parser.intValue(); + } else { + throw new IllegalArgumentException("Unkown field name [" + fieldName + "]"); + } } + context.doc().addWithKey(fullPath(), new OffsetField(NAME, fullPath() + "." + sourceFieldName, startOffset, endOffset)); + } finally { + context.path().setWithinLeafObject(isWithinLeafObject); } - context.doc().addWithKey(fullPath(), new OffsetField(NAME, fullPath() + "." + sourceFieldName, startOffset, endOffset)); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index b36390cf8ef74..a23647b274a04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -8,18 +8,18 @@ package org.elasticsearch.xpack.inference.mapper; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -34,7 +34,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -42,49 +41,44 @@ import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; -import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs. + * The resulting object preserves the original input under the {@link SemanticTextField#TEXT_FIELD} and exposes + * the inference results under the {@link SemanticTextField#INFERENCE_FIELD}. * * @param fieldName The original field name. + * @param originalValues The original values associated with the field name. + * @param inference The inference result. * @param contentType The {@link XContentType} used to store the embeddings chunks. */ -public record SemanticTextField( - String fieldName, - String inferenceId, - ModelSettings modelSettings, - List chunks, - XContentType contentType -) implements ToXContentObject { +public record SemanticTextField(String fieldName, List originalValues, InferenceResult inference, XContentType contentType) + implements + ToXContentObject { + static final String TEXT_FIELD = "text"; + static final String INFERENCE_FIELD = "inference"; static final String INFERENCE_ID_FIELD = "inference_id"; static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; static final String CHUNKS_FIELD = "chunks"; static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; static final String CHUNKED_OFFSET_FIELD = "offset"; - static final String CHUNKED_OFFSET_SOURCE_FIELD = "field"; - static final String CHUNKED_OFFSET_START_FIELD = "start"; - static final String CHUNKED_OFFSET_END_FIELD = "end"; + static final String CHUNKED_SOURCE_FIELD_FIELD = "source_field"; + static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; + static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String TASK_TYPE_FIELD = "task_type"; static final String DIMENSIONS_FIELD = "dimensions"; static final String SIMILARITY_FIELD = "similarity"; static final String ELEMENT_TYPE_FIELD = "element_type"; - public record Chunk(Offset offset, BytesReference rawEmbeddings) {} + public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} - public record Offset(String sourceFieldName, int startOffset, int endOffset) implements ToXContentObject { - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CHUNKED_OFFSET_SOURCE_FIELD, sourceFieldName()); - builder.field(CHUNKED_OFFSET_START_FIELD, startOffset()); - builder.field(CHUNKED_OFFSET_END_FIELD, endOffset()); - return builder.endObject(); - } - } + public record Chunk(String sourceField, @Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {} + + public record Offset(String sourceFieldName, int startOffset, int endOffset) {} public record ModelSettings( TaskType taskType, @@ -186,8 +180,16 @@ private void validateFieldNotPresent(String field, Object fieldValue) { } } + public static String getOriginalTextFieldName(String fieldName) { + return fieldName + "." + TEXT_FIELD; + } + + public static String getInferenceFieldName(String fieldName) { + return fieldName + "." + INFERENCE_FIELD; + } + public static String getChunksFieldName(String fieldName) { - return fieldName + "." + CHUNKS_FIELD; + return getInferenceFieldName(fieldName) + "." + CHUNKS_FIELD; } public static String getEmbeddingsFieldName(String fieldName) { @@ -223,13 +225,21 @@ static ModelSettings parseModelSettingsFromMap(Object node) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(INFERENCE_ID_FIELD, inferenceId); - builder.field(MODEL_SETTINGS_FIELD, modelSettings); + if (originalValues != null && originalValues.isEmpty() == false) { + builder.field(TEXT_FIELD, originalValues.size() == 1 ? originalValues.get(0) : originalValues); + } + builder.startObject(INFERENCE_FIELD); + builder.field(INFERENCE_ID_FIELD, inference.inferenceId); + builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); builder.startArray(CHUNKS_FIELD); - for (var chunk : chunks) { + for (var chunk : inference.chunks) { builder.startObject(); - builder.field(CHUNKED_OFFSET_FIELD); - chunk.offset.toXContent(builder, params); + if (chunk.text != null) { + builder.field(TEXT_FIELD, chunk.text); + } else if (chunk.startOffset != -1) { + builder.field(CHUNKED_START_OFFSET_FIELD, chunk.startOffset); + builder.field(CHUNKED_START_OFFSET_FIELD, chunk.endOffset); + } XContentParser parser = XContentHelper.createParserNotCompressed( XContentParserConfiguration.EMPTY, chunk.rawEmbeddings, @@ -240,6 +250,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); builder.endObject(); + builder.endObject(); return builder; } @@ -250,23 +261,25 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws true, (args, context) -> new SemanticTextField( context.v1(), - (String) args[0], - (ModelSettings) args[1], - (List) args[2], + (List) (args[0] == null ? List.of() : args[0]), + (InferenceResult) args[1], context.v2() ) ); - private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( - CHUNKS_FIELD, + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( + INFERENCE_FIELD, true, - args -> new Chunk((Offset) args[0], (BytesReference) args[1]) + args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List) args[2]) ); - private static final ConstructingObjectParser OFFSET_PARSER = new ConstructingObjectParser<>( - CHUNKED_OFFSET_FIELD, + + private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( + CHUNKS_FIELD, true, - args -> new Offset((String) args[0], (int) args[1], (int) args[2]) + args -> new Chunk((String) args[0], args[1] != null ? (String) args[1] : null, args[2] != null ? (int) args[2] : -1, args[3] != null ? (int) args[3] : -1, (BytesReference) args[4]) ); + private static final ConstructingObjectParser MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>( MODEL_SETTINGS_FIELD, true, @@ -280,34 +293,28 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return new ModelSettings(taskType, dimensions, similarity, elementType); } ); + static { - SEMANTIC_TEXT_FIELD_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); + SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD)); SEMANTIC_TEXT_FIELD_PARSER.declareObject( constructorArg(), - (p, c) -> MODEL_SETTINGS_PARSER.parse(p, null), - new ParseField(MODEL_SETTINGS_FIELD) - ); - SEMANTIC_TEXT_FIELD_PARSER.declareObjectArray( - constructorArg(), - (p, c) -> CHUNKS_PARSER.parse(p, null), - new ParseField(CHUNKS_FIELD) + (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), + new ParseField(INFERENCE_FIELD) ); - CHUNKS_PARSER.declareField( - constructorArg(), - (p, c) -> OFFSET_PARSER.parse(p, null), - new ParseField(CHUNKED_OFFSET_FIELD), - ObjectParser.ValueType.OBJECT - ); + INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); + INFERENCE_RESULT_PARSER.declareObject(constructorArg(), MODEL_SETTINGS_PARSER, new ParseField(MODEL_SETTINGS_FIELD)); + INFERENCE_RESULT_PARSER.declareObjectArray(constructorArg(), CHUNKS_PARSER, new ParseField(CHUNKS_FIELD)); + + CHUNKS_PARSER.declareString(constructorArg(), new ParseField(CHUNKED_SOURCE_FIELD_FIELD)); + CHUNKS_PARSER.declareString(optionalConstructorArg(), new ParseField(TEXT_FIELD)); CHUNKS_PARSER.declareField(constructorArg(), (p, c) -> { XContentBuilder b = XContentBuilder.builder(p.contentType().xContent()); b.copyCurrentStructure(p); return BytesReference.bytes(b); }, new ParseField(CHUNKED_EMBEDDINGS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY); - - OFFSET_PARSER.declareString(constructorArg(), new ParseField(CHUNKED_OFFSET_SOURCE_FIELD)); - OFFSET_PARSER.declareInt(constructorArg(), new ParseField(CHUNKED_OFFSET_START_FIELD)); - OFFSET_PARSER.declareInt(constructorArg(), new ParseField(CHUNKED_OFFSET_END_FIELD)); + CHUNKS_PARSER.declareInt(optionalConstructorArg(), new ParseField(CHUNKED_START_OFFSET_FIELD)); + CHUNKS_PARSER.declareInt(optionalConstructorArg(), new ParseField(CHUNKED_END_OFFSET_FIELD)); MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD)); MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD)); @@ -330,34 +337,9 @@ public static List toSemanticTextFieldChunks( .hasNext();) { var chunkAsByteReference = it.next(); int startOffset = input.indexOf(chunkAsByteReference.matchedText()); - chunks.add( - new Chunk( - new Offset(sourceFieldName, startOffset, startOffset + chunkAsByteReference.matchedText().length()), - chunkAsByteReference.bytesReference() - ) - ); + chunks.add(new Chunk(sourceFieldName, null, startOffset, startOffset + chunkAsByteReference.matchedText().length(), chunkAsByteReference.bytesReference())); } } return chunks; } - - /** - * This method converts the given {@code valueObj} into a list of strings. - * If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException. - */ - public static String nodeStringValues(String field, Object valueObj) { - if (valueObj instanceof Number || valueObj instanceof Boolean) { - return valueObj.toString(); - } else if (valueObj instanceof String value) { - return value; - } else if (valueObj instanceof Collection values) { - return Strings.collectionToDelimitedString(values, String.valueOf(MULTIVAL_SEP_CHAR)); - } - throw new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - valueObj.getClass().getSimpleName() - ); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 56ef3369ea352..68ef6dc788571 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -14,11 +14,11 @@ import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; @@ -30,6 +30,7 @@ import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.InferenceFieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.index.mapper.MapperBuilderContext; @@ -37,6 +38,7 @@ import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.NestedObjectMapper; import org.elasticsearch.index.mapper.ObjectMapper; +import org.elasticsearch.index.mapper.SimpleMappedFieldType; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; @@ -50,12 +52,13 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentLocation; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; -import org.elasticsearch.xpack.inference.queries.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import java.io.IOException; import java.util.ArrayList; @@ -64,71 +67,79 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.Function; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_OFFSET_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.SEARCH_INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.TEXT_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.DEFAULT_ELSER_ID; /** * A {@link FieldMapper} for semantic text fields. */ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFieldMapper { + public static final NodeFeature SEMANTIC_TEXT_SEARCH_INFERENCE_ID = new NodeFeature("semantic_text.search_inference_id"); + public static final NodeFeature SEMANTIC_TEXT_DEFAULT_ELSER_2 = new NodeFeature("semantic_text.default_elser_2"); + public static final NodeFeature SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX = new NodeFeature("semantic_text.in_object_field_fix"); + public static final NodeFeature SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX = new NodeFeature("semantic_text.single_field_update_fix"); + public static final NodeFeature SEMANTIC_TEXT_DELETE_FIX = new NodeFeature("semantic_text.delete_fix"); + public static final NodeFeature SEMANTIC_TEXT_ZERO_SIZE_FIX = new NodeFeature("semantic_text.zero_size_fix"); + public static final NodeFeature SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX = new NodeFeature( + "semantic_text.always_emit_inference_id_fix" + ); + public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; private final IndexSettings indexSettings; - public static final TypeParser PARSER = new TypeParser((n, c) -> { - if (c.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) { - return new Builder(n, c::bitSetProducer, c.getIndexSettings()); - } - return new LegacySemanticTextFieldMapper.Builder(n, c.indexVersionCreated(), c::bitSetProducer, c.getIndexSettings()); - }, List.of(notInMultiFields(CONTENT_TYPE), notFromDynamicTemplates(CONTENT_TYPE))); + public static final TypeParser PARSER = new TypeParser( + (n, c) -> new Builder(n, c::bitSetProducer, c.getIndexSettings()), + List.of(notInMultiFields(CONTENT_TYPE), notFromDynamicTemplates(CONTENT_TYPE)) + ); public static class Builder extends FieldMapper.Builder { private final IndexSettings indexSettings; private final Parameter inferenceId = Parameter.stringParam( - SemanticTextField.INFERENCE_ID_FIELD, + INFERENCE_ID_FIELD, false, mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId, DEFAULT_ELSER_2_INFERENCE_ID ).addValidator(v -> { if (Strings.isEmpty(v)) { throw new IllegalArgumentException( - "[" - + SemanticTextField.INFERENCE_ID_FIELD - + "] on mapper [" - + leafName() - + "] of type [" - + CONTENT_TYPE - + "] must not be empty" + "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty" ); } }).alwaysSerialize(); private final Parameter searchInferenceId = Parameter.stringParam( - SemanticTextField.SEARCH_INFERENCE_ID_FIELD, + SEARCH_INFERENCE_ID_FIELD, true, mapper -> ((SemanticTextFieldType) mapper.fieldType()).searchInferenceId, null ).acceptsNull().addValidator(v -> { if (v != null && Strings.isEmpty(v)) { throw new IllegalArgumentException( - "[" - + SemanticTextField.SEARCH_INFERENCE_ID_FIELD - + "] on mapper [" - + leafName() - + "] of type [" - + CONTENT_TYPE - + "] must not be empty" + "[" + SEARCH_INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty" ); } }); private final Parameter modelSettings = new Parameter<>( - SemanticTextField.MODEL_SETTINGS_FIELD, + MODEL_SETTINGS_FIELD, true, () -> null, (n, c, o) -> SemanticTextField.parseModelSettingsFromMap(o), @@ -139,18 +150,32 @@ public static class Builder extends FieldMapper.Builder { private final Parameter> meta = Parameter.metaParam(); - private Function chunksFieldBuilder; + private Function inferenceFieldBuilder; public static Builder from(SemanticTextFieldMapper mapper) { - Builder builder = new Builder(mapper.leafName(), mapper.fieldType().getChunksField().bitsetProducer(), mapper.indexSettings); + Builder builder = new Builder( + mapper.leafName(), + mapper.fieldType().getChunksField().bitsetProducer(), + mapper.indexSettings + ); builder.init(mapper); return builder; } - public Builder(String name, Function bitSetProducer, IndexSettings indexSettings) { + public Builder( + String name, + Function bitSetProducer, + IndexSettings indexSettings + ) { super(name); this.indexSettings = indexSettings; - this.chunksFieldBuilder = c -> createChunksField(modelSettings.get(), bitSetProducer, indexSettings).build(c); + this.inferenceFieldBuilder = c -> createInferenceField( + c, + indexSettings.getIndexVersionCreated(), + modelSettings.get(), + bitSetProducer, + indexSettings + ); } public Builder setInferenceId(String id) { @@ -181,9 +206,9 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont super.merge(semanticMergeWith, conflicts, mapperMergeContext); conflicts.check(); var context = mapperMergeContext.createChildContext(semanticMergeWith.leafName(), ObjectMapper.Dynamic.FALSE); - var inferenceField = chunksFieldBuilder.apply(context.getMapperBuilderContext()); - var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getChunksField(), context); - chunksFieldBuilder = c -> (NestedObjectMapper) mergedInferenceField; + var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext()); + var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), context); + inferenceFieldBuilder = c -> mergedInferenceField; } @Override @@ -200,7 +225,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { throw new IllegalArgumentException(CONTENT_TYPE + " field [" + fullName + "] cannot be nested"); } var childContext = context.createChildContext(leafName(), ObjectMapper.Dynamic.FALSE); - final NestedObjectMapper chunksField = chunksFieldBuilder.apply(childContext); + final ObjectMapper inferenceField = inferenceFieldBuilder.apply(childContext); return new SemanticTextFieldMapper( leafName(), @@ -209,7 +234,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { inferenceId.getValue(), searchInferenceId.getValue(), modelSettings.getValue(), - chunksField, + inferenceField, + indexSettings.getIndexVersionCreated(), meta.getValue() ), builderParams(this, context), @@ -250,7 +276,7 @@ private SemanticTextFieldMapper( @Override public Iterator iterator() { List subIterators = new ArrayList<>(); - subIterators.add(fieldType().getChunksField()); + subIterators.add(fieldType().getInferenceField()); return subIterators.iterator(); } @@ -262,8 +288,8 @@ public FieldMapper.Builder getMergeBuilder() { @Override protected void parseCreateField(DocumentParserContext context) throws IOException { if (context.isWithinInferenceMetadata() == false) { - assert indexSettings.getIndexVersionCreated().onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS); // ignore original text value + context.parser().skipChildren(); return; } XContentParser parser = context.parser(); @@ -282,15 +308,15 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio } final String fullFieldName = fieldType().name(); - if (field.inferenceId().equals(fieldType().getInferenceId()) == false) { + if (field.inference().inferenceId().equals(fieldType().getInferenceId()) == false) { throw new DocumentParsingException( xContentLocation, Strings.format( "The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.", - SemanticTextField.INFERENCE_ID_FIELD, - field.inferenceId(), + INFERENCE_ID_FIELD, + field.inference().inferenceId(), fullFieldName, - SemanticTextField.INFERENCE_ID_FIELD, + INFERENCE_ID_FIELD, fieldType().getInferenceId() ) ); @@ -299,10 +325,14 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio final SemanticTextFieldMapper mapper; if (fieldType().getModelSettings() == null) { context.path().remove(); - Builder builder = (Builder) new Builder(leafName(), fieldType().getChunksField().bitsetProducer(), indexSettings).init(this); + Builder builder = (Builder) new Builder( + leafName(), + fieldType().getChunksField().bitsetProducer(), + indexSettings + ).init(this); try { - mapper = builder.setModelSettings(field.modelSettings()) - .setInferenceId(field.inferenceId()) + mapper = builder.setModelSettings(field.inference().modelSettings()) + .setInferenceId(field.inference().inferenceId()) .build(context.createDynamicMapperBuilderContext()); context.addDynamicMapper(mapper); } finally { @@ -310,7 +340,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio } } else { Conflicts conflicts = new Conflicts(fullFieldName); - canMergeModelSettings(fieldType().getModelSettings(), field.modelSettings(), conflicts); + canMergeModelSettings(fieldType().getModelSettings(), field.inference().modelSettings(), conflicts); try { conflicts.check(); } catch (Exception exc) { @@ -319,7 +349,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio "Incompatible model settings for field [" + fullPath() + "]. Check that the " - + SemanticTextField.INFERENCE_ID_FIELD + + INFERENCE_ID_FIELD + " is not using different model settings", exc ); @@ -330,7 +360,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio var chunksField = mapper.fieldType().getChunksField(); var embeddingsField = mapper.fieldType().getEmbeddingsField(); var offsetsField = mapper.fieldType().getOffsetsField(); - for (var chunk : field.chunks()) { + for (var chunk : field.inference().chunks()) { var nestedContext = context.createNestedContext(chunksField); try ( XContentParser subParser = XContentHelper.createParserNotCompressed( @@ -343,17 +373,23 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio subParser.nextToken(); embeddingsField.parse(subContext); } - - try ( - XContentParser subParser = XContentHelper.createParserNotCompressed( - XContentParserConfiguration.EMPTY, - new BytesArray(Strings.toString(chunk.offset())), - context.parser().contentType() - ) - ) { - DocumentParserContext subContext = nestedContext.switchParser(subParser); - subParser.nextToken(); - offsetsField.parse(subContext); + try (XContentBuilder builder = XContentFactory.contentBuilder(context.parser().contentType())) { + builder.startObject(); + builder.field("field", chunk.sourceField()); + builder.field("start", chunk.startOffset()); + builder.field("end", chunk.endOffset()); + builder.endObject(); + try ( + XContentParser subParser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + BytesReference.bytes(builder), + context.parser().contentType() + ) + ) { + DocumentParserContext subContext = nestedContext.switchParser(subParser); + subParser.nextToken(); + offsetsField.parse(subContext); + } } } } @@ -376,12 +412,6 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields); } - @Override - public Object getOriginalValue(Map sourceAsMap) { - Object ret = XContentMapValues.extractValue(fullPath(), sourceAsMap); - return SemanticTextField.nodeStringValues(fullPath(), ret); - } - @Override protected void doValidate(MappingLookup mappers) { int parentPathIndex = fullPath().lastIndexOf(leafName()); @@ -401,25 +431,28 @@ protected void doValidate(MappingLookup mappers) { } } - public static class SemanticTextFieldType extends AbstractSemanticTextFieldType { + public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final String searchInferenceId; private final SemanticTextField.ModelSettings modelSettings; - private final NestedObjectMapper chunksField; + private final ObjectMapper inferenceField; + private final IndexVersion indexVersionCreated; public SemanticTextFieldType( String name, String inferenceId, String searchInferenceId, SemanticTextField.ModelSettings modelSettings, - NestedObjectMapper chunksField, + ObjectMapper inferenceField, + IndexVersion indexVersionCreated, Map meta ) { super(name, true, false, false, TextSearchInfo.NONE, meta); this.inferenceId = inferenceId; this.searchInferenceId = searchInferenceId; this.modelSettings = modelSettings; - this.chunksField = chunksField; + this.inferenceField = inferenceField; + this.indexVersionCreated = indexVersionCreated; } @Override @@ -439,16 +472,20 @@ public SemanticTextField.ModelSettings getModelSettings() { return modelSettings; } - public NestedObjectMapper getChunksField() { - return chunksField; + public ObjectMapper getInferenceField() { + return inferenceField; } - public FieldMapper getOffsetsField() { - return (FieldMapper) getChunksField().getMapper(SemanticTextField.CHUNKED_OFFSET_FIELD); + public NestedObjectMapper getChunksField() { + return (NestedObjectMapper) inferenceField.getMapper(CHUNKS_FIELD); } public FieldMapper getEmbeddingsField() { - return (FieldMapper) getChunksField().getMapper(SemanticTextField.CHUNKED_EMBEDDINGS_FIELD); + return (FieldMapper) getChunksField().getMapper(CHUNKED_EMBEDDINGS_FIELD); + } + + public FieldMapper getOffsetsField() { + return (FieldMapper) getChunksField().getMapper(CHUNKED_OFFSET_FIELD); } @Override @@ -464,7 +501,7 @@ public Query existsQuery(SearchExecutionContext context) { return NestedQueryBuilder.toQuery( (c -> getEmbeddingsField().fieldType().existsQuery(c)), - SemanticTextField.getChunksFieldName(name()), + getChunksFieldName(name()), ScoreMode.None, false, context @@ -473,7 +510,12 @@ public Query existsQuery(SearchExecutionContext context) { @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.toString(name(), context, format); + if (indexVersionCreated.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) { + return SourceValueFetcher.toString(name(), context, format); + } else { + // Redirect the fetcher to load the original values of the field + return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format); + } } @Override @@ -483,12 +525,12 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext @Override public boolean fieldHasValue(FieldInfos fieldInfos) { - return fieldInfos.fieldInfo(SemanticTextField.getEmbeddingsFieldName(name())) != null; + return fieldInfos.fieldInfo(getEmbeddingsFieldName(name())) != null; } public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer requestSize, float boost, String queryName) { - String nestedFieldPath = SemanticTextField.getChunksFieldName(name()); - String inferenceResultsFieldName = SemanticTextField.getEmbeddingsFieldName(name()); + String nestedFieldPath = getChunksFieldName(name()); + String inferenceResultsFieldName = getEmbeddingsFieldName(name()); QueryBuilder childQueryBuilder; if (modelSettings == null) { @@ -601,7 +643,20 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { } } + private static ObjectMapper createInferenceField( + MapperBuilderContext context, + IndexVersion indexVersionCreated, + @Nullable SemanticTextField.ModelSettings modelSettings, + Function bitSetProducer, + IndexSettings indexSettings + ) { + return new ObjectMapper.Builder(INFERENCE_FIELD, Optional.of(ObjectMapper.Subobjects.ENABLED)).dynamic(ObjectMapper.Dynamic.FALSE) + .add(createChunksField(indexVersionCreated, modelSettings, bitSetProducer, indexSettings)) + .build(context); + } + private static NestedObjectMapper.Builder createChunksField( + IndexVersion indexVersionCreated, @Nullable SemanticTextField.ModelSettings modelSettings, Function bitSetProducer, IndexSettings indexSettings @@ -613,20 +668,24 @@ private static NestedObjectMapper.Builder createChunksField( indexSettings ); chunksField.dynamic(ObjectMapper.Dynamic.FALSE); - if (modelSettings != null) { chunksField.add(createEmbeddingsField(indexSettings.getIndexVersionCreated(), modelSettings)); } - chunksField.add(new OffsetSourceFieldMapper.Builder(SemanticTextField.CHUNKED_OFFSET_FIELD)); + if (indexVersionCreated.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) { + chunksField.add(new OffsetSourceFieldMapper.Builder(CHUNKED_OFFSET_FIELD)); + } else { + var chunkTextField = new KeywordFieldMapper.Builder(TEXT_FIELD, indexVersionCreated).indexed(false).docValues(false); + chunksField.add(chunkTextField); + } return chunksField; } private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCreated, SemanticTextField.ModelSettings modelSettings) { return switch (modelSettings.taskType()) { - case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(SemanticTextField.CHUNKED_EMBEDDINGS_FIELD); + case SPARSE_EMBEDDING -> new SparseVectorFieldMapper.Builder(CHUNKED_EMBEDDINGS_FIELD); case TEXT_EMBEDDING -> { DenseVectorFieldMapper.Builder denseVectorMapperBuilder = new DenseVectorFieldMapper.Builder( - SemanticTextField.CHUNKED_EMBEDDINGS_FIELD, + CHUNKED_EMBEDDINGS_FIELD, indexVersionCreated ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtils.java new file mode 100644 index 0000000000000..90814e02d128e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtils.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.RestStatus; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR; + +public interface SemanticTextUtils { + /** + * This method converts the given {@code valueObj} into a list of strings. + * If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException. + */ + static String nodeStringValues(String field, Object valueObj) { + if (valueObj instanceof Number || valueObj instanceof Boolean) { + return valueObj.toString(); + } else if (valueObj instanceof String value) { + return value; + } else if (valueObj instanceof Collection values) { + return Strings.collectionToDelimitedString(values, String.valueOf(MULTIVAL_SEP_CHAR)); + } + throw new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + valueObj.getClass().getSimpleName() + ); + } + + /** + *

+ * Insert or replace the path's value in the map with the provided new value. The map will be modified in-place. + * If the complete path does not exist in the map, it will be added to the deepest (sub-)map possible. + *

+ *

+ * For example, given the map: + *

+ *
+   * {
+   *   "path1": {
+   *     "path2": {
+   *       "key1": "value1"
+   *     }
+   *   }
+   * }
+   * 
+ *

+ * And the caller wanted to insert {@code "path1.path2.path3.key2": "value2"}, the method would emit the modified map: + *

+ *
+   * {
+   *   "path1": {
+   *     "path2": {
+   *       "key1": "value1",
+   *       "path3.key2": "value2"
+   *     }
+   *   }
+   * }
+   * 
+ * + * @param path the value's path in the map. + * @param map the map to search and modify in-place. + * @param newValue the new value to assign to the path. + * + * @throws IllegalArgumentException If either the path cannot be fully traversed or there is ambiguity about where to insert the new + * value. + */ + static void insertValue(String path, Map map, Object newValue) { + String[] pathElements = path.split("\\."); + if (pathElements.length == 0) { + return; + } + + List suffixMaps = extractSuffixMaps(pathElements, 0, map); + if (suffixMaps.isEmpty()) { + // This should never happen. Throw in case it does for some reason. + throw new IllegalStateException("extractSuffixMaps returned an empty suffix map list"); + } else if (suffixMaps.size() == 1) { + SuffixMap suffixMap = suffixMaps.getFirst(); + suffixMap.map().put(suffixMap.suffix(), newValue); + } else { + throw new IllegalArgumentException( + "Path [" + path + "] could be inserted in " + suffixMaps.size() + " distinct ways, it is ambiguous which one to use" + ); + } + } + + record SuffixMap(String suffix, Map map) {} + + private static List extractSuffixMaps(String[] pathElements, int index, Object currentValue) { + if (currentValue instanceof List valueList) { + List suffixMaps = new ArrayList<>(valueList.size()); + for (Object o : valueList) { + suffixMaps.addAll(extractSuffixMaps(pathElements, index, o)); + } + + return suffixMaps; + } else if (currentValue instanceof Map) { + @SuppressWarnings("unchecked") + Map map = (Map) currentValue; + List suffixMaps = new ArrayList<>(map.size()); + + String key = pathElements[index]; + while (index < pathElements.length) { + if (map.containsKey(key)) { + if (index + 1 == pathElements.length) { + // We found the complete path + suffixMaps.add(new SuffixMap(key, map)); + } else { + // We've matched that path partially, keep traversing to try to match it fully + suffixMaps.addAll(extractSuffixMaps(pathElements, index + 1, map.get(key))); + } + } + + if (++index < pathElements.length) { + key += "." + pathElements[index]; + } + } + + if (suffixMaps.isEmpty()) { + // We checked for all remaining elements in the path, and they do not exist. This means we found a leaf map that we should + // add the value to. + suffixMaps.add(new SuffixMap(key, map)); + } + + return suffixMaps; + } else { + throw new IllegalArgumentException( + "Path [" + + String.join(".", Arrays.copyOfRange(pathElements, 0, index)) + + "] has value [" + + currentValue + + "] of type [" + + currentValue.getClass().getSimpleName() + + "], which cannot be traversed into further" + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 46c6da34a8f33..d648db2fbfdbc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -37,7 +37,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; -import org.elasticsearch.xpack.inference.mapper.AbstractSemanticTextFieldType; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.io.IOException; import java.util.Collection; @@ -162,7 +162,7 @@ private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchEx MappedFieldType fieldType = searchExecutionContext.getFieldType(fieldName); if (fieldType == null) { return new MatchNoneQueryBuilder(); - } else if (fieldType instanceof AbstractSemanticTextFieldType semanticTextFieldType) { + } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) { if (inferenceResults == null) { // This should never happen, but throw on it in case it ever does throw new IllegalStateException( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapperTests.java deleted file mode 100644 index 3e1b5d7e611d1..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldMapperTests.java +++ /dev/null @@ -1,1227 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.apache.lucene.document.FeatureField; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.index.IndexableField; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.QueryBitSetProducer; -import org.apache.lucene.search.join.ScoreMode; -import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.common.CheckedBiConsumer; -import org.elasticsearch.common.CheckedBiFunction; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.compress.CompressedXContent; -import org.elasticsearch.common.lucene.search.Queries; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.mapper.DocumentMapper; -import org.elasticsearch.index.mapper.DocumentParsingException; -import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.KeywordFieldMapper; -import org.elasticsearch.index.mapper.LuceneDocument; -import org.elasticsearch.index.mapper.MappedFieldType; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MapperParsingException; -import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MapperTestCase; -import org.elasticsearch.index.mapper.NestedLookup; -import org.elasticsearch.index.mapper.NestedObjectMapper; -import org.elasticsearch.index.mapper.ParsedDocument; -import org.elasticsearch.index.mapper.SourceToParse; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; -import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.index.search.ESToParentBlockJoinQuery; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.LeafNestedDocuments; -import org.elasticsearch.search.NestedDocuments; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.test.index.IndexVersionUtils; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.model.TestModel; -import org.junit.AssumptionViolatedException; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.function.BiConsumer; -import java.util.stream.Stream; - -import static java.util.Collections.singletonList; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.CHUNKED_EMBEDDINGS_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.CHUNKED_TEXT_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.CHUNKS_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.INFERENCE_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.INFERENCE_ID_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.MODEL_SETTINGS_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.SEARCH_INFERENCE_ID_FIELD; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.getChunksFieldName; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextField.getEmbeddingsFieldName; -import static org.elasticsearch.xpack.inference.mapper.LegacySemanticTextFieldTests.randomSemanticText; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; - -public class LegacySemanticTextFieldMapperTests extends MapperTestCase { - @Override - protected Collection getPlugins() { - return singletonList(new InferencePlugin(Settings.EMPTY)); - } - - @Override - protected Settings getIndexSettings() { - return Settings.builder() - .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersionUtils.getPreviousVersion(IndexVersions.INFERENCE_METADATA_FIELDS)) - .build(); - } - - @Override - protected IndexVersion getVersion() { - return IndexVersionUtils.getPreviousVersion(IndexVersions.INFERENCE_METADATA_FIELDS); - } - - @Override - protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "semantic_text"); - } - - @Override - protected String minimalIsInvalidRoutingPathErrorMessage(Mapper mapper) { - return "cannot have nested fields when index is in [index.mode=time_series]"; - } - - @Override - protected void metaMapping(XContentBuilder b) throws IOException { - super.metaMapping(b); - b.field(INFERENCE_ID_FIELD, DEFAULT_ELSER_2_INFERENCE_ID); - } - - @Override - protected Object getSampleValueForDocument() { - return null; - } - - @Override - protected boolean supportsIgnoreMalformed() { - return false; - } - - @Override - protected boolean supportsStoredFields() { - return false; - } - - @Override - protected void registerParameters(ParameterChecker checker) throws IOException {} - - @Override - protected Object generateRandomInputValue(MappedFieldType ft) { - assumeFalse("doc_values are not supported in semantic_text", true); - return null; - } - - @Override - protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { - throw new AssumptionViolatedException("not supported"); - } - - @Override - protected IngestScriptSupport ingestScriptSupport() { - throw new AssumptionViolatedException("not supported"); - } - - @Override - public MappedFieldType getMappedFieldType() { - return new LegacySemanticTextFieldMapper.LegacySemanticTextFieldType( - "field", - "fake-inference-id", - null, - null, - null, - IndexVersionUtils.randomVersionBetween( - random(), - IndexVersionUtils.getFirstVersion(), - IndexVersionUtils.getPreviousVersion(IndexVersions.INFERENCE_METADATA_FIELDS) - ), - Map.of() - ); - } - - @Override - protected void assertSearchable(MappedFieldType fieldType) { - assertThat(fieldType, instanceOf(LegacySemanticTextFieldMapper.LegacySemanticTextFieldType.class)); - assertTrue(fieldType.isIndexed()); - assertTrue(fieldType.isSearchable()); - } - - public void testDefaults() throws Exception { - final String fieldName = "field"; - final XContentBuilder fieldMapping = fieldMapping(this::minimalMapping); - final XContentBuilder expectedMapping = fieldMapping(this::metaMapping); - - MapperService mapperService = createMapperService(fieldMapping); - DocumentMapper mapper = mapperService.documentMapper(); - assertEquals(Strings.toString(expectedMapping), mapper.mappingSource().toString()); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID); - - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); - - // No indexable fields - assertTrue(fields.isEmpty()); - } - - @Override - public void testFieldHasValue() { - MappedFieldType fieldType = getMappedFieldType(); - FieldInfos fieldInfos = new FieldInfos(new FieldInfo[] { getFieldInfoWithName(getEmbeddingsFieldName("field")) }); - assertTrue(fieldType.fieldHasValue(fieldInfos)); - } - - public void testSetInferenceEndpoints() throws IOException { - final String fieldName = "field"; - final String inferenceId = "foo"; - final String searchInferenceId = "bar"; - - CheckedBiConsumer assertSerialization = (expectedMapping, mapperService) -> { - DocumentMapper mapper = mapperService.documentMapper(); - assertEquals(Strings.toString(expectedMapping), mapper.mappingSource().toString()); - }; - - { - final XContentBuilder fieldMapping = fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, inferenceId)); - final MapperService mapperService = createMapperService(fieldMapping); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - assertSerialization.accept(fieldMapping, mapperService); - } - { - final XContentBuilder fieldMapping = fieldMapping( - b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) - ); - final XContentBuilder expectedMapping = fieldMapping( - b -> b.field("type", "semantic_text") - .field(INFERENCE_ID_FIELD, DEFAULT_ELSER_2_INFERENCE_ID) - .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) - ); - final MapperService mapperService = createMapperService(fieldMapping); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId); - assertSerialization.accept(expectedMapping, mapperService); - } - { - final XContentBuilder fieldMapping = fieldMapping( - b -> b.field("type", "semantic_text") - .field(INFERENCE_ID_FIELD, inferenceId) - .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) - ); - MapperService mapperService = createMapperService(fieldMapping); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); - assertSerialization.accept(fieldMapping, mapperService); - } - } - - public void testInvalidInferenceEndpoints() { - { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, (String) null))) - ); - assertThat( - e.getMessage(), - containsString("[inference_id] on mapper [field] of type [semantic_text] must not have a [null] value") - ); - } - { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, ""))) - ); - assertThat(e.getMessage(), containsString("[inference_id] on mapper [field] of type [semantic_text] must not be empty")); - } - { - Exception e = expectThrows( - MapperParsingException.class, - () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, ""))) - ); - assertThat(e.getMessage(), containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty")); - } - } - - public void testCannotBeUsedInMultiFields() { - Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { - b.field("type", "text"); - b.startObject("fields"); - b.startObject("semantic"); - b.field("type", "semantic_text"); - b.field("inference_id", "my_inference_id"); - b.endObject(); - b.endObject(); - }))); - assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); - } - - public void testUpdatesToInferenceIdNotSupported() throws IOException { - String fieldName = randomAlphaOfLengthBetween(5, 15); - MapperService mapperService = createMapperService( - mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) - ); - assertSemanticTextField(mapperService, fieldName, false); - Exception e = expectThrows( - IllegalArgumentException.class, - () -> merge( - mapperService, - mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "another_model").endObject()) - ) - ); - assertThat(e.getMessage(), containsString("Cannot update parameter [inference_id] from [test_model] to [another_model]")); - } - - public void testDynamicUpdate() throws IOException { - final String fieldName = "semantic"; - final String inferenceId = "test_service"; - final String searchInferenceId = "search_test_service"; - - { - MapperService mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) - ); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - } - - { - MapperService mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - searchInferenceId, - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) - ); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); - } - } - - public void testUpdateModelSettings() throws IOException { - for (int depth = 1; depth < 5; depth++) { - String fieldName = randomFieldName(depth); - MapperService mapperService = createMapperService( - mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) - ); - assertSemanticTextField(mapperService, fieldName, false); - { - Exception exc = expectThrows( - MapperParsingException.class, - () -> merge( - mapperService, - mapping( - b -> b.startObject(fieldName) - .field("type", "semantic_text") - .field("inference_id", "test_model") - .startObject("model_settings") - .field("inference_id", "test_model") - .endObject() - .endObject() - ) - ) - ); - assertThat(exc.getMessage(), containsString("Required [task_type]")); - } - { - merge( - mapperService, - mapping( - b -> b.startObject(fieldName) - .field("type", "semantic_text") - .field("inference_id", "test_model") - .startObject("model_settings") - .field("task_type", "sparse_embedding") - .endObject() - .endObject() - ) - ); - assertSemanticTextField(mapperService, fieldName, true); - } - { - merge( - mapperService, - mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()) - ); - assertSemanticTextField(mapperService, fieldName, true); - } - { - Exception exc = expectThrows( - IllegalArgumentException.class, - () -> merge( - mapperService, - mapping( - b -> b.startObject(fieldName) - .field("type", "semantic_text") - .field("inference_id", "test_model") - .startObject("model_settings") - .field("task_type", "text_embedding") - .field("dimensions", 10) - .field("similarity", "cosine") - .field("element_type", "float") - .endObject() - .endObject() - ) - ) - ); - assertThat( - exc.getMessage(), - containsString( - "Cannot update parameter [model_settings] " - + "from [task_type=sparse_embedding] " - + "to [task_type=text_embedding, dimensions=10, similarity=cosine, element_type=float]" - ) - ); - } - } - } - - public void testUpdateSearchInferenceId() throws IOException { - final String inferenceId = "test_inference_id"; - final String searchInferenceId1 = "test_search_inference_id_1"; - final String searchInferenceId2 = "test_search_inference_id_2"; - - CheckedBiFunction buildMapping = (f, sid) -> mapping(b -> { - b.startObject(f).field("type", "semantic_text").field("inference_id", inferenceId); - if (sid != null) { - b.field("search_inference_id", sid); - } - b.endObject(); - }); - - for (int depth = 1; depth < 5; depth++) { - String fieldName = randomFieldName(depth); - MapperService mapperService = createMapperService(buildMapping.apply(fieldName, null)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); - - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); - - merge(mapperService, buildMapping.apply(fieldName, null)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - - mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) - ); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); - - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); - - merge(mapperService, buildMapping.apply(fieldName, null)); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - } - } - - private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { - Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); - assertNotNull(mapper); - assertThat(mapper, instanceOf(LegacySemanticTextFieldMapper.class)); - LegacySemanticTextFieldMapper semanticFieldMapper = (LegacySemanticTextFieldMapper) mapper; - - var fieldType = mapperService.fieldType(fieldName); - assertNotNull(fieldType); - assertThat(fieldType, instanceOf(LegacySemanticTextFieldMapper.LegacySemanticTextFieldType.class)); - LegacySemanticTextFieldMapper.LegacySemanticTextFieldType semanticTextFieldType = - (LegacySemanticTextFieldMapper.LegacySemanticTextFieldType) fieldType; - assertTrue(semanticFieldMapper.fieldType() == semanticTextFieldType); - - NestedObjectMapper chunksMapper = mapperService.mappingLookup() - .nestedLookup() - .getNestedMappers() - .get(getChunksFieldName(fieldName)); - assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); - assertThat(chunksMapper.fullPath(), equalTo(getChunksFieldName(fieldName))); - Mapper textMapper = chunksMapper.getMapper(CHUNKED_TEXT_FIELD); - assertNotNull(textMapper); - assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); - KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; - assertFalse(textFieldMapper.fieldType().isIndexed()); - assertFalse(textFieldMapper.fieldType().hasDocValues()); - if (expectedModelSettings) { - assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); - Mapper embeddingsMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD); - assertNotNull(embeddingsMapper); - assertThat(embeddingsMapper, instanceOf(FieldMapper.class)); - FieldMapper embeddingsFieldMapper = (FieldMapper) embeddingsMapper; - assertTrue(embeddingsFieldMapper.fieldType() == mapperService.mappingLookup().getFieldType(getEmbeddingsFieldName(fieldName))); - assertThat(embeddingsMapper.fullPath(), equalTo(getEmbeddingsFieldName(fieldName))); - switch (semanticFieldMapper.fieldType().getModelSettings().taskType()) { - case SPARSE_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(SparseVectorFieldMapper.class)); - case TEXT_EMBEDDING -> assertThat(embeddingsMapper, instanceOf(DenseVectorFieldMapper.class)); - default -> throw new AssertionError("Invalid task type"); - } - } else { - assertNull(semanticFieldMapper.fieldType().getModelSettings()); - } - } - - private static void assertInferenceEndpoints( - MapperService mapperService, - String fieldName, - String expectedInferenceId, - String expectedSearchInferenceId - ) { - var fieldType = mapperService.fieldType(fieldName); - assertNotNull(fieldType); - assertThat(fieldType, instanceOf(LegacySemanticTextFieldMapper.LegacySemanticTextFieldType.class)); - LegacySemanticTextFieldMapper.LegacySemanticTextFieldType semanticTextFieldType = - (LegacySemanticTextFieldMapper.LegacySemanticTextFieldType) fieldType; - assertEquals(expectedInferenceId, semanticTextFieldType.getInferenceId()); - assertEquals(expectedSearchInferenceId, semanticTextFieldType.getSearchInferenceId()); - } - - public void testSuccessfulParse() throws IOException { - for (int depth = 1; depth < 4; depth++) { - final String fieldName1 = randomFieldName(depth); - final String fieldName2 = randomFieldName(depth + 1); - final String searchInferenceId = randomAlphaOfLength(8); - final boolean setSearchInferenceId = randomBoolean(); - - Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - }); - - MapperService mapperService = createMapperService(mapping); - assertSemanticTextField(mapperService, fieldName1, false); - assertInferenceEndpoints( - mapperService, - fieldName1, - model1.getInferenceEntityId(), - setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId() - ); - assertSemanticTextField(mapperService, fieldName2, false); - assertInferenceEndpoints( - mapperService, - fieldName2, - model2.getInferenceEntityId(), - setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId() - ); - - DocumentMapper documentMapper = mapperService.documentMapper(); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - b, - List.of( - randomSemanticText(fieldName1, model1, List.of("a b", "c"), XContentType.JSON), - randomSemanticText(fieldName2, model2, List.of("d e f"), XContentType.JSON) - ) - ) - ) - ); - - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - for (int i = 0; i < 3; i++) { - assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); - } - // nested docs are in reversed order - assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); - assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); - assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - - withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - mapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() - ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), - new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), - new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) - ); - - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value()); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), - 10 - ); - assertEquals(1, topDocs.totalHits.value()); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 - ); - assertEquals(1, topDocs.totalHits.value()); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 - ); - assertEquals(0, topDocs.totalHits.value()); - } - }); - } - } - - public void testMissingInferenceId() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null))); - IllegalArgumentException ex = expectThrows( - DocumentParsingException.class, - IllegalArgumentException.class, - () -> documentMapper.parse( - source( - b -> b.startObject("field") - .startObject(INFERENCE_FIELD) - .field(MODEL_SETTINGS_FIELD, new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)) - .field(CHUNKS_FIELD, List.of()) - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); - } - - public void testMissingModelSettings() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null))); - IllegalArgumentException ex = expectThrows( - DocumentParsingException.class, - IllegalArgumentException.class, - () -> documentMapper.parse( - source(b -> b.startObject("field").startObject(INFERENCE_FIELD).field(INFERENCE_ID_FIELD, "my_id").endObject().endObject()) - ) - ); - assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); - } - - public void testMissingTaskType() throws IOException { - DocumentMapper documentMapper = createDocumentMapper(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null))); - IllegalArgumentException ex = expectThrows( - DocumentParsingException.class, - IllegalArgumentException.class, - () -> documentMapper.parse( - source( - b -> b.startObject("field") - .startObject(INFERENCE_FIELD) - .field(INFERENCE_ID_FIELD, "my_id") - .startObject(MODEL_SETTINGS_FIELD) - .endObject() - .endObject() - .endObject() - ) - ) - ); - assertThat(ex.getCause().getMessage(), containsString("failed to parse field [model_settings]")); - } - - public void testDenseVectorElementType() throws IOException { - final String fieldName = "field"; - final String inferenceId = "test_service"; - - BiConsumer assertMapperService = (m, e) -> { - Mapper mapper = m.mappingLookup().getMapper(fieldName); - assertThat(mapper, instanceOf(LegacySemanticTextFieldMapper.class)); - LegacySemanticTextFieldMapper semanticTextFieldMapper = (LegacySemanticTextFieldMapper) mapper; - assertThat(semanticTextFieldMapper.fieldType().getModelSettings().elementType(), equalTo(e)); - }; - - MapperService floatMapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new LegacySemanticTextField.ModelSettings( - TaskType.TEXT_EMBEDDING, - 1024, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ) - ); - assertMapperService.accept(floatMapperService, DenseVectorFieldMapper.ElementType.FLOAT); - - MapperService byteMapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new LegacySemanticTextField.ModelSettings( - TaskType.TEXT_EMBEDDING, - 1024, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.BYTE - ) - ); - assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE); - } - - private MapperService mapperServiceForFieldWithModelSettings( - String fieldName, - String inferenceId, - LegacySemanticTextField.ModelSettings modelSettings - ) throws IOException { - return mapperServiceForFieldWithModelSettings(fieldName, inferenceId, null, modelSettings); - } - - private MapperService mapperServiceForFieldWithModelSettings( - String fieldName, - String inferenceId, - String searchInferenceId, - LegacySemanticTextField.ModelSettings modelSettings - ) throws IOException { - String mappingParams = "type=semantic_text,inference_id=" + inferenceId; - if (searchInferenceId != null) { - mappingParams += ",search_inference_id=" + searchInferenceId; - } - - MapperService mapperService = createMapperService(mapping(b -> {})); - mapperService.merge( - "_doc", - new CompressedXContent(Strings.toString(PutMappingRequest.simpleMapping(fieldName, mappingParams))), - MapperService.MergeReason.MAPPING_UPDATE - ); - - LegacySemanticTextField semanticTextField = new LegacySemanticTextField( - fieldName, - List.of(), - new LegacySemanticTextField.InferenceResult(inferenceId, modelSettings, List.of()), - XContentType.JSON - ); - XContentBuilder builder = JsonXContent.contentBuilder().startObject(); - builder.field(semanticTextField.fieldName()); - builder.value(semanticTextField); - builder.endObject(); - - SourceToParse sourceToParse = new SourceToParse("test", BytesReference.bytes(builder), XContentType.JSON); - ParsedDocument parsedDocument = mapperService.documentMapper().parse(sourceToParse); - mapperService.merge( - "_doc", - parsedDocument.dynamicMappingsUpdate().toCompressedXContent(), - MapperService.MergeReason.MAPPING_UPDATE - ); - return mapperService; - } - - public void testExistsQuerySparseVector() throws IOException { - final String fieldName = "semantic"; - final String inferenceId = "test_service"; - - MapperService mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) - ); - - Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); - assertNotNull(mapper); - SearchExecutionContext searchExecutionContext = createSearchExecutionContext(mapperService); - Query existsQuery = ((LegacySemanticTextFieldMapper) mapper).fieldType().existsQuery(searchExecutionContext); - assertThat(existsQuery, instanceOf(ESToParentBlockJoinQuery.class)); - } - - public void testExistsQueryDenseVector() throws IOException { - final String fieldName = "semantic"; - final String inferenceId = "test_service"; - - MapperService mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new LegacySemanticTextField.ModelSettings( - TaskType.TEXT_EMBEDDING, - 1024, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ) - ); - - Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); - assertNotNull(mapper); - SearchExecutionContext searchExecutionContext = createSearchExecutionContext(mapperService); - Query existsQuery = ((LegacySemanticTextFieldMapper) mapper).fieldType().existsQuery(searchExecutionContext); - assertThat(existsQuery, instanceOf(ESToParentBlockJoinQuery.class)); - } - - public void testInsertValueMapTraversal() throws IOException { - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("test", "value").endObject(); - - Map map = toSourceMap(Strings.toString(builder)); - LegacySemanticTextFieldMapper.insertValue("test", map, "value2"); - assertThat(getMapValue(map, "test"), equalTo("value2")); - LegacySemanticTextFieldMapper.insertValue("something.else", map, "something_else_value"); - assertThat(getMapValue(map, "something\\.else"), equalTo("something_else_value")); - } - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.startObject("path1").startObject("path2").field("test", "value").endObject().endObject(); - builder.endObject(); - - Map map = toSourceMap(Strings.toString(builder)); - LegacySemanticTextFieldMapper.insertValue("path1.path2.test", map, "value2"); - assertThat(getMapValue(map, "path1.path2.test"), equalTo("value2")); - LegacySemanticTextFieldMapper.insertValue("path1.path2.test_me", map, "test_me_value"); - assertThat(getMapValue(map, "path1.path2.test_me"), equalTo("test_me_value")); - LegacySemanticTextFieldMapper.insertValue("path1.non_path2.test", map, "test_value"); - assertThat(getMapValue(map, "path1.non_path2\\.test"), equalTo("test_value")); - - LegacySemanticTextFieldMapper.insertValue("path1.path2", map, Map.of("path3", "bar")); - assertThat(getMapValue(map, "path1.path2"), equalTo(Map.of("path3", "bar"))); - - LegacySemanticTextFieldMapper.insertValue("path1", map, "baz"); - assertThat(getMapValue(map, "path1"), equalTo("baz")); - - LegacySemanticTextFieldMapper.insertValue("path3.path4", map, Map.of("test", "foo")); - assertThat(getMapValue(map, "path3\\.path4"), equalTo(Map.of("test", "foo"))); - } - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - builder.startObject("path1").array("test", "value1", "value2").endObject(); - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - - LegacySemanticTextFieldMapper.insertValue("path1.test", map, List.of("value3", "value4", "value5")); - assertThat(getMapValue(map, "path1.test"), equalTo(List.of("value3", "value4", "value5"))); - - LegacySemanticTextFieldMapper.insertValue("path2.test", map, List.of("value6", "value7", "value8")); - assertThat(getMapValue(map, "path2\\.test"), equalTo(List.of("value6", "value7", "value8"))); - } - } - - public void testInsertValueListTraversal() throws IOException { - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - { - builder.startObject("path1"); - { - builder.startArray("path2"); - builder.startObject().field("test", "value1").endObject(); - builder.endArray(); - } - builder.endObject(); - } - { - builder.startObject("path3"); - { - builder.startArray("path4"); - builder.startObject().field("test", "value1").endObject(); - builder.endArray(); - } - builder.endObject(); - } - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - - LegacySemanticTextFieldMapper.insertValue("path1.path2.test", map, "value2"); - assertThat(getMapValue(map, "path1.path2.test"), equalTo("value2")); - LegacySemanticTextFieldMapper.insertValue("path1.path2.test2", map, "value3"); - assertThat(getMapValue(map, "path1.path2.test2"), equalTo("value3")); - assertThat(getMapValue(map, "path1.path2"), equalTo(List.of(Map.of("test", "value2", "test2", "value3")))); - - LegacySemanticTextFieldMapper.insertValue("path3.path4.test", map, "value4"); - assertThat(getMapValue(map, "path3.path4.test"), equalTo("value4")); - } - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - { - builder.startObject("path1"); - { - builder.startArray("path2"); - builder.startArray(); - builder.startObject().field("test", "value1").endObject(); - builder.endArray(); - builder.endArray(); - } - builder.endObject(); - } - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - - LegacySemanticTextFieldMapper.insertValue("path1.path2.test", map, "value2"); - assertThat(getMapValue(map, "path1.path2.test"), equalTo("value2")); - LegacySemanticTextFieldMapper.insertValue("path1.path2.test2", map, "value3"); - assertThat(getMapValue(map, "path1.path2.test2"), equalTo("value3")); - assertThat(getMapValue(map, "path1.path2"), equalTo(List.of(List.of(Map.of("test", "value2", "test2", "value3"))))); - } - } - - public void testInsertValueFieldsWithDots() throws IOException { - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("xxx.yyy", "value1").endObject(); - Map map = toSourceMap(Strings.toString(builder)); - - LegacySemanticTextFieldMapper.insertValue("xxx.yyy", map, "value2"); - assertThat(getMapValue(map, "xxx\\.yyy"), equalTo("value2")); - - LegacySemanticTextFieldMapper.insertValue("xxx", map, "value3"); - assertThat(getMapValue(map, "xxx"), equalTo("value3")); - } - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - { - builder.startObject("path1.path2"); - { - builder.startObject("path3.path4"); - builder.field("test", "value1"); - builder.endObject(); - } - builder.endObject(); - } - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - - LegacySemanticTextFieldMapper.insertValue("path1.path2.path3.path4.test", map, "value2"); - assertThat(getMapValue(map, "path1\\.path2.path3\\.path4.test"), equalTo("value2")); - - LegacySemanticTextFieldMapper.insertValue("path1.path2.path3.path4.test2", map, "value3"); - assertThat(getMapValue(map, "path1\\.path2.path3\\.path4.test2"), equalTo("value3")); - assertThat(getMapValue(map, "path1\\.path2.path3\\.path4"), equalTo(Map.of("test", "value2", "test2", "value3"))); - } - } - - public void testInsertValueAmbiguousPath() throws IOException { - // Mixed dotted object notation - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - { - builder.startObject("path1.path2"); - { - builder.startObject("path3"); - builder.field("test1", "value1"); - builder.endObject(); - } - builder.endObject(); - } - { - builder.startObject("path1"); - { - builder.startObject("path2.path3"); - builder.field("test2", "value2"); - builder.endObject(); - } - builder.endObject(); - } - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - final Map originalMap = Collections.unmodifiableMap(toSourceMap(Strings.toString(builder))); - - IllegalArgumentException ex = assertThrows( - IllegalArgumentException.class, - () -> LegacySemanticTextFieldMapper.insertValue("path1.path2.path3.test1", map, "value3") - ); - assertThat( - ex.getMessage(), - equalTo("Path [path1.path2.path3.test1] could be inserted in 2 distinct ways, it is ambiguous which one to use") - ); - - ex = assertThrows( - IllegalArgumentException.class, - () -> LegacySemanticTextFieldMapper.insertValue("path1.path2.path3.test3", map, "value4") - ); - assertThat( - ex.getMessage(), - equalTo("Path [path1.path2.path3.test3] could be inserted in 2 distinct ways, it is ambiguous which one to use") - ); - - assertThat(map, equalTo(originalMap)); - } - - // traversal through lists - { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - { - builder.startObject("path1.path2"); - { - builder.startArray("path3"); - builder.startObject().field("test1", "value1").endObject(); - builder.endArray(); - } - builder.endObject(); - } - { - builder.startObject("path1"); - { - builder.startArray("path2.path3"); - builder.startObject().field("test2", "value2").endObject(); - builder.endArray(); - } - builder.endObject(); - } - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - final Map originalMap = Collections.unmodifiableMap(toSourceMap(Strings.toString(builder))); - - IllegalArgumentException ex = assertThrows( - IllegalArgumentException.class, - () -> LegacySemanticTextFieldMapper.insertValue("path1.path2.path3.test1", map, "value3") - ); - assertThat( - ex.getMessage(), - equalTo("Path [path1.path2.path3.test1] could be inserted in 2 distinct ways, it is ambiguous which one to use") - ); - - ex = assertThrows( - IllegalArgumentException.class, - () -> LegacySemanticTextFieldMapper.insertValue("path1.path2.path3.test3", map, "value4") - ); - assertThat( - ex.getMessage(), - equalTo("Path [path1.path2.path3.test3] could be inserted in 2 distinct ways, it is ambiguous which one to use") - ); - - assertThat(map, equalTo(originalMap)); - } - } - - public void testInsertValueCannotTraversePath() throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); - { - builder.startObject("path1"); - { - builder.startArray("path2"); - builder.startArray(); - builder.startObject().field("test", "value1").endObject(); - builder.endArray(); - builder.endArray(); - } - builder.endObject(); - } - builder.endObject(); - Map map = toSourceMap(Strings.toString(builder)); - final Map originalMap = Collections.unmodifiableMap(toSourceMap(Strings.toString(builder))); - - IllegalArgumentException ex = assertThrows( - IllegalArgumentException.class, - () -> LegacySemanticTextFieldMapper.insertValue("path1.path2.test.test2", map, "value2") - ); - assertThat( - ex.getMessage(), - equalTo("Path [path1.path2.test] has value [value1] of type [String], which cannot be traversed into further") - ); - - assertThat(map, equalTo(originalMap)); - } - - @Override - protected void assertExistsQuery(MappedFieldType fieldType, Query query, LuceneDocument fields) { - // Until a doc is indexed, the query is rewritten as match no docs - assertThat(query, instanceOf(MatchNoDocsQuery.class)); - } - - private static void addSemanticTextMapping( - XContentBuilder mappingBuilder, - String fieldName, - String inferenceId, - String searchInferenceId - ) throws IOException { - mappingBuilder.startObject(fieldName); - mappingBuilder.field("type", LegacySemanticTextFieldMapper.CONTENT_TYPE); - mappingBuilder.field("inference_id", inferenceId); - if (searchInferenceId != null) { - mappingBuilder.field("search_inference_id", searchInferenceId); - } - mappingBuilder.endObject(); - } - - private static void addSemanticTextInferenceResults( - XContentBuilder sourceBuilder, - List semanticTextInferenceResults - ) throws IOException { - for (var field : semanticTextInferenceResults) { - sourceBuilder.field(field.fieldName()); - sourceBuilder.value(field); - } - } - - static String randomFieldName(int numLevel) { - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < numLevel; i++) { - if (i > 0) { - builder.append('.'); - } - builder.append(randomAlphaOfLengthBetween(5, 15)); - } - return builder.toString(); - } - - private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLookup, String fieldName, List tokens) { - NestedObjectMapper mapper = nestedLookup.getNestedMappers().get(getChunksFieldName(fieldName)); - assertNotNull(mapper); - - BitSetProducer parentFilter = new QueryBitSetProducer(Queries.newNonNestedFilter(IndexVersion.current())); - BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder(); - for (String token : tokens) { - queryBuilder.add( - new BooleanClause(new TermQuery(new Term(getEmbeddingsFieldName(fieldName), token)), BooleanClause.Occur.MUST) - ); - } - queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); - - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); - } - - private static void assertChildLeafNestedDocument( - LeafNestedDocuments leaf, - int advanceToDoc, - int expectedRootDoc, - Set visitedNestedIdentities - ) throws IOException { - - assertNotNull(leaf.advance(advanceToDoc)); - assertEquals(advanceToDoc, leaf.doc()); - assertEquals(expectedRootDoc, leaf.rootDoc()); - assertNotNull(leaf.nestedIdentity()); - visitedNestedIdentities.add(leaf.nestedIdentity()); - } - - private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { - int count = 0; - for (IndexableField field : doc.getFields()) { - if (field instanceof FeatureField featureField) { - assertThat(featureField.name(), equalTo(fieldName)); - ++count; - } - } - assertThat(count, equalTo(expectedCount)); - } - - private Map toSourceMap(String source) throws IOException { - try (XContentParser parser = createParser(JsonXContent.jsonXContent, source)) { - return parser.map(); - } - } - - private static Object getMapValue(Map map, String key) { - // Split the path on unescaped "." chars and then unescape the escaped "." chars - final String[] pathElements = Arrays.stream(key.split("(? k.replace("\\.", ".")).toArray(String[]::new); - - Object value = null; - Object nextLayer = map; - for (int i = 0; i < pathElements.length; i++) { - if (nextLayer instanceof Map nextMap) { - value = nextMap.get(pathElements[i]); - } else if (nextLayer instanceof List nextList) { - final String pathElement = pathElements[i]; - List values = nextList.stream().flatMap(v -> { - Stream.Builder streamBuilder = Stream.builder(); - if (v instanceof List innerList) { - traverseList(innerList, streamBuilder); - } else { - streamBuilder.add(v); - } - return streamBuilder.build(); - }).filter(v -> v instanceof Map).map(v -> ((Map) v).get(pathElement)).filter(Objects::nonNull).toList(); - - if (values.isEmpty()) { - return null; - } else if (values.size() > 1) { - throw new AssertionError("List " + nextList + " contains multiple values for [" + pathElement + "]"); - } else { - value = values.getFirst(); - } - } else if (nextLayer == null) { - break; - } else { - throw new AssertionError( - "Path [" - + String.join(".", Arrays.copyOfRange(pathElements, 0, i)) - + "] has value [" - + value - + "] of type [" - + value.getClass().getSimpleName() - + "], which cannot be traversed into further" - ); - } - - nextLayer = value; - } - - return value; - } - - private static void traverseList(List list, Stream.Builder streamBuilder) { - for (Object value : list) { - if (value instanceof List innerList) { - traverseList(innerList, streamBuilder); - } else { - streamBuilder.add(value); - } - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldTests.java deleted file mode 100644 index f7a88865c58cd..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/LegacySemanticTextFieldTests.java +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.mapper; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractXContentTestCase; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; -import org.elasticsearch.xpack.core.utils.FloatConversionUtils; -import org.elasticsearch.xpack.inference.model.TestModel; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.function.Predicate; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; - -public class LegacySemanticTextFieldTests extends AbstractXContentTestCase { - private static final String NAME = "field"; - - @Override - protected Predicate getRandomFieldsExcludeFilter() { - return n -> n.endsWith(LegacySemanticTextField.CHUNKED_EMBEDDINGS_FIELD); - } - - @Override - protected void assertEqualInstances(LegacySemanticTextField expectedInstance, LegacySemanticTextField newInstance) { - assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); - assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); - assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); - LegacySemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); - for (int i = 0; i < newInstance.inference().chunks().size(); i++) { - assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); - switch (modelSettings.taskType()) { - case TEXT_EMBEDDING -> { - double[] expectedVector = parseDenseVector( - expectedInstance.inference().chunks().get(i).rawEmbeddings(), - modelSettings.dimensions(), - expectedInstance.contentType() - ); - double[] newVector = parseDenseVector( - newInstance.inference().chunks().get(i).rawEmbeddings(), - modelSettings.dimensions(), - newInstance.contentType() - ); - assertArrayEquals(expectedVector, newVector, 0.0000001f); - } - case SPARSE_EMBEDDING -> { - List expectedTokens = parseWeightedTokens( - expectedInstance.inference().chunks().get(i).rawEmbeddings(), - expectedInstance.contentType() - ); - List newTokens = parseWeightedTokens( - newInstance.inference().chunks().get(i).rawEmbeddings(), - newInstance.contentType() - ); - assertThat(newTokens, equalTo(expectedTokens)); - } - default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); - } - } - } - - @Override - protected LegacySemanticTextField createTestInstance() { - List rawValues = randomList(1, 5, () -> randomSemanticTextInput().toString()); - try { // try catch required for override - return randomSemanticText(NAME, TestModel.createRandomInstance(), rawValues, randomFrom(XContentType.values())); - } catch (IOException e) { - fail("Failed to create random LegacySemanticTextField instance"); - } - return null; - } - - @Override - protected LegacySemanticTextField doParseInstance(XContentParser parser) throws IOException { - return LegacySemanticTextField.parse(parser, new Tuple<>(NAME, parser.contentType())); - } - - @Override - protected boolean supportsUnknownFields() { - return true; - } - - public void testModelSettingsValidation() { - NullPointerException npe = expectThrows(NullPointerException.class, () -> { - new LegacySemanticTextField.ModelSettings(null, 10, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT); - }); - assertThat(npe.getMessage(), equalTo("task type must not be null")); - - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings( - TaskType.COMPLETION, - 10, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ); - }); - assertThat(ex.getMessage(), containsString("Wrong [task_type]")); - - ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, 10, null, null); - }); - assertThat(ex.getMessage(), containsString("[dimensions] is not allowed")); - - ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE, null); - }); - assertThat(ex.getMessage(), containsString("[similarity] is not allowed")); - - ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, DenseVectorFieldMapper.ElementType.FLOAT); - }); - assertThat(ex.getMessage(), containsString("[element_type] is not allowed")); - - ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings( - TaskType.TEXT_EMBEDDING, - null, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ); - }); - assertThat(ex.getMessage(), containsString("required [dimensions] field is missing")); - - ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, 10, null, DenseVectorFieldMapper.ElementType.FLOAT); - }); - assertThat(ex.getMessage(), containsString("required [similarity] field is missing")); - - ex = expectThrows(IllegalArgumentException.class, () -> { - new LegacySemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, 10, SimilarityMeasure.COSINE, null); - }); - assertThat(ex.getMessage(), containsString("required [element_type] field is missing")); - } - - public static InferenceChunkedTextEmbeddingFloatResults randomInferenceChunkedTextEmbeddingFloatResults( - Model model, - List inputs - ) throws IOException { - List chunks = new ArrayList<>(); - for (String input : inputs) { - float[] values = new float[model.getServiceSettings().dimensions()]; - for (int j = 0; j < values.length; j++) { - values[j] = (float) randomDouble(); - } - chunks.add(new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(input, values)); - } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); - } - - public static InferenceChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); - for (String input : inputs) { - var tokens = new ArrayList(); - for (var token : input.split("\\s+")) { - tokens.add(new WeightedToken(token, randomFloat())); - } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(input, tokens)); - } - return new InferenceChunkedSparseEmbeddingResults(chunks); - } - - public static LegacySemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) - throws IOException { - ChunkedInferenceServiceResults results = switch (model.getTaskType()) { - case TEXT_EMBEDDING -> randomInferenceChunkedTextEmbeddingFloatResults(model, inputs); - case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); - default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); - }; - return semanticTextFieldFromChunkedInferenceResults(fieldName, model, inputs, results, contentType); - } - - public static LegacySemanticTextField semanticTextFieldFromChunkedInferenceResults( - String fieldName, - Model model, - List inputs, - ChunkedInferenceServiceResults results, - XContentType contentType - ) { - return new LegacySemanticTextField( - fieldName, - inputs, - new LegacySemanticTextField.InferenceResult( - model.getInferenceEntityId(), - new LegacySemanticTextField.ModelSettings(model), - LegacySemanticTextField.toSemanticTextFieldChunks(List.of(results), contentType) - ), - contentType - ); - } - - /** - * Returns a randomly generated object for Semantic Text tests purpose. - */ - public static Object randomSemanticTextInput() { - if (rarely()) { - return switch (randomIntBetween(0, 4)) { - case 0 -> randomInt(); - case 1 -> randomLong(); - case 2 -> randomFloat(); - case 3 -> randomBoolean(); - case 4 -> randomDouble(); - default -> throw new IllegalStateException("Illegal state while generating random semantic text input"); - }; - } else { - return randomAlphaOfLengthBetween(10, 20); - } - } - - public static ChunkedInferenceServiceResults toChunkedResult(LegacySemanticTextField field) throws IOException { - switch (field.inference().modelSettings().taskType()) { - case SPARSE_EMBEDDING -> { - List chunks = new ArrayList<>(); - for (var chunk : field.inference().chunks()) { - var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); - // TODO - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(null, tokens)); - } - return new InferenceChunkedSparseEmbeddingResults(chunks); - } - case TEXT_EMBEDDING -> { - List chunks = new ArrayList<>(); - for (var chunk : field.inference().chunks()) { - double[] values = parseDenseVector( - chunk.rawEmbeddings(), - field.inference().modelSettings().dimensions(), - field.contentType() - ); - // TODO - chunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk( - null, - FloatConversionUtils.floatArrayOf(values) - ) - ); - } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); - } - default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); - } - } - - private static double[] parseDenseVector(BytesReference value, int numDims, XContentType contentType) { - try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { - parser.nextToken(); - assertThat(parser.currentToken(), equalTo(XContentParser.Token.START_ARRAY)); - double[] values = new double[numDims]; - for (int i = 0; i < numDims; i++) { - assertThat(parser.nextToken(), equalTo(XContentParser.Token.VALUE_NUMBER)); - values[i] = parser.doubleValue(); - } - assertThat(parser.nextToken(), equalTo(XContentParser.Token.END_ARRAY)); - return values; - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private static List parseWeightedTokens(BytesReference value, XContentType contentType) { - try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, value, contentType)) { - Map map = parser.map(); - List weightedTokens = new ArrayList<>(); - for (var entry : map.entrySet()) { - weightedTokens.add(new WeightedToken(entry.getKey(), ((Number) entry.getValue()).floatValue())); - } - return weightedTokens; - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index ea948ec65ad82..9e34715e4f83a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -23,7 +23,6 @@ import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; -import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; @@ -32,10 +31,10 @@ import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.LuceneDocument; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.Mapper; @@ -57,9 +56,8 @@ import org.elasticsearch.search.LeafNestedDocuments; import org.elasticsearch.search.NestedDocuments; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -67,22 +65,22 @@ import org.junit.AssumptionViolatedException; import java.io.IOException; -import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.function.BiConsumer; -import java.util.stream.Stream; import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.MODEL_SETTINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.SEARCH_INFERENCE_ID_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.TEXT_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID; @@ -97,21 +95,6 @@ protected Collection getPlugins() { return singletonList(new InferencePlugin(Settings.EMPTY)); } - @Override - protected Settings getIndexSettings() { - return Settings.builder() - .put( - IndexMetadata.SETTING_VERSION_CREATED, - IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, IndexVersion.current()) - ) - .build(); - } - - @Override - protected IndexVersion getVersion() { - return IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, IndexVersion.current()); - } - @Override protected void minimalMapping(XContentBuilder b) throws IOException { b.field("type", "semantic_text"); @@ -164,7 +147,15 @@ protected IngestScriptSupport ingestScriptSupport() { @Override public MappedFieldType getMappedFieldType() { - return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, Map.of()); + return new SemanticTextFieldMapper.SemanticTextFieldType( + "field", + "fake-inference-id", + null, + null, + null, + IndexVersion.current(), + Map.of() + ); } @Override @@ -475,6 +466,12 @@ private static void assertSemanticTextField(MapperService mapperService, String .get(getChunksFieldName(fieldName)); assertThat(chunksMapper, equalTo(semanticFieldMapper.fieldType().getChunksField())); assertThat(chunksMapper.fullPath(), equalTo(getChunksFieldName(fieldName))); + Mapper textMapper = chunksMapper.getMapper(TEXT_FIELD); + assertNotNull(textMapper); + assertThat(textMapper, instanceOf(KeywordFieldMapper.class)); + KeywordFieldMapper textFieldMapper = (KeywordFieldMapper) textMapper; + assertFalse(textFieldMapper.fieldType().isIndexed()); + assertFalse(textFieldMapper.fieldType().hasDocValues()); if (expectedModelSettings) { assertNotNull(semanticFieldMapper.fieldType().getModelSettings()); Mapper embeddingsMapper = chunksMapper.getMapper(CHUNKED_EMBEDDINGS_FIELD); @@ -631,9 +628,11 @@ public void testMissingInferenceId() throws IOException { () -> documentMapper.parse( source( b -> b.startObject("field") + .startObject(INFERENCE_FIELD) .field(MODEL_SETTINGS_FIELD, new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)) .field(CHUNKS_FIELD, List.of()) .endObject() + .endObject() ) ) ); @@ -645,7 +644,9 @@ public void testMissingModelSettings() throws IOException { IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, - () -> documentMapper.parse(source(b -> b.startObject("field").field(INFERENCE_ID_FIELD, "my_id").endObject())) + () -> documentMapper.parse( + source(b -> b.startObject("field").startObject(INFERENCE_FIELD).field(INFERENCE_ID_FIELD, "my_id").endObject().endObject()) + ) ); assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); } @@ -657,7 +658,13 @@ public void testMissingTaskType() throws IOException { IllegalArgumentException.class, () -> documentMapper.parse( source( - b -> b.startObject("field").field(INFERENCE_ID_FIELD, "my_id").startObject(MODEL_SETTINGS_FIELD).endObject().endObject() + b -> b.startObject("field") + .startObject(INFERENCE_FIELD) + .field(INFERENCE_ID_FIELD, "my_id") + .startObject(MODEL_SETTINGS_FIELD) + .endObject() + .endObject() + .endObject() ) ) ); @@ -726,7 +733,12 @@ private MapperService mapperServiceForFieldWithModelSettings( MapperService.MergeReason.MAPPING_UPDATE ); - SemanticTextField semanticTextField = new SemanticTextField(fieldName, inferenceId, modelSettings, List.of(), XContentType.JSON); + SemanticTextField semanticTextField = new SemanticTextField( + fieldName, + List.of(), + new SemanticTextField.InferenceResult(inferenceId, modelSettings, List.of()), + XContentType.JSON + ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); builder.field(semanticTextField.fieldName()); builder.value(semanticTextField); @@ -861,68 +873,4 @@ private static void assertSparseFeatures(LuceneDocument doc, String fieldName, i } assertThat(count, equalTo(expectedCount)); } - - private Map toSourceMap(String source) throws IOException { - try (XContentParser parser = createParser(JsonXContent.jsonXContent, source)) { - return parser.map(); - } - } - - private static Object getMapValue(Map map, String key) { - // Split the path on unescaped "." chars and then unescape the escaped "." chars - final String[] pathElements = Arrays.stream(key.split("(? k.replace("\\.", ".")).toArray(String[]::new); - - Object value = null; - Object nextLayer = map; - for (int i = 0; i < pathElements.length; i++) { - if (nextLayer instanceof Map nextMap) { - value = nextMap.get(pathElements[i]); - } else if (nextLayer instanceof List nextList) { - final String pathElement = pathElements[i]; - List values = nextList.stream().flatMap(v -> { - Stream.Builder streamBuilder = Stream.builder(); - if (v instanceof List innerList) { - traverseList(innerList, streamBuilder); - } else { - streamBuilder.add(v); - } - return streamBuilder.build(); - }).filter(v -> v instanceof Map).map(v -> ((Map) v).get(pathElement)).filter(Objects::nonNull).toList(); - - if (values.isEmpty()) { - return null; - } else if (values.size() > 1) { - throw new AssertionError("List " + nextList + " contains multiple values for [" + pathElement + "]"); - } else { - value = values.getFirst(); - } - } else if (nextLayer == null) { - break; - } else { - throw new AssertionError( - "Path [" - + String.join(".", Arrays.copyOfRange(pathElements, 0, i)) - + "] has value [" - + value - + "] of type [" - + value.getClass().getSimpleName() - + "], which cannot be traversed into further" - ); - } - - nextLayer = value; - } - - return value; - } - - private static void traverseList(List list, Stream.Builder streamBuilder) { - for (Object value : list) { - if (value instanceof List innerList) { - traverseList(innerList, streamBuilder); - } else { - streamBuilder.add(value); - } - } - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 43eec68c28cd9..60f3c89caf994 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -48,20 +48,21 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) { assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); - assertThat(newInstance.modelSettings(), equalTo(expectedInstance.modelSettings())); - assertThat(newInstance.chunks().size(), equalTo(expectedInstance.chunks().size())); - SemanticTextField.ModelSettings modelSettings = newInstance.modelSettings(); - for (int i = 0; i < newInstance.chunks().size(); i++) { - assertThat(newInstance.chunks().get(i).offset(), equalTo(expectedInstance.chunks().get(i).offset())); + assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); + assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); + assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); + SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); + for (int i = 0; i < newInstance.inference().chunks().size(); i++) { + assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); switch (modelSettings.taskType()) { case TEXT_EMBEDDING -> { double[] expectedVector = parseDenseVector( - expectedInstance.chunks().get(i).rawEmbeddings(), + expectedInstance.inference().chunks().get(i).rawEmbeddings(), modelSettings.dimensions(), expectedInstance.contentType() ); double[] newVector = parseDenseVector( - newInstance.chunks().get(i).rawEmbeddings(), + newInstance.inference().chunks().get(i).rawEmbeddings(), modelSettings.dimensions(), newInstance.contentType() ); @@ -69,11 +70,11 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic } case SPARSE_EMBEDDING -> { List expectedTokens = parseWeightedTokens( - expectedInstance.chunks().get(i).rawEmbeddings(), + expectedInstance.inference().chunks().get(i).rawEmbeddings(), expectedInstance.contentType() ); List newTokens = parseWeightedTokens( - newInstance.chunks().get(i).rawEmbeddings(), + newInstance.inference().chunks().get(i).rawEmbeddings(), newInstance.contentType() ); assertThat(newTokens, equalTo(expectedTokens)); @@ -191,27 +192,25 @@ public static SemanticTextField randomSemanticText(String fieldName, Model model case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults( - fieldName, - model, - SemanticTextField.nodeStringValues(fieldName, inputs), - results, - contentType - ); + return semanticTextFieldFromChunkedInferenceResults(fieldName, model, inputs, results, contentType); } public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( String fieldName, Model model, - String input, + List inputs, ChunkedInferenceServiceResults results, XContentType contentType ) { return new SemanticTextField( fieldName, - model.getInferenceEntityId(), - new SemanticTextField.ModelSettings(model), - toSemanticTextFieldChunks(fieldName, input, List.of(results), contentType), + inputs, + new SemanticTextField.InferenceResult( + model.getInferenceEntityId(), + new SemanticTextField.ModelSettings(model), + // TODO + toSemanticTextFieldChunks(fieldName, inputs.get(0), List.of(results), contentType) + ), contentType ); } @@ -235,31 +234,33 @@ public static Object randomSemanticTextInput() { } public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) throws IOException { - switch (field.modelSettings().taskType()) { + switch (field.inference().modelSettings().taskType()) { case SPARSE_EMBEDDING -> { List chunks = new ArrayList<>(); - for (var chunk : field.chunks()) { + for (var chunk : field.inference().chunks()) { var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); - // TODO - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(null, tokens)); + chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); } return new InferenceChunkedSparseEmbeddingResults(chunks); } case TEXT_EMBEDDING -> { List chunks = new ArrayList<>(); - for (var chunk : field.chunks()) { - double[] values = parseDenseVector(chunk.rawEmbeddings(), field.modelSettings().dimensions(), field.contentType()); - // TODO + for (var chunk : field.inference().chunks()) { + double[] values = parseDenseVector( + chunk.rawEmbeddings(), + field.inference().modelSettings().dimensions(), + field.contentType() + ); chunks.add( new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk( - null, + chunk.text(), FloatConversionUtils.floatArrayOf(values) ) ); } return new InferenceChunkedTextEmbeddingFloatResults(chunks); } - default -> throw new AssertionError("Invalid task_type: " + field.modelSettings().taskType().name()); + default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtilsTests.java new file mode 100644 index 0000000000000..f89b00f593519 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextUtilsTests.java @@ -0,0 +1,351 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class SemanticTextUtilsTests extends ESTestCase { + public void testInsertValueMapTraversal() throws IOException { + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("test", "value").endObject(); + + Map map = toSourceMap(Strings.toString(builder)); + SemanticTextUtils.insertValue("test", map, "value2"); + assertThat(getMapValue(map, "test"), equalTo("value2")); + SemanticTextUtils.insertValue("something.else", map, "something_else_value"); + assertThat(getMapValue(map, "something\\.else"), equalTo("something_else_value")); + } + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.startObject("path1").startObject("path2").field("test", "value").endObject().endObject(); + builder.endObject(); + + Map map = toSourceMap(Strings.toString(builder)); + SemanticTextUtils.insertValue("path1.path2.test", map, "value2"); + assertThat(getMapValue(map, "path1.path2.test"), equalTo("value2")); + SemanticTextUtils.insertValue("path1.path2.test_me", map, "test_me_value"); + assertThat(getMapValue(map, "path1.path2.test_me"), equalTo("test_me_value")); + SemanticTextUtils.insertValue("path1.non_path2.test", map, "test_value"); + assertThat(getMapValue(map, "path1.non_path2\\.test"), equalTo("test_value")); + + SemanticTextUtils.insertValue("path1.path2", map, Map.of("path3", "bar")); + assertThat(getMapValue(map, "path1.path2"), equalTo(Map.of("path3", "bar"))); + + SemanticTextUtils.insertValue("path1", map, "baz"); + assertThat(getMapValue(map, "path1"), equalTo("baz")); + + SemanticTextUtils.insertValue("path3.path4", map, Map.of("test", "foo")); + assertThat(getMapValue(map, "path3\\.path4"), equalTo(Map.of("test", "foo"))); + } + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.startObject("path1").array("test", "value1", "value2").endObject(); + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + + SemanticTextUtils.insertValue("path1.test", map, List.of("value3", "value4", "value5")); + assertThat(getMapValue(map, "path1.test"), equalTo(List.of("value3", "value4", "value5"))); + + SemanticTextUtils.insertValue("path2.test", map, List.of("value6", "value7", "value8")); + assertThat(getMapValue(map, "path2\\.test"), equalTo(List.of("value6", "value7", "value8"))); + } + } + + public void testInsertValueListTraversal() throws IOException { + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + { + builder.startObject("path1"); + { + builder.startArray("path2"); + builder.startObject().field("test", "value1").endObject(); + builder.endArray(); + } + builder.endObject(); + } + { + builder.startObject("path3"); + { + builder.startArray("path4"); + builder.startObject().field("test", "value1").endObject(); + builder.endArray(); + } + builder.endObject(); + } + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + + SemanticTextUtils.insertValue("path1.path2.test", map, "value2"); + assertThat(getMapValue(map, "path1.path2.test"), equalTo("value2")); + SemanticTextUtils.insertValue("path1.path2.test2", map, "value3"); + assertThat(getMapValue(map, "path1.path2.test2"), equalTo("value3")); + assertThat(getMapValue(map, "path1.path2"), equalTo(List.of(Map.of("test", "value2", "test2", "value3")))); + + SemanticTextUtils.insertValue("path3.path4.test", map, "value4"); + assertThat(getMapValue(map, "path3.path4.test"), equalTo("value4")); + } + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + { + builder.startObject("path1"); + { + builder.startArray("path2"); + builder.startArray(); + builder.startObject().field("test", "value1").endObject(); + builder.endArray(); + builder.endArray(); + } + builder.endObject(); + } + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + + SemanticTextUtils.insertValue("path1.path2.test", map, "value2"); + assertThat(getMapValue(map, "path1.path2.test"), equalTo("value2")); + SemanticTextUtils.insertValue("path1.path2.test2", map, "value3"); + assertThat(getMapValue(map, "path1.path2.test2"), equalTo("value3")); + assertThat(getMapValue(map, "path1.path2"), equalTo(List.of(List.of(Map.of("test", "value2", "test2", "value3"))))); + } + } + + public void testInsertValueFieldsWithDots() throws IOException { + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("xxx.yyy", "value1").endObject(); + Map map = toSourceMap(Strings.toString(builder)); + + SemanticTextUtils.insertValue("xxx.yyy", map, "value2"); + assertThat(getMapValue(map, "xxx\\.yyy"), equalTo("value2")); + + SemanticTextUtils.insertValue("xxx", map, "value3"); + assertThat(getMapValue(map, "xxx"), equalTo("value3")); + } + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + { + builder.startObject("path1.path2"); + { + builder.startObject("path3.path4"); + builder.field("test", "value1"); + builder.endObject(); + } + builder.endObject(); + } + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + + SemanticTextUtils.insertValue("path1.path2.path3.path4.test", map, "value2"); + assertThat(getMapValue(map, "path1\\.path2.path3\\.path4.test"), equalTo("value2")); + + SemanticTextUtils.insertValue("path1.path2.path3.path4.test2", map, "value3"); + assertThat(getMapValue(map, "path1\\.path2.path3\\.path4.test2"), equalTo("value3")); + assertThat(getMapValue(map, "path1\\.path2.path3\\.path4"), equalTo(Map.of("test", "value2", "test2", "value3"))); + } + } + + public void testInsertValueAmbiguousPath() throws IOException { + // Mixed dotted object notation + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + { + builder.startObject("path1.path2"); + { + builder.startObject("path3"); + builder.field("test1", "value1"); + builder.endObject(); + } + builder.endObject(); + } + { + builder.startObject("path1"); + { + builder.startObject("path2.path3"); + builder.field("test2", "value2"); + builder.endObject(); + } + builder.endObject(); + } + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + final Map originalMap = Collections.unmodifiableMap(toSourceMap(Strings.toString(builder))); + + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> SemanticTextUtils.insertValue("path1.path2.path3.test1", map, "value3") + ); + assertThat( + ex.getMessage(), + equalTo("Path [path1.path2.path3.test1] could be inserted in 2 distinct ways, it is ambiguous which one to use") + ); + + ex = assertThrows( + IllegalArgumentException.class, + () -> SemanticTextUtils.insertValue("path1.path2.path3.test3", map, "value4") + ); + assertThat( + ex.getMessage(), + equalTo("Path [path1.path2.path3.test3] could be inserted in 2 distinct ways, it is ambiguous which one to use") + ); + + assertThat(map, equalTo(originalMap)); + } + + // traversal through lists + { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + { + builder.startObject("path1.path2"); + { + builder.startArray("path3"); + builder.startObject().field("test1", "value1").endObject(); + builder.endArray(); + } + builder.endObject(); + } + { + builder.startObject("path1"); + { + builder.startArray("path2.path3"); + builder.startObject().field("test2", "value2").endObject(); + builder.endArray(); + } + builder.endObject(); + } + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + final Map originalMap = Collections.unmodifiableMap(toSourceMap(Strings.toString(builder))); + + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> SemanticTextUtils.insertValue("path1.path2.path3.test1", map, "value3") + ); + assertThat( + ex.getMessage(), + equalTo("Path [path1.path2.path3.test1] could be inserted in 2 distinct ways, it is ambiguous which one to use") + ); + + ex = assertThrows( + IllegalArgumentException.class, + () -> SemanticTextUtils.insertValue("path1.path2.path3.test3", map, "value4") + ); + assertThat( + ex.getMessage(), + equalTo("Path [path1.path2.path3.test3] could be inserted in 2 distinct ways, it is ambiguous which one to use") + ); + + assertThat(map, equalTo(originalMap)); + } + } + + public void testInsertValueCannotTraversePath() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + { + builder.startObject("path1"); + { + builder.startArray("path2"); + builder.startArray(); + builder.startObject().field("test", "value1").endObject(); + builder.endArray(); + builder.endArray(); + } + builder.endObject(); + } + builder.endObject(); + Map map = toSourceMap(Strings.toString(builder)); + final Map originalMap = Collections.unmodifiableMap(toSourceMap(Strings.toString(builder))); + + IllegalArgumentException ex = assertThrows( + IllegalArgumentException.class, + () -> SemanticTextUtils.insertValue("path1.path2.test.test2", map, "value2") + ); + assertThat( + ex.getMessage(), + equalTo("Path [path1.path2.test] has value [value1] of type [String], which cannot be traversed into further") + ); + + assertThat(map, equalTo(originalMap)); + } + + private Map toSourceMap(String source) throws IOException { + try (XContentParser parser = createParser(JsonXContent.jsonXContent, source)) { + return parser.map(); + } + } + + private static Object getMapValue(Map map, String key) { + // Split the path on unescaped "." chars and then unescape the escaped "." chars + final String[] pathElements = Arrays.stream(key.split("(? k.replace("\\.", ".")).toArray(String[]::new); + + Object value = null; + Object nextLayer = map; + for (int i = 0; i < pathElements.length; i++) { + if (nextLayer instanceof Map nextMap) { + value = nextMap.get(pathElements[i]); + } else if (nextLayer instanceof List nextList) { + final String pathElement = pathElements[i]; + List values = nextList.stream().flatMap(v -> { + Stream.Builder streamBuilder = Stream.builder(); + if (v instanceof List innerList) { + traverseList(innerList, streamBuilder); + } else { + streamBuilder.add(v); + } + return streamBuilder.build(); + }).filter(v -> v instanceof Map).map(v -> ((Map) v).get(pathElement)).filter(Objects::nonNull).toList(); + + if (values.isEmpty()) { + return null; + } else if (values.size() > 1) { + throw new AssertionError("List " + nextList + " contains multiple values for [" + pathElement + "]"); + } else { + value = values.getFirst(); + } + } else if (nextLayer == null) { + break; + } else { + throw new AssertionError( + "Path [" + + String.join(".", Arrays.copyOfRange(pathElements, 0, i)) + + "] has value [" + + value + + "] of type [" + + value.getClass().getSimpleName() + + "], which cannot be traversed into further" + ); + } + + nextLayer = value; + } + + return value; + } + + private static void traverseList(List list, Stream.Builder streamBuilder) { + for (Object value : list) { + if (value instanceof List innerList) { + traverseList(innerList, streamBuilder); + } else { + streamBuilder.add(value); + } + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 67494ed10e471..2015d1c938c56 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQuery; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilderTests.java index a2d3eb9fc1198..832514929e42c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SparseVectorQueryBuilderTests.java @@ -40,6 +40,9 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQuery; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -49,7 +52,7 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.xpack.inference.queries.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD; +import static org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder.QUERY_VECTOR_FIELD; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilderTests.java index 090a4ec8556d2..d2a8544a53b92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TextExpansionQueryBuilderTests.java @@ -35,7 +35,10 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.TextExpansionQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfigTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfigTests.java index a5e569950c319..41f5116db25bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfigTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/TokenPruningConfigTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractXContentSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilderTests.java index 6833dd37a445d..20cdf7c27dc3d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/WeightedTokensQueryBuilderTests.java @@ -35,7 +35,10 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQuery; +import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig; import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; @@ -43,7 +46,7 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.xpack.inference.queries.WeightedTokensQueryBuilder.TOKENS_FIELD; +import static org.elasticsearch.xpack.core.ml.search.WeightedTokensQueryBuilder.TOKENS_FIELD; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.Matchers.either;