From 955807cffcc9b5fd1391bde62bde577e889c7c23 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 14 Oct 2024 12:05:01 +0800 Subject: [PATCH 1/3] Add dimensions parameter support for bedrock titan embedding v2 model Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 8 ++-- .../BedrockEmbeddingPreProcessFunction.java | 14 +++++++ .../ConnectorPreProcessFunction.java | 33 ++++++++++++--- .../preprocess/PreProcessFunction.java | 40 +++++++++++++++++++ ...edrockEmbeddingPreProcessFunctionTest.java | 13 +++++- .../algorithms/remote/ConnectorUtils.java | 6 +-- 6 files changed, 99 insertions(+), 15 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 723da8c07d..ab1c575055 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -7,7 +7,6 @@ import java.util.HashMap; import java.util.Map; -import java.util.function.Function; import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; @@ -15,12 +14,11 @@ import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction; public class MLPreProcessFunction { - private static final Map> PRE_PROCESS_FUNCTIONS = new HashMap<>(); + private static final Map PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.cohere.multimodal_embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; @@ -55,7 +53,7 @@ public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static Function get(String postProcessFunction) { + public static PreProcessFunction get(String postProcessFunction) { return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java index b6a95be042..cbc140fcc1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -8,7 +8,9 @@ import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.Map; +import java.util.Optional; +import org.apache.commons.lang3.math.NumberUtils; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -24,10 +26,22 @@ public void validate(MLInput mlInput) { validateTextDocsInput(mlInput); } + // Keep this method for robust @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); Map processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); } + + @Override + public RemoteInferenceInputDataSet process(Map connectorParams, MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + // Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html + // Default dimension is 1024 + int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024); + Map processedResult = Map + .of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions)); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 387ac27467..e8305ab5c3 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -10,7 +10,6 @@ import java.util.Collections; import java.util.Locale; import java.util.Map; -import java.util.function.Function; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -29,7 +28,7 @@ * If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true. */ @Log4j2 -public abstract class ConnectorPreProcessFunction implements Function { +public abstract class ConnectorPreProcessFunction implements PreProcessFunction { /** * This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet. @@ -37,6 +36,32 @@ public abstract class ConnectorPreProcessFunction implements Function connectorParams, MLInput mlInput) { + if (mlInput == null) { + throw new IllegalArgumentException("Preprocess function input can't be null"); + } + if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + } else { + validate(mlInput); + if (connectorParams != null) { + return process(connectorParams, mlInput); + } else { + return process(mlInput); + } + } + } + /** * Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet. * @@ -57,10 +82,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { } } - public abstract void validate(MLInput mlInput); - - public abstract RemoteInferenceInputDataSet process(MLInput mlInput); - /** * Validates the input of a pre-process function for text documents. * diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java new file mode 100644 index 0000000000..dd24d3c6a3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java @@ -0,0 +1,40 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import java.util.Map; + +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +/** + * The PreProcessFunction interface defines methods for preprocessing {@link MLInput} data + * before it is used for inference. It includes methods to apply preprocessing with or without + * additional parameters and to validate the input data. + */ +public interface PreProcessFunction { + + RemoteInferenceInputDataSet apply(Map connectorParams, MLInput mlInput); + + RemoteInferenceInputDataSet apply(MLInput mlInput); + + /** + * The default behavior of this method is to invoke process method with only the MLInput parameter, when the process + * needs more parameters from the connector parameters, the concrete implementation should override this method. + * @param connectorParams + * @param mlInput + * @return + */ + default RemoteInferenceInputDataSet process(Map connectorParams, MLInput mlInput) { + return process(mlInput); + } + + RemoteInferenceInputDataSet process(MLInput mlInput); + + void validate(MLInput mlInput); +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java index 851d7eaab7..eb6e023c34 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -39,7 +39,10 @@ public void setUp() { function = new BedrockEmbeddingPreProcessFunction(); textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); - remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("key1", "value1", "key2", "value2", "dimensions", "1024")) + .build(); textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); @@ -73,4 +76,12 @@ public void process_RemoteInferenceInput() { RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); assertEquals(remoteInferenceInputDataSet, dataSet); } + + @Test + public void process_TextDocsInput_withConnectorParams() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(Map.of("dimensions", "1024"), mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("1024", dataSet.getParameters().get("dimensions")); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f2c93ef5fd..89af9ed6a2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -25,7 +25,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; @@ -34,6 +33,7 @@ import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; @@ -106,8 +106,8 @@ private static RemoteInferenceInputDataSet processMLInput( } else { preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction); if (MLPreProcessFunction.contains(preProcessFunction)) { - Function function = MLPreProcessFunction.get(preProcessFunction); - return function.apply(mlInput); + PreProcessFunction function = MLPreProcessFunction.get(preProcessFunction); + return function.apply(parameters, mlInput); } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT) && Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) { From ab2d736b3f4871a8c596bc8746bb9edbfbbc7f26 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 14 Oct 2024 15:30:06 +0800 Subject: [PATCH 2/3] Add UT and ITs Signed-off-by: zane-neo --- .../BedrockEmbeddingPreProcessFunction.java | 4 + ...edrockEmbeddingPreProcessFunctionTest.java | 8 ++ .../ml/rest/RestBedRockInferenceIT.java | 102 ++++++++++++++++++ .../BedRockEmbeddingV2ModelBodies.json | 66 ++++++++++++ 4 files changed, 180 insertions(+) create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java index cbc140fcc1..34b72bee97 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -15,6 +15,9 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import lombok.extern.slf4j.Slf4j; + +@Slf4j public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { public BedrockEmbeddingPreProcessFunction() { @@ -40,6 +43,7 @@ public RemoteInferenceInputDataSet process(Map connectorParams, // Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html // Default dimension is 1024 int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024); + log.error("The bedrock dimensions parameter value is: {}", dimensions); Map processedResult = Map .of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions)); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java index eb6e023c34..228baec782 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -84,4 +84,12 @@ public void process_TextDocsInput_withConnectorParams() { assertEquals(2, dataSet.getParameters().size()); assertEquals("1024", dataSet.getParameters().get("dimensions")); } + + @Test + public void process_TextDocsInput_withoutConnectorParams() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(Map.of(), mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("1024", dataSet.getParameters().get("dimensions")); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 286d45d308..d8e21471d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -16,6 +16,7 @@ import org.junit.Before; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; @@ -242,6 +243,107 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro } } + public void test_bedrock_embedding_v2_model_with_connector_dimensions() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = "with_connector_dimensions"; + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateMap.get("with_connector_dimensions")), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + List input = new ArrayList<>(); + input.add("Can you tell me a joke?"); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 512, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + + public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exception { + // Skip test if key is null + if (tokenNotSet()) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = "with_request_dimensions"; + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateMap.get("with_request_dimensions")), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512")) + .build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + String errorMsg = String + .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 1, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 512, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + private boolean tokenNotSet() { if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { log.info("#### The AWS credentials are not set. Skipping test. ####"); diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json new file mode 100644 index 0000000000..a674843b94 --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json @@ -0,0 +1,66 @@ +{ + "with_connector_dimensions": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "input_docs_processed_step_size": "1", + "dimensions": "512" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + + "with_request_dimensions": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v2:0", + "input_docs_processed_step_size": "1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} From 04e1d8ec24541738e8286a6646ee3cf93e48b648 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 3 Dec 2024 20:24:08 +0800 Subject: [PATCH 3/3] fix IT failure Signed-off-by: zane-neo --- .../ml/rest/MLCommonsRestTestCase.java | 19 +++++++++++++++++++ .../ml/rest/RestBedRockInferenceIT.java | 4 +++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 2092b9f4b4..8ab548afa8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -930,6 +930,25 @@ public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOExc return parseResponseToMap(response); } + public Map predictTextEmbeddingModelIgnoreFunctionName(String modelId, MLInput mlInput) throws IOException { + Response response = null; + try { + response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/models/" + modelId + "/_predict", + null, + TestHelper.toJsonString(mlInput), + null + ); + } catch (ResponseException e) { + log.error(e.getMessage(), e); + response = e.getResponse(); + } + return parseResponseToMap(response); + } + public Consumer> verifyTextEmbeddingModelDeployed() { return (modelProfile) -> { if (modelProfile.containsKey("model_state")) { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index d8e21471d9..bf074a1073 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -15,6 +15,7 @@ import org.junit.Before; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -327,9 +328,10 @@ public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exc RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder() .parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512")) + .actionType(ConnectorAction.ActionType.PREDICT) .build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); - Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModelIgnoreFunctionName(modelId, mlInput); String errorMsg = String .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));