diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 06f476c809..419e460c95 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -30,8 +30,20 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; -import org.opensearch.ml.common.utils.GsonUtil; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; /** * Connector defines how to connect to a remote service. @@ -108,7 +120,7 @@ static Connector createConnector(XContentParser parser) throws IOException { Map connectorMap = parser.map(); String jsonStr; try { - jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> GsonUtil.toJson(connectorMap)); + jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(connectorMap)); } catch (PrivilegedActionException e) { throw new IllegalArgumentException("wrong connector"); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 663ee6b031..9d9ba90171 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -19,7 +19,7 @@ public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; - public static final String NEURAL_SEARCH_EMBEDDING = "connector.post_process.neural_search.text_embedding"; + public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); @@ -29,25 +29,25 @@ public class MLPostProcessFunction { static { JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); - JSON_PATH_EXPRESSION.put(NEURAL_SEARCH_EMBEDDING, "$[*]"); + JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]"); POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList()); POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList()); - POST_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING, buildModelTensorList()); + POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList()); } public static Function>, List> buildModelTensorList() { - return numbersList -> { + return embeddings -> { List modelTensors = new ArrayList<>(); - if (numbersList == null) { - throw new IllegalArgumentException("NumbersList is null when applying build-in post process function!"); + if (embeddings == null) { + throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); } - numbersList.forEach(numbers -> modelTensors.add( + embeddings.forEach(embedding -> modelTensors.add( ModelTensor .builder() .name("sentence_embedding") .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{numbers.size()}) - .data(numbers.toArray(new Number[0])) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) .build() )); return modelTensors; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 23a575e860..0a41e17a9b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -16,7 +16,7 @@ public class MLPreProcessFunction { public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; - public static final String NEURAL_SEARCH_EMBEDDING_INPUT = "connector.pre_process.neural_search.text_embedding"; + public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; private static Function, Map> cohereTextEmbeddingPreProcess() { return inputs -> Map.of("parameters", Map.of("texts", inputs)); @@ -29,7 +29,7 @@ private static Function, Map> openAiTextEmbeddingPr static { PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); - PRE_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); } public static boolean contains(String functionName) { diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index 412e7a8e7e..da4a9ad73d 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -5,10 +5,6 @@ package org.opensearch.ml.common.input.remote; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; -import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; @@ -17,6 +13,11 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + @org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) public class RemoteInferenceMLInput extends MLInput { public static final String PARAMETERS_FIELD = "parameters"; diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 5b35000807..edbd94b37f 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,26 +5,30 @@ package org.opensearch.ml.common.utils; +import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonParser; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.xcontent.XContentBuilder; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; public class StringUtils { + public static final Gson gson; + + static { + gson = new Gson(); + } + public static boolean isJson(String Json) { try { new JSONObject(Json); @@ -45,16 +49,13 @@ public static String toUTF8(String rawString) { return utf8EncodedString; } - public static Map fromJson(String input, String defaultKey) { - if (!isJson(input)) { - return Collections.singletonMap(defaultKey, input); - } + public static Map fromJson(String jsonStr, String defaultKey) { Map result; - JsonElement jsonElement = JsonParser.parseString(input); + JsonElement jsonElement = JsonParser.parseString(jsonStr); if (jsonElement.isJsonObject()) { - result = GsonUtil.fromJson(jsonElement, Map.class); + result = gson.fromJson(jsonElement, Map.class); } else if (jsonElement.isJsonArray()) { - List list = GsonUtil.fromJson(jsonElement, List.class); + List list = gson.fromJson(jsonElement, List.class); result = new HashMap<>(); result.put(defaultKey, list); } else { @@ -72,7 +73,7 @@ public static Map getParameterMap(Map parameterObjs) if (value instanceof String) { parameters.put(key, (String)value); } else { - parameters.put(key, GsonUtil.toJson(value)); + parameters.put(key, gson.toJson(value)); } return null; }); @@ -82,8 +83,4 @@ public static Map getParameterMap(Map parameterObjs) } return parameters; } - - public static String xContentBuilderToString(XContentBuilder builder) { - return BytesReference.bytes(builder).utf8ToString(); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 2582c8a6df..ab32c1580b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -32,6 +32,7 @@ import java.util.zip.ZipEntry; import java.util.zip.ZipFile; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index e8795a2e79..ac3f8a7eda 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -18,7 +18,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.GsonUtil; import org.opensearch.script.ScriptService; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -40,6 +39,7 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; @@ -87,9 +87,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat throw new IllegalArgumentException("no predict action found"); } String preProcessFunction = predictAction.get().getPreProcessFunction(); - if (preProcessFunction == null) { - throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input."); - } + preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; if (MLPreProcessFunction.contains(preProcessFunction)) { Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); @@ -102,7 +100,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat if (processedInput.isEmpty()) { throw new IllegalArgumentException("Wrong input"); } - Map map = GsonUtil.fromJson(processedInput.get(), Map.class); + Map map = gson.fromJson(processedInput.get(), Map.class); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); } } @@ -116,7 +114,7 @@ private static Map convertScriptStringToJsonString(Map executeBuildInPostProcessFunction(List executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { - Map result = org.opensearch.ml.common.utils.StringUtils.fromJson(resultJson, "result"); + Map result = StringUtils.fromJson(resultJson, "result"); if (postProcessFunction != null) { return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 98e90e668f..8e046b151c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -24,7 +24,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.GsonUtil; import org.opensearch.script.ScriptService; import java.io.IOException; @@ -37,6 +36,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; public class ConnectorUtilsTest { @@ -60,8 +60,6 @@ public void processInput_NullInput() { @Test public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); @@ -126,7 +124,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() { List input = Collections.singletonList("test_value"); - String inputJson = GsonUtil.toJson(input); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( "{\"input\": \"${parameters.input}\"}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "texts"); } @@ -136,7 +134,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() List input = new ArrayList<>(); input.add("test_value1"); input.add("test_value2"); - String inputJson = GsonUtil.toJson(input); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( "{\"input\": ${parameters.input}}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "input"); } @@ -166,24 +164,6 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); } - @Test - public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOException { - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Map parameters = new HashMap<>(); - parameters.put("input", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); - ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, parameters); - Assert.assertEquals(1, tensors.getMlModelTensors().size()); - Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); - Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size()); - Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response")); - } - @Test public void processOutput_PostprocessFunction() throws IOException { String postprocessResult = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[1536],\"data\":[-0.014555434, -2.135904E-4, 0.0035105038]}"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index a8290f0692..9caf621087 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -29,12 +29,15 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import java.io.IOException; import java.util.Arrays; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -94,19 +97,25 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); + public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java index 24daf5ea8c..6ca1401efd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -39,7 +39,7 @@ public void test_executePreprocessFunction() { @Test public void test_executeBuildInPostProcessFunction() { List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); - List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.NEURAL_SEARCH_EMBEDDING)); + List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING)); assertNotNull(modelTensors); assertEquals(2, modelTensors.size()); }