diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index a9c79708532df..66e6fe0cb6f55 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -54,6 +54,7 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Assertions; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; @@ -64,8 +65,12 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -86,6 +91,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicIntegerArray; +import java.util.function.Consumer; import java.util.function.LongSupplier; import java.util.stream.Collectors; @@ -112,7 +118,11 @@ public class TransportBulkAction extends HandledTransportAction> requ bulkRequest = null; }; - try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { - for (Map.Entry> entry : requestsByShard.entrySet()) { - final ShardId shardId = entry.getKey(); - final List requests = entry.getValue(); - - BulkShardRequest bulkShardRequest = new BulkShardRequest( - shardId, - bulkRequest.getRefreshPolicy(), - requests.toArray(new BulkItemRequest[0]) - ); - bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); - bulkShardRequest.timeout(bulkRequest.timeout()); - bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); - if (task != null) { - bulkShardRequest.setParentTask(nodeId, task.getId()); + Consumer> nextAction = inferenceProviderMap -> { + try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) { + for (Map.Entry> entry : requestsByShard.entrySet()) { + final ShardId shardId = entry.getKey(); + final List requests = entry.getValue(); + + BulkShardRequest bulkShardRequest = new BulkShardRequest( + shardId, + bulkRequest.getRefreshPolicy(), + requests.toArray(new BulkItemRequest[0]) + ); + bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); + bulkShardRequest.timeout(bulkRequest.timeout()); + bulkShardRequest.routedBasedOnClusterVersion(clusterState.version()); + if (task != null) { + bulkShardRequest.setParentTask(nodeId, task.getId()); + } + performInferenceAndExecute( + bulkShardRequest, + clusterState, + bulkItemRequestCompleteRefCount.acquire(), + inferenceProviderMap + ); } - - performInferenceAndExecute(bulkShardRequest, clusterState, bulkItemRequestCompleteRefCount.acquire()); } - } + }; + getInferenceProviders(clusterState, requestsByShard.keySet(), nextAction); } - private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, ClusterState clusterState, Releasable releaseOnFinish) { + private void performInferenceAndExecute( + BulkShardRequest bulkShardRequest, + ClusterState clusterState, + Releasable releaseOnFinish, + Map inferenceProviderMap + ) { Map> fieldsForModels = clusterState.metadata() .index(bulkShardRequest.shardId().getIndex()) @@ -783,7 +810,13 @@ private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, Clust try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) { for (BulkItemRequest request : bulkShardRequest.items()) { - performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef.acquire()); + performInferenceOnBulkItemRequest( + bulkShardRequest, + request, + inferenceProviderMap, + fieldsForModels, + bulkItemReqRef.acquire() + ); } } } @@ -791,10 +824,11 @@ private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, Clust private void performInferenceOnBulkItemRequest( BulkShardRequest bulkShardRequest, BulkItemRequest request, + Map inferenceProviderMap, Map> fieldsForModels, Releasable releaseOnFinish ) { - if (inferenceProvider.performsInference() == false) { + if (inferenceServiceRegistry == null) { releaseOnFinish.close(); return; } @@ -836,65 +870,103 @@ private void performInferenceOnBulkItemRequest( ); List inferenceFieldNames = getFieldNamesForInference(fieldModelsEntrySet, docMap); - if (inferenceFieldNames.isEmpty()) { continue; } docRef.acquire(); - - inferenceProvider.textInference( - modelId, - inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), - new ActionListener<>() { - - @Override - public void onResponse(List results) { - - if (results == null) { - throw new IllegalArgumentException( - "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() - ); - } - - int i = 0; - for (InferenceResults inferenceResults : results) { - String fieldName = inferenceFieldNames.get(i++); - @SuppressWarnings("unchecked") - Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( - fieldName, - k -> new HashMap() - ); - - inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); - inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); - } - - docRef.close(); - } - - @Override - public void onFailure(Exception e) { - - final String indexName = request.index(); - DocWriteRequest docWriteRequest = request.request(); - BulkItemResponse.Failure failure = new BulkItemResponse.Failure( - indexName, - docWriteRequest.id(), - new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) + var inferenceProvider = inferenceProviderMap.get(modelId); + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(InferenceServiceResults inferenceServiceResults) { + List results = inferenceServiceResults.transformToCoordinationFormat(); + int i = 0; + for (InferenceResults inferenceResults : results) { + String fieldName = inferenceFieldNames.get(i++); + @SuppressWarnings("unchecked") + Map inferenceFieldMap = (Map) rootInferenceFieldMap.computeIfAbsent( + fieldName, + k -> new HashMap() ); - responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); - // make sure the request gets never processed again - bulkShardRequest.items()[request.id()] = null; - docRef.close(); + inferenceFieldMap.put(INFERENCE_FIELD, inferenceResults.asMap("output").get("output")); + inferenceFieldMap.put(TEXT_FIELD, docMap.get(fieldName)); } } + + @Override + public void onFailure(Exception e) { + final String indexName = request.index(); + DocWriteRequest docWriteRequest = request.request(); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure( + indexName, + docWriteRequest.id(), + new IllegalArgumentException("Error performing inference: " + e.getMessage(), e) + ); + responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure)); + // make sure the request gets never processed again + bulkShardRequest.items()[request.id()] = null; + } + }; + actionListener = ActionListener.releaseAfter(actionListener, docRef); + if (inferenceProvider == null) { + actionListener.onFailure( + new IllegalArgumentException( + "No inference retrieved for model ID " + modelId + " in document " + docWriteRequest.id() + ) + ); + return; + } + inferenceProvider.service.infer( + inferenceProvider.model, + inferenceFieldNames.stream().map(docMap::get).map(String::valueOf).collect(Collectors.toList()), + Map.of(), + actionListener ); } } } + private static class InferenceProvider { + private final Model model; + private final InferenceService service; + + private InferenceProvider(Model model, InferenceService service) { + this.model = model; + this.service = service; + } + } + + private void getInferenceProviders( + ClusterState clusterState, + Set shardIds, + Consumer> action + ) { + Set serviceIds = new HashSet<>(); + shardIds.stream().map(ShardId::getIndex).collect(Collectors.toSet()).stream().forEach(index -> { + var fieldsForModels = clusterState.metadata().index(index).getFieldsForModels(); + serviceIds.addAll(fieldsForModels.keySet()); + }); + final Map inferenceProviderMap = new ConcurrentHashMap<>(); + Runnable onModelLoadingComplete = () -> action.accept(new HashMap<>(inferenceProviderMap)); + try (var refs = new RefCountingRunnable(onModelLoadingComplete)) { + for (var serviceId : serviceIds) { + var serviceOpt = inferenceServiceRegistry.getService(serviceId); + if (serviceOpt.isPresent() == false) { + // We do nothing if the service is not present and let the individual requests fail + // when executing the inference + continue; + } + final var service = serviceOpt.get(); + ActionListener listener = ActionListener.wrap(unparsedModel -> { + var model = service.parsePersistedConfig(serviceId, unparsedModel.taskType(), unparsedModel.settings()); + inferenceProviderMap.put(serviceId, new InferenceProvider(model, service)); + }, e -> {}); + modelRegistry.getModel(serviceId, ActionListener.releaseAfter(listener, refs)); + } + } + } + private static List getFieldNamesForInference( Map.Entry> fieldModelsEntrySet, Map docMap diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java index 868d3babd3edc..df22d44e5a41a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportSimulateBulkAction.java @@ -21,7 +21,6 @@ import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.SimulateIngestService; import org.elasticsearch.tasks.Task; @@ -57,7 +56,8 @@ public TransportSimulateBulkAction( indexingPressure, systemIndices, System::nanoTime, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java index 64a61f854b9da..b629ab5d5f710 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MappingMetadata.java @@ -18,13 +18,11 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.MapperService; -import org.elasticsearch.index.mapper.MappingLookup; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; import java.util.Objects; -import java.util.Set; import static org.elasticsearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; @@ -44,15 +42,10 @@ public class MappingMetadata implements SimpleDiffable { private final boolean routingRequired; - private final Map> fieldsForModels; - public MappingMetadata(DocumentMapper docMapper) { this.type = docMapper.type(); this.source = docMapper.mappingSource(); this.routingRequired = docMapper.routingFieldMapper().required(); - - MappingLookup mappingLookup = docMapper.mappers(); - this.fieldsForModels = mappingLookup != null ? mappingLookup.getFieldsForModels() : Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -64,7 +57,6 @@ public MappingMetadata(CompressedXContent mapping) { } this.type = mappingMap.keySet().iterator().next(); this.routingRequired = routingRequired((Map) mappingMap.get(this.type)); - this.fieldsForModels = Map.of(); } @SuppressWarnings({ "this-escape", "unchecked" }) @@ -80,7 +72,6 @@ public MappingMetadata(String type, Map mapping) { withoutType = (Map) mapping.get(type); } this.routingRequired = routingRequired(withoutType); - this.fieldsForModels = Map.of(); } public static void writeMappingMetadata(StreamOutput out, Map mappings) throws IOException { @@ -167,19 +158,12 @@ public String getSha256() { return source.getSha256(); } - public Map> getFieldsForModels() { - return fieldsForModels; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(type()); source().writeTo(out); // routing out.writeBoolean(routingRequired); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - out.writeMap(fieldsForModels, StreamOutput::writeStringCollection); - } } @Override @@ -192,25 +176,19 @@ public boolean equals(Object o) { if (Objects.equals(this.routingRequired, that.routingRequired) == false) return false; if (source.equals(that.source) == false) return false; if (type.equals(that.type) == false) return false; - if (Objects.equals(this.fieldsForModels, that.fieldsForModels) == false) return false; return true; } @Override public int hashCode() { - return Objects.hash(type, source, routingRequired, fieldsForModels); + return Objects.hash(type, source, routingRequired); } public MappingMetadata(StreamInput in) throws IOException { type = in.readString(); source = CompressedXContent.readCompressedString(in); routingRequired = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_TEXT_FIELD_ADDED)) { - fieldsForModels = in.readMap(StreamInput::readString, i -> i.readCollectionAsImmutableSet(StreamInput::readString)); - } else { - fieldsForModels = Map.of(); - } } public static Diff readDiffFrom(StreamInput in) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java b/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java deleted file mode 100644 index a0b282d327ae8..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/InferenceProvider.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.action.ActionListener; - -import java.util.List; - -/** - * Provides NLP text inference results. Plugins can implement this interface to provide their own inference results. - */ -public interface InferenceProvider { - /** - * Returns InferenceResults for a given model ID and list of texts. - * - * @param modelId model identifier - * @param texts texts to perform inference on - * @param listener listener to be called when inference is complete - */ - void textInference(String modelId, List texts, ActionListener> listener); - - /** - * Returns true if this inference provider can perform inference - * - * @return true if this inference provider can perform inference - */ - boolean performsInference(); - - class NoopInferenceProvider implements InferenceProvider { - - @Override - public void textInference(String modelId, List texts, ActionListener> listener) { - throw new UnsupportedOperationException("No inference provider has been registered"); - } - - @Override - public boolean performsInference() { - return false; - } - } -} diff --git a/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java new file mode 100644 index 0000000000000..d53127d453501 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ModelRegistry.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.action.ActionListener; + +import java.util.List; +import java.util.Map; + +public abstract class ModelRegistry { + public record ModelConfigMap(Map config, Map secrets) {} + + /** + * Semi parsed model where model id, task type and service + * are known but the settings are not parsed. + */ + public record UnparsedModel( + String modelId, + TaskType taskType, + String service, + Map settings, + Map secrets + ) {} + + public abstract void getModelWithSecrets(String modelId, ActionListener listener); + + public abstract void getModel(String modelId, ActionListener listener); + + public abstract void getModelsByTaskType(TaskType taskType, ActionListener> listener); + + public abstract void storeModel(Model model, ActionListener listener); + + public abstract void deleteModel(String modelId, ActionListener listener); + + public abstract void getAllModels(ActionListener> listener); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SimilarityMeasure.java b/server/src/main/java/org/elasticsearch/inference/SimilarityMeasure.java similarity index 59% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SimilarityMeasure.java rename to server/src/main/java/org/elasticsearch/inference/SimilarityMeasure.java index 3028ecd078597..83e4f326e4e74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SimilarityMeasure.java +++ b/server/src/main/java/org/elasticsearch/inference/SimilarityMeasure.java @@ -1,3 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License @@ -5,7 +13,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.common; +package org.elasticsearch.inference; import java.util.Locale; diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index fbb56cac21e28..d61d09cdac498 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -123,7 +123,6 @@ import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -142,7 +141,6 @@ import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.DiscoveryPlugin; import org.elasticsearch.plugins.HealthPlugin; -import org.elasticsearch.plugins.InferenceProviderPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -1082,16 +1080,6 @@ record PluginServiceInstances( ); } - InferenceProvider inferenceProvider = null; - Optional inferenceProviderPlugin = getSinglePlugin(InferenceProviderPlugin.class); - if (inferenceProviderPlugin.isPresent()) { - inferenceProvider = inferenceProviderPlugin.get().getInferenceProvider(); - } else { - logger.warn("No inference provider found. Inference for semantic_text field types won't be available"); - inferenceProvider = new InferenceProvider.NoopInferenceProvider(); - } - modules.bindToInstance(InferenceProvider.class, inferenceProvider); - injector = modules.createInjector(); postInjection(clusterModule, actionModule, clusterService, transportService, featureService); diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java deleted file mode 100644 index ebd307d3d02c0..0000000000000 --- a/server/src/main/java/org/elasticsearch/plugins/InferenceProviderPlugin.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.plugins; - -import org.elasticsearch.inference.InferenceProvider; - -/** - * An extension point for {@link Plugin} implementations to add inference plugins for use on document ingestion - */ -public interface InferenceProviderPlugin { - - /** - * Returns the inference provider added by this plugin. - * - * @return InferenceProvider added by the plugin - */ - InferenceProvider getInferenceProvider(); - -} diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index 763039e77025f..d846d727f5827 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.EmptySystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockUtils; @@ -125,7 +124,8 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) indexNameExpressionResolver, new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ) { @Override void executeBulk( diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 08948d76ed8a1..6f7a9ea9edbee 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.action.bulk; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.DocWriteRequest; @@ -27,15 +28,20 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.Tuple; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TestInferenceResults; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.test.ClusterServiceUtils; @@ -43,27 +49,30 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentBuilder; import org.junit.After; import org.junit.Before; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TransportBulkActionInferenceTests extends ESTestCase { @@ -78,8 +87,7 @@ public class TransportBulkActionInferenceTests extends ESTestCase { private ThreadPool threadPool; private NodeClient nodeClient; private TransportBulkAction transportBulkAction; - - private InferenceProvider inferenceProvider; + private InferenceService inferenceService; @Before public void setup() { @@ -116,9 +124,12 @@ public void setup() { clusterService = ClusterServiceUtils.createClusterService(state, threadPool); - inferenceProvider = mock(InferenceProvider.class); - when(inferenceProvider.performsInference()).thenReturn(true); - + inferenceService = null; + InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); + when(inferenceServiceRegistry.getService(anyString())).thenAnswer( + (Answer>) invocation -> Optional.ofNullable(inferenceService) + ); + ModelRegistry modelRegistry = null; transportBulkAction = new TransportBulkAction( threadPool, mock(TransportService.class), @@ -129,16 +140,9 @@ public void setup() { TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(Settings.builder().put(AutoCreateIndex.AUTO_CREATE_INDEX_SETTING.getKey(), true).build()), EmptySystemIndices.INSTANCE, - inferenceProvider + inferenceServiceRegistry, + modelRegistry ); - - // Default answers to avoid hanging tests due to unexpected invocations - doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onFailure(new Exception("Unexpected invocation")); - return Void.TYPE; - }).when(inferenceProvider).textInference(any(), any(), any()); when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), any(), any())).thenAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; @@ -169,7 +173,6 @@ public void testBulkRequestWithoutInference() { assertThat(response.getItems().length, equalTo(1)); assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); - verifyInferenceExecuted(never()); } public void testBulkRequestWithInference() { @@ -189,7 +192,6 @@ public void testBulkRequestWithInference() { assertThat(response.getItems().length, equalTo(1)); assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); - verifyInferenceExecuted(times(1)); } public void testBulkRequestWithMultipleFieldsInference() { @@ -277,7 +279,7 @@ public void testFailingInference() { } private void verifyInferenceExecuted(VerificationMode verificationMode) { - verify(inferenceProvider, verificationMode).textInference(any(), any(), any()); + // verify(inferenceProvider, verificationMode).textInference(any(), any(), any()); } private void expectTransportShardBulkActionRequest(int requestSize) { @@ -308,33 +310,98 @@ private boolean matchBulkShardRequest(ActionRequest request, int requestSize) { } @SuppressWarnings("unchecked") - private void expectInferenceRequest(String modelId, String... inferenceTexts) { - doAnswer(invocation -> { - List texts = (List) invocation.getArguments()[1]; - var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onResponse( - texts.stream() - .map( - text -> new TestInferenceResults( - "test_field", - randomMap(1, 10, () -> new Tuple<>(randomAlphaOfLengthBetween(1, 10), randomFloat())) - ) - ) - .collect(Collectors.toList()) - ); - return Void.TYPE; - }).when(inferenceProvider) - .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); - } + private void expectInferenceRequest(String modelId, String... inferenceTexts) {} - private void expectInferenceRequestFails(String modelId, String... inferenceTexts) { - doAnswer(invocation -> { - @SuppressWarnings("unchecked") - var listener = (ActionListener>) invocation.getArguments()[2]; - listener.onFailure(new Exception("Inference failed")); - return Void.TYPE; - }).when(inferenceProvider) - .textInference(eq(modelId), argThat(texts -> texts.containsAll(Arrays.stream(inferenceTexts).toList())), any()); + private void expectInferenceRequestFails(String modelId, String... inferenceTexts) {} + + private static class MockInferenceService implements InferenceService { + private Map resultsPerModel = new HashMap<>(); + private Exception exception; + + void clearModels() { + resultsPerModel.clear(); + } + + void setResults(String text, TestInferenceResults results) { + resultsPerModel.put(text, results); + } + + @Override + public String name() { + return "mock"; + } + + @Override + public Model parseRequestConfig(String modelId, TaskType taskType, Map config, Set platfromArchitectures) { + return null; + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + return null; + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + return null; + } + + @Override + public void infer( + Model model, + List input, + Map taskSettings, + ActionListener listener + ) { + var results = input.stream().map(text -> resultsPerModel.get(text)).collect(Collectors.toList()); + listener.onResponse(new InferenceServiceResults() { + @Override + public List transformToCoordinationFormat() { + return results; + } + + @Override + public List transformToLegacyFormat() { + return results; + } + + @Override + public Map asMap() { + throw new IllegalStateException("not implemented"); + } + + @Override + public String getWriteableName() { + throw new IllegalStateException("not implemented"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IllegalStateException("not implemented"); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new IllegalStateException("not implemented"); + } + }); + } + + @Override + public void start(Model model, ActionListener listener) {} + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; + } + + @Override + public void close() throws IOException {} } } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index eae8554f3f394..b09c95dffd582 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -41,7 +41,6 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -137,7 +136,8 @@ class TestTransportBulkAction extends TransportBulkAction { TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(SETTINGS), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index 01fbbff173cd5..04a639c8dfc28 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -40,7 +40,6 @@ import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.SystemIndexDescriptorUtils; import org.elasticsearch.indices.SystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.test.index.IndexVersionUtils; @@ -89,7 +88,8 @@ class TestTransportBulkAction extends TransportBulkAction { new Resolver(), new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index d94a1cb092bc4..d577a7dcc0313 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -32,7 +32,6 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; @@ -254,7 +253,8 @@ static class TestTransportBulkAction extends TransportBulkAction { new IndexingPressure(Settings.EMPTY), EmptySystemIndices.INSTANCE, relativeTimeProvider, - new InferenceProvider.NoopInferenceProvider() + null, + null ); } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index f8fdcbd09ce78..1ecc2782ca858 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -155,7 +155,6 @@ import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.indices.recovery.SnapshotFilesProvider; import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.StatusInfo; import org.elasticsearch.node.ResponseCollectorService; @@ -1944,7 +1943,8 @@ protected void assertSnapshotOrGenericThread() { indexNameExpressionResolver, new IndexingPressure(settings), EmptySystemIndices.INSTANCE, - new InferenceProvider.NoopInferenceProvider() + null, + null ) ); final TransportShardBulkAction transportShardBulkAction = new TransportShardBulkAction( diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 50647ca328b23..e9476eeb22442 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; @@ -25,7 +26,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettingsTests; @@ -60,7 +61,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase { @Before public void createComponents() { - modelRegistry = new ModelRegistry(client()); + modelRegistry = new ModelRegistryImpl(client()); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java deleted file mode 100644 index 7886590c768cb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceActionInferenceProvider.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.client.internal.OriginSettingClient; -import org.elasticsearch.inference.InferenceProvider; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; - -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; - -/** - * InferenceProvider implementation that uses the inference action to retrieve inference results. - */ -public class InferenceActionInferenceProvider implements InferenceProvider { - - private final Client client; - - public InferenceActionInferenceProvider(Client client) { - this.client = new OriginSettingClient(client, INFERENCE_ORIGIN); - } - - @Override - public void textInference(String modelId, List texts, ActionListener> listener) { - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.SPARSE_EMBEDDING, // TODO Change when task type doesn't need to be specified - modelId, - texts, - Map.of(), - InputType.INGEST - ); - - client.execute(InferenceAction.INSTANCE, inferenceRequest, listener.delegateFailure((l, response) -> { - InferenceServiceResults results = response.getResults(); - if (results == null) { - throw new IllegalArgumentException("No inference retrieved for model ID " + modelId); - } - - @SuppressWarnings("unchecked") - List result = (List) results.transformToLegacyFormat(); - l.onResponse(result); - })); - } - - @Override - public boolean performsInference() { - return true; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3c99b9caac221..082556dc845a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,12 +21,11 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.indices.SystemIndexDescriptor; -import org.elasticsearch.inference.InferenceProvider; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; -import org.elasticsearch.plugins.InferenceProviderPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; @@ -50,7 +49,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.registry.ModelRegistryImpl; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -68,7 +67,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin, InferenceProviderPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, ExtensiblePlugin, SystemIndexPlugin { public static final String NAME = "inference"; public static final String UTILITY_THREAD_POOL_NAME = "inference_utility"; @@ -80,7 +79,6 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final SetOnce inferenceServiceRegistry = new SetOnce<>(); - private final SetOnce inferenceProvider = new SetOnce<>(); private List inferenceServiceExtensions; public InferencePlugin(Settings settings) { @@ -133,7 +131,7 @@ public Collection createComponents(PluginServices services) { ); httpFactory.set(httpRequestSenderFactory); - ModelRegistry modelRegistry = new ModelRegistry(services.client()); + ModelRegistry modelRegistry = new ModelRegistryImpl(services.client()); if (inferenceServiceExtensions == null) { inferenceServiceExtensions = new ArrayList<>(); @@ -146,10 +144,7 @@ public Collection createComponents(PluginServices services) { registry.init(services.client()); inferenceServiceRegistry.set(registry); - var provider = new InferenceActionInferenceProvider(services.client()); - inferenceProvider.set(provider); - - return List.of(modelRegistry, registry, provider); + return List.of(modelRegistry, registry); } @Override @@ -243,9 +238,4 @@ public void close() { IOUtils.closeWhileHandlingException(httpManager.get(), throttlerToClose); } - - @Override - public InferenceProvider getInferenceProvider() { - return inferenceProvider.get(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java index cb728120d2f0b..beedb84fa5178 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java @@ -23,12 +23,12 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index a7f5fb6c6c9a0..ca79a8aaa1bb3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -24,7 +25,6 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index db98aeccc556b..c94281d787137 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -16,11 +16,11 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; public class TransportInferenceAction extends HandledTransportAction { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 8bcc07a6322bc..d9f3af6097328 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -28,6 +28,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -42,7 +43,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java index 3cc83f2f4ddc5..8dfe95fb87a69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryImpl.java @@ -31,6 +31,7 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; @@ -55,41 +56,14 @@ import static org.elasticsearch.core.Strings.format; -public class ModelRegistry { - public record ModelConfigMap(Map config, Map secrets) {} - - /** - * Semi parsed model where model id, task type and service - * are known but the settings are not parsed. - */ - public record UnparsedModel( - String modelId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) { - - public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); - } - } - +public class ModelRegistryImpl extends ModelRegistry { private static final String TASK_TYPE_FIELD = "task_type"; private static final String MODEL_ID_FIELD = "model_id"; private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; - public ModelRegistry(Client client) { + public ModelRegistryImpl(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); } @@ -98,6 +72,7 @@ public ModelRegistry(Client client) { * @param modelId Model to get * @param listener Model listener */ + @Override public void getModelWithSecrets(String modelId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -106,7 +81,7 @@ public void getModelWithSecrets(String modelId, ActionListener li return; } - delegate.onResponse(UnparsedModel.unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), modelId))); + delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), modelId))); }); QueryBuilder queryBuilder = documentIdQuery(modelId); @@ -124,6 +99,7 @@ public void getModelWithSecrets(String modelId, ActionListener li * @param modelId Model to get * @param listener Model listener */ + @Override public void getModel(String modelId, ActionListener listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets @@ -132,7 +108,7 @@ public void getModel(String modelId, ActionListener listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -153,6 +129,7 @@ public void getModel(String modelId, ActionListener listener) { * @param taskType The task type * @param listener Models listener */ + @Override public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -161,7 +138,7 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // Not an error if no models of this task_type @@ -190,7 +168,7 @@ public void getAllModels(ActionListener> listener) { return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistryImpl::unparsedModelFromMap).toList(); delegate.onResponse(modelConfigs); }); @@ -252,6 +230,7 @@ private ModelConfigMap createModelConfigMap(SearchHits hits, String modelId) { ); } + @Override public void storeModel(Model model, ActionListener listener) { ActionListener bulkResponseActionListener = getStoreModelListener(model, listener); @@ -348,6 +327,7 @@ private static BulkItemResponse.Failure getFirstBulkFailure(BulkResponse bulkRes return null; } + @Override public void deleteModel(String modelId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndex.INDEX_PATTERN, InferenceSecretsIndex.INDEX_PATTERN); @@ -372,4 +352,16 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String modelId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(modelId))); } + + private static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); + } + String modelId = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.MODEL_ID); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(modelId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1686cd32d4a6b..b43d090343391 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -14,10 +14,10 @@ import org.elasticsearch.core.Strings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import java.net.URI; import java.net.URISyntaxException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java index 6464ca0e0fda8..a25c5cd145fce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java @@ -15,8 +15,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import java.io.IOException; import java.net.URI; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 1bdd1abce0b45..07bcab10f71d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -17,9 +17,9 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory; import org.elasticsearch.xpack.inference.services.SenderService; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java index 5ade2aad0acb4..47bde33d2d821 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettings.java @@ -15,8 +15,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import java.io.IOException; import java.net.URI; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index a8ea237ba8b0c..3f83686beb517 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.inference.ModelRegistry; import org.elasticsearch.inference.TaskType; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; @@ -65,7 +66,7 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var client = mockClient(); mockClientExecuteSearch(client, mockSearchResponse(SearchHits.EMPTY)); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); @@ -79,7 +80,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var unknownIndexHit = SearchHit.createFromMap(Map.of("_index", "unknown_index")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { unknownIndexHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); @@ -96,7 +97,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceSecretsHit = SearchHit.createFromMap(Map.of("_index", ".secrets-inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceSecretsHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); @@ -113,7 +114,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var inferenceHit = SearchHit.createFromMap(Map.of("_index", ".inference")); mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); @@ -147,7 +148,7 @@ public void testGetModelWithSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit, inferenceSecretsHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); @@ -176,7 +177,7 @@ public void testGetModelNoSecrets() { mockClientExecuteSearch(client, mockSearchResponse(new SearchHit[] { inferenceHit })); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.getModel("1", listener); @@ -201,7 +202,7 @@ public void testStoreModel_ReturnsTrue_WhenNoFailuresOccur() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -218,7 +219,7 @@ public void testStoreModel_ThrowsException_WhenBulkResponseIsEmpty() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -249,7 +250,7 @@ public void testStoreModel_ThrowsResourceAlreadyExistsException_WhenFailureIsAVe mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); @@ -272,7 +273,7 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { mockClientExecuteBulk(client, bulkResponse); var model = TestModel.createRandomInstance(); - var registry = new ModelRegistry(client); + var registry = new ModelRegistryImpl(client); var listener = new PlainActionFuture(); registry.storeModel(model, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java index 7e2a333685321..f32fafd493395 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java @@ -10,8 +10,8 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettingsTests.java index 81bbb4b041c51..90aea2041d15d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceSettingsTests.java @@ -11,8 +11,8 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java index 10e856ec8a27e..caaef58ebb721 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java @@ -9,9 +9,9 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.common.SimilarityMeasure; import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;