Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Mar 20, 2024
1 parent 7b578d1 commit 2be50d7
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public static TestServiceSettings fromMap(Map<String, Object> map) {
SimilarityMeasure similarity = null;
String similarityStr = (String) map.remove("similarity");
if (similarityStr != null) {
similarity = SimilarityMeasure.valueOf(similarityStr);
similarity = SimilarityMeasure.fromString(similarityStr);
}

return new TestServiceSettings(model, dimensions, similarity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
continue;
}
final Map<String, Object> docMap = indexRequest.sourceAsMap();
boolean hasInput = false;
for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) {
String field = entry.getKey();
String inferenceId = entry.getValue().inferenceId();
Expand All @@ -315,6 +316,7 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
if (value instanceof String valueStr) {
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr));
hasInput = true;
} else {
inferenceResults.get(item.id()).failures.add(
new ElasticsearchStatusException(
Expand All @@ -326,6 +328,12 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
);
}
}
if (hasInput == false) {
// remove the existing _inference field (if present) since none of the content require inference.
if (docMap.remove(InferenceMetadataFieldMapper.NAME) != null) {
indexRequest.source(docMap);
}
}
}
return fieldRequestsMap;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.mapper.DocumentParserContext;
import org.elasticsearch.index.mapper.DocumentParsingException;
Expand All @@ -37,7 +36,6 @@
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentLocation;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.support.MapXContentParser;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
Expand Down Expand Up @@ -67,8 +65,8 @@
* "my_semantic_text_field": "these are not the droids you're looking for",
* "_inference": {
* "my_semantic_text_field": {
* "inference_id": "my_inference_id",
* "model_settings": {
* "inference_id": "my_inference_id",
* "task_type": "SPARSE_EMBEDDING"
* },
* "results" [
Expand Down Expand Up @@ -118,6 +116,7 @@ public class InferenceMetadataFieldMapper extends MetadataFieldMapper {
public static final String NAME = "_inference";
public static final String CONTENT_TYPE = "_inference";

private static final String INFERENCE_ID = "inference_id";
public static final String RESULTS = "results";
public static final String INFERENCE_CHUNKS_RESULTS = "inference";
public static final String INFERENCE_CHUNKS_TEXT = "text";
Expand Down Expand Up @@ -178,19 +177,20 @@ private NestedObjectMapper updateSemanticTextFieldMapper(
MapperBuilderContext mapperBuilderContext,
ObjectMapper parent,
SemanticTextFieldMapper original,
String inferenceId,
SemanticTextModelSettings modelSettings,
XContentLocation xContentLocation
) {
if (modelSettings.inferenceId().equals(original.fieldType().getInferenceId()) == false) {
if (inferenceId.equals(original.fieldType().getInferenceId()) == false) {
throw new DocumentParsingException(
xContentLocation,
Strings.format(
"The configured %s [%s] for field [%s] doesn't match the %s [%s] reported in the document.",
SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(),
modelSettings.inferenceId(),
INFERENCE_ID,
inferenceId,
original.name(),
SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(),
modelSettings.inferenceId()
INFERENCE_ID,
original.fieldType().getInferenceId()
)
);
}
Expand All @@ -208,7 +208,7 @@ private NestedObjectMapper updateSemanticTextFieldMapper(
original.simpleName(),
docContext.indexSettings().getIndexVersionCreated(),
docContext.indexAnalyzers()
).setInferenceId(modelSettings.inferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext);
).setInferenceId(original.fieldType().getInferenceId()).setModelSettings(modelSettings).build(mapperBuilderContext);
docContext.addDynamicMapper(newMapper);
return newMapper.getSubMappers();
} else {
Expand Down Expand Up @@ -238,8 +238,18 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex

// record the location of the inference field in the original source
XContentLocation xContentLocation = parser.getTokenLocation();
// parse eagerly to extract the model settings first
// parse eagerly to extract the inference id and the model settings first
Map<String, Object> map = parser.mapOrdered();
logger.info("map=" + map.toString());

// inference_id
Object inferenceIdObj = map.remove(INFERENCE_ID);
final String inferenceId = XContentMapValues.nodeStringValue(inferenceIdObj, null);
if (inferenceId == null) {
throw new IllegalArgumentException("required [" + INFERENCE_ID + "] is missing");
}

// model_settings
Object modelSettingsObj = map.remove(SemanticTextModelSettings.NAME);
if (modelSettingsObj == null) {
throw new DocumentParsingException(
Expand All @@ -252,12 +262,9 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex
)
);
}
Map<String, Object> modelSettingsMap = XContentMapValues.nodeMapValue(modelSettingsObj, "model_settings");
final SemanticTextModelSettings modelSettings;
try {
modelSettings = SemanticTextModelSettings.parse(
XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, modelSettingsMap)
);
modelSettings = SemanticTextModelSettings.fromMap(modelSettingsObj);
} catch (Exception exc) {
throw new DocumentParsingException(
xContentLocation,
Expand All @@ -270,11 +277,13 @@ private void parseSingleField(DocumentParserContext context, MapperBuilderContex
exc
);
}

var nestedObjectMapper = updateSemanticTextFieldMapper(
context,
mapperBuilderContext,
res.parent,
(SemanticTextFieldMapper) res.mapper,
inferenceId,
modelSettings,
xContentLocation
);
Expand Down Expand Up @@ -406,8 +415,9 @@ public static void applyFieldInference(
);
}
Map<String, Object> fieldMap = new LinkedHashMap<>();
fieldMap.put(INFERENCE_ID, model.getInferenceEntityId());
fieldMap.putAll(new SemanticTextModelSettings(model).asMap());
fieldMap.put(InferenceMetadataFieldMapper.RESULTS, chunks);
fieldMap.put(RESULTS, chunks);
inferenceMap.put(field, fieldMap);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ private static Mapper.Builder createInferenceMapperBuilder(
);
}
}
Integer dimensions = modelSettings.dimensions();
denseVectorMapperBuilder.dimensions(dimensions);
denseVectorMapperBuilder.dimensions(modelSettings.dimensions());
yield denseVectorMapperBuilder;
}
default -> throw new IllegalArgumentException(
Expand All @@ -287,24 +286,6 @@ private static Mapper.Builder createInferenceMapperBuilder(
};
}

@Override
protected void checkIncomingMergeType(FieldMapper mergeWith) {
if (mergeWith instanceof SemanticTextFieldMapper other) {
if (other.modelSettings != null && other.modelSettings.inferenceId().equals(other.fieldType().getInferenceId()) == false) {
throw new IllegalArgumentException(
"mapper ["
+ name()
+ "] refers to different model ids ["
+ other.modelSettings.inferenceId()
+ "] and ["
+ other.fieldType().getInferenceId()
+ "]"
);
}
}
super.checkIncomingMergeType(mergeWith);
}

static boolean canMergeModelSettings(
SemanticTextModelSettings previous,
SemanticTextModelSettings current,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,40 @@ public class SemanticTextModelSettings implements ToXContentObject {

public static final String NAME = "model_settings";
public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type");
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
public static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
public static final ParseField SIMILARITY_FIELD = new ParseField("similarity");
private final TaskType taskType;
private final String inferenceId;
private final Integer dimensions;
private final SimilarityMeasure similarity;

public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) {
public SemanticTextModelSettings(Model model) {
this(model.getTaskType(), model.getServiceSettings().dimensions(), model.getServiceSettings().similarity());
}

public SemanticTextModelSettings(TaskType taskType, Integer dimensions, SimilarityMeasure similarity) {
Objects.requireNonNull(taskType, "task type must not be null");
Objects.requireNonNull(inferenceId, "inferenceId must not be null");
this.taskType = taskType;
this.inferenceId = inferenceId;
this.dimensions = dimensions;
this.similarity = similarity;
}

public SemanticTextModelSettings(Model model) {
this(
model.getTaskType(),
model.getInferenceEntityId(),
model.getServiceSettings().dimensions(),
model.getServiceSettings().similarity()
);
validate();
}

public static SemanticTextModelSettings parse(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}

private static final ConstructingObjectParser<SemanticTextModelSettings, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
TaskType taskType = TaskType.fromString((String) args[0]);
String inferenceId = (String) args[1];
Integer dimensions = (Integer) args[2];
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]);
return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity);
});
private static final ConstructingObjectParser<SemanticTextModelSettings, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
args -> {
TaskType taskType = TaskType.fromString((String) args[0]);
Integer dimensions = (Integer) args[1];
SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
return new SemanticTextModelSettings(taskType, dimensions, similarity);
}
);
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), INFERENCE_ID_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), DIMENSIONS_FIELD);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), SIMILARITY_FIELD);
}
Expand All @@ -88,11 +81,6 @@ public static SemanticTextModelSettings fromMap(Object node) {
}
try {
Map<String, Object> map = XContentMapValues.nodeMapValue(node, NAME);
if (map.containsKey(INFERENCE_ID_FIELD.getPreferredName()) == false) {
throw new IllegalArgumentException(
"Failed to parse [" + NAME + "], required [" + INFERENCE_ID_FIELD.getPreferredName() + "] is missing"
);
}
if (map.containsKey(TASK_TYPE_FIELD.getPreferredName()) == false) {
throw new IllegalArgumentException(
"Failed to parse [" + NAME + "], required [" + TASK_TYPE_FIELD.getPreferredName() + "] is missing"
Expand All @@ -113,7 +101,6 @@ public static SemanticTextModelSettings fromMap(Object node) {
public Map<String, Object> asMap() {
Map<String, Object> attrsMap = new HashMap<>();
attrsMap.put(TASK_TYPE_FIELD.getPreferredName(), taskType.toString());
attrsMap.put(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
if (dimensions != null) {
attrsMap.put(DIMENSIONS_FIELD.getPreferredName(), dimensions);
}
Expand All @@ -127,10 +114,6 @@ public TaskType taskType() {
return taskType;
}

public String inferenceId() {
return inferenceId;
}

public Integer dimensions() {
return dimensions;
}
Expand All @@ -143,7 +126,6 @@ public SimilarityMeasure similarity() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TASK_TYPE_FIELD.getPreferredName(), taskType.toString());
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
if (dimensions != null) {
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
}
Expand All @@ -156,12 +138,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public void validate() {
switch (taskType) {
case TEXT_EMBEDDING:
if (dimensions == null) {
throw new IllegalArgumentException(
"required [" + DIMENSIONS_FIELD + "] field is missing for task_type [" + taskType.name() + "]"
);
}
if (similarity == null) {
throw new IllegalArgumentException(
"required [" + SIMILARITY_FIELD + "] field is missing for task_type [" + taskType.name() + "]"
);
}
break;
case SPARSE_EMBEDDING:
break;

default:
throw new IllegalArgumentException("Wrong [" + TASK_TYPE_FIELD.getPreferredName() + "], expected " +
TEXT_EMBEDDING + "or " + SPARSE_EMBEDDING + ", got " + taskType.name());
throw new IllegalArgumentException(
"Wrong ["
+ TASK_TYPE_FIELD.getPreferredName()
+ "], expected "
+ TEXT_EMBEDDING
+ "or "
+ SPARSE_EMBEDDING
+ ", got "
+ taskType.name()
);
}
}

Expand All @@ -170,14 +171,11 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SemanticTextModelSettings that = (SemanticTextModelSettings) o;
return taskType == that.taskType
&& inferenceId.equals(that.inferenceId)
&& Objects.equals(dimensions, that.dimensions)
&& similarity == that.similarity;
return taskType == that.taskType && Objects.equals(dimensions, that.dimensions) && similarity == that.similarity;
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceId, dimensions, similarity);
return Objects.hash(taskType, dimensions, similarity);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ public void testUpdateModelSettings() throws IOException {
.field("type", "semantic_text")
.field("inference_id", "test_model")
.startObject("model_settings")
.field("inference_id", "test_model")
.field("task_type", "sparse_embedding")
.endObject()
.endObject()
Expand All @@ -186,10 +185,7 @@ public void testUpdateModelSettings() throws IOException {
);
assertThat(
exc.getMessage(),
containsString(
"Cannot update parameter [model_settings] "
+ "from [{\"task_type\":\"sparse_embedding\",\"inference_id\":\"test_model\"}] to [null]"
)
containsString("Cannot update parameter [model_settings] " + "from [{\"task_type\":\"sparse_embedding\"}] to [null]")
);
}
{
Expand All @@ -202,9 +198,9 @@ public void testUpdateModelSettings() throws IOException {
.field("type", "semantic_text")
.field("inference_id", "test_model")
.startObject("model_settings")
.field("inference_id", "test_model")
.field("task_type", "text_embedding")
.field("dimensions", 10)
.field("similarity", "cosine")
.endObject()
.endObject()
)
Expand All @@ -214,8 +210,8 @@ public void testUpdateModelSettings() throws IOException {
exc.getMessage(),
containsString(
"Cannot update parameter [model_settings] "
+ "from [{\"task_type\":\"sparse_embedding\",\"inference_id\":\"test_model\"}] "
+ "to [{\"task_type\":\"text_embedding\",\"inference_id\":\"test_model\",\"dimensions\":10}]"
+ "from [{\"task_type\":\"sparse_embedding\"}] "
+ "to [{\"task_type\":\"text_embedding\",\"dimensions\":10,\"similarity\":\"cosine\"}]"
)
);
}
Expand Down
Loading

0 comments on commit 2be50d7

Please sign in to comment.