From f204cc327b527152481451de7f148c221c78c421 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Mon, 2 Dec 2024 09:24:33 +0000 Subject: [PATCH] iter --- .../inference/mapper/SemanticTextField.java | 31 ++++---- .../mapper/SemanticTextFieldTests.java | 77 +++++++++++-------- 2 files changed, 63 insertions(+), 45 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index e158ce35b052f..27369adf75c70 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -51,14 +51,14 @@ * the inference results under the {@link SemanticTextField#INFERENCE_FIELD}. * * @param fieldName The original field name. - * @param originalValues The original values associated with the field name. + * @param originalValues The original values associated with the field name for indices created before {@link IndexVersions#INFERENCE_METADATA_FIELDS}, null otherwise. * @param inference The inference result. * @param contentType The {@link XContentType} used to store the embeddings chunks. */ public record SemanticTextField( IndexVersion indexCreatedVersion, String fieldName, - List originalValues, + @Nullable List originalValues, InferenceResult inference, XContentType contentType ) implements ToXContentObject { @@ -274,17 +274,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @SuppressWarnings("unchecked") private static final ConstructingObjectParser SEMANTIC_TEXT_FIELD_PARSER = - new ConstructingObjectParser<>( - SemanticTextFieldMapper.CONTENT_TYPE, - true, - (args, context) -> new SemanticTextField( + new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> { + List originalValues = (List) args[0]; + if (context.indexVersionCreated.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) { + if (originalValues != null && originalValues.isEmpty() == false) { + throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]"); + } + originalValues = null; + } + return new SemanticTextField( context.indexVersionCreated(), context.fieldName(), - (List) (args[0] == null ? List.of() : args[0]), + originalValues, (InferenceResult) args[1], context.xContentType() - ) - ); + ); + }); @SuppressWarnings("unchecked") private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( @@ -332,13 +337,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws (p, c) -> MODEL_SETTINGS_PARSER.parse(p, null), new ParseField(MODEL_SETTINGS_FIELD) ); - INFERENCE_RESULT_PARSER.declareObject(constructorArg(), (p, c) -> { + INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> { if (c.indexVersionCreated.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS)) { return parseChunksMap(p); } else { return Map.of(c.fieldName, parseChunksArrayLegacy(p)); } - }, new ParseField(CHUNKS_FIELD)); + }, new ParseField(CHUNKS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY); CHUNKS_PARSER.declareString(optionalConstructorArg(), new ParseField(TEXT_FIELD)); CHUNKS_PARSER.declareInt(optionalConstructorArg(), new ParseField(CHUNKED_START_OFFSET_FIELD)); @@ -372,7 +377,7 @@ private static Map> parseChunksMap(XContentParser parser) th private static List parseChunksArrayLegacy(XContentParser parser) throws IOException { List results = new ArrayList<>(); - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { results.add(CHUNKS_PARSER.parse(parser, null)); } @@ -397,7 +402,7 @@ public static List toSemanticTextFieldChunks( chunks.add( new Chunk( withOffsets ? null : input, - startOffset, + withOffsets ? startOffset : -1, withOffsets ? startOffset + chunkAsByteReference.matchedText().length() : -1, chunkAsByteReference.bytesReference() ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index e3600c03c53fe..03eadaa30b437 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; @@ -41,6 +42,8 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase { private static final String NAME = "field"; + private IndexVersion currentIndexVersion; + @Override protected Predicate getRandomFieldsExcludeFilter() { return n -> n.endsWith(CHUNKED_EMBEDDINGS_FIELD); @@ -48,49 +51,59 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) { + assertThat(newInstance.indexCreatedVersion(), equalTo(newInstance.indexCreatedVersion())); assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName())); assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings(); - for (int i = 0; i < newInstance.inference().chunks().size(); i++) { - /* assertThat(newInstance.inference().chunks().get(i).text(), equalTo(expectedInstance.inference().chunks().get(i).text())); - switch (modelSettings.taskType()) { - case TEXT_EMBEDDING -> { - double[] expectedVector = parseDenseVector( - expectedInstance.inference().chunks().get(i).rawEmbeddings(), - modelSettings.dimensions(), - expectedInstance.contentType() - ); - double[] newVector = parseDenseVector( - newInstance.inference().chunks().get(i).rawEmbeddings(), - modelSettings.dimensions(), - newInstance.contentType() - ); - assertArrayEquals(expectedVector, newVector, 0.0000001f); - } - case SPARSE_EMBEDDING -> { - List expectedTokens = parseWeightedTokens( - expectedInstance.inference().chunks().get(i).rawEmbeddings(), - expectedInstance.contentType() - ); - List newTokens = parseWeightedTokens( - newInstance.inference().chunks().get(i).rawEmbeddings(), - newInstance.contentType() - ); - assertThat(newTokens, equalTo(expectedTokens)); + for (var entry : newInstance.inference().chunks().entrySet()) { + var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey()); + assertNotNull(expectedChunks); + assertThat(entry.getValue().size(), equalTo(expectedChunks.size())); + for (int i = 0; i < entry.getValue().size(); i++) { + var actualChunk = entry.getValue().get(i); + assertThat(actualChunk.text(), equalTo(expectedChunks.get(i).text())); + assertThat(actualChunk.startOffset(), equalTo(expectedChunks.get(i).startOffset())); + assertThat(actualChunk.endOffset(), equalTo(expectedChunks.get(i).endOffset())); + switch (modelSettings.taskType()) { + case TEXT_EMBEDDING -> { + double[] expectedVector = parseDenseVector( + expectedChunks.get(i).rawEmbeddings(), + modelSettings.dimensions(), + expectedInstance.contentType() + ); + double[] newVector = parseDenseVector( + actualChunk.rawEmbeddings(), + modelSettings.dimensions(), + newInstance.contentType() + ); + assertArrayEquals(expectedVector, newVector, 0.0000001f); + } + case SPARSE_EMBEDDING -> { + List expectedTokens = parseWeightedTokens( + expectedChunks.get(i).rawEmbeddings(), + expectedInstance.contentType() + ); + List newTokens = parseWeightedTokens(actualChunk.rawEmbeddings(), newInstance.contentType()); + assertThat(newTokens, equalTo(expectedTokens)); + } + default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); } - default -> throw new AssertionError("Invalid task type " + modelSettings.taskType()); - }**/ + } } } @Override protected SemanticTextField createTestInstance() { + currentIndexVersion = randomFrom( + IndexVersionUtils.randomPreviousCompatibleVersion(random(), IndexVersions.INFERENCE_METADATA_FIELDS), + IndexVersionUtils.randomVersionBetween(random(), IndexVersions.INFERENCE_METADATA_FIELDS, IndexVersion.current()) + ); List rawValues = randomList(1, 5, () -> randomSemanticTextInput().toString()); try { // try catch required for override return randomSemanticText( - IndexVersion.current(), + currentIndexVersion, NAME, TestModel.createRandomInstance(), rawValues, @@ -104,12 +117,12 @@ protected SemanticTextField createTestInstance() { @Override protected SemanticTextField doParseInstance(XContentParser parser) throws IOException { - return SemanticTextField.parse(parser, new SemanticTextField.ParserContext(IndexVersion.current(), NAME, parser.contentType())); + return SemanticTextField.parse(parser, new SemanticTextField.ParserContext(currentIndexVersion, NAME, parser.contentType())); } @Override protected boolean supportsUnknownFields() { - return true; + return false; } public void testModelSettingsValidation() { @@ -218,7 +231,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( return new SemanticTextField( indexVersion, fieldName, - inputs, + indexVersion.onOrAfter(IndexVersions.INFERENCE_METADATA_FIELDS) ? null : inputs, new SemanticTextField.InferenceResult( model.getInferenceEntityId(), new SemanticTextField.ModelSettings(model),