Skip to content

Commit

Permalink
Adding ModelValidators to OpenAiService
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-rubinstein committed Sep 18, 2024
1 parent a244fd9 commit 82770d4
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ default void checkModelConfig(Model model, ActionListener<Model> listener) {
listener.onResponse(model);
};

/**
* Update a text embedding model's dimensions based on a provided embedding
* size and set the default similarity if required. The default behaviour is to just return the model.
* @param model The original model without updated embedding details
* @param embeddingSize The embedding size to update the model with
* @return The model with updated embedding details
*/
default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
return model;
}

/**
* Return true if this model is hosted in the local Elasticsearch cluster
* @return True if in cluster
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ default SimilarityMeasure similarity() {
return null;
}

default void setSimilarity(SimilarityMeasure similarity) {}

/**
* Number of dimensions the service works with. Will be null if not applicable.
*
Expand All @@ -35,6 +33,11 @@ default Integer dimensions() {
return null;
}

/**
* Boolean signifying whether the dimensions were set by the user
*
* @return boolean signifying whether the dimensions were set by the user
*/
default Boolean dimensionsSetByUser() {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class AmazonBedrockEmbeddingsServiceSettings extends AmazonBedrockService
private final Integer dimensions;
private final Boolean dimensionsSetByUser;
private final Integer maxInputTokens;
private SimilarityMeasure similarity;
private final SimilarityMeasure similarity;

public static AmazonBedrockEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
Expand Down Expand Up @@ -178,17 +178,11 @@ public SimilarityMeasure similarity() {
return similarity;
}

@Override
public void setSimilarity(SimilarityMeasure similarity) {
this.similarity = similarity;
}

@Override
public Integer dimensions() {
return dimensions;
}

@Override
public Boolean dimensionsSetByUser() {
return this.dimensionsSetByUser;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,13 @@ private AzureAiStudioEmbeddingsServiceSettings(AzureAiStudioEmbeddingCommonField
private final Integer dimensions;
private final Boolean dimensionsSetByUser;
private final Integer maxInputTokens;
private SimilarityMeasure similarity;
private final SimilarityMeasure similarity;

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public void setSimilarity(SimilarityMeasure similarity) {
this.similarity = similarity;
}

@Override
public Boolean dimensionsSetByUser() {
return this.dimensionsSetByUser;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private record CommonFields(
private final Integer dimensions;
private final Boolean dimensionsSetByUser;
private final Integer maxInputTokens;
private SimilarityMeasure similarity;
private final SimilarityMeasure similarity;
private final RateLimitSettings rateLimitSettings;

public AzureOpenAiEmbeddingsServiceSettings(
Expand Down Expand Up @@ -229,7 +229,6 @@ public Integer dimensions() {
return dimensions;
}

@Override
public Boolean dimensionsSetByUser() {
return dimensionsSetByUser;
}
Expand All @@ -243,11 +242,6 @@ public SimilarityMeasure similarity() {
return similarity;
}

@Override
public void setSimilarity(SimilarityMeasure similarity) {
this.similarity = similarity;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object

private final Integer dims;

private SimilarityMeasure similarity;
private final SimilarityMeasure similarity;
private final Integer maxInputTokens;

private final RateLimitSettings rateLimitSettings;
Expand Down Expand Up @@ -169,7 +169,6 @@ public String modelId() {
return modelId;
}

@Override
public Boolean dimensionsSetByUser() {
return dimensionsSetByUser;
}
Expand All @@ -193,11 +192,6 @@ public SimilarityMeasure similarity() {
return similarity;
}

@Override
public void setSimilarity(SimilarityMeasure similarity) {
this.similarity = similarity;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
Expand All @@ -31,10 +30,10 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -255,48 +254,32 @@ protected void doChunkedInfer(
*/
@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
if (model instanceof OpenAiEmbeddingsModel embeddingsModel) {
ServiceUtils.getEmbeddingSize(
model,
this,
listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
);
} else {
listener.onResponse(model);
}
// TODO: Remove this function once all services have been updated to use the new model validators
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
}

private OpenAiEmbeddingsModel updateModelWithEmbeddingDetails(OpenAiEmbeddingsModel model, int embeddingSize) {
if (model.getServiceSettings().dimensionsSetByUser()
&& model.getServiceSettings().dimensions() != null
&& model.getServiceSettings().dimensions() != embeddingSize) {
throw new ElasticsearchStatusException(
Strings.format(
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
+ "Please recreate the [%s] configuration with the correct dimensions",
embeddingSize,
model.getServiceSettings().dimensions(),
model.getConfigurations().getInferenceEntityId()
),
RestStatus.BAD_REQUEST
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof OpenAiEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
var similarityFromModel = serviceSettings.similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

var updatedServiceSettings = new OpenAiEmbeddingsServiceSettings(
serviceSettings.modelId(),
serviceSettings.uri(),
serviceSettings.organizationId(),
similarityToUse,
embeddingSize,
serviceSettings.maxInputTokens(),
serviceSettings.dimensionsSetByUser(),
serviceSettings.rateLimitSettings()
);
}

var similarityFromModel = model.getServiceSettings().similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

OpenAiEmbeddingsServiceSettings serviceSettings = new OpenAiEmbeddingsServiceSettings(
model.getServiceSettings().modelId(),
model.getServiceSettings().uri(),
model.getServiceSettings().organizationId(),
similarityToUse,
embeddingSize,
model.getServiceSettings().maxInputTokens(),
model.getServiceSettings().dimensionsSetByUser(),
model.getServiceSettings().rateLimitSettings()
);

return new OpenAiEmbeddingsModel(model, serviceSettings);
return new OpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
} else {
throw new ElasticsearchStatusException("Cannot update model with embedding details", RestStatus.BAD_REQUEST);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ private record CommonFields(
private final URI uri;
private final String organizationId;
private SimilarityMeasure similarity;
private final Integer dimensions;
private Integer dimensions;
private final Integer maxInputTokens;
private final Boolean dimensionsSetByUser;
private final RateLimitSettings rateLimitSettings;
Expand Down Expand Up @@ -242,11 +242,6 @@ public SimilarityMeasure similarity() {
return similarity;
}

@Override
public void setSimilarity(SimilarityMeasure similarity) {
this.similarity = similarity;
}

@Override
public Integer dimensions() {
return dimensions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbedding;

import java.util.Objects;

public class TextEmbeddingModelValidator implements ModelValidator {

private final ServiceIntegrationValidator serviceIntegrationValidator;
Expand All @@ -32,48 +28,50 @@ public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegratio
@Override
public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> {
delegate.onResponse(postValidate(model, r));
delegate.onResponse(postValidate(service, model, r));
}));
}

private Model postValidate(Model model, InferenceServiceResults results) {
private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
if (results instanceof TextEmbedding embeddingResults) {
try {
if (model.getServiceSettings().dimensionsSetByUser()
&& model.getServiceSettings().dimensions() != null
&& Objects.equals(model.getServiceSettings().dimensions(), embeddingResults.getFirstEmbeddingSize()) == false) {
throw new ElasticsearchStatusException(
Strings.format(
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
+ "Please recreate the [%s] configuration with the correct dimensions",
embeddingResults.getFirstEmbeddingSize(),
model.getServiceSettings().dimensions(),
model.getConfigurations().getInferenceEntityId()
),
RestStatus.BAD_REQUEST
);
}
} catch (Exception e) {
var serviceSettings = model.getServiceSettings();
var dimensions = serviceSettings.dimensions();
int embeddingSize = getEmbeddingSize(embeddingResults);

if (serviceSettings.dimensionsSetByUser() && dimensions != null && dimensions.equals(embeddingSize) == false) {
throw new ElasticsearchStatusException(
"Could not determine embedding size. "
+ "Expected a result of type ["
+ InferenceTextEmbeddingFloatResults.NAME
+ "] got ["
+ results.getWriteableName()
+ "]",
Strings.format(
"The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
+ "Please recreate the [%s] configuration with the correct dimensions",
embeddingResults.getFirstEmbeddingSize(),
serviceSettings.dimensions(),
model.getInferenceEntityId()
),
RestStatus.BAD_REQUEST
);
}

var similarityFromModel = model.getServiceSettings().similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

ServiceSettings serviceSettings = model.getServiceSettings();
serviceSettings.setSimilarity(similarityToUse);

return model;
return service.updateModelWithEmbeddingDetails(model, embeddingSize);
} else {
throw new ElasticsearchStatusException("Validation call did not return text embedding response", RestStatus.BAD_REQUEST);
throw new ElasticsearchStatusException(
"Validation call did not return expected results type."
+ "Expected a result of type ["
+ InferenceTextEmbeddingFloatResults.NAME
+ "] got ["
+ (results == null ? "null" : results.getWriteableName())
+ "]",
RestStatus.BAD_REQUEST
);
}
}

private int getEmbeddingSize(TextEmbedding embeddingResults) {
int embeddingSize;
try {
embeddingSize = embeddingResults.getFirstEmbeddingSize();
} catch (Exception e) {
throw new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e);
}
return embeddingSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.elser.ElserModels;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.After;
Expand Down Expand Up @@ -311,7 +310,13 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists
public void testCheckModelConfig_ReturnsNewModelReference() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
try (
var service = new ElasticInferenceService(
senderFactory,
createWithEmptySettings(threadPool),
new ElasticInferenceServiceComponents(getUrl(webServer))
)
) {
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer));
PlainActionFuture<Model> listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener);
Expand Down
Loading

0 comments on commit 82770d4

Please sign in to comment.