-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dimensions parameter support for bedrock titan embedding v2 model #3136
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,16 @@ | |
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; | ||
|
||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
import org.apache.commons.lang3.math.NumberUtils; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import lombok.extern.slf4j.Slf4j; | ||
|
||
@Slf4j | ||
public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { | ||
|
||
public BedrockEmbeddingPreProcessFunction() { | ||
|
@@ -24,10 +29,23 @@ public void validate(MLInput mlInput) { | |
validateTextDocsInput(mlInput); | ||
} | ||
|
||
// Keep this method for robust | ||
@Override | ||
public RemoteInferenceInputDataSet process(MLInput mlInput) { | ||
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); | ||
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))); | ||
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); | ||
} | ||
|
||
@Override | ||
public RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) { | ||
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); | ||
// Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html | ||
// Default dimension is 1024 | ||
int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is BedrockEmbeddingPreProcessFunction only for titan models? What about other bedrock models? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also notice we have multimodal embedding model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-multiemb-models.html. Is the multimodal model using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Currently this is a titan specific process function, but if in the future bedrock has another similar(similar means they have same request body, e.g. inputText & dimensions etc) text embedding model, then this can be reused to that one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case do we need to set the default dimension? If dimension is provided then we will add otherwise not? |
||
log.error("The bedrock dimensions parameter value is: {}", dimensions); | ||
Map<String, Object> processedResult = Map | ||
.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions)); | ||
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* | ||
* * Copyright OpenSearch Contributors | ||
* * SPDX-License-Identifier: Apache-2.0 | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import java.util.Map; | ||
|
||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
/** | ||
* The PreProcessFunction interface defines methods for preprocessing {@link MLInput} data | ||
* before it is used for inference. It includes methods to apply preprocessing with or without | ||
* additional parameters and to validate the input data. | ||
*/ | ||
public interface PreProcessFunction { | ||
|
||
RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput); | ||
|
||
RemoteInferenceInputDataSet apply(MLInput mlInput); | ||
|
||
/** | ||
* The default behavior of this method is to invoke process method with only the MLInput parameter, when the process | ||
* needs more parameters from the connector parameters, the concrete implementation should override this method. | ||
* @param connectorParams | ||
* @param mlInput | ||
* @return | ||
*/ | ||
default RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) { | ||
return process(mlInput); | ||
} | ||
|
||
RemoteInferenceInputDataSet process(MLInput mlInput); | ||
|
||
void validate(MLInput mlInput); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,9 @@ | |
|
||
import org.junit.Before; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.connector.ConnectorAction; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
|
||
|
@@ -242,6 +244,108 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro | |
} | ||
} | ||
|
||
public void test_bedrock_embedding_v2_model_with_connector_dimensions() throws Exception { | ||
// Skip test if key is null | ||
if (tokenNotSet()) { | ||
return; | ||
} | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json") | ||
.toURI() | ||
) | ||
); | ||
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class); | ||
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); | ||
String testCaseName = "with_connector_dimensions"; | ||
String modelId = registerRemoteModel( | ||
String | ||
.format( | ||
StringUtils.gson.toJson(templateMap.get("with_connector_dimensions")), | ||
GITHUB_CI_AWS_REGION, | ||
AWS_ACCESS_KEY_ID, | ||
AWS_SECRET_ACCESS_KEY, | ||
AWS_SESSION_TOKEN | ||
), | ||
bedrockEmbeddingModelName, | ||
true | ||
); | ||
|
||
List<String> input = new ArrayList<>(); | ||
input.add("Can you tell me a joke?"); | ||
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); | ||
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); | ||
String errorMsg = String | ||
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 1, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List); | ||
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output"); | ||
assertEquals(errorMsg, 1, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size()); | ||
} | ||
|
||
public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exception { | ||
// Skip test if key is null | ||
if (tokenNotSet()) { | ||
return; | ||
} | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json") | ||
.toURI() | ||
) | ||
); | ||
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class); | ||
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); | ||
String testCaseName = "with_request_dimensions"; | ||
String modelId = registerRemoteModel( | ||
String | ||
.format( | ||
StringUtils.gson.toJson(templateMap.get("with_request_dimensions")), | ||
GITHUB_CI_AWS_REGION, | ||
AWS_ACCESS_KEY_ID, | ||
AWS_SECRET_ACCESS_KEY, | ||
AWS_SESSION_TOKEN | ||
), | ||
bedrockEmbeddingModelName, | ||
true | ||
); | ||
|
||
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet | ||
.builder() | ||
.parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512")) | ||
.actionType(ConnectorAction.ActionType.PREDICT) | ||
.build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); | ||
Map inferenceResult = predictTextEmbeddingModelIgnoreFunctionName(modelId, mlInput); | ||
String errorMsg = String | ||
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult)); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why it is showing error here? |
||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 1, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List); | ||
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output"); | ||
assertEquals(errorMsg, 1, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size()); | ||
} | ||
|
||
private boolean tokenNotSet() { | ||
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { | ||
log.info("#### The AWS credentials are not set. Skipping test. ####"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add @log4j instead to keep the consistency?