Skip to content

Commit

Permalink
Add bedrock pre/post process function
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Oct 25, 2023
1 parent e15221d commit 3ed8d32
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,60 @@ public class MLPostProcessFunction {

public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";

public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, Function<List<?>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();


static {
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildListResultModelTensors());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildListResultModelTensors());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildListResultModelTensors());
POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildSingleResultModelTensor());
}

public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
public static Function<List<?>, List<ModelTensor>> buildSingleResultModelTensor() {
return embedding -> {
if (embedding == null) {
throw new IllegalArgumentException("The embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> modelTensors.add(
List<ModelTensor> modelTensors = new ArrayList<>();
modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
);
return modelTensors;
};
}

public static Function<List<?>, List<ModelTensor>> buildListResultModelTensors() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> {
List<Number> eachEmbedding = (List<Number>) embedding;
modelTensors.add(ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[] { eachEmbedding.size() })
.data(eachEmbedding.toArray(new Number[0]))
.build());
});
return modelTensors;
};
}
Expand All @@ -58,7 +80,7 @@ public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static Function<List<List<Float>>, List<ModelTensor>> get(String postProcessFunction) {
public static Function<List<?>, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.common.connector;

import org.apache.http.impl.cookie.BasicCommentHandler;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -13,9 +15,11 @@
public class MLPreProcessFunction {

private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();

private static final Map<String, Boolean> BATCH_EMBEDDING_SUPPORT = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
Expand All @@ -26,17 +30,39 @@ private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPr
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

private static Function<List<String>, Map<String, Object>> bedrockTextEmbeddingPreProcess() {
return inputs -> {
if (inputs.size() != 1) {
throw new IllegalArgumentException("The length of inputs is not 1 when using the bedrock pre-processing function.");
}
return Map.of("parameters", Map.of("inputText", inputs.get(0)));
};
}

static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess());
BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, true);
BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, true);
BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, false);
BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, true);
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<List<String>, Map<String, Object>> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
public static Function<List<String>, Map<String, Object>> get(String preProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(preProcessFunction);
}

public static boolean getBatchEmbeddingSupportFlag(String preProcessFunction) {
//by default, set the batch embedding support to false.
if (!BATCH_EMBEDDING_SUPPORT.containsKey(preProcessFunction)) {
return false;
}
return BATCH_EMBEDDING_SUPPORT.get(preProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.opensearch.ml.common.output.MLOutput;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;

public class MLPostProcessFunctionTest {
Expand All @@ -40,16 +43,31 @@ public void test_getResponseFilter() {
}

@Test
public void test_buildModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList());
public void test_buildListResultModelTensors() {
Assert.assertNotNull(MLPostProcessFunction.buildListResultModelTensors());
List<List<Float>> numbersList = new ArrayList<>();
numbersList.add(Collections.singletonList(1.0f));
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList));
Assert.assertNotNull(MLPostProcessFunction.buildListResultModelTensors().apply(numbersList));
}

@Test
public void test_buildModelTensorList_exception() {
public void test_buildListResultModelTensors_exception() {
exceptionRule.expect(IllegalArgumentException.class);
MLPostProcessFunction.buildModelTensorList().apply(null);
MLPostProcessFunction.buildListResultModelTensors().apply(null);
}

@Test
public void test_buildSingleResultModelTensors() {
Assert.assertNotNull(MLPostProcessFunction.buildSingleResultModelTensor());
List<Float> numbersList = Collections.singletonList(1.0f);
Assert.assertNotNull(MLPostProcessFunction.buildSingleResultModelTensor().apply(numbersList));
}

@Test
public void test_buildSingleResultModelTensors_exception() {
exceptionRule.expect(IllegalArgumentException.class);
ArgumentCaptor<IllegalArgumentException> argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
MLPostProcessFunction.buildSingleResultModelTensor().apply(null);
verify(argumentCaptor.capture().getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,7 @@ public class ConnectorUtils {
signer = Aws4Signer.create();
}

public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
if (mlInput == null) {
throw new IllegalArgumentException("Input is null");
}
RemoteInferenceInputDataSet inputData;
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService);
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset();
} else {
throw new IllegalArgumentException("Wrong input type");
}
private static RemoteInferenceInputDataSet escapeInput(RemoteInferenceInputDataSet inputData) {
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().forEach((key, value) -> {
Expand All @@ -80,13 +69,15 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
}
return inputData;
}
private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map<String, String> parameters, ScriptService scriptService) {
Optional<ConnectorAction> predictAction = connector.findPredictAction();
if (predictAction.isEmpty()) {
throw new IllegalArgumentException("no predict action found");

public static RemoteInferenceInputDataSet processRemoteInput(MLInput input) {
return escapeInput((RemoteInferenceInputDataSet) input.getInputDataset());
}
public static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, String preProcessFunction, Map<String, String> parameters, ScriptService scriptService) {
if (inputDataSet == null) {
throw new IllegalArgumentException("Input is null");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction;
preProcessFunction = Optional.ofNullable(preProcessFunction).orElse(MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT);
if (MLPreProcessFunction.contains(preProcessFunction)) {
Map<String, Object> buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(inputDataSet.getDocs());
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build();
Expand All @@ -111,7 +102,8 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
throw new IllegalArgumentException("Wrong input");
}
Map<String, Object> map = gson.fromJson(processedInput.get(), Map.class);
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build();
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build();
return escapeInput(remoteInferenceInputDataSet);
}
}

Expand Down
Loading

0 comments on commit 3ed8d32

Please sign in to comment.