Skip to content

Commit

Permalink
Adding updateModelWithEmbeddingDetails tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-rubinstein committed Nov 8, 2024
1 parent 9cbc5d3 commit f80f054
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

import java.util.ArrayList;
Expand Down Expand Up @@ -504,30 +505,34 @@ public void checkModelConfig(Model model, ActionListener<Model> listener) {

@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof CustomElandEmbeddingModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();

var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
serviceSettings.getNumAllocations(),
serviceSettings.getNumThreads(),
serviceSettings.modelId(),
serviceSettings.getAdaptiveAllocationsSettings(),
embeddingSize,
serviceSettings.similarity(),
serviceSettings.elementType()
);
if (model instanceof ElasticsearchInternalModel) {
if (model instanceof CustomElandEmbeddingModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();

var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
serviceSettings.getNumAllocations(),
serviceSettings.getNumThreads(),
serviceSettings.modelId(),
serviceSettings.getAdaptiveAllocationsSettings(),
embeddingSize,
serviceSettings.similarity(),
serviceSettings.elementType()
);

return new CustomElandEmbeddingModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
updatedServiceSettings,
model.getConfigurations().getChunkingSettings()
);
return new CustomElandEmbeddingModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
updatedServiceSettings,
model.getConfigurations().getChunkingSettings()
);
} else {
// TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do
// need to update the dimensions?
return model;
}
} else {
// TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do
// need to update the dimensions?
return model;
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
Expand Down Expand Up @@ -1509,6 +1511,84 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() {
assertThat(model, is(expectedModel));
}

public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
var client = mock(Client.class);
try (var service = createService(client)) {
var model = OpenAiChatCompletionModelTests.createChatCompletionModel(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomAlphaOfLength(10)
);
assertThrows(
ElasticsearchStatusException.class,
() -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
);
}
}

public void testUpdateModelWithEmbeddingDetails_NonElandModelProvided() throws IOException {
var client = mock(Client.class);
try (var service = createService(client)) {
var originalModel = new MultilingualE5SmallModel(
randomAlphaOfLength(10),
TaskType.TEXT_EMBEDDING,
randomAlphaOfLength(10),
new MultilingualE5SmallInternalServiceSettings(
randomNonNegativeInt(),
randomNonNegativeInt(),
randomAlphaOfLength(10),
null
),
null
);

var updatedModel = service.updateModelWithEmbeddingDetails(originalModel, randomNonNegativeInt());
assertEquals(originalModel, updatedModel);
}
}

public void testUpdateModelWithEmbeddingDetails_ElandModelProvided() throws IOException {
var client = mock(Client.class);
try (var service = createService(client)) {
var originalServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
randomNonNegativeInt(),
randomNonNegativeInt(),
randomAlphaOfLength(10),
null
);
var originalModel = new CustomElandEmbeddingModel(
randomAlphaOfLength(10),
TaskType.TEXT_EMBEDDING,
randomAlphaOfLength(10),
originalServiceSettings,
ChunkingSettingsTests.createRandomChunkingSettings()
);

var embeddingSize = randomNonNegativeInt();
var expectedUpdatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
originalServiceSettings.getNumAllocations(),
originalServiceSettings.getNumThreads(),
originalServiceSettings.modelId(),
originalServiceSettings.getAdaptiveAllocationsSettings(),
embeddingSize,
originalServiceSettings.similarity(),
originalServiceSettings.elementType()
);
var expectedUpdatedModel = new CustomElandEmbeddingModel(
originalModel.getInferenceEntityId(),
originalModel.getTaskType(),
originalModel.getConfigurations().getService(),
expectedUpdatedServiceSettings,
originalModel.getConfigurations().getChunkingSettings()
);

var actualUpdatedModel = service.updateModelWithEmbeddingDetails(originalModel, embeddingSize);
assertEquals(expectedUpdatedModel, actualUpdatedModel);
}
}

public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
{
assertFalse(
Expand Down

0 comments on commit f80f054

Please sign in to comment.