Skip to content

Commit

Permalink
Add diff support for model for fields, changed implementation to Set<…
Browse files Browse the repository at this point in the history
…String>
  • Loading branch information
carlosdelest committed Nov 21, 2023
1 parent 38dcb93 commit ab97838
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ public Iterator<Setting<?>> settings() {
private final Double writeLoadForecast;
@Nullable
private final Long shardSizeInBytesForecast;
private final Map<String, List<String>> inferenceModelsForFields;
private final Map<String, Set<String>> inferenceModelsForFields;

private IndexMetadata(
final Index index,
Expand Down Expand Up @@ -685,7 +685,7 @@ private IndexMetadata(
@Nullable final IndexMetadataStats stats,
@Nullable final Double writeLoadForecast,
@Nullable Long shardSizeInBytesForecast,
final Map<String, List<String>> inferenceModelsForFields
final Map<String, Set<String>> inferenceModelsForFields
) {
this.index = index;
this.version = version;
Expand Down Expand Up @@ -1224,7 +1224,7 @@ public OptionalLong getForecastedShardSizeInBytes() {
return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast);
}

public Map<String, List<String>> getInferenceModelsForFields() {
public Map<String, Set<String>> getInferenceModelsForFields() {
return inferenceModelsForFields;
}

Expand Down Expand Up @@ -1492,6 +1492,7 @@ private static class IndexMetadataDiff implements Diff<IndexMetadata> {
private final IndexMetadataStats stats;
private final Double indexWriteLoadForecast;
private final Long shardSizeInBytesForecast;
private final Diff<Map<String, Set<String>>> modelsForFields;

IndexMetadataDiff(IndexMetadata before, IndexMetadata after) {
index = after.index.getName();
Expand Down Expand Up @@ -1528,6 +1529,12 @@ private static class IndexMetadataDiff implements Diff<IndexMetadata> {
stats = after.stats;
indexWriteLoadForecast = after.writeLoadForecast;
shardSizeInBytesForecast = after.shardSizeInBytesForecast;
modelsForFields = DiffableUtils.diff(
before.inferenceModelsForFields,
after.inferenceModelsForFields,
DiffableUtils.getStringKeySerializer(),
DiffableUtils.StringSetValueSerializer.getInstance()
);
}

private static final DiffableUtils.DiffableValueReader<String, AliasMetadata> ALIAS_METADATA_DIFF_VALUE_READER =
Expand Down Expand Up @@ -1587,6 +1594,15 @@ private static class IndexMetadataDiff implements Diff<IndexMetadata> {
indexWriteLoadForecast = null;
shardSizeInBytesForecast = null;
}
if (in.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD)) {
modelsForFields = DiffableUtils.readJdkMapDiff(
in,
DiffableUtils.getStringKeySerializer(),
DiffableUtils.StringSetValueSerializer.getInstance()
);
} else {
modelsForFields = DiffableUtils.emptyDiff();
}
}

@Override
Expand Down Expand Up @@ -1622,6 +1638,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(indexWriteLoadForecast);
out.writeOptionalLong(shardSizeInBytesForecast);
}
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD)) {
modelsForFields.writeTo(out);
}
}

@Override
Expand Down Expand Up @@ -1651,6 +1670,7 @@ public IndexMetadata apply(IndexMetadata part) {
builder.stats(stats);
builder.indexWriteLoadForecast(indexWriteLoadForecast);
builder.shardSizeInBytesForecast(shardSizeInBytesForecast);
builder.inferenceModelsForFields(modelsForFields.apply(part.inferenceModelsForFields));
return builder.build(true);
}
}
Expand Down Expand Up @@ -1719,7 +1739,9 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function<String,
builder.shardSizeInBytesForecast(in.readOptionalLong());
}
if (in.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD)) {
builder.inferenceModelsForfields(in.readImmutableMap(StreamInput::readStringCollectionAsImmutableList));
builder.inferenceModelsForfields(
in.readImmutableMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString))
);
}
return builder.build(true);
}
Expand Down Expand Up @@ -1819,7 +1841,8 @@ public static class Builder {
private IndexMetadataStats stats = null;
private Double indexWriteLoadForecast = null;
private Long shardSizeInBytesForecast = null;
private Map<String, List<String>> inferenceModelsForFields = Map.of();

private Map<String, Set<String>> inferenceModelsForFields = Map.of();

public Builder(String index) {
this.index = index;
Expand Down Expand Up @@ -1933,7 +1956,7 @@ public Builder putMapping(String source) {

public Builder putMapping(MappingMetadata mappingMd) {
mapping = mappingMd;
Map<String, List<String>> fieldsForModels = mappingMd.getFieldsForModels();
Map<String, Set<String>> fieldsForModels = mappingMd.getFieldsForModels();
if (fieldsForModels != null) {
inferenceModelsForFields = fieldsForModels;
}
Expand Down Expand Up @@ -2085,11 +2108,16 @@ public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) {
return this;
}

public Builder inferenceModelsForfields(Map<String, List<String>> inferenceModelsForfields) {
public Builder inferenceModelsForfields(Map<String, Set<String>> inferenceModelsForfields) {
this.inferenceModelsForFields = inferenceModelsForfields;
return this;
}

public Builder inferenceModelsForFields(Map<String, Set<String>> inferenceModelsForFields) {
this.inferenceModelsForFields = inferenceModelsForFields;
return this;
}

public IndexMetadata build() {
return build(false);
}
Expand Down Expand Up @@ -2411,7 +2439,7 @@ public static void toXContent(IndexMetadata indexMetadata, XContentBuilder build
builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast);
}

Map<String, List<String>> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields();
Map<String, Set<String>> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields();
if ((inferenceModelsForFields != null) && (inferenceModelsForFields.isEmpty() == false)) {
builder.field(INFERENCE_MODELS_FIELDS, indexMetadata.getInferenceModelsForFields());
}
Expand Down Expand Up @@ -2494,10 +2522,15 @@ public static IndexMetadata fromXContent(XContentParser parser, Map<String, Mapp
builder.stats(IndexMetadataStats.fromXContent(parser));
break;
case INFERENCE_MODELS_FIELDS:
Map<String, List<String>> inferenceModels = parser.map(HashMap::new, XContentParser::list)
Map<String, Set<String>> inferenceModels = parser.map(HashMap::new, XContentParser::list)
.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().stream().map(Object::toString).toList()));
.collect(
Collectors.toMap(
Map.Entry::getKey,
e -> e.getValue().stream().map(Object::toString).collect(Collectors.toSet())
)
);
builder.inferenceModelsForfields(inferenceModels);
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue;

Expand All @@ -43,7 +43,7 @@ public class MappingMetadata implements SimpleDiffable<MappingMetadata> {

private final boolean routingRequired;

private final Map<String, List<String>> fieldsForModels;
private final Map<String, Set<String>> fieldsForModels;

public MappingMetadata(DocumentMapper docMapper) {
this.type = docMapper.type();
Expand Down Expand Up @@ -127,7 +127,7 @@ public CompressedXContent source() {
return this.source;
}

public Map<String, List<String>> getFieldsForModels() {
public Map<String, Set<String>> getFieldsForModels() {
return fieldsForModels;
}

Expand Down Expand Up @@ -205,7 +205,7 @@ public MappingMetadata(StreamInput in) throws IOException {
source = CompressedXContent.readCompressedString(in);
routingRequired = in.readBoolean();
if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD)) {
fieldsForModels = in.readMapOfLists(StreamInput::readString);
fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString));
} else {
fieldsForModels = Map.of();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,15 +523,6 @@ public Iterator<Setting<?>> settings() {
Property.ServerlessPublic
);

public static final Setting<String> INFERENCE_PIPELINE = new Setting<>(
"index.inference_pipeline",
IngestService.NOOP_PIPELINE_NAME,
Function.identity(),
Property.PrivateIndex,
Property.IndexScope,
Property.ServerlessPublic
);

/**
* Marks an index to be searched throttled. This means that never more than one shard of such an index will be searched concurrently
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

import org.elasticsearch.common.regex.Regex;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
Expand All @@ -38,7 +36,7 @@ final class FieldTypeLookup {
*/
private final Map<String, Set<String>> fieldToCopiedFields;

private final Map<String, List<String>> fieldsForModel;
private final Map<String, Set<String>> fieldsForModel;

private final int maxParentPathDots;

Expand All @@ -52,7 +50,7 @@ final class FieldTypeLookup {
final Map<String, String> fullSubfieldNameToParentPath = new HashMap<>();
final Map<String, DynamicFieldType> dynamicFieldTypes = new HashMap<>();
final Map<String, Set<String>> fieldToCopiedFields = new HashMap<>();
final Map<String, List<String>> fieldsForModel = new HashMap<>();
final Map<String, Set<String>> fieldsForModel = new HashMap<>();
for (FieldMapper fieldMapper : fieldMappers) {
String fieldName = fieldMapper.name();
MappedFieldType fieldType = fieldMapper.fieldType();
Expand All @@ -71,7 +69,7 @@ final class FieldTypeLookup {
fieldToCopiedFields.get(targetField).add(fieldName);
}
if (fieldType.hasInferenceModel()) {
Collection<String> fields = fieldsForModel.computeIfAbsent(fieldType.getInferenceModel(), v -> new ArrayList<>());
Collection<String> fields = fieldsForModel.computeIfAbsent(fieldType.getInferenceModel(), v -> new HashSet<>());
fields.add(fieldName);
}
}
Expand Down Expand Up @@ -119,11 +117,11 @@ public static int dotCount(String path) {
return dotCount;
}

List<String> fieldsForModel(String modelName) {
Set<String> fieldsForModel(String modelName) {
return this.fieldsForModel.get(modelName);
}

Map<String, List<String>> fieldsForModel() {
Map<String, Set<String>> fieldsForModel() {
return this.fieldsForModel;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,7 @@ public void validateDoesNotShadow(String name) {
}
}

public List<String> fieldsForModel(String modelName) {
return fieldTypeLookup.fieldsForModel(modelName);
}

public Map<String, List<String>> fieldsForModels() {
public Map<String, Set<String>> fieldsForModels() {
return fieldTypeLookup.fieldsForModel();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ private static Optional<Pipelines> resolvePipelinesFromIndexTemplates(IndexReque
defaultPipeline = Objects.requireNonNullElse(defaultPipeline, NOOP_PIPELINE_NAME);
finalPipeline = Objects.requireNonNullElse(finalPipeline, NOOP_PIPELINE_NAME);

return Optional.of(new Pipelines(defaultPipeline, finalPipeline, null));
return Optional.of(new Pipelines(defaultPipeline, finalPipeline, NOOP_PIPELINE_NAME));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
Expand Down Expand Up @@ -2287,13 +2288,13 @@ public Optional<Pipeline> getIngestPipeline(IndexMetadata indexMetadata, Process
return Optional.empty();
}

Map<String, List<String>> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields();
Map<String, Set<String>> inferenceModelsForFields = indexMetadata.getInferenceModelsForFields();
if (inferenceModelsForFields.isEmpty()) {
return Optional.empty();
}

Collection<Processor> inferenceProcessors = new ArrayList<>();
for (Map.Entry<String, List<String>> modelsForFieldsEntry : inferenceModelsForFields.entrySet()) {
for (Map.Entry<String, Set<String>> modelsForFieldsEntry : inferenceModelsForFields.entrySet()) {
Map<String, Object> inferenceConfig = new HashMap<>();
String modelId = modelsForFieldsEntry.getKey();
inferenceConfig.put("model_id", modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;

public class SemanticTextInferenceProcessor extends AbstractProcessor implements WrappingProcessor {

public static final String TYPE = "semanticTextInference";
public static final String TAG = "semantic_text";

private final Map<String, List<String>> fieldsForModels;
private final Map<String, Set<String>> fieldsForModels;

private final Processor wrappedProcessor;

Expand All @@ -36,7 +37,7 @@ public SemanticTextInferenceProcessor(
Client client,
InferenceAuditor inferenceAuditor,
String description,
Map<String, List<String>> fieldsForModels
Map<String, Set<String>> fieldsForModels
) {
super(TAG, description);
this.client = client;
Expand All @@ -54,7 +55,7 @@ private Processor createWrappedProcessor() {
return new CompoundProcessor(inferenceProcessors);
}

private InferenceProcessor createInferenceProcessor(String modelId, List<String> fields) {
private InferenceProcessor createInferenceProcessor(String modelId, Set<String> fields) {
List<InferenceProcessor.Factory.InputConfig> inputConfigs = fields.stream()
.map(f -> new InferenceProcessor.Factory.InputConfig(f, "ml.inference", f, Map.of()))
.toList();
Expand Down

0 comments on commit ab97838

Please sign in to comment.