Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dimensions parameter support for bedrock titan embedding v2 model #3136

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@

import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction;

public class MLPreProcessFunction {

private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, PreProcessFunction> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.cohere.multimodal_embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
Expand Down Expand Up @@ -55,7 +53,7 @@ public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
public static PreProcessFunction get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.Map;
import java.util.Optional;

import org.apache.commons.lang3.math.NumberUtils;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import lombok.extern.slf4j.Slf4j;

@Slf4j
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add @log4j instead to keep the consistency?

public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public BedrockEmbeddingPreProcessFunction() {
Expand All @@ -24,10 +29,23 @@ public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

// Keep this method for robust
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}

@Override
public RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
// Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
// Default dimension is 1024
int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is BedrockEmbeddingPreProcessFunction only for titan models? What about other bedrock models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also notice we have multimodal embedding model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-multiemb-models.html. Is the multimodal model using MultiModalConnectorPreProcessFunction, if so, can we also add dimension for this?

Copy link
Collaborator Author

@zane-neo zane-neo Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is BedrockEmbeddingPreProcessFunction only for titan models? What about other bedrock models?

Currently this is a titan specific process function, but if in the future bedrock has another similar(similar means they have same request body, e.g. inputText & dimensions etc) text embedding model, then this can be reused to that one.
BedRock integrated with multiple models like anthropic, cohere, meta etc, currently we only have titan, cohere pre process functions in our code, from cohere API: https://docs.cohere.com/reference/embed I didn't see they support dimension parameter so it's fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case do we need to set the default dimension? If dimension is provided then we will add otherwise not?

log.error("The bedrock dimensions parameter value is: {}", dimensions);
Map<String, Object> processedResult = Map
.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand All @@ -29,14 +28,40 @@
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true.
*/
@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {
public abstract class ConnectorPreProcessFunction implements PreProcessFunction {

/**
* This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet.
* If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed.
*/
protected boolean returnDirectlyForRemoteInferenceInput;

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
* @param connectorParams the connector parameters: including parameters defined in the connector and the parameters from request.
* refer to RemoteConnectorExecutor.preparePayloadAndInvoke for details.
* @param mlInput the MLInput object to be processed
* @return RemoteInferenceInputDataSet resulting from the pre-processing function
* @throws IllegalArgumentException if the input MLInput object is null
*/
@Override
public RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput) {
if (mlInput == null) {
throw new IllegalArgumentException("Preprocess function input can't be null");
}
if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
} else {
validate(mlInput);
if (connectorParams != null) {
return process(connectorParams, mlInput);
} else {
return process(mlInput);
}
}
}

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
Expand All @@ -57,10 +82,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {
}
}

public abstract void validate(MLInput mlInput);

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

/**
* Validates the input of a pre-process function for text documents.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import java.util.Map;

import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
* The PreProcessFunction interface defines methods for preprocessing {@link MLInput} data
* before it is used for inference. It includes methods to apply preprocessing with or without
* additional parameters and to validate the input data.
*/
public interface PreProcessFunction {

RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput);

RemoteInferenceInputDataSet apply(MLInput mlInput);

/**
* The default behavior of this method is to invoke process method with only the MLInput parameter, when the process
* needs more parameters from the connector parameters, the concrete implementation should override this method.
* @param connectorParams
* @param mlInput
* @return
*/
default RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) {
return process(mlInput);
}

RemoteInferenceInputDataSet process(MLInput mlInput);

void validate(MLInput mlInput);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public void setUp() {
function = new BedrockEmbeddingPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("key1", "value1", "key2", "value2", "dimensions", "1024"))
.build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
Expand Down Expand Up @@ -73,4 +76,20 @@ public void process_RemoteInferenceInput() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}

@Test
public void process_TextDocsInput_withConnectorParams() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(Map.of("dimensions", "1024"), mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("1024", dataSet.getParameters().get("dimensions"));
}

@Test
public void process_TextDocsInput_withoutConnectorParams() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(Map.of(), mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("1024", dataSet.getParameters().get("dimensions"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
Expand All @@ -34,6 +33,7 @@
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
Expand Down Expand Up @@ -106,8 +106,8 @@ private static RemoteInferenceInputDataSet processMLInput(
} else {
preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction);
if (MLPreProcessFunction.contains(preProcessFunction)) {
Function<MLInput, RemoteInferenceInputDataSet> function = MLPreProcessFunction.get(preProcessFunction);
return function.apply(mlInput);
PreProcessFunction function = MLPreProcessFunction.get(preProcessFunction);
return function.apply(parameters, mlInput);
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT)
&& Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,25 @@ public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOExc
return parseResponseToMap(response);
}

public Map predictTextEmbeddingModelIgnoreFunctionName(String modelId, MLInput mlInput) throws IOException {
Response response = null;
try {
response = TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/models/" + modelId + "/_predict",
null,
TestHelper.toJsonString(mlInput),
null
);
} catch (ResponseException e) {
log.error(e.getMessage(), e);
response = e.getResponse();
}
return parseResponseToMap(response);
}

public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
return (modelProfile) -> {
if (modelProfile.containsKey("model_state")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

Expand Down Expand Up @@ -242,6 +244,108 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro
}
}

public void test_bedrock_embedding_v2_model_with_connector_dimensions() throws Exception {
// Skip test if key is null
if (tokenNotSet()) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json")
.toURI()
)
);
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = "with_connector_dimensions";
String modelId = registerRemoteModel(
String
.format(
StringUtils.gson.toJson(templateMap.get("with_connector_dimensions")),
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN
),
bedrockEmbeddingModelName,
true
);

List<String> input = new ArrayList<>();
input.add("Can you tell me a joke?");
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
String errorMsg = String
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 1, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}

public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exception {
// Skip test if key is null
if (tokenNotSet()) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json")
.toURI()
)
);
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = "with_request_dimensions";
String modelId = registerRemoteModel(
String
.format(
StringUtils.gson.toJson(templateMap.get("with_request_dimensions")),
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN
),
bedrockEmbeddingModelName,
true
);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512"))
.actionType(ConnectorAction.ActionType.PREDICT)
.build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
Map inferenceResult = predictTextEmbeddingModelIgnoreFunctionName(modelId, mlInput);
String errorMsg = String
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it is showing error here?

List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 1, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}

private boolean tokenNotSet() {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
log.info("#### The AWS credentials are not set. Skipping test. ####");
Expand Down
Loading
Loading