Skip to content

Commit

Permalink
Add dimensions parameter support for bedrock titan embedding v2 model
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Oct 14, 2024
1 parent 6277410 commit 2d4c496
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@

import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction;

public class MLPreProcessFunction {

private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, PreProcessFunction> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
Expand Down Expand Up @@ -50,7 +48,7 @@ public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
public static PreProcessFunction get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.Map;
import java.util.Optional;

import org.apache.commons.lang3.math.NumberUtils;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -24,10 +26,22 @@ public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

// Keep this method for robust
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}

@Override
public RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
// Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
// Default dimension is 1024
int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024);
Map<String, Object> processedResult = Map
.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand All @@ -29,14 +28,40 @@
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true.
*/
@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {
public abstract class ConnectorPreProcessFunction implements PreProcessFunction {

/**
* This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet.
* If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed.
*/
protected boolean returnDirectlyForRemoteInferenceInput;

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
* @param connectorParams the connector parameters: including parameters defined in the connector and the parameters from request.
* refer to RemoteConnectorExecutor.preparePayloadAndInvoke for details.
* @param mlInput the MLInput object to be processed
* @return RemoteInferenceInputDataSet resulting from the pre-processing function
* @throws IllegalArgumentException if the input MLInput object is null
*/
@Override
public RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput) {
if (mlInput == null) {
throw new IllegalArgumentException("Preprocess function input can't be null");
}
if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
} else {
validate(mlInput);
if (connectorParams != null) {
return process(connectorParams, mlInput);
} else {
return process(mlInput);
}
}
}

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
Expand All @@ -57,10 +82,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {
}
}

public abstract void validate(MLInput mlInput);

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

/**
* Validates the input of a pre-process function for text documents.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import java.util.Map;

import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
* The PreProcessFunction interface defines methods for preprocessing {@link MLInput} data
* before it is used for inference. It includes methods to apply preprocessing with or without
* additional parameters and to validate the input data.
*/
public interface PreProcessFunction {

RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput);

RemoteInferenceInputDataSet apply(MLInput mlInput);

/**
* The default behavior of this method is to invoke process method with only the MLInput parameter, when the process
* needs more parameters from the connector parameters, the concrete implementation should override this method.
* @param connectorParams
* @param mlInput
* @return
*/
default RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) {
return process(mlInput);
}

RemoteInferenceInputDataSet process(MLInput mlInput);

void validate(MLInput mlInput);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public void setUp() {
function = new BedrockEmbeddingPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("key1", "value1", "key2", "value2", "dimensions", "1024"))
.build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
Expand Down Expand Up @@ -73,4 +76,12 @@ public void process_RemoteInferenceInput() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}

@Test
public void process_TextDocsInput_withConnectorParams() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(Map.of("dimensions", "1024"), mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("1024", dataSet.getParameters().get("dimensions"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
Expand All @@ -34,6 +33,7 @@
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
Expand Down Expand Up @@ -106,8 +106,8 @@ private static RemoteInferenceInputDataSet processMLInput(
} else {
preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction);
if (MLPreProcessFunction.contains(preProcessFunction)) {
Function<MLInput, RemoteInferenceInputDataSet> function = MLPreProcessFunction.get(preProcessFunction);
return function.apply(mlInput);
PreProcessFunction function = MLPreProcessFunction.get(preProcessFunction);
return function.apply(parameters, mlInput);
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT)
&& Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {
Expand Down

0 comments on commit 2d4c496

Please sign in to comment.