Skip to content

Commit

Permalink
Fix UT failures
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Sep 27, 2023
1 parent 9557283 commit 01c4ef7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -44,9 +45,12 @@ public static String toUTF8(String rawString) {
return utf8EncodedString;
}

public static Map<String, Object> fromJson(String jsonStr, String defaultKey) {
public static Map<String, Object> fromJson(String input, String defaultKey) {
if (!isJson(input)) {
return Collections.singletonMap(defaultKey, input);
}
Map<String, Object> result;
JsonElement jsonElement = JsonParser.parseString(jsonStr);
JsonElement jsonElement = JsonParser.parseString(input);
if (jsonElement.isJsonObject()) {
result = GsonUtil.fromJson(jsonElement, Map.class);
} else if (jsonElement.isJsonArray()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
if (inputData.getParameters() != null) {
Map<String, String> newParameters = new HashMap<>();
inputData.getParameters().forEach((key, value) -> {
if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
if (value == null) {
newParameters.put(key, null);
} else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) {
// no need to escape if it's already valid json
newParameters.put(key, value);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOExcep
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("key1", "value1");
parameters.put("input", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, ImmutableMap.of());
ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, parameters);
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,36 +112,39 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() {
@Test
public void executePredict_TextDocsInput() throws IOException {
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
String postprocessResult1 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[1, 2, 3]}";
String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }";
String postprocessResult2 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[4, 5, 6]}";
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult1))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult2));
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2));

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING)
.requestBody("{\"input\": \"${parameters.input}\"}")
.requestBody("{\"input\": ${parameters.input}}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
+ " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n"
+ " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n"
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {1, 2, 3}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {4, 5, 6}, modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
}
}

0 comments on commit 01c4ef7

Please sign in to comment.