From 6432cd0d6cdbe13ca712369ef3bf76cfa5687208 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Wed, 25 Sep 2024 16:21:42 -0400 Subject: [PATCH] Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService --- .../inference/ModelConfigurations.java | 10 + .../googleaistudio/GoogleAiStudioService.java | 48 ++- .../GoogleAiStudioEmbeddingsModel.java | 29 +- .../huggingface/HuggingFaceBaseService.java | 25 ++ .../huggingface/HuggingFaceService.java | 25 +- .../elser/HuggingFaceElserService.java | 2 + .../HuggingFaceEmbeddingsModel.java | 7 +- .../services/mistral/MistralService.java | 57 +++- .../embeddings/MistralEmbeddingsModel.java | 6 +- .../GoogleAiStudioServiceTests.java | 283 +++++++++++++++++- .../GoogleAiStudioEmbeddingsModelTests.java | 15 + .../HuggingFaceBaseServiceTests.java | 15 - .../huggingface/HuggingFaceServiceTests.java | 209 +++++++++++++ .../HuggingFaceEmbeddingsModelTests.java | 4 + .../services/mistral/MistralServiceTests.java | 268 ++++++++++++++++- .../MistralEmbeddingModelTests.java | 23 ++ 16 files changed, 986 insertions(+), 40 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java index e5bd5a629a912..9f7a247d00a3a 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java @@ -74,6 +74,16 @@ public ModelConfigurations(String inferenceEntityId, TaskType taskType, String s this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE); } + public ModelConfigurations( + String inferenceEntityId, + TaskType taskType, + String service, + ServiceSettings serviceSettings, + ChunkingSettings chunkingSettings + ) { + this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings); + } + public ModelConfigurations( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 08eb67ca744a4..7b4c3d4b66bd5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -23,6 +24,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -72,11 +75,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + GoogleAiStudioModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -98,6 +109,7 @@ private static GoogleAiStudioModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -118,6 +130,7 @@ private static GoogleAiStudioModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -136,11 +149,17 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -151,6 +170,7 @@ private static GoogleAiStudioModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -159,6 +179,7 @@ private static GoogleAiStudioModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -170,11 +191,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -247,11 +274,22 @@ protected void doChunkedInfer( GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + googleAiStudioModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java index af19e26f3e97a..5d46a8e129dff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java @@ -9,6 +9,7 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; @@ -38,6 +39,7 @@ public GoogleAiStudioEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secrets, ConfigurationParseContext context ) { @@ -47,6 +49,7 @@ public GoogleAiStudioEmbeddingsModel( service, GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -62,10 +65,11 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google String service, GoogleAiStudioEmbeddingsServiceSettings serviceSettings, TaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings ); @@ -98,6 +102,29 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google } } + // Should only be used directly for testing + GoogleAiStudioEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + String uri, + GoogleAiStudioEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + ChunkingSettings chunkingsettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = new URI(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + @Override public GoogleAiStudioEmbeddingsServiceSettings getServiceSettings() { return (GoogleAiStudioEmbeddingsServiceSettings) super.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java index d129a0c44e835..a4116d14bf595 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java @@ -9,12 +9,15 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -27,6 +30,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -54,10 +58,18 @@ public void parseRequestConfig( try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + var model = createModel( inferenceEntityId, taskType, serviceSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, name()), ConfigurationParseContext.REQUEST @@ -82,10 +94,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModel( inferenceEntityId, taskType, serviceSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT @@ -96,10 +114,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets( public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModel( inferenceEntityId, taskType, serviceSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, name()), ConfigurationParseContext.PERSISTENT @@ -110,6 +134,7 @@ protected abstract HuggingFaceModel createModel( String inferenceEntityId, TaskType taskType, Map serviceSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage, ConfigurationParseContext context diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index bdfa87e77b708..cea3e530bd9fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -15,11 +15,13 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -47,6 +49,7 @@ protected HuggingFaceModel createModel( String inferenceEntityId, TaskType taskType, Map serviceSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -57,6 +60,7 @@ protected HuggingFaceModel createModel( taskType, NAME, serviceSettings, + chunkingSettings, secretSettings, context ); @@ -113,11 +117,22 @@ protected void doChunkedInfer( var huggingFaceModel = (HuggingFaceModel) model; var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + huggingFaceModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index b9540cab17a9a..a8de51c23831f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -56,6 +57,7 @@ protected HuggingFaceModel createModel( String inferenceEntityId, TaskType taskType, Map serviceSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java index fedd6380d035f..7c4d8094fc213 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.huggingface.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -26,6 +27,7 @@ public HuggingFaceEmbeddingsModel( TaskType taskType, String service, Map serviceSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -34,6 +36,7 @@ public HuggingFaceEmbeddingsModel( taskType, service, HuggingFaceServiceSettings.fromMap(serviceSettings, context), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -44,10 +47,11 @@ public HuggingFaceEmbeddingsModel( TaskType taskType, String service, HuggingFaceServiceSettings serviceSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings, secrets @@ -60,6 +64,7 @@ public HuggingFaceEmbeddingsModel(HuggingFaceEmbeddingsModel model, HuggingFaceS model.getTaskType(), model.getConfigurations().getService(), serviceSettings, + model.getConfigurations().getChunkingSettings(), model.getSecretSettings() ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 1acc13f50778b..cd85a391a2fc0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -22,6 +23,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -84,11 +87,21 @@ protected void doChunkedInfer( var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - MistralConstants.MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + MistralConstants.MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + mistralEmbeddingsModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + MistralConstants.MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); @@ -116,11 +129,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + MistralEmbeddingsModel model = createModel( modelId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -147,11 +168,17 @@ public Model parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( modelId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(modelId, NAME) ); @@ -162,11 +189,17 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( modelId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(modelId, NAME) ); @@ -182,12 +215,22 @@ private static MistralEmbeddingsModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context ) { if (taskType == TaskType.TEXT_EMBEDDING) { - return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + return new MistralEmbeddingsModel( + modelId, + taskType, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); } throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); @@ -198,6 +241,7 @@ private MistralEmbeddingsModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -206,6 +250,7 @@ private MistralEmbeddingsModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java index d883e7c687c20..11f6c456b88cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -37,6 +38,7 @@ public MistralEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -46,6 +48,7 @@ public MistralEmbeddingsModel( service, MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -61,10 +64,11 @@ public MistralEmbeddingsModel( String service, MistralEmbeddingsServiceSettings serviceSettings, TaskSettings taskSettings, + ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), new ModelSecrets(secrets) ); setPropertiesFromServiceSettings(serviceSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index a8882bb244512..1477fde55ae82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -29,6 +30,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; @@ -60,6 +62,8 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -155,6 +159,96 @@ public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModel() throw } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + var apiKey = "apiKey"; + var modelId = "model"; + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createGoogleAiStudioService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, Matchers.instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelListener + ); + } + } + + public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createGoogleAiStudioService()) { var failureListener = getModelListenerForException( @@ -300,6 +394,98 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddings } } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), Matchers.instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), Matchers.instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var modelId = "model"; var apiKey = "apiKey"; @@ -431,6 +617,74 @@ public void testParsePersistedConfig_CreatesAGoogleAiStudioCompletionModel() thr } } + public void testParsePersistedConfig_CreatesAGoogleAiEmbeddingsModelWithoutChunkingSettingsWhenChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAGoogleAiStudioEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), getTaskSettingsMapEmpty()); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var modelId = "model"; @@ -667,6 +921,22 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { public void testChunkedInfer_Batches() throws IOException { var modelId = "modelId"; var apiKey = "apiKey"; + var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, apiKey, getUrl(webServer)); + + testChunkedInfer(modelId, apiKey, model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var modelId = "modelId"; + var apiKey = "apiKey"; + var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, createRandomChunkingSettings(), apiKey, getUrl(webServer)); + + testChunkedInfer(modelId, apiKey, model); + } + + private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { + var input = List.of("foo", "bar"); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -693,7 +963,6 @@ public void testChunkedInfer_Batches() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, apiKey, getUrl(webServer)); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -928,6 +1197,18 @@ private static ActionListener getModelListenerForException(Class excep }); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java index 5ea9bbfc9d970..32bd95c954292 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -29,6 +30,19 @@ public static GoogleAiStudioEmbeddingsModel createModel(String model, String api ); } + public static GoogleAiStudioEmbeddingsModel createModel(String model, ChunkingSettings chunkingSettings, String apiKey, String url) { + return new GoogleAiStudioEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + url, + new GoogleAiStudioEmbeddingsServiceSettings(model, null, null, SimilarityMeasure.DOT_PRODUCT, null), + EmptyTaskSettings.INSTANCE, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + public static GoogleAiStudioEmbeddingsModel createModel( String url, String model, @@ -59,6 +73,7 @@ public static GoogleAiStudioEmbeddingsModel createModel( "service", new GoogleAiStudioEmbeddingsServiceSettings(model, tokenLimit, dimensions, SimilarityMeasure.DOT_PRODUCT, null), EmptyTaskSettings.INSTANCE, + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 22c3b7895460a..644209c29be5f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -13,13 +13,11 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.junit.After; import org.junit.Before; @@ -27,7 +25,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; @@ -104,17 +101,5 @@ public String name() { public TransportVersion getMinimalSupportedVersion() { return TransportVersion.current(); } - - @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return null; - } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index f68aedd69f365..f31cd116bd642 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -28,6 +29,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; @@ -57,6 +59,7 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; @@ -64,6 +67,7 @@ import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -110,6 +114,75 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createHuggingFaceService()) { + var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")); + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret")), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap((model) -> { + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, (e) -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret")), + Set.of(), + modelVerificationActionListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationActionListener = ActionListener.wrap((model) -> { + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, (e) -> fail("parse request should not fail " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")), + Set.of(), + modelVerificationActionListener + ); + } + } + public void testParseRequestConfig_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { ActionListener modelVerificationActionListener = ActionListener.wrap((model) -> { @@ -213,6 +286,82 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throw } } + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + new HashMap<>(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("url"), + new HashMap<>(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret")); @@ -352,6 +501,55 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModel() throws IOExcepti } } + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap()); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), createRandomChunkingSettingsMap()); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url")); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(HuggingFaceEmbeddingsModel.class)); + + var embeddingsModel = (HuggingFaceEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + public void testParsePersistedConfig_CreatesAnElserModel() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>()); @@ -709,6 +907,17 @@ private HuggingFaceService createHuggingFaceService() { return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { var builtServiceSettings = new HashMap<>(); builtServiceSettings.putAll(serviceSettings); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java index baf5467d8fe06..b81eabb20edc3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java @@ -31,6 +31,7 @@ public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey) TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(url), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -41,6 +42,7 @@ public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey, TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(createUri(url), null, null, tokenLimit, null), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -51,6 +53,7 @@ public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey, TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(createUri(url), null, dimensions, tokenLimit, null), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -67,6 +70,7 @@ public static HuggingFaceEmbeddingsModel createModel( TaskType.TEXT_EMBEDDING, "service", new HuggingFaceServiceSettings(createUri(url), similarityMeasure, dimensions, tokenLimit, null), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index c833f00c4c433..4729340d3816e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -28,6 +29,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -45,7 +47,6 @@ import org.junit.Before; import java.io.IOException; -import java.net.URISyntaxException; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -57,6 +58,8 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -118,6 +121,90 @@ public void testParseRequestConfig_CreatesAMistralEmbeddingsModel() throws IOExc } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + var serviceSettings = (MistralEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(serviceSettings.modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + var serviceSettings = (MistralEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(serviceSettings.modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -238,6 +325,76 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOE } } + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -360,6 +517,74 @@ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() thro } } + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -423,7 +648,31 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I verifyNoMoreInteractions(sender); } - public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException { + public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException { + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = MistralEmbeddingModelTests.createModel( + "id", + "mistral-embed", + createRandomChunkingSettings(), + "apikey", + null, + null, + null, + null + ); + model.setURI(getUrl(webServer)); + + testChunkedInfer(model); + } + + public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { @@ -458,9 +707,6 @@ public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throw """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - model.setURI(getUrl(webServer)); - PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -546,6 +792,18 @@ private MistralService createService() { return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java index 0fe8723664c6e..6f8b40fd7f19c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -24,6 +25,7 @@ public static MistralEmbeddingsModel createModel(String inferenceId, String mode public static MistralEmbeddingsModel createModel( String inferenceId, String model, + ChunkingSettings chunkingSettings, String apiKey, @Nullable Integer dimensions, @Nullable Integer maxTokens, @@ -36,6 +38,27 @@ public static MistralEmbeddingsModel createModel( "mistral", new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), EmptyTaskSettings.INSTANCE, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static MistralEmbeddingsModel createModel( + String inferenceId, + String model, + String apiKey, + @Nullable Integer dimensions, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings + ) { + return new MistralEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "mistral", + new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), + EmptyTaskSettings.INSTANCE, + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); }