Skip to content

Commit

Permalink
TransportBulkAction uses the inference service directly
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Jan 18, 2024
1 parent 8389572 commit 5ae163f
Show file tree
Hide file tree
Showing 30 changed files with 381 additions and 376 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.indices.SystemIndices;
import org.elasticsearch.inference.InferenceProvider;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.SimulateIngestService;
import org.elasticsearch.tasks.Task;
Expand Down Expand Up @@ -57,7 +56,8 @@ public TransportSimulateBulkAction(
indexingPressure,
systemIndices,
System::nanoTime,
new InferenceProvider.NoopInferenceProvider()
null,
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.index.mapper.DocumentMapper;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MappingLookup;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue;

Expand All @@ -44,15 +42,10 @@ public class MappingMetadata implements SimpleDiffable<MappingMetadata> {

private final boolean routingRequired;

private final Map<String, Set<String>> fieldsForModels;

public MappingMetadata(DocumentMapper docMapper) {
this.type = docMapper.type();
this.source = docMapper.mappingSource();
this.routingRequired = docMapper.routingFieldMapper().required();

MappingLookup mappingLookup = docMapper.mappers();
this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : Map.of();
}

@SuppressWarnings({ "this-escape", "unchecked" })
Expand All @@ -64,7 +57,6 @@ public MappingMetadata(CompressedXContent mapping) {
}
this.type = mappingMap.keySet().iterator().next();
this.routingRequired = routingRequired((Map<String, Object>) mappingMap.get(this.type));
this.fieldsForModels = Map.of();
}

@SuppressWarnings({ "this-escape", "unchecked" })
Expand All @@ -80,7 +72,6 @@ public MappingMetadata(String type, Map<String, Object> mapping) {
withoutType = (Map<String, Object>) mapping.get(type);
}
this.routingRequired = routingRequired(withoutType);
this.fieldsForModels = Map.of();
}

public static void writeMappingMetadata(StreamOutput out, Map<String, MappingMetadata> mappings) throws IOException {
Expand Down Expand Up @@ -167,19 +158,12 @@ public String getSha256() {
return source.getSha256();
}

public Map<String, Set<String>> getFieldsForModels() {
return fieldsForModels;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(type());
source().writeTo(out);
// routing
out.writeBoolean(routingRequired);
if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) {
out.writeMap(fieldsForModels, StreamOutput::writeStringCollection);
}
}

@Override
Expand All @@ -192,25 +176,19 @@ public boolean equals(Object o) {
if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false;
if (source.equals(that.source) == false) return false;
if (type.equals(that.type) == false) return false;
if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false;

return true;
}

@Override
public int hashCode() {
return Objects.hash(type, source, routingRequired, fieldsForModels);
return Objects.hash(type, source, routingRequired);
}

public MappingMetadata(StreamInput in) throws IOException {
type = in.readString();
source = CompressedXContent.readCompressedString(in);
routingRequired = in.readBoolean();
if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) {
fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString));
} else {
fieldsForModels = Map.of();
}
}

public static Diff<MappingMetadata> readDiffFrom(StreamInput in) throws IOException {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.inference;

import org.elasticsearch.action.ActionListener;

import java.util.List;
import java.util.Map;

public abstract class ModelRegistry {
public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {}

/**
* Semi parsed model where model id, task type and service
* are known but the settings are not parsed.
*/
public record UnparsedModel(
String modelId,
TaskType taskType,
String service,
Map<String, Object> settings,
Map<String, Object> secrets
) {}

public abstract void getModelWithSecrets(String modelId, ActionListener<UnparsedModel> listener);

public abstract void getModel(String modelId, ActionListener<UnparsedModel> listener);

public abstract void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener);

public abstract void storeModel(Model model, ActionListener<Boolean> listener);

public abstract void deleteModel(String modelId, ActionListener<Boolean> listener);

public abstract void getAllModels(ActionListener<List<UnparsedModel>> listener);
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;
package org.elasticsearch.inference;

import java.util.Locale;

Expand Down
12 changes: 0 additions & 12 deletions server/src/main/java/org/elasticsearch/node/NodeConstruction.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@
import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService;
import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService;
import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService;
import org.elasticsearch.inference.InferenceProvider;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.monitor.MonitorService;
import org.elasticsearch.monitor.fs.FsHealthService;
Expand All @@ -142,7 +141,6 @@
import org.elasticsearch.plugins.ClusterPlugin;
import org.elasticsearch.plugins.DiscoveryPlugin;
import org.elasticsearch.plugins.HealthPlugin;
import org.elasticsearch.plugins.InferenceProviderPlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.MetadataUpgrader;
Expand Down Expand Up @@ -1082,16 +1080,6 @@ record PluginServiceInstances(
);
}

InferenceProvider inferenceProvider = null;
Optional<InferenceProviderPlugin> inferenceProviderPlugin = getSinglePlugin(InferenceProviderPlugin.class);
if (inferenceProviderPlugin.isPresent()) {
inferenceProvider = inferenceProviderPlugin.get().getInferenceProvider();
} else {
logger.warn("No inference provider found. Inference for semantic_text field types won't be available");
inferenceProvider = new InferenceProvider.NoopInferenceProvider();
}
modules.bindToInstance(InferenceProvider.class, inferenceProvider);

injector = modules.createInjector();

postInjection(clusterModule, actionModule, clusterService, transportService, featureService);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.VersionType;
import org.elasticsearch.indices.EmptySystemIndices;
import org.elasticsearch.inference.InferenceProvider;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockUtils;
Expand Down Expand Up @@ -125,7 +124,8 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state)
indexNameExpressionResolver,
new IndexingPressure(Settings.EMPTY),
EmptySystemIndices.INSTANCE,
new InferenceProvider.NoopInferenceProvider()
null,
null
) {
@Override
void executeBulk(
Expand Down
Loading

0 comments on commit 5ae163f

Please sign in to comment.