From 3ed8d3229eef5642a85de3e16f59b84f43c405cc Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 25 Oct 2023 14:14:58 +0800 Subject: [PATCH] Add bedrock pre/post process function Signed-off-by: zane-neo --- .../connector/MLPostProcessFunction.java | 48 ++++++++--- .../connector/MLPreProcessFunction.java | 32 ++++++- .../connector/MLPostProcessFunctionTest.java | 28 +++++-- .../algorithms/remote/ConnectorUtils.java | 30 +++---- .../remote/RemoteConnectorExecutor.java | 83 ++++++++++++------- .../engine/algorithms/remote/RemoteModel.java | 3 + .../ml/engine/utils/ScriptUtils.java | 2 +- .../algorithms/remote/ConnectorUtilsTest.java | 8 +- .../remote/HttpJsonConnectorExecutorTest.java | 26 +++++- 9 files changed, 182 insertions(+), 78 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 9d9ba90171..b62213bc0d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -18,30 +18,32 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; - + public static final String BEDROCK_EMBEDDING = "connector.post_process.bedrock.embedding"; public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); - private static final Map>, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); + private static final Map, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); static { JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); - POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList()); + JSON_PATH_EXPRESSION.put(BEDROCK_EMBEDDING, "$.embedding"); + POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildListResultModelTensors()); + POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildListResultModelTensors()); + POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildListResultModelTensors()); + POST_PROCESS_FUNCTIONS.put(BEDROCK_EMBEDDING, buildSingleResultModelTensor()); } - public static Function>, List> buildModelTensorList() { - return embeddings -> { - List modelTensors = new ArrayList<>(); - if (embeddings == null) { - throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); + public static Function, List> buildSingleResultModelTensor() { + return embedding -> { + if (embedding == null) { + throw new IllegalArgumentException("The embeddings is null when using the built-in post-processing function."); } - embeddings.forEach(embedding -> modelTensors.add( + List modelTensors = new ArrayList<>(); + modelTensors.add( ModelTensor .builder() .name("sentence_embedding") @@ -49,7 +51,27 @@ public static Function>, List> buildModelTensorLis .shape(new long[]{embedding.size()}) .data(embedding.toArray(new Number[0])) .build() - )); + ); + return modelTensors; + }; + } + + public static Function, List> buildListResultModelTensors() { + return embeddings -> { + List modelTensors = new ArrayList<>(); + if (embeddings == null) { + throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); + } + embeddings.forEach(embedding -> { + List eachEmbedding = (List) embedding; + modelTensors.add(ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { eachEmbedding.size() }) + .data(eachEmbedding.toArray(new Number[0])) + .build()); + }); return modelTensors; }; } @@ -58,7 +80,7 @@ public static String getResponseFilter(String postProcessFunction) { return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static Function>, List> get(String postProcessFunction) { + public static Function, List> get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 0a41e17a9b..d4b0e5d302 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import org.apache.http.impl.cookie.BasicCommentHandler; + import java.util.HashMap; import java.util.List; import java.util.Map; @@ -13,9 +15,11 @@ public class MLPreProcessFunction { private static final Map, Map>> PRE_PROCESS_FUNCTIONS = new HashMap<>(); + + private static final Map BATCH_EMBEDDING_SUPPORT = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; - + public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; private static Function, Map> cohereTextEmbeddingPreProcess() { @@ -26,17 +30,39 @@ private static Function, Map> openAiTextEmbeddingPr return inputs -> Map.of("parameters", Map.of("input", inputs)); } + private static Function, Map> bedrockTextEmbeddingPreProcess() { + return inputs -> { + if (inputs.size() != 1) { + throw new IllegalArgumentException("The length of inputs is not 1 when using the bedrock pre-processing function."); + } + return Map.of("parameters", Map.of("inputText", inputs.get(0))); + }; + } + static { PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockTextEmbeddingPreProcess()); + BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, true); + BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, true); + BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, false); + BATCH_EMBEDDING_SUPPORT.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, true); } public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static Function, Map> get(String postProcessFunction) { - return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); + public static Function, Map> get(String preProcessFunction) { + return PRE_PROCESS_FUNCTIONS.get(preProcessFunction); + } + + public static boolean getBatchEmbeddingSupportFlag(String preProcessFunction) { + //by default, set the batch embedding support to false. + if (!BATCH_EMBEDDING_SUPPORT.containsKey(preProcessFunction)) { + return false; + } + return BATCH_EMBEDDING_SUPPORT.get(preProcessFunction); } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 5d4c0c88d7..58558a122a 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -9,11 +9,14 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.opensearch.ml.common.output.MLOutput; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import static org.mockito.Mockito.verify; import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; public class MLPostProcessFunctionTest { @@ -40,16 +43,31 @@ public void test_getResponseFilter() { } @Test - public void test_buildModelTensorList() { - Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList()); + public void test_buildListResultModelTensors() { + Assert.assertNotNull(MLPostProcessFunction.buildListResultModelTensors()); List> numbersList = new ArrayList<>(); numbersList.add(Collections.singletonList(1.0f)); - Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList)); + Assert.assertNotNull(MLPostProcessFunction.buildListResultModelTensors().apply(numbersList)); } @Test - public void test_buildModelTensorList_exception() { + public void test_buildListResultModelTensors_exception() { exceptionRule.expect(IllegalArgumentException.class); - MLPostProcessFunction.buildModelTensorList().apply(null); + MLPostProcessFunction.buildListResultModelTensors().apply(null); + } + + @Test + public void test_buildSingleResultModelTensors() { + Assert.assertNotNull(MLPostProcessFunction.buildSingleResultModelTensor()); + List numbersList = Collections.singletonList(1.0f); + Assert.assertNotNull(MLPostProcessFunction.buildSingleResultModelTensor().apply(numbersList)); + } + + @Test + public void test_buildSingleResultModelTensors_exception() { + exceptionRule.expect(IllegalArgumentException.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + MLPostProcessFunction.buildSingleResultModelTensor().apply(null); + verify(argumentCaptor.capture().getMessage()); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cbfad6fd5f..5674d16c12 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -52,18 +52,7 @@ public class ConnectorUtils { signer = Aws4Signer.create(); } - public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connector connector, Map parameters, ScriptService scriptService) { - if (mlInput == null) { - throw new IllegalArgumentException("Input is null"); - } - RemoteInferenceInputDataSet inputData; - if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService); - } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { - inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset(); - } else { - throw new IllegalArgumentException("Wrong input type"); - } + private static RemoteInferenceInputDataSet escapeInput(RemoteInferenceInputDataSet inputData) { if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); inputData.getParameters().forEach((key, value) -> { @@ -80,13 +69,15 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } return inputData; } - private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map parameters, ScriptService scriptService) { - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + + public static RemoteInferenceInputDataSet processRemoteInput(MLInput input) { + return escapeInput((RemoteInferenceInputDataSet) input.getInputDataset()); + } + public static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, String preProcessFunction, Map parameters, ScriptService scriptService) { + if (inputDataSet == null) { + throw new IllegalArgumentException("Input is null"); } - String preProcessFunction = predictAction.get().getPreProcessFunction(); - preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; + preProcessFunction = Optional.ofNullable(preProcessFunction).orElse(MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT); if (MLPreProcessFunction.contains(preProcessFunction)) { Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(inputDataSet.getDocs()); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); @@ -111,7 +102,8 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat throw new IllegalArgumentException("Wrong input"); } Map map = gson.fromJson(processedInput.get(), Map.class); - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + return escapeInput(remoteInferenceInputDataSet); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 2366b150d4..231feeeb7a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -8,8 +8,9 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -22,31 +23,31 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; -import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processRemoteInput; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processTextDocsInput; public interface RemoteConnectorExecutor { default ModelTensorOutput executePredict(MLInput mlInput) { List tensorOutputs = new ArrayList<>(); - - if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - int processedDocs = 0; - while(processedDocs < textDocsInputDataSet.getDocs().size()) { - List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); - List tempTensorOutputs = new ArrayList<>(); - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); - int tensorCount = 0; - if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { - tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); - } - processedDocs += Math.max(tensorCount, 1); - tensorOutputs.addAll(tempTensorOutputs); - } - + Connector connector = getConnector(); + Optional predictAction = connector.findPredictAction(); + if (predictAction.isEmpty()) { + throw new IllegalArgumentException("no predict action found"); + } + Map parameters = new HashMap<>(); + if (connector.getParameters() != null) { + parameters.putAll(connector.getParameters()); + } + MLInputDataset inputDataset = mlInput.getInputDataset(); + if (inputDataset instanceof RemoteInferenceInputDataSet) { + processRemoteInputDataSetInvocation(mlInput, parameters, connector, tensorOutputs); + } else if (inputDataset instanceof TextDocsInputDataSet) { + processTextDocsInputDataSetInvocation(mlInput, parameters, connector, tensorOutputs); } else { - preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); + throw new IllegalArgumentException("Wrong input type"); } return new ModelTensorOutput(tensorOutputs); } @@ -57,25 +58,47 @@ default void setClient(Client client){} default void setXContentRegistry(NamedXContentRegistry xContentRegistry){} default void setClusterService(ClusterService clusterService){} - default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List tensorOutputs) { - Connector connector = getConnector(); - - Map parameters = new HashMap<>(); - if (connector.getParameters() != null) { - parameters.putAll(connector.getParameters()); + private void processTextDocsInputDataSetInvocation(MLInput mlInput, Map parameters, Connector connector, List tensorOutputs) { + if (((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs() == null || ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().isEmpty()) { + throw new IllegalArgumentException("Input text docs size is empty, can not invoke remote model"); } - MLInputDataset inputDataset = mlInput.getInputDataset(); - if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) { - parameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); + String preProcessFunction = connector.findPredictAction() + .map(ConnectorAction::getPreProcessFunction) + .orElse(MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT); + boolean batchEmbeddingSupportFlag = MLPreProcessFunction.getBatchEmbeddingSupportFlag(preProcessFunction); + if (batchEmbeddingSupportFlag) { + RemoteInferenceInputDataSet inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), preProcessFunction, parameters, getScriptService()); + String payload = createPayload(inputData, connector, parameters); + invokeRemoteModel(mlInput, parameters, payload, tensorOutputs); + } else { + TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); + int size = Optional.ofNullable(textDocsInputDataSet).map(TextDocsInputDataSet::getDocs).map(List::size).orElse(0); + for (int i = 0; i < size; i++) { + TextDocsInputDataSet singleDocTextDocInputDataSet = + TextDocsInputDataSet.builder().docs(List.of(textDocsInputDataSet.getDocs().get(i))).build(); + RemoteInferenceInputDataSet inputData = processTextDocsInput(singleDocTextDocInputDataSet, preProcessFunction, parameters, getScriptService()); + String payload = createPayload(inputData, connector, parameters); + invokeRemoteModel(mlInput, parameters, payload, tensorOutputs); + } } + } + + private void processRemoteInputDataSetInvocation(MLInput mlInput, Map parameters, Connector connector, List tensorOutputs) { + RemoteInferenceInputDataSet remoteInferenceInputDataSet = processRemoteInput(mlInput); + if (remoteInferenceInputDataSet.getParameters() != null) { + parameters.putAll(remoteInferenceInputDataSet.getParameters()); + } + String payload = createPayload(remoteInferenceInputDataSet, connector, parameters); + invokeRemoteModel(mlInput, parameters, payload, tensorOutputs); + } - RemoteInferenceInputDataSet inputData = processInput(mlInput, connector, parameters, getScriptService()); + private String createPayload(RemoteInferenceInputDataSet inputData, Connector connector, Map parameters) { if (inputData.getParameters() != null) { parameters.putAll(inputData.getParameters()); } String payload = connector.createPredictPayload(parameters); connector.validatePayload(payload); - invokeRemoteModel(mlInput, parameters, payload, tensorOutputs); + return payload; } void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 4449ee6996..b92973ecc4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -49,6 +49,9 @@ public MLOutput predict(MLInput mlInput, MLModel model) { @Override public MLOutput predict(MLInput mlInput) { + if (mlInput == null) { + throw new IllegalArgumentException("Input is null"); + } if (!isModelReady()) { throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models//_deploy"); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java index 28053ddca6..2ac426cae1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ScriptUtils.java @@ -25,7 +25,7 @@ public static Optional executePreprocessFunction(ScriptService scriptSer return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences))); } - public static List executeBuildInPostProcessFunction(List> vectors, Function>, List> function) { + public static List executeBuildInPostProcessFunction(List vectors, Function, List> function) { return function.apply(vectors); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 8e046b151c..1bb72ec68d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -55,7 +55,7 @@ public void setUp() { public void processInput_NullInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Input is null"); - ConnectorUtils.processInput(null, null, new HashMap<>(), null); + ConnectorUtils.processTextDocsInput(null, null, new HashMap<>(), null); } @Test @@ -70,7 +70,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processTextDocsInput(dataSet, null, new HashMap<>(), scriptService); } @Test @@ -117,7 +117,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processRemoteInput(mlInput); Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); } @@ -204,7 +204,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request Map parameters = new HashMap<>(); parameters.put("key1", "value1"); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); - RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processTextDocsInput(dataSet, preProcessName, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 11ae20c470..39b50e3fe9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -39,6 +39,9 @@ import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -112,7 +115,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti .requestBody("{\"input\": ${parameters.input}}") .build(); when(httpClient.execute(any())).thenReturn(response); - HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + HttpEntity entity = new StringEntity("[{\"response\": \"test result1\"}, {\"response\": \"test result2\"}]"); when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); @@ -121,10 +124,27 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); + // If TextDocsInputDataSet has no preprocess function, the preprocess function will be set to MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT. + // This default preprocess function will process the input into a list of strings format and for now all TextDocsInputDataSet is for text embedding + // including dense embedding and sparse embedding which both case accepts list of string format. For this input format, the result will + // always be a single MLModelOutput with a single MLModelTensor with a dataAsMap with key "response" and value is the original result from + // remote interface including a list of embeddings or a list of objects(sparse embedding). + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")); + Assert.assertEquals(2, ((List)modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")).size()); + Assert.assertEquals("test result1", + Optional.of(modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")) + .map(x -> ((List) x).get(0)) + .map(x -> ((Map) x).get("response")) + .get() + ); + Assert.assertEquals("test result2", + Optional.of(modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")) + .map(x -> ((List) x).get(1)) + .map(x -> ((Map) x).get("response")) + .get() + ); } @Test