From 01c4ef792d9513cacfa45bd0a815e462f6596e53 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 22 Aug 2023 12:42:32 +0800 Subject: [PATCH] Fix UT failures Signed-off-by: zane-neo --- .../ml/common/utils/StringUtils.java | 8 +++++-- .../algorithms/remote/ConnectorUtils.java | 4 +++- .../algorithms/remote/ConnectorUtilsTest.java | 4 ++-- .../remote/HttpJsonConnectorExecutorTest.java | 23 +++++++++++-------- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 479adf7a79..5b35000807 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -18,6 +18,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,9 +45,12 @@ public static String toUTF8(String rawString) { return utf8EncodedString; } - public static Map fromJson(String jsonStr, String defaultKey) { + public static Map fromJson(String input, String defaultKey) { + if (!isJson(input)) { + return Collections.singletonMap(defaultKey, input); + } Map result; - JsonElement jsonElement = JsonParser.parseString(jsonStr); + JsonElement jsonElement = JsonParser.parseString(input); if (jsonElement.isJsonObject()) { result = GsonUtil.fromJson(jsonElement, Map.class); } else if (jsonElement.isJsonArray()) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 7b307d98fa..e8795a2e79 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -67,7 +67,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); inputData.getParameters().forEach((key, value) -> { - if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { + if (value == null) { + newParameters.put(key, null); + } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { // no need to escape if it's already valid json newParameters.put(key, value); } else { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 86e1568137..98e90e668f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -175,9 +175,9 @@ public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOExcep .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Map parameters = new HashMap<>(); - parameters.put("key1", "value1"); + parameters.put("input", "value1"); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); - ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, ImmutableMap.of()); + ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, parameters); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 8d04603d2a..a8290f0692 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -112,14 +112,10 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() { @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String postprocessResult1 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[1, 2, 3]}"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - String postprocessResult2 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[4, 5, 6]}"; when(scriptService.compile(any(), any())) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult2)); + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -127,21 +123,28 @@ public void executePredict_TextDocsInput() throws IOException { .url("http://test.com/mock") .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" + + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" + + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertArrayEquals(new Number[] {1, 2, 3}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); - Assert.assertArrayEquals(new Number[] {4, 5, 6}, modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } }