diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 4bfe4ae535bed..0d4c66dd13849 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -174,6 +174,17 @@ default void checkModelConfig(Model model, ActionListener listener) { listener.onResponse(model); }; + /** + * Update a text embedding model's dimensions based on a provided embedding + * size and set the default similarity if required. The default behaviour is to just return the model. + * @param model The original model without updated embedding details + * @param embeddingSize The embedding size to update the model with + * @return The model with updated embedding details + */ + default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + return model; + } + /** * Return true if this model is hosted in the local Elasticsearch cluster * @return True if in cluster diff --git a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java index d1aa0035ba8e8..7d1780c2de389 100644 --- a/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java @@ -24,8 +24,6 @@ default SimilarityMeasure similarity() { return null; } - default void setSimilarity(SimilarityMeasure similarity) {} - /** * Number of dimensions the service works with. Will be null if not applicable. * @@ -35,6 +33,11 @@ default Integer dimensions() { return null; } + /** + * Boolean signifying whether the dimensions were set by the user + * + * @return boolean signifying whether the dimensions were set by the user + */ default Boolean dimensionsSetByUser() { return null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java index 0916bf0ca334f..2716c3df97685 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java @@ -39,7 +39,7 @@ public class AmazonBedrockEmbeddingsServiceSettings extends AmazonBedrockService private final Integer dimensions; private final Boolean dimensionsSetByUser; private final Integer maxInputTokens; - private SimilarityMeasure similarity; + private final SimilarityMeasure similarity; public static AmazonBedrockEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); @@ -178,17 +178,11 @@ public SimilarityMeasure similarity() { return similarity; } - @Override - public void setSimilarity(SimilarityMeasure similarity) { - this.similarity = similarity; - } - @Override public Integer dimensions() { return dimensions; } - @Override public Boolean dimensionsSetByUser() { return this.dimensionsSetByUser; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java index 0b13b607fc972..66482e56d9615 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java @@ -139,19 +139,13 @@ private AzureAiStudioEmbeddingsServiceSettings(AzureAiStudioEmbeddingCommonField private final Integer dimensions; private final Boolean dimensionsSetByUser; private final Integer maxInputTokens; - private SimilarityMeasure similarity; + private final SimilarityMeasure similarity; @Override public SimilarityMeasure similarity() { return similarity; } - @Override - public void setSimilarity(SimilarityMeasure similarity) { - this.similarity = similarity; - } - - @Override public Boolean dimensionsSetByUser() { return this.dimensionsSetByUser; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java index 3b7d76f77f417..941a4bdeeb41a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -153,7 +153,7 @@ private record CommonFields( private final Integer dimensions; private final Boolean dimensionsSetByUser; private final Integer maxInputTokens; - private SimilarityMeasure similarity; + private final SimilarityMeasure similarity; private final RateLimitSettings rateLimitSettings; public AzureOpenAiEmbeddingsServiceSettings( @@ -229,7 +229,6 @@ public Integer dimensions() { return dimensions; } - @Override public Boolean dimensionsSetByUser() { return dimensionsSetByUser; } @@ -243,11 +242,6 @@ public SimilarityMeasure similarity() { return similarity; } - @Override - public void setSimilarity(SimilarityMeasure similarity) { - this.similarity = similarity; - } - @Override public DenseVectorFieldMapper.ElementType elementType() { return DenseVectorFieldMapper.ElementType.FLOAT; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java index b818a2335769c..097ce6240439b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java @@ -117,7 +117,7 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map listener) { - if (model instanceof OpenAiEmbeddingsModel embeddingsModel) { - ServiceUtils.getEmbeddingSize( - model, - this, - listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) - ); - } else { - listener.onResponse(model); - } + // TODO: Remove this function once all services have been updated to use the new model validators + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); } - private OpenAiEmbeddingsModel updateModelWithEmbeddingDetails(OpenAiEmbeddingsModel model, int embeddingSize) { - if (model.getServiceSettings().dimensionsSetByUser() - && model.getServiceSettings().dimensions() != null - && model.getServiceSettings().dimensions() != embeddingSize) { - throw new ElasticsearchStatusException( - Strings.format( - "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " - + "Please recreate the [%s] configuration with the correct dimensions", - embeddingSize, - model.getServiceSettings().dimensions(), - model.getConfigurations().getInferenceEntityId() - ), - RestStatus.BAD_REQUEST + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof OpenAiEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new OpenAiEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + serviceSettings.organizationId(), + similarityToUse, + embeddingSize, + serviceSettings.maxInputTokens(), + serviceSettings.dimensionsSetByUser(), + serviceSettings.rateLimitSettings() ); - } - - var similarityFromModel = model.getServiceSettings().similarity(); - var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; - - OpenAiEmbeddingsServiceSettings serviceSettings = new OpenAiEmbeddingsServiceSettings( - model.getServiceSettings().modelId(), - model.getServiceSettings().uri(), - model.getServiceSettings().organizationId(), - similarityToUse, - embeddingSize, - model.getServiceSettings().maxInputTokens(), - model.getServiceSettings().dimensionsSetByUser(), - model.getServiceSettings().rateLimitSettings() - ); - return new OpenAiEmbeddingsModel(model, serviceSettings); + return new OpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw new ElasticsearchStatusException("Cannot update model with embedding details", RestStatus.BAD_REQUEST); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index 0a2fc9ef73e71..c59a2385245f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -134,7 +134,7 @@ private record CommonFields( private final URI uri; private final String organizationId; private SimilarityMeasure similarity; - private final Integer dimensions; + private Integer dimensions; private final Integer maxInputTokens; private final Boolean dimensionsSetByUser; private final RateLimitSettings rateLimitSettings; @@ -242,11 +242,6 @@ public SimilarityMeasure similarity() { return similarity; } - @Override - public void setSimilarity(SimilarityMeasure similarity) { - this.similarity = similarity; - } - @Override public Integer dimensions() { return dimensions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index df0b5c9b2bdfc..58fa21cf37cf0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -13,14 +13,10 @@ import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbedding; -import java.util.Objects; - public class TextEmbeddingModelValidator implements ModelValidator { private final ServiceIntegrationValidator serviceIntegrationValidator; @@ -32,48 +28,50 @@ public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegratio @Override public void validate(InferenceService service, Model model, ActionListener listener) { serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> { - delegate.onResponse(postValidate(model, r)); + delegate.onResponse(postValidate(service, model, r)); })); } - private Model postValidate(Model model, InferenceServiceResults results) { + private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) { if (results instanceof TextEmbedding embeddingResults) { - try { - if (model.getServiceSettings().dimensionsSetByUser() - && model.getServiceSettings().dimensions() != null - && Objects.equals(model.getServiceSettings().dimensions(), embeddingResults.getFirstEmbeddingSize()) == false) { - throw new ElasticsearchStatusException( - Strings.format( - "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " - + "Please recreate the [%s] configuration with the correct dimensions", - embeddingResults.getFirstEmbeddingSize(), - model.getServiceSettings().dimensions(), - model.getConfigurations().getInferenceEntityId() - ), - RestStatus.BAD_REQUEST - ); - } - } catch (Exception e) { + var serviceSettings = model.getServiceSettings(); + var dimensions = serviceSettings.dimensions(); + int embeddingSize = getEmbeddingSize(embeddingResults); + + if (serviceSettings.dimensionsSetByUser() && dimensions != null && dimensions.equals(embeddingSize) == false) { throw new ElasticsearchStatusException( - "Could not determine embedding size. " - + "Expected a result of type [" - + InferenceTextEmbeddingFloatResults.NAME - + "] got [" - + results.getWriteableName() - + "]", + Strings.format( + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " + + "Please recreate the [%s] configuration with the correct dimensions", + embeddingResults.getFirstEmbeddingSize(), + serviceSettings.dimensions(), + model.getInferenceEntityId() + ), RestStatus.BAD_REQUEST ); } - var similarityFromModel = model.getServiceSettings().similarity(); - var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; - - ServiceSettings serviceSettings = model.getServiceSettings(); - serviceSettings.setSimilarity(similarityToUse); - - return model; + return service.updateModelWithEmbeddingDetails(model, embeddingSize); } else { - throw new ElasticsearchStatusException("Validation call did not return text embedding response", RestStatus.BAD_REQUEST); + throw new ElasticsearchStatusException( + "Validation call did not return expected results type." + + "Expected a result of type [" + + InferenceTextEmbeddingFloatResults.NAME + + "] got [" + + (results == null ? "null" : results.getWriteableName()) + + "]", + RestStatus.BAD_REQUEST + ); + } + } + + private int getEmbeddingSize(TextEmbedding embeddingResults) { + int embeddingSize; + try { + embeddingSize = embeddingResults.getFirstEmbeddingSize(); + } catch (Exception e) { + throw new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e); } + return embeddingSize; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 62416f05800c6..f3bf7413d2553 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elser.ElserModels; -import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; @@ -311,7 +310,13 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new ElasticInferenceService( + senderFactory, + createWithEmptySettings(threadPool), + new ElasticInferenceServiceComponents(getUrl(webServer)) + ) + ) { var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer)); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 9ff175ca9685e..dbc365f3d6919 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -63,6 +63,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; @@ -1151,6 +1152,52 @@ public void testCheckModelConfig_ReturnsNewModelReference_DoesNotOverrideSimilar } } + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { + try (var service = createOpenAiService()) { + var model = createChatCompletionModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10) + ); + assertThrows( + ElasticsearchStatusException.class, + () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); } + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(null); + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values())); + } + + private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { + try (var service = createOpenAiService()) { + var embeddingSize = randomNonNegativeInt(); + var model = OpenAiEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + null, + randomNonNegativeInt(), + randomNonNegativeInt(), + randomBoolean() + ); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + public void testInfer_UnauthorisedResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index 4bc96817204a1..154d613f31220 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; @@ -24,10 +23,10 @@ import org.junit.Before; import org.mockito.Mock; -import java.util.Arrays; import java.util.List; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; @@ -57,8 +56,10 @@ public void setup() { underTest = new TextEmbeddingModelValidator(mockServiceIntegrationValidator); + when(mockInferenceService.updateModelWithEmbeddingDetails(eq(mockModel), anyInt())).thenReturn(mockModel); when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); when(mockModel.getServiceSettings()).thenReturn(mockServiceSettings); + when(mockModel.getInferenceEntityId()).thenReturn(randomAlphaOfLength(10)); } public void testValidate_ServiceIntegrationValidatorThrowsException() { @@ -95,60 +96,49 @@ public void testValidate_RetrievingEmbeddingSizeThrowsIllegalStateException() { verifyCallToServiceIntegrationValidator(results); verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class)); - verify(mockModel, times(3)).getServiceSettings(); - verify(mockServiceSettings).dimensionsSetByUser(); - verify(mockServiceSettings, times(2)).dimensions(); + verify(mockModel, times(1)).getServiceSettings(); + verify(mockServiceSettings).dimensions(); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener, mockServiceSettings); } public void testValidate_DimensionsSetByUserDoNotEqualEmbeddingSize() { InferenceTextEmbeddingByteResults results = InferenceTextEmbeddingByteResultsTests.createRandomResults(); - var dimensions = randomNonNegativeInt(); - while (dimensions == results.getFirstEmbeddingSize()) { - dimensions = randomNonNegativeInt(); - } + var dimensions = randomValueOtherThan(results.getFirstEmbeddingSize(), ESTestCase::randomNonNegativeInt); when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true); when(mockServiceSettings.dimensions()).thenReturn(dimensions); verifyCallToServiceIntegrationValidator(results); verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class)); - verify(mockModel, times(4)).getServiceSettings(); - verify(mockModel).getConfigurations(); + verify(mockModel).getServiceSettings(); + verify(mockModel).getInferenceEntityId(); verify(mockServiceSettings).dimensionsSetByUser(); - verify(mockServiceSettings, times(3)).dimensions(); + verify(mockServiceSettings, times(2)).dimensions(); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener, mockServiceSettings); } - public void testValidate_NullSimilarityProvided() { - verifySimilarityUpdatedProperly(null, SimilarityMeasure.DOT_PRODUCT); + public void testValidate_DimensionsSetByUserEqualEmbeddingSize() { + verifySuccessfulValidation(true); } - public void testValidate_NonNullSimilarityProvided() { - SimilarityMeasure similarityProvided = randomFrom( - Arrays.stream(SimilarityMeasure.values()) - .filter(similarityMeasure -> similarityMeasure.equals(SimilarityMeasure.DOT_PRODUCT) == false) - .toList() - ); - verifySimilarityUpdatedProperly(similarityProvided, similarityProvided); + public void testValidate_DimensionsNotSetByUser() { + verifySuccessfulValidation(false); } - private void verifySimilarityUpdatedProperly(SimilarityMeasure similarityProvided, SimilarityMeasure updatedSimilarity) { + private void verifySuccessfulValidation(Boolean dimensionsSetByUser) { InferenceTextEmbeddingByteResults results = InferenceTextEmbeddingByteResultsTests.createRandomResults(); when(mockModel.getConfigurations()).thenReturn(ModelConfigurationsTests.createRandomInstance()); when(mockModel.getTaskSettings()).thenReturn(EmptyTaskSettingsTests.createRandom()); - when(mockServiceSettings.dimensionsSetByUser()).thenReturn(true); - when(mockServiceSettings.dimensions()).thenReturn(results.getFirstEmbeddingSize()); - when(mockServiceSettings.similarity()).thenReturn(similarityProvided); + when(mockServiceSettings.dimensionsSetByUser()).thenReturn(dimensionsSetByUser); + when(mockServiceSettings.dimensions()).thenReturn(dimensionsSetByUser ? results.getFirstEmbeddingSize() : null); verifyCallToServiceIntegrationValidator(results); verify(mockActionListener).onResponse(mockModel); - verify(mockModel, times(5)).getServiceSettings(); + verify(mockModel).getServiceSettings(); verify(mockServiceSettings).dimensionsSetByUser(); - verify(mockServiceSettings, times(2)).dimensions(); - verify(mockServiceSettings).similarity(); - verify(mockServiceSettings).setSimilarity(updatedSimilarity); + verify(mockServiceSettings).dimensions(); + verify(mockInferenceService).updateModelWithEmbeddingDetails(mockModel, results.getFirstEmbeddingSize()); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener, mockServiceSettings); } @@ -164,17 +154,4 @@ private void verifyCallToServiceIntegrationValidator(InferenceServiceResults res verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); verify(mockActionListener).delegateFailureAndWrap(any()); } - - private void verifyPostValidationException(InferenceServiceResults results, Class exceptionClass) { - doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); - responseListener.onResponse(results); - return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); - - assertThrows(exceptionClass, () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); }); - - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); - verify(mockActionListener).delegateFailureAndWrap(any()); - } }