Skip to content

Commit

Permalink
First version - uses subfields (not multifields) for storing text and…
Browse files Browse the repository at this point in the history
… sparse_vector
  • Loading branch information
carlosdelest committed Oct 26, 2023
1 parent 3c855f6 commit 6da59a3
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public class SemanticTextFieldMapper extends FieldMapper {
public static final String CONTENT_TYPE = "semantic_text";
private static final String SPARSE_VECTOR_SUFFIX = "_inference";

private static ParseField TEXT_FIELD = new ParseField("text");
private static ParseField INFERENCE_FIELD = new ParseField("inference");
private static final String TEXT_SUBFIELD_NAME = "text";
private static final String SPARSE_VECTOR_SUBFIELD_NAME = "inference";

private static SemanticTextFieldMapper toType(FieldMapper in) {
return (SemanticTextFieldMapper) in;
Expand Down Expand Up @@ -161,24 +161,49 @@ public FieldMapper.Builder getMergeBuilder() {
}

@Override
protected void parseCreateField(DocumentParserContext context) throws IOException {
public void parse(DocumentParserContext context) throws IOException {

XContentParser parser = context.parser();
final String value = parser.textOrNull();
context.parser();
if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) {
throw new IllegalArgumentException(
"[semantic_text] fields must be a json object, expected a START_OBJECT but got: " + context.parser().currentToken()
);
}

if (value == null) {
return;
boolean textFound = false;
boolean inferenceFound = false;
for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser().nextToken()) {
if (token != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token);
}

String fieldName = context.parser().currentName();
XContentParser.Token valueToken = context.parser().nextToken();
switch (fieldName) {
case TEXT_SUBFIELD_NAME:
context.doc().add(new StringField(name() + TEXT_SUBFIELD_NAME, context.parser().textOrNull(), Field.Store.NO));
textFound = true;
break;
case SPARSE_VECTOR_SUBFIELD_NAME:
sparseVectorFieldInfo.sparseVectorFieldMapper.parse(context);
inferenceFound = true;
break;
default:
throw new IllegalArgumentException("Unexpected subfield value: " + fieldName);
}
}

// Create field for original text
context.doc().add(new StringField(name(), value, Field.Store.NO));
if (textFound == false) {
throw new IllegalArgumentException("[semantic_text] value does not contain [" + TEXT_SUBFIELD_NAME + "] subfield");
}
if (inferenceFound == false) {
throw new IllegalArgumentException("[semantic_text] value does not contain [" + SPARSE_VECTOR_SUBFIELD_NAME + "] subfield");
}
}

// Parses inference field, for now a separate field in the doc
// TODO make inference field a multifield / child field?
context.path().add(simpleName() + SPARSE_VECTOR_SUFFIX);
parser.nextToken();
sparseVectorFieldInfo.sparseVectorFieldMapper.parse(context);
context.path().remove();
@Override
protected void parseCreateField(DocumentParserContext context) {
throw new AssertionError("parse is implemented directly");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@ protected void processIndexRequest(

String index = indexRequest.index();
Map<String, Object> sourceMap = indexRequest.sourceAsMap();
sourceMap.entrySet().stream().filter(entry -> fieldNeedsInference(index, entry.getKey())).forEach(entry -> {
sourceMap.entrySet().stream().filter(entry -> fieldNeedsInference(index, entry.getKey(), entry.getValue())).forEach(entry -> {
runInferenceForField(indexRequest, entry.getKey(), refs, slot, onFailure);
});
}

@Override
public boolean needsProcessing(DocWriteRequest<?> docWriteRequest, IndexRequest indexRequest, Metadata metadata) {
return (indexRequest.isFieldInferenceDone() == false)
&& indexRequest.sourceAsMap().keySet().stream().anyMatch(fieldName -> fieldNeedsInference(indexRequest.index(), fieldName));
&& indexRequest.sourceAsMap()
.entrySet()
.stream()
.anyMatch(entry -> fieldNeedsInference(indexRequest.index(), entry.getKey(), entry.getValue()));
}

@Override
Expand All @@ -67,9 +70,11 @@ public boolean shouldExecuteOnIngestNode() {
return false;
}

// TODO actual mapping check here
private boolean fieldNeedsInference(String index, String fieldName) {
return fieldName.startsWith("infer_");
private boolean fieldNeedsInference(String index, String fieldName, Object fieldValue) {
// TODO actual mapping check here
return fieldName.startsWith("infer_")
// We want to perform inference when we don't have already calculated it
&& (fieldValue instanceof String);
}

private void runInferenceForField(
Expand All @@ -87,10 +92,11 @@ private void runInferenceForField(
refs.acquire();

// TODO Hardcoding model ID and task type
final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class);
InferenceAction.Request inferenceRequest = new InferenceAction.Request(
TaskType.SPARSE_EMBEDDING,
"my-elser-model",
ingestDocument.getFieldValue(fieldName, String.class),
fieldValue,
Map.of()
);

Expand All @@ -99,7 +105,10 @@ private void runInferenceForField(
client.execute(InferenceAction.INSTANCE, inferenceRequest, ActionListener.runAfter(new ActionListener<InferenceAction.Response>() {
@Override
public void onResponse(InferenceAction.Response response) {
ingestDocument.setFieldValue(fieldName + "_inference", response.getResult().asMap(fieldName).get(fieldName));
ingestDocument.removeField(fieldName);
// Transform into two subfields, one with the actual text and other with the inference
ingestDocument.setFieldValue(fieldName + "._text", fieldValue);
ingestDocument.setFieldValue(fieldName + "._inference", response.getResult().asMap(fieldName).get(fieldName));
updateIndexRequestSource(indexRequest, ingestDocument);
}

Expand Down

0 comments on commit 6da59a3

Please sign in to comment.