diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java index ec4dc87273dd3..09284b4223073 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java @@ -26,8 +26,8 @@ public class SemanticTextFieldMapper extends FieldMapper { public static final String CONTENT_TYPE = "semantic_text"; private static final String SPARSE_VECTOR_SUFFIX = "_inference"; - private static ParseField TEXT_FIELD = new ParseField("text"); - private static ParseField INFERENCE_FIELD = new ParseField("inference"); + private static final String TEXT_SUBFIELD_NAME = "text"; + private static final String SPARSE_VECTOR_SUBFIELD_NAME = "inference"; private static SemanticTextFieldMapper toType(FieldMapper in) { return (SemanticTextFieldMapper) in; @@ -161,24 +161,49 @@ public FieldMapper.Builder getMergeBuilder() { } @Override - protected void parseCreateField(DocumentParserContext context) throws IOException { + public void parse(DocumentParserContext context) throws IOException { - XContentParser parser = context.parser(); - final String value = parser.textOrNull(); + context.parser(); + if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException( + "[semantic_text] fields must be a json object, expected a START_OBJECT but got: " + context.parser().currentToken() + ); + } - if (value == null) { - return; + boolean textFound = false; + boolean inferenceFound = false; + for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser().nextToken()) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token); + } + + String fieldName = context.parser().currentName(); + XContentParser.Token valueToken = context.parser().nextToken(); + switch (fieldName) { + case TEXT_SUBFIELD_NAME: + context.doc().add(new StringField(name() + TEXT_SUBFIELD_NAME, context.parser().textOrNull(), Field.Store.NO)); + textFound = true; + break; + case SPARSE_VECTOR_SUBFIELD_NAME: + sparseVectorFieldInfo.sparseVectorFieldMapper.parse(context); + inferenceFound = true; + break; + default: + throw new IllegalArgumentException("Unexpected subfield value: " + fieldName); + } } - // Create field for original text - context.doc().add(new StringField(name(), value, Field.Store.NO)); + if (textFound == false) { + throw new IllegalArgumentException("[semantic_text] value does not contain [" + TEXT_SUBFIELD_NAME + "] subfield"); + } + if (inferenceFound == false) { + throw new IllegalArgumentException("[semantic_text] value does not contain [" + SPARSE_VECTOR_SUBFIELD_NAME + "] subfield"); + } + } - // Parses inference field, for now a separate field in the doc - // TODO make inference field a multifield / child field? - context.path().add(simpleName() + SPARSE_VECTOR_SUFFIX); - parser.nextToken(); - sparseVectorFieldInfo.sparseVectorFieldMapper.parse(context); - context.path().remove(); + @Override + protected void parseCreateField(DocumentParserContext context) { + throw new AssertionError("parse is implemented directly"); } @Override diff --git a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java index beea8c77d868c..2cd0a2f3d4d1b 100644 --- a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java +++ b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java @@ -46,7 +46,7 @@ protected void processIndexRequest( String index = indexRequest.index(); Map sourceMap = indexRequest.sourceAsMap(); - sourceMap.entrySet().stream().filter(entry -> fieldNeedsInference(index, entry.getKey())).forEach(entry -> { + sourceMap.entrySet().stream().filter(entry -> fieldNeedsInference(index, entry.getKey(), entry.getValue())).forEach(entry -> { runInferenceForField(indexRequest, entry.getKey(), refs, slot, onFailure); }); } @@ -54,7 +54,10 @@ protected void processIndexRequest( @Override public boolean needsProcessing(DocWriteRequest docWriteRequest, IndexRequest indexRequest, Metadata metadata) { return (indexRequest.isFieldInferenceDone() == false) - && indexRequest.sourceAsMap().keySet().stream().anyMatch(fieldName -> fieldNeedsInference(indexRequest.index(), fieldName)); + && indexRequest.sourceAsMap() + .entrySet() + .stream() + .anyMatch(entry -> fieldNeedsInference(indexRequest.index(), entry.getKey(), entry.getValue())); } @Override @@ -67,9 +70,11 @@ public boolean shouldExecuteOnIngestNode() { return false; } - // TODO actual mapping check here - private boolean fieldNeedsInference(String index, String fieldName) { - return fieldName.startsWith("infer_"); + private boolean fieldNeedsInference(String index, String fieldName, Object fieldValue) { + // TODO actual mapping check here + return fieldName.startsWith("infer_") + // We want to perform inference when we don't have already calculated it + && (fieldValue instanceof String); } private void runInferenceForField( @@ -87,10 +92,11 @@ private void runInferenceForField( refs.acquire(); // TODO Hardcoding model ID and task type + final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class); InferenceAction.Request inferenceRequest = new InferenceAction.Request( TaskType.SPARSE_EMBEDDING, "my-elser-model", - ingestDocument.getFieldValue(fieldName, String.class), + fieldValue, Map.of() ); @@ -99,7 +105,10 @@ private void runInferenceForField( client.execute(InferenceAction.INSTANCE, inferenceRequest, ActionListener.runAfter(new ActionListener() { @Override public void onResponse(InferenceAction.Response response) { - ingestDocument.setFieldValue(fieldName + "_inference", response.getResult().asMap(fieldName).get(fieldName)); + ingestDocument.removeField(fieldName); + // Transform into two subfields, one with the actual text and other with the inference + ingestDocument.setFieldValue(fieldName + "._text", fieldValue); + ingestDocument.setFieldValue(fieldName + "._inference", response.getResult().asMap(fieldName).get(fieldName)); updateIndexRequestSource(indexRequest, ingestDocument); }