diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2244622445ecc..ee8d4b0fbbc6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.ArrayList; @@ -504,30 +505,34 @@ public void checkModelConfig(Model model, ActionListener listener) { @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { - if (model instanceof CustomElandEmbeddingModel embeddingsModel) { - var serviceSettings = embeddingsModel.getServiceSettings(); - - var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( - serviceSettings.getNumAllocations(), - serviceSettings.getNumThreads(), - serviceSettings.modelId(), - serviceSettings.getAdaptiveAllocationsSettings(), - embeddingSize, - serviceSettings.similarity(), - serviceSettings.elementType() - ); + if (model instanceof ElasticsearchInternalModel) { + if (model instanceof CustomElandEmbeddingModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + + var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + serviceSettings.getNumAllocations(), + serviceSettings.getNumThreads(), + serviceSettings.modelId(), + serviceSettings.getAdaptiveAllocationsSettings(), + embeddingSize, + serviceSettings.similarity(), + serviceSettings.elementType() + ); - return new CustomElandEmbeddingModel( - model.getInferenceEntityId(), - model.getTaskType(), - model.getConfigurations().getService(), - updatedServiceSettings, - model.getConfigurations().getChunkingSettings() - ); + return new CustomElandEmbeddingModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + updatedServiceSettings, + model.getConfigurations().getChunkingSettings() + ); + } else { + // TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do + // need to update the dimensions? + return model; + } } else { - // TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do - // need to update the dimensions? - return model; + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index cad33b56ce235..510a093e9c162 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -67,11 +67,13 @@ import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import java.io.IOException; import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; @@ -1509,6 +1511,84 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { assertThat(model, is(expectedModel)); } + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { + var client = mock(Client.class); + try (var service = createService(client)) { + var model = OpenAiChatCompletionModelTests.createChatCompletionModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10) + ); + assertThrows( + ElasticsearchStatusException.class, + () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); } + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_NonElandModelProvided() throws IOException { + var client = mock(Client.class); + try (var service = createService(client)) { + var originalModel = new MultilingualE5SmallModel( + randomAlphaOfLength(10), + TaskType.TEXT_EMBEDDING, + randomAlphaOfLength(10), + new MultilingualE5SmallInternalServiceSettings( + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + null + ), + null + ); + + var updatedModel = service.updateModelWithEmbeddingDetails(originalModel, randomNonNegativeInt()); + assertEquals(originalModel, updatedModel); + } + } + + public void testUpdateModelWithEmbeddingDetails_ElandModelProvided() throws IOException { + var client = mock(Client.class); + try (var service = createService(client)) { + var originalServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + null + ); + var originalModel = new CustomElandEmbeddingModel( + randomAlphaOfLength(10), + TaskType.TEXT_EMBEDDING, + randomAlphaOfLength(10), + originalServiceSettings, + ChunkingSettingsTests.createRandomChunkingSettings() + ); + + var embeddingSize = randomNonNegativeInt(); + var expectedUpdatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + originalServiceSettings.getNumAllocations(), + originalServiceSettings.getNumThreads(), + originalServiceSettings.modelId(), + originalServiceSettings.getAdaptiveAllocationsSettings(), + embeddingSize, + originalServiceSettings.similarity(), + originalServiceSettings.elementType() + ); + var expectedUpdatedModel = new CustomElandEmbeddingModel( + originalModel.getInferenceEntityId(), + originalModel.getTaskType(), + originalModel.getConfigurations().getService(), + expectedUpdatedServiceSettings, + originalModel.getConfigurations().getChunkingSettings() + ); + + var actualUpdatedModel = service.updateModelWithEmbeddingDetails(originalModel, embeddingSize); + assertEquals(expectedUpdatedModel, actualUpdatedModel); + } + } + public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { { assertFalse(