Skip to content

Commit

Permalink
This was supposed to be merged into elastic#105515 but didn't make it
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Mar 6, 2024
1 parent b1a3ee8 commit 2039fb3
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.SemanticTextModelSettings;

import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -270,7 +270,7 @@ public void onResponse(InferenceServiceResults results) {
for (InferenceResults inferenceResults : results.transformToCoordinationFormat()) {
String inferenceFieldName = inferenceFieldNames.get(i++);
Map<String, Object> inferenceFieldResult = new LinkedHashMap<>();
inferenceFieldResult.putAll(new ModelSettings(inferenceProvider.model).asMap());
inferenceFieldResult.putAll(new SemanticTextModelSettings(inferenceProvider.model).asMap());
inferenceFieldResult.put(
INFERENCE_RESULTS,
List.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Serialization class for specifying the settings of a model from semantic_text inference to field mapper.
* See {@link org.elasticsearch.action.bulk.BulkShardRequestInferenceProvider}
*/
public class ModelSettings {
public class SemanticTextModelSettings {

public static final String NAME = "model_settings";
public static final ParseField TASK_TYPE_FIELD = new ParseField("task_type");
Expand All @@ -33,7 +33,7 @@ public class ModelSettings {
private final Integer dimensions;
private final SimilarityMeasure similarity;

public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) {
public SemanticTextModelSettings(TaskType taskType, String inferenceId, Integer dimensions, SimilarityMeasure similarity) {
Objects.requireNonNull(taskType, "task type must not be null");
Objects.requireNonNull(inferenceId, "inferenceId must not be null");
this.taskType = taskType;
Expand All @@ -42,7 +42,7 @@ public ModelSettings(TaskType taskType, String inferenceId, Integer dimensions,
this.similarity = similarity;
}

public ModelSettings(Model model) {
public SemanticTextModelSettings(Model model) {
this(
model.getTaskType(),
model.getInferenceEntityId(),
Expand All @@ -51,16 +51,16 @@ public ModelSettings(Model model) {
);
}

public static ModelSettings parse(XContentParser parser) throws IOException {
public static SemanticTextModelSettings parse(XContentParser parser) throws IOException {
return PARSER.apply(parser, null);
}

private static final ConstructingObjectParser<ModelSettings, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
private static final ConstructingObjectParser<SemanticTextModelSettings, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
TaskType taskType = TaskType.fromString((String) args[0]);
String inferenceId = (String) args[1];
Integer dimensions = (Integer) args[2];
SimilarityMeasure similarity = args[3] == null ? null : SimilarityMeasure.fromString((String) args[3]);
return new ModelSettings(taskType, inferenceId, dimensions, similarity);
return new SemanticTextModelSettings(taskType, inferenceId, dimensions, similarity);
});
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_TYPE_FIELD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelRegistry;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.SemanticTextModelSettings;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
Expand Down Expand Up @@ -396,10 +396,10 @@ private static void checkInferenceResults(
Map<String, Object> inferenceService1FieldResults = (Map<String, Object>) inferenceRootResultField.get(inferenceFieldName);
assertNotNull(inferenceService1FieldResults);
assertThat(inferenceService1FieldResults.size(), equalTo(2));
Map<String, Object> modelSettings = (Map<String, Object>) inferenceService1FieldResults.get(ModelSettings.NAME);
Map<String, Object> modelSettings = (Map<String, Object>) inferenceService1FieldResults.get(SemanticTextModelSettings.NAME);
assertNotNull(modelSettings);
assertNotNull(modelSettings.get(ModelSettings.TASK_TYPE_FIELD.getPreferredName()));
assertNotNull(modelSettings.get(ModelSettings.INFERENCE_ID_FIELD.getPreferredName()));
assertNotNull(modelSettings.get(SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName()));
assertNotNull(modelSettings.get(SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName()));

List<Map<String, Object>> inferenceResultElement = (List<Map<String, Object>>) inferenceService1FieldResults.get(
INFERENCE_RESULTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.SemanticTextModelSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.logging.LogManager;
Expand Down Expand Up @@ -168,7 +168,7 @@ private static void parseSingleField(DocumentParserContext context, MapperBuilde
parser.nextToken();
failIfTokenIsNot(parser, XContentParser.Token.START_OBJECT);
parser.nextToken();
ModelSettings modelSettings = ModelSettings.parse(parser);
SemanticTextModelSettings modelSettings = SemanticTextModelSettings.parse(parser);
for (XContentParser.Token token = parser.nextToken(); token != XContentParser.Token.END_OBJECT; token = parser.nextToken()) {
failIfTokenIsNot(parser, XContentParser.Token.FIELD_NAME);

Expand All @@ -192,7 +192,7 @@ private static void parseFieldInferenceChunks(
DocumentParserContext context,
MapperBuilderContext mapperBuilderContext,
String fieldName,
ModelSettings modelSettings,
SemanticTextModelSettings modelSettings,
NestedObjectMapper nestedObjectMapper
) throws IOException {
XContentParser parser = context.parser();
Expand All @@ -209,7 +209,7 @@ private static void parseFieldInferenceChunks(
private static void parseFieldInferenceChunkElement(
DocumentParserContext context,
ObjectMapper objectMapper,
ModelSettings modelSettings
SemanticTextModelSettings modelSettings
) throws IOException {
XContentParser parser = context.parser();
DocumentParserContext childContext = context.createChildContext(objectMapper);
Expand Down Expand Up @@ -254,7 +254,7 @@ private static NestedObjectMapper createInferenceResultsObjectMapper(
DocumentParserContext context,
MapperBuilderContext mapperBuilderContext,
String fieldName,
ModelSettings modelSettings
SemanticTextModelSettings modelSettings
) {
IndexVersion indexVersionCreated = context.indexSettings().getIndexVersionCreated();
FieldMapper.Builder resultsBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.inference.ModelSettings;
import org.elasticsearch.inference.SemanticTextModelSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.LeafNestedDocuments;
Expand Down Expand Up @@ -417,7 +417,7 @@ private static void addSemanticTextInferenceResults(
Map<String, Map<String, Object>> inferenceResultsMap = new HashMap<>();
for (SemanticTextInferenceResults semanticTextInferenceResult : semanticTextInferenceResults) {
Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put(ModelSettings.NAME, modelSettingsMap());
fieldMap.put(SemanticTextModelSettings.NAME, modelSettingsMap());
List<Map<String, Object>> parsedInferenceResults = new ArrayList<>(semanticTextInferenceResult.text().size());

Iterator<SparseEmbeddingResults.Embedding> embeddingsIterator = semanticTextInferenceResult.sparseEmbeddingResults()
Expand Down Expand Up @@ -451,9 +451,9 @@ private static void addSemanticTextInferenceResults(

private static Map<String, Object> modelSettingsMap() {
return Map.of(
ModelSettings.TASK_TYPE_FIELD.getPreferredName(),
SemanticTextModelSettings.TASK_TYPE_FIELD.getPreferredName(),
TaskType.SPARSE_EMBEDDING.toString(),
ModelSettings.INFERENCE_ID_FIELD.getPreferredName(),
SemanticTextModelSettings.INFERENCE_ID_FIELD.getPreferredName(),
randomAlphaOfLength(8)
);
}
Expand Down

0 comments on commit 2039fb3

Please sign in to comment.