From 03423e3b2f2954d6f2da0a42385da21a96f31c91 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 5 Dec 2023 12:27:16 +0100 Subject: [PATCH] Added tests --- .../ml/mapper/SemanticTextFieldMapper.java | 40 ++++-- .../mapper/SemanticTextFieldMapperTests.java | 126 ++++++++++++++++++ 2 files changed, 152 insertions(+), 14 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java index df1447c5368bb..b3bc399c04a4c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapper.java @@ -9,6 +9,8 @@ import org.apache.lucene.search.Query; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.fielddata.FieldDataContext; +import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; @@ -17,7 +19,6 @@ import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; -import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; import java.io.IOException; @@ -32,13 +33,18 @@ private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; } + private static Builder builder(FieldMapper in) { + return ((SemanticTextFieldMapper) in).builder; + } + public static class Builder extends FieldMapper.Builder { - final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).modelId, null).addValidator(v -> { - if (Strings.isEmpty(v)) { - throw new IllegalArgumentException("field [model_id] must be specified"); - } - }); + private final Parameter modelId = Parameter.stringParam("model_id", false, m -> builder(m).modelId.get(), null) + .addValidator(v -> { + if (Strings.isEmpty(v)) { + throw new IllegalArgumentException("field [model_id] must be specified"); + } + }); private final Parameter> meta = Parameter.metaParam(); @@ -62,7 +68,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { name(), new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), modelId.getValue(), - copyTo + copyTo, + this ); } } @@ -71,13 +78,10 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { public static class SemanticTextFieldType extends SimpleMappedFieldType { - private final SparseVectorFieldMapper.SparseVectorFieldType sparseVectorFieldType; - private final String modelId; public SemanticTextFieldType(String name, String modelId, Map meta) { - super(name, true, false, false, TextSearchInfo.NONE, meta); - this.sparseVectorFieldType = new SparseVectorFieldMapper.SparseVectorFieldType(name + "." + "inference", meta); + super(name, false, false, false, TextSearchInfo.NONE, meta); this.modelId = modelId; } @@ -93,19 +97,27 @@ public String getInferenceModel() { @Override public Query termQuery(Object value, SearchExecutionContext context) { - return null; + throw new IllegalArgumentException("termQuery not implemented yet"); } @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { - return SourceValueFetcher.identity(name(), context, format); + return SourceValueFetcher.toString(name(), context, format); + } + + @Override + public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { + throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating"); } } private final String modelId; - private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, String modelId, CopyTo copyTo) { + private final Builder builder; + + private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, String modelId, CopyTo copyTo, Builder builder) { super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.builder = builder; this.modelId = modelId; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java new file mode 100644 index 0000000000000..0f08abfd6fa59 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/mapper/SemanticTextFieldMapperTests.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.mapper; + +import org.apache.lucene.index.IndexableField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperParsingException; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.junit.AssumptionViolatedException; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.hamcrest.Matchers.containsString; + +public class SemanticTextFieldMapperTests extends MapperTestCase { + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse(source(this::writeField)); + List fields = doc1.rootDoc().getFields("field"); + + // No indexable fields + assertTrue(fields.isEmpty()); + } + + public void testModelIdNotPresent() throws IOException { + Exception e = expectThrows( + MapperParsingException.class, + () -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text"))) + ); + assertThat(e.getMessage(), containsString("field [model_id] must be specified")); + } + + public void testCannotBeUsedInMultiFields() { + Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> { + b.field("type", "text"); + b.startObject("fields"); + b.startObject("semantic"); + b.field("type", "semantic_text"); + b.endObject(); + b.endObject(); + }))); + assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields")); + } + + public void testUpdatesToModelIdNotSupported() throws IOException { + MapperService mapperService = createMapperService( + fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model")) + ); + Exception e = expectThrows( + IllegalArgumentException.class, + () -> merge( + mapperService, + fieldMapping( + b -> b.field("type", "semantic_text") + .field("model_id", "another_model") + ) + ) + ); + assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]")); + } + + @Override + protected Collection getPlugins() { + return singletonList(new MachineLearning(Settings.EMPTY)); + } + + + @Override + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "semantic_text").field("model_id", "test_model"); + } + + @Override + protected Object getSampleValueForDocument() { + return "value"; + } + + @Override + protected boolean supportsIgnoreMalformed() { + return false; + } + + @Override + protected boolean supportsStoredFields() { + return false; + } + + @Override + protected void registerParameters(ParameterChecker checker) throws IOException { + } + + @Override + protected Object generateRandomInputValue(MappedFieldType ft) { + assumeFalse("doc_values are not supported in semantic_text", true); + return null; + } + + @Override + protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { + throw new AssumptionViolatedException("not supported"); + } + + @Override + protected IngestScriptSupport ingestScriptSupport() { + throw new AssumptionViolatedException("not supported"); + } +}