Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Dec 5, 2023
1 parent 085751e commit 03423e3
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import org.apache.lucene.search.Query;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.FieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
Expand All @@ -17,7 +19,6 @@
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;

import java.io.IOException;
Expand All @@ -32,13 +33,18 @@ private static SemanticTextFieldMapper toType(FieldMapper in) {
return (SemanticTextFieldMapper) in;
}

private static Builder builder(FieldMapper in) {
return ((SemanticTextFieldMapper) in).builder;
}

public static class Builder extends FieldMapper.Builder {

final Parameter<String> modelId = Parameter.stringParam("model_id", false, m -> toType(m).modelId, null).addValidator(v -> {
if (Strings.isEmpty(v)) {
throw new IllegalArgumentException("field [model_id] must be specified");
}
});
private final Parameter<String> modelId = Parameter.stringParam("model_id", false, m -> builder(m).modelId.get(), null)
.addValidator(v -> {
if (Strings.isEmpty(v)) {
throw new IllegalArgumentException("field [model_id] must be specified");
}
});

private final Parameter<Map<String, String>> meta = Parameter.metaParam();

Expand All @@ -62,7 +68,8 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {
name(),
new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()),
modelId.getValue(),
copyTo
copyTo,
this
);
}
}
Expand All @@ -71,13 +78,10 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) {

public static class SemanticTextFieldType extends SimpleMappedFieldType {

private final SparseVectorFieldMapper.SparseVectorFieldType sparseVectorFieldType;

private final String modelId;

public SemanticTextFieldType(String name, String modelId, Map<String, String> meta) {
super(name, true, false, false, TextSearchInfo.NONE, meta);
this.sparseVectorFieldType = new SparseVectorFieldMapper.SparseVectorFieldType(name + "." + "inference", meta);
super(name, false, false, false, TextSearchInfo.NONE, meta);
this.modelId = modelId;
}

Expand All @@ -93,19 +97,27 @@ public String getInferenceModel() {

@Override
public Query termQuery(Object value, SearchExecutionContext context) {
return null;
throw new IllegalArgumentException("termQuery not implemented yet");
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
return SourceValueFetcher.identity(name(), context, format);
return SourceValueFetcher.toString(name(), context, format);
}

@Override
public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) {
throw new IllegalArgumentException("[semantic_text] fields do not support sorting, scripting or aggregating");
}
}

private final String modelId;

private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, String modelId, CopyTo copyTo) {
private final Builder builder;

private SemanticTextFieldMapper(String simpleName, MappedFieldType mappedFieldType, String modelId, CopyTo copyTo, Builder builder) {
super(simpleName, mappedFieldType, MultiFields.empty(), copyTo);
this.builder = builder;
this.modelId = modelId;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.mapper;

import org.apache.lucene.index.IndexableField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MapperTestCase;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.junit.AssumptionViolatedException;

import java.io.IOException;
import java.util.Collection;
import java.util.List;

import static java.util.Collections.singletonList;
import static org.hamcrest.Matchers.containsString;

public class SemanticTextFieldMapperTests extends MapperTestCase {

public void testDefaults() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping));
assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString());

ParsedDocument doc1 = mapper.parse(source(this::writeField));
List<IndexableField> fields = doc1.rootDoc().getFields("field");

// No indexable fields
assertTrue(fields.isEmpty());
}

public void testModelIdNotPresent() throws IOException {
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text")))
);
assertThat(e.getMessage(), containsString("field [model_id] must be specified"));
}

public void testCannotBeUsedInMultiFields() {
Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
b.field("type", "text");
b.startObject("fields");
b.startObject("semantic");
b.field("type", "semantic_text");
b.endObject();
b.endObject();
})));
assertThat(e.getMessage(), containsString("Field [semantic] of type [semantic_text] can't be used in multifields"));
}

public void testUpdatesToModelIdNotSupported() throws IOException {
MapperService mapperService = createMapperService(
fieldMapping(b -> b.field("type", "semantic_text").field("model_id", "test_model"))
);
Exception e = expectThrows(
IllegalArgumentException.class,
() -> merge(
mapperService,
fieldMapping(
b -> b.field("type", "semantic_text")
.field("model_id", "another_model")
)
)
);
assertThat(e.getMessage(), containsString("Cannot update parameter [model_id] from [test_model] to [another_model]"));
}

@Override
protected Collection<? extends Plugin> getPlugins() {
return singletonList(new MachineLearning(Settings.EMPTY));
}


@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "semantic_text").field("model_id", "test_model");
}

@Override
protected Object getSampleValueForDocument() {
return "value";
}

@Override
protected boolean supportsIgnoreMalformed() {
return false;
}

@Override
protected boolean supportsStoredFields() {
return false;
}

@Override
protected void registerParameters(ParameterChecker checker) throws IOException {
}

@Override
protected Object generateRandomInputValue(MappedFieldType ft) {
assumeFalse("doc_values are not supported in semantic_text", true);
return null;
}

@Override
protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) {
throw new AssumptionViolatedException("not supported");
}

@Override
protected IngestScriptSupport ingestScriptSupport() {
throw new AssumptionViolatedException("not supported");
}
}

0 comments on commit 03423e3

Please sign in to comment.