Skip to content

Commit

Permalink
Add UT and ITs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Nov 29, 2024
1 parent 955807c commit ab2d736
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import lombok.extern.slf4j.Slf4j;

@Slf4j
public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public BedrockEmbeddingPreProcessFunction() {
Expand All @@ -40,6 +43,7 @@ public RemoteInferenceInputDataSet process(Map<String, String> connectorParams,
// Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
// Default dimension is 1024
int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024);
log.error("The bedrock dimensions parameter value is: {}", dimensions);
Map<String, Object> processedResult = Map
.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,12 @@ public void process_TextDocsInput_withConnectorParams() {
assertEquals(2, dataSet.getParameters().size());
assertEquals("1024", dataSet.getParameters().get("dimensions"));
}

@Test
public void process_TextDocsInput_withoutConnectorParams() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(Map.of(), mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("1024", dataSet.getParameters().get("dimensions"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.junit.Before;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

Expand Down Expand Up @@ -242,6 +243,107 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro
}
}

public void test_bedrock_embedding_v2_model_with_connector_dimensions() throws Exception {
// Skip test if key is null
if (tokenNotSet()) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json")
.toURI()
)
);
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = "with_connector_dimensions";
String modelId = registerRemoteModel(
String
.format(
StringUtils.gson.toJson(templateMap.get("with_connector_dimensions")),
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN
),
bedrockEmbeddingModelName,
true
);

List<String> input = new ArrayList<>();
input.add("Can you tell me a joke?");
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
String errorMsg = String
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 1, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}

public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exception {
// Skip test if key is null
if (tokenNotSet()) {
return;
}
String templates = Files
.readString(
Path
.of(
RestMLPredictionAction.class
.getClassLoader()
.getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json")
.toURI()
)
);
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = "with_request_dimensions";
String modelId = registerRemoteModel(
String
.format(
StringUtils.gson.toJson(templateMap.get("with_request_dimensions")),
GITHUB_CI_AWS_REGION,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN
),
bedrockEmbeddingModelName,
true
);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512"))
.build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
String errorMsg = String
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 1, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}

private boolean tokenNotSet() {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
log.info("#### The AWS credentials are not set. Skipping test. ####");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"with_connector_dimensions": {
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v2:0",
"input_docs_processed_step_size": "1",
"dimensions": "512"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "connector.post_process.bedrock.embedding"
}
]
},

"with_request_dimensions": {
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v2:0",
"input_docs_processed_step_size": "1"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "connector.post_process.bedrock.embedding"
}
]
}
}

0 comments on commit ab2d736

Please sign in to comment.