Skip to content

Commit

Permalink
Adding chunking settings to MistralService, GoogleAiStudioService, an…
Browse files Browse the repository at this point in the history
…d HuggingFaceService
  • Loading branch information
dan-rubinstein committed Sep 26, 2024
1 parent c18c531 commit 6432cd0
Show file tree
Hide file tree
Showing 16 changed files with 986 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ public ModelConfigurations(String inferenceEntityId, TaskType taskType, String s
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE);
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
String service,
ServiceSettings serviceSettings,
ChunkingSettings chunkingSettings
) {
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings);
}

public ModelConfigurations(
String inferenceEntityId,
TaskType taskType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand All @@ -23,6 +24,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -72,11 +75,19 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

GoogleAiStudioModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
Expand All @@ -98,6 +109,7 @@ private static GoogleAiStudioModel createModel(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -118,6 +130,7 @@ private static GoogleAiStudioModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand All @@ -136,11 +149,17 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -151,6 +170,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage
) {
Expand All @@ -159,6 +179,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
Expand All @@ -170,11 +191,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand Down Expand Up @@ -247,11 +274,22 @@ protected void doChunkedInfer(
GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
googleAiStudioModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -38,6 +39,7 @@ public GoogleAiStudioEmbeddingsModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -47,6 +49,7 @@ public GoogleAiStudioEmbeddingsModel(
service,
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
EmptyTaskSettings.INSTANCE,
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -62,10 +65,11 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
String service,
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secrets),
serviceSettings
);
Expand Down Expand Up @@ -98,6 +102,29 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
}
}

// Should only be used directly for testing
GoogleAiStudioEmbeddingsModel(
String inferenceEntityId,
TaskType taskType,
String service,
String uri,
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings,
ChunkingSettings chunkingsettings,
@Nullable DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings),
new ModelSecrets(secrets),
serviceSettings
);
try {
this.uri = new URI(uri);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

@Override
public GoogleAiStudioEmbeddingsServiceSettings getServiceSettings() {
return (GoogleAiStudioEmbeddingsServiceSettings) super.getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
Expand All @@ -27,6 +30,7 @@

import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;

Expand Down Expand Up @@ -54,10 +58,18 @@ public void parseRequestConfig(
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

var model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
Expand All @@ -82,10 +94,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
Expand All @@ -96,10 +114,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
Expand All @@ -110,6 +134,7 @@ protected abstract HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -47,6 +49,7 @@ protected HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -57,6 +60,7 @@ protected HuggingFaceModel createModel(
taskType,
NAME,
serviceSettings,
chunkingSettings,
secretSettings,
context
);
Expand Down Expand Up @@ -113,11 +117,22 @@ protected void doChunkedInfer(
var huggingFaceModel = (HuggingFaceModel) model;
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
huggingFaceModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = huggingFaceModel.accept(actionCreator);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -56,6 +57,7 @@ protected HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand Down
Loading

0 comments on commit 6432cd0

Please sign in to comment.