Skip to content

Commit

Permalink
Adds support for input_type field to Vertex inference service (elas…
Browse files Browse the repository at this point in the history
…tic#116431)

* Adding input type to google vertex ai service

* Update docs/changelog/116431.yaml

* PR feedback - backwards compatibility

* Fix lint error
  • Loading branch information
ymao1 authored Nov 12, 2024
1 parent a71c132 commit 7039a1d
Show file tree
Hide file tree
Showing 18 changed files with 697 additions and 106 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/116431.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116431
summary: Adds support for `input_type` field to Vertex inference service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ROLE_MONITOR_STATS = def(8_787_00_0);
public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0);
public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0);
public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.googlevertexai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiEmbeddingsRequestManager;
Expand All @@ -33,9 +34,10 @@ public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceCompo
}

@Override
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings) {
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings, inputType);
var requestManager = new GoogleVertexAiEmbeddingsRequestManager(
model,
overriddenModel,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.external.action.googlevertexai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
Expand All @@ -15,7 +16,7 @@

public interface GoogleVertexAiActionVisitor {

ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings);
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings().autoTruncate()))
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings()))
.getBytes(StandardCharsets.UTF_8)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,35 @@

package org.elasticsearch.xpack.inference.external.request.googlevertexai;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, @Nullable Boolean autoTruncation) implements ToXContentObject {
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;

public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, GoogleVertexAiEmbeddingsTaskSettings taskSettings)
implements
ToXContentObject {

private static final String INSTANCES_FIELD = "instances";
private static final String CONTENT_FIELD = "content";
private static final String PARAMETERS_FIELD = "parameters";
private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
private static final String TASK_TYPE_FIELD = "task_type";

private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
private static final String CLUSTERING_TASK_TYPE = "CLUSTERING";
private static final String RETRIEVAL_DOCUMENT_TASK_TYPE = "RETRIEVAL_DOCUMENT";
private static final String RETRIEVAL_QUERY_TASK_TYPE = "RETRIEVAL_QUERY";

public GoogleVertexAiEmbeddingsRequestEntity {
Objects.requireNonNull(inputs);
Objects.requireNonNull(taskSettings);
}

@Override
Expand All @@ -35,21 +47,38 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
{
builder.field(CONTENT_FIELD, input);

if (taskSettings.getInputType() != null) {
builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
}
}
builder.endObject();
}

builder.endArray();

if (autoTruncation != null) {
if (taskSettings.autoTruncate() != null) {
builder.startObject(PARAMETERS_FIELD);
{
builder.field(AUTO_TRUNCATE_FIELD, autoTruncation);
builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
}
builder.endObject();
}
builder.endObject();

return builder;
}

static String convertToString(InputType inputType) {
return switch (inputType) {
case INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
case SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;
case CLASSIFICATION -> CLASSIFICATION_TASK_TYPE;
case CLUSTERING -> CLUSTERING_TASK_TYPE;
default -> {
assert false : invalidInputTypeMessage(inputType);
yield null;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,25 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor;

import java.net.URI;
import java.util.Map;
import java.util.Objects;

public abstract class GoogleVertexAiModel extends Model {

private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;

protected URI uri;

public GoogleVertexAiModel(
ModelConfigurations configurations,
ModelSecrets secrets,
Expand All @@ -34,13 +39,24 @@ public GoogleVertexAiModel(
public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);

uri = model.uri();
rateLimitServiceSettings = model.rateLimitServiceSettings();
}

public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) {
super(model, taskSettings);

uri = model.uri();
rateLimitServiceSettings = model.rateLimitServiceSettings();
}

public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map<String, Object> taskSettings);
public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);

public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

public URI uri() {
return uri;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ protected void doInfer(

var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());

var action = googleVertexAiModel.accept(actionCreator, taskSettings);
var action = googleVertexAiModel.accept(actionCreator, taskSettings, inputType);
action.execute(inputs, timeout, listener);
}

Expand All @@ -235,7 +235,7 @@ protected void doChunkedInfer(
).batchRequestsWithListeners(listener);

for (var request : batchedRequests) {
var action = googleVertexAiModel.accept(actionCreator, taskSettings);
var action = googleVertexAiModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor;
import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiUtils;
Expand All @@ -29,13 +31,25 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Stream;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;

public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {

private URI uri;
public static GoogleVertexAiEmbeddingsModel of(
GoogleVertexAiEmbeddingsModel model,
Map<String, Object> taskSettings,
InputType inputType
) {
var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(taskSettings);
return new GoogleVertexAiEmbeddingsModel(
model,
GoogleVertexAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)
);
}

public GoogleVertexAiEmbeddingsModel(
String inferenceEntityId,
Expand All @@ -62,6 +76,10 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google
super(model, serviceSettings);
}

public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, GoogleVertexAiEmbeddingsTaskSettings taskSettings) {
super(model, taskSettings);
}

// Should only be used directly for testing
GoogleVertexAiEmbeddingsModel(
String inferenceEntityId,
Expand Down Expand Up @@ -126,13 +144,9 @@ public GoogleVertexAiEmbeddingsRateLimitServiceSettings rateLimitServiceSettings
return (GoogleVertexAiEmbeddingsRateLimitServiceSettings) super.rateLimitServiceSettings();
}

public URI uri() {
return uri;
}

@Override
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this, taskSettings);
public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
return visitor.create(this, taskSettings, inputType);
}

public static URI buildUri(String location, String projectId, String modelId) throws URISyntaxException {
Expand Down Expand Up @@ -161,11 +175,32 @@ public static Map<String, SettingsConfiguration> get() {
new LazyInitializable<>(() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
INPUT_TYPE,
new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.DROPDOWN)
.setLabel("Input Type")
.setOrder(1)
.setRequired(false)
.setSensitive(false)
.setTooltip("Specifies the type of input passed to the model.")
.setType(SettingsConfigurationFieldType.STRING)
.setOptions(
Stream.of(
InputType.CLASSIFICATION.toString(),
InputType.CLUSTERING.toString(),
InputType.INGEST.toString(),
InputType.SEARCH.toString()
).map(v -> new SettingsConfigurationSelectOption.Builder().setLabelAndValue(v).build()).toList()
)
.setValue("")
.build()
);

configurationMap.put(
AUTO_TRUNCATE,
new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.TOGGLE)
.setLabel("Auto Truncate")
.setOrder(1)
.setOrder(2)
.setRequired(false)
.setSensitive(false)
.setTooltip("Specifies if the API truncates inputs longer than the maximum token length automatically.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,46 @@

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;

import java.util.Map;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.VALID_REQUEST_VALUES;

public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate) {
public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate, @Nullable InputType inputType) {

public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(null);
public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(
null,
null
);

public static GoogleVertexAiEmbeddingsRequestTaskSettings fromMap(Map<String, Object> map) {
if (map.isEmpty()) {
return GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS;
if (map == null || map.isEmpty()) {
return EMPTY_SETTINGS;
}

ValidationException validationException = new ValidationException();

InputType inputType = extractOptionalEnum(
map,
INPUT_TYPE,
ModelConfigurations.TASK_SETTINGS,
InputType::fromString,
VALID_REQUEST_VALUES,
validationException
);

Boolean autoTruncate = extractOptionalBoolean(map, GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, validationException);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate);
return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate, inputType);
}

}
Loading

0 comments on commit 7039a1d

Please sign in to comment.