From 83fdc3cc78dc0319d3cf008e0a14adc1d5a83558 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 19 Dec 2024 14:26:52 +0000 Subject: [PATCH] add parametrized tests to ensure that the legacy and new format are always tested --- .../ShardBulkInferenceActionFilterIT.java | 24 +- .../ShardBulkInferenceActionFilterTests.java | 16 +- .../mapper/SemanticTextFieldMapperTests.java | 431 +++++++++--------- .../mapper/SemanticTextFieldTests.java | 14 +- .../queries/SemanticQueryBuilderTests.java | 14 +- 5 files changed, 264 insertions(+), 235 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 304e46f3f7665..90a9dd3355b3c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.action.filter; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -19,13 +21,12 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; @@ -35,6 +36,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; @@ -45,7 +47,16 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase { public static final String INDEX_NAME = "test-index"; - private IndexVersion indexVersion; + private final boolean useLegacyFormat; + + public ShardBulkInferenceActionFilterIT(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return List.of(new Object[] { true }, new Object[] { false }); + } @Before public void setup() throws Exception { @@ -68,16 +79,13 @@ protected Collection> nodePlugins() { @Override public Settings indexSettings() { return Settings.builder() - .put(IndexMetadata.SETTING_VERSION_CREATED, indexVersion) + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)) + .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) .build(); } public void testBulkOperations() throws Exception { - this.indexVersion = randomFrom( - IndexVersionUtils.randomPreviousCompatibleVersion(random(), IndexVersions.INFERENCE_METADATA_FIELDS), - IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, IndexVersion.current()) - ); indicesAdmin().prepareCreate(INDEX_NAME) .setMapping( String.format( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 6310d46495695..478c81f7c5a32 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.action.filter; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.bulk.BulkItemRequest; @@ -79,8 +81,18 @@ import static org.mockito.Mockito.when; public class ShardBulkInferenceActionFilterTests extends ESTestCase { + private final boolean useLegacyFormat; private ThreadPool threadPool; + public ShardBulkInferenceActionFilterTests(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return List.of(new Object[] { true }, new Object[] { false }); + } + @Before public void setupThreadPool() { threadPool = new TestThreadPool(getTestName()); @@ -93,7 +105,6 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - boolean useLegacyFormat = randomBoolean(); ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { @@ -120,7 +131,6 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = StaticModel.createRandomInstance(); - boolean useLegacyFormat = randomBoolean(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), @@ -166,7 +176,6 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { StaticModel model = StaticModel.createRandomInstance(); - boolean useLegacyFormat = randomBoolean(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -225,7 +234,6 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { - boolean useLegacyFormat = randomBoolean(); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 5f44641343fd0..7d25abbed8d2b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexableField; @@ -92,6 +94,17 @@ import static org.hamcrest.Matchers.instanceOf; public class SemanticTextFieldMapperTests extends MapperTestCase { + private final boolean useLegacyFormat; + + public SemanticTextFieldMapperTests(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return List.of(new Object[] { true }, new Object[] { false }); + } + @Override protected Collection getPlugins() { return List.of(new InferencePlugin(Settings.EMPTY), new XPackClientPlugin()); @@ -104,6 +117,14 @@ private MapperService createMapperService(XContentBuilder mappings, boolean useL return createMapperService(settings, mappings); } + @Override + protected Settings getIndexSettings() { + return Settings.builder() + .put(super.getIndexSettings()) + .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat) + .build(); + } + @Override protected void minimalMapping(XContentBuilder b) throws IOException { b.field("type", "semantic_text"); @@ -176,23 +197,21 @@ protected void assertSearchable(MappedFieldType fieldType) { } public void testDefaults() throws Exception { - for (boolean useLegacyFormat : new boolean[] { true, false }) { - final String fieldName = "field"; - final XContentBuilder fieldMapping = fieldMapping(this::minimalMapping); - final XContentBuilder expectedMapping = fieldMapping(this::metaMapping); + final String fieldName = "field"; + final XContentBuilder fieldMapping = fieldMapping(this::minimalMapping); + final XContentBuilder expectedMapping = fieldMapping(this::metaMapping); - MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); - DocumentMapper mapper = mapperService.documentMapper(); - assertEquals(Strings.toString(expectedMapping), mapper.mappingSource().toString()); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID); + MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); + DocumentMapper mapper = mapperService.documentMapper(); + assertEquals(Strings.toString(expectedMapping), mapper.mappingSource().toString()); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID); - ParsedDocument doc1 = mapper.parse(source(this::writeField)); - List fields = doc1.rootDoc().getFields("field"); + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); - // No indexable fields - assertTrue(fields.isEmpty()); - } + // No indexable fields + assertTrue(fields.isEmpty()); } @Override @@ -212,41 +231,37 @@ public void testSetInferenceEndpoints() throws IOException { assertEquals(Strings.toString(expectedMapping), mapper.mappingSource().toString()); }; - for (boolean useLegacyFormat : new boolean[] { true, false }) { - { - final XContentBuilder fieldMapping = fieldMapping( - b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, inferenceId) - ); - final MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - assertSerialization.accept(fieldMapping, mapperService); - } - { - final XContentBuilder fieldMapping = fieldMapping( - b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) - ); - final XContentBuilder expectedMapping = fieldMapping( - b -> b.field("type", "semantic_text") - .field(INFERENCE_ID_FIELD, DEFAULT_ELSER_2_INFERENCE_ID) - .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) - ); - final MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId); - assertSerialization.accept(expectedMapping, mapperService); - } - { - final XContentBuilder fieldMapping = fieldMapping( - b -> b.field("type", "semantic_text") - .field(INFERENCE_ID_FIELD, inferenceId) - .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) - ); - MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); - assertSerialization.accept(fieldMapping, mapperService); - } + { + final XContentBuilder fieldMapping = fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, inferenceId)); + final MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); + assertSerialization.accept(fieldMapping, mapperService); + } + { + final XContentBuilder fieldMapping = fieldMapping( + b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) + ); + final XContentBuilder expectedMapping = fieldMapping( + b -> b.field("type", "semantic_text") + .field(INFERENCE_ID_FIELD, DEFAULT_ELSER_2_INFERENCE_ID) + .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) + ); + final MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId); + assertSerialization.accept(expectedMapping, mapperService); + } + { + final XContentBuilder fieldMapping = fieldMapping( + b -> b.field("type", "semantic_text") + .field(INFERENCE_ID_FIELD, inferenceId) + .field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId) + ); + MapperService mapperService = createMapperService(fieldMapping, useLegacyFormat); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); + assertSerialization.accept(fieldMapping, mapperService); } } @@ -314,14 +329,12 @@ public void testDynamicUpdate() throws IOException { final String fieldName = "semantic"; final String inferenceId = "test_service"; final String searchInferenceId = "search_test_service"; - final boolean useLegacyFormat = randomBoolean(); { MapperService mapperService = mapperServiceForFieldWithModelSettings( fieldName, inferenceId, - new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null), - useLegacyFormat + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) ); assertSemanticTextField(mapperService, fieldName, true); assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); @@ -332,8 +345,7 @@ public void testDynamicUpdate() throws IOException { fieldName, inferenceId, searchInferenceId, - new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null), - useLegacyFormat + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) ); assertSemanticTextField(mapperService, fieldName, true); assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId); @@ -343,7 +355,6 @@ public void testDynamicUpdate() throws IOException { public void testUpdateModelSettings() throws IOException { for (int depth = 1; depth < 5; depth++) { String fieldName = randomFieldName(depth); - final boolean useLegacyFormat = randomBoolean(); MapperService mapperService = createMapperService( mapping(b -> b.startObject(fieldName).field("type", "semantic_text").field("inference_id", "test_model").endObject()), useLegacyFormat @@ -433,46 +444,43 @@ public void testUpdateSearchInferenceId() throws IOException { b.endObject(); }); - for (boolean useLegacyFormat : new boolean[] { true, false }) { - for (int depth = 1; depth < 5; depth++) { - String fieldName = randomFieldName(depth); - MapperService mapperService = createMapperService(buildMapping.apply(fieldName, null), useLegacyFormat); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); - - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); - - merge(mapperService, buildMapping.apply(fieldName, null)); - assertSemanticTextField(mapperService, fieldName, false); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - - mapperService = mapperServiceForFieldWithModelSettings( - fieldName, - inferenceId, - new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null), - useLegacyFormat - ); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); + for (int depth = 1; depth < 5; depth++) { + String fieldName = randomFieldName(depth); + MapperService mapperService = createMapperService(buildMapping.apply(fieldName, null), useLegacyFormat); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); - merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); - merge(mapperService, buildMapping.apply(fieldName, null)); - assertSemanticTextField(mapperService, fieldName, true); - assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); - } + merge(mapperService, buildMapping.apply(fieldName, null)); + assertSemanticTextField(mapperService, fieldName, false); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); + + mapperService = mapperServiceForFieldWithModelSettings( + fieldName, + inferenceId, + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) + ); + assertSemanticTextField(mapperService, fieldName, true); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); + + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId1)); + assertSemanticTextField(mapperService, fieldName, true); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId1); + + merge(mapperService, buildMapping.apply(fieldName, searchInferenceId2)); + assertSemanticTextField(mapperService, fieldName, true); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId2); + + merge(mapperService, buildMapping.apply(fieldName, null)); + assertSemanticTextField(mapperService, fieldName, true); + assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId); } } @@ -540,130 +548,123 @@ private static void assertInferenceEndpoints( } public void testSuccessfulParse() throws IOException { - for (boolean useLegacyFormat : new boolean[] { true, false }) { - for (int depth = 1; depth < 4; depth++) { - final String fieldName1 = randomFieldName(depth); - final String fieldName2 = randomFieldName(depth + 1); - final String searchInferenceId = randomAlphaOfLength(8); - final boolean setSearchInferenceId = randomBoolean(); - - Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - }); - - MapperService mapperService = createMapperService(mapping, useLegacyFormat); - assertSemanticTextField(mapperService, fieldName1, false); - assertInferenceEndpoints( - mapperService, - fieldName1, - model1.getInferenceEntityId(), - setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId() - ); - assertSemanticTextField(mapperService, fieldName2, false); - assertInferenceEndpoints( - mapperService, - fieldName2, - model2.getInferenceEntityId(), - setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId() - ); + for (int depth = 1; depth < 4; depth++) { + final String fieldName1 = randomFieldName(depth); + final String fieldName2 = randomFieldName(depth + 1); + final String searchInferenceId = randomAlphaOfLength(8); + final boolean setSearchInferenceId = randomBoolean(); + + Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + XContentBuilder mapping = mapping(b -> { + addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); + addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); + }); + + MapperService mapperService = createMapperService(mapping, useLegacyFormat); + assertSemanticTextField(mapperService, fieldName1, false); + assertInferenceEndpoints( + mapperService, + fieldName1, + model1.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : model1.getInferenceEntityId() + ); + assertSemanticTextField(mapperService, fieldName2, false); + assertInferenceEndpoints( + mapperService, + fieldName2, + model2.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : model2.getInferenceEntityId() + ); - DocumentMapper documentMapper = mapperService.documentMapper(); - ParsedDocument doc = documentMapper.parse( - source( - b -> addSemanticTextInferenceResults( - useLegacyFormat, - b, - List.of( - randomSemanticText(useLegacyFormat, fieldName1, model1, List.of("a b", "c"), XContentType.JSON), - randomSemanticText(useLegacyFormat, fieldName2, model2, List.of("d e f"), XContentType.JSON) - ) + DocumentMapper documentMapper = mapperService.documentMapper(); + ParsedDocument doc = documentMapper.parse( + source( + b -> addSemanticTextInferenceResults( + useLegacyFormat, + b, + List.of( + randomSemanticText(useLegacyFormat, fieldName1, model1, List.of("a b", "c"), XContentType.JSON), + randomSemanticText(useLegacyFormat, fieldName2, model2, List.of("d e f"), XContentType.JSON) ) ) + ) + ); + + List luceneDocs = doc.docs(); + assertEquals(4, luceneDocs.size()); + for (int i = 0; i < 3; i++) { + assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); + } + // nested docs are in reversed order + assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); + assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); + assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); + assertEquals(doc.rootDoc(), luceneDocs.get(3)); + assertNull(luceneDocs.get(3).getParent()); + + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { + NestedDocuments nested = new NestedDocuments( + mapperService.mappingLookup(), + QueryBitSetProducer::new, + IndexVersion.current() + ); + LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); + + Set visitedNestedIdentities = new HashSet<>(); + Set expectedVisitedNestedIdentities = Set.of( + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), + new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) ); - List luceneDocs = doc.docs(); - assertEquals(4, luceneDocs.size()); - for (int i = 0; i < 3; i++) { - assertEquals(doc.rootDoc(), luceneDocs.get(i).getParent()); + assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); + assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); + assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); + + assertNull(leaf.advance(3)); + assertEquals(3, leaf.doc()); + assertEquals(3, leaf.rootDoc()); + assertNull(leaf.nestedIdentity()); + + IndexSearcher searcher = newSearcher(reader); + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), + 10 + ); + assertEquals(1, topDocs.totalHits.value()); + assertEquals(3, topDocs.scoreDocs[0].doc); } - // nested docs are in reversed order - assertSparseFeatures(luceneDocs.get(0), getEmbeddingsFieldName(fieldName1), 2); - assertSparseFeatures(luceneDocs.get(1), getEmbeddingsFieldName(fieldName1), 1); - assertSparseFeatures(luceneDocs.get(2), getEmbeddingsFieldName(fieldName2), 3); - assertEquals(doc.rootDoc(), luceneDocs.get(3)); - assertNull(luceneDocs.get(3).getParent()); - - withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), reader -> { - NestedDocuments nested = new NestedDocuments( - mapperService.mappingLookup(), - QueryBitSetProducer::new, - IndexVersion.current() + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a", "b")), + 10 ); - LeafNestedDocuments leaf = nested.getLeafNestedDocuments(reader.leaves().get(0)); - - Set visitedNestedIdentities = new HashSet<>(); - Set expectedVisitedNestedIdentities = Set.of( - new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 0, null), - new SearchHit.NestedIdentity(getChunksFieldName(fieldName1), 1, null), - new SearchHit.NestedIdentity(getChunksFieldName(fieldName2), 0, null) + assertEquals(1, topDocs.totalHits.value()); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), + 10 ); - - assertChildLeafNestedDocument(leaf, 0, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 1, 3, visitedNestedIdentities); - assertChildLeafNestedDocument(leaf, 2, 3, visitedNestedIdentities); - assertEquals(expectedVisitedNestedIdentities, visitedNestedIdentities); - - assertNull(leaf.advance(3)); - assertEquals(3, leaf.doc()); - assertEquals(3, leaf.rootDoc()); - assertNull(leaf.nestedIdentity()); - - IndexSearcher searcher = newSearcher(reader); - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName1, List.of("a")), - 10 - ); - assertEquals(1, topDocs.totalHits.value()); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery( - mapperService.mappingLookup().nestedLookup(), - fieldName1, - List.of("a", "b") - ), - 10 - ); - assertEquals(1, topDocs.totalHits.value()); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("d")), - 10 - ); - assertEquals(1, topDocs.totalHits.value()); - assertEquals(3, topDocs.scoreDocs[0].doc); - } - { - TopDocs topDocs = searcher.search( - generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), - 10 - ); - assertEquals(0, topDocs.totalHits.value()); - } - }); - } + assertEquals(1, topDocs.totalHits.value()); + assertEquals(3, topDocs.scoreDocs[0].doc); + } + { + TopDocs topDocs = searcher.search( + generateNestedTermSparseVectorQuery(mapperService.mappingLookup().nestedLookup(), fieldName2, List.of("z")), + 10 + ); + assertEquals(0, topDocs.totalHits.value()); + } + }); } } public void testMissingInferenceId() throws IOException { - boolean useLegacyFormat = randomBoolean(); final MapperService mapperService = createMapperService( mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat @@ -689,7 +690,6 @@ public void testMissingInferenceId() throws IOException { } public void testMissingModelSettings() throws IOException { - boolean useLegacyFormat = randomBoolean(); MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, @@ -706,7 +706,6 @@ public void testMissingModelSettings() throws IOException { } public void testMissingTaskType() throws IOException { - boolean useLegacyFormat = randomBoolean(); MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, @@ -747,8 +746,7 @@ public void testDenseVectorElementType() throws IOException { 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT - ), - randomBoolean() + ) ); assertMapperService.accept(floatMapperService, DenseVectorFieldMapper.ElementType.FLOAT); @@ -760,8 +758,7 @@ public void testDenseVectorElementType() throws IOException { 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE - ), - randomBoolean() + ) ); assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE); } @@ -769,18 +766,16 @@ public void testDenseVectorElementType() throws IOException { private MapperService mapperServiceForFieldWithModelSettings( String fieldName, String inferenceId, - SemanticTextField.ModelSettings modelSettings, - boolean useLegacyFormat + SemanticTextField.ModelSettings modelSettings ) throws IOException { - return mapperServiceForFieldWithModelSettings(fieldName, inferenceId, null, modelSettings, useLegacyFormat); + return mapperServiceForFieldWithModelSettings(fieldName, inferenceId, null, modelSettings); } private MapperService mapperServiceForFieldWithModelSettings( String fieldName, String inferenceId, String searchInferenceId, - SemanticTextField.ModelSettings modelSettings, - boolean useLegacyFormat + SemanticTextField.ModelSettings modelSettings ) throws IOException { String mappingParams = "type=semantic_text,inference_id=" + inferenceId; if (searchInferenceId != null) { @@ -827,8 +822,7 @@ public void testExistsQuerySparseVector() throws IOException { MapperService mapperService = mapperServiceForFieldWithModelSettings( fieldName, inferenceId, - new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null), - randomBoolean() + new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null) ); Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); @@ -850,8 +844,7 @@ public void testExistsQueryDenseVector() throws IOException { 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT - ), - randomBoolean() + ) ); Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 46f9aabfe974e..29ca71d38e1b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.mapper; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -40,7 +42,16 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase { private static final String NAME = "field"; - private boolean useLegacyFormat; + private final boolean useLegacyFormat; + + public SemanticTextFieldTests(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return List.of(new Object[] { true }, new Object[] { false }); + } @Override protected Predicate getRandomFieldsExcludeFilter() { @@ -94,7 +105,6 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic @Override protected SemanticTextField createTestInstance() { - useLegacyFormat = randomBoolean(); List rawValues = randomList(1, 5, () -> randomSemanticTextInput().toString()); try { // try catch required for override return randomSemanticText( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 12d8af947a265..d5042643013e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.queries; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; @@ -86,7 +88,7 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase parameters() throws Exception { + return List.of(new Object[] { true }, new Object[] { false }); + } + @BeforeClass public static void setInferenceResultType() { // These are class variables because they are used when initializing additional mappings, which happens once per test suite run in @@ -106,7 +117,6 @@ public static void setInferenceResultType() { () -> randomFrom(DenseVectorFieldMapper.ElementType.values()) ); // TODO: Support bit elements once KNN bit vector queries are available useSearchInferenceId = randomBoolean(); - useLegacyFormat = randomBoolean(); } @Override