Skip to content

Commit

Permalink
Fix conflicts when backport
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Sep 27, 2023
1 parent 01c4ef7 commit cb83022
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,20 @@
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.utils.GsonUtil;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;

/**
* Connector defines how to connect to a remote service.
Expand Down Expand Up @@ -108,7 +120,7 @@ static Connector createConnector(XContentParser parser) throws IOException {
Map<String, Object> connectorMap = parser.map();
String jsonStr;
try {
jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> GsonUtil.toJson(connectorMap));
jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(connectorMap));
} catch (PrivilegedActionException e) {
throw new IllegalArgumentException("wrong connector");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class MLPostProcessFunction {
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";

public static final String NEURAL_SEARCH_EMBEDDING = "connector.post_process.neural_search.text_embedding";
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

Expand All @@ -29,25 +29,25 @@ public class MLPostProcessFunction {
static {
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(NEURAL_SEARCH_EMBEDDING, "$[*]");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
}

public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
return numbersList -> {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (numbersList == null) {
throw new IllegalArgumentException("NumbersList is null when applying build-in post process function!");
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
numbersList.forEach(numbers -> modelTensors.add(
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{numbers.size()})
.data(numbers.toArray(new Number[0]))
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
return modelTensors;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String NEURAL_SEARCH_EMBEDDING_INPUT = "connector.pre_process.neural_search.text_embedding";
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
Expand All @@ -29,7 +29,7 @@ private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPr
static {
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(NEURAL_SEARCH_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
}

public static boolean contains(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

package org.opensearch.ml.common.input.remote;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.Map;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
Expand All @@ -17,6 +13,11 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
public class RemoteInferenceMLInput extends MLInput {
public static final String PARAMETERS_FIELD = "parameters";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,30 @@

package org.opensearch.ml.common.utils;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class StringUtils {

public static final Gson gson;

static {
gson = new Gson();
}

public static boolean isJson(String Json) {
try {
new JSONObject(Json);
Expand All @@ -45,16 +49,13 @@ public static String toUTF8(String rawString) {
return utf8EncodedString;
}

public static Map<String, Object> fromJson(String input, String defaultKey) {
if (!isJson(input)) {
return Collections.singletonMap(defaultKey, input);
}
public static Map<String, Object> fromJson(String jsonStr, String defaultKey) {
Map<String, Object> result;
JsonElement jsonElement = JsonParser.parseString(input);
JsonElement jsonElement = JsonParser.parseString(jsonStr);
if (jsonElement.isJsonObject()) {
result = GsonUtil.fromJson(jsonElement, Map.class);
result = gson.fromJson(jsonElement, Map.class);
} else if (jsonElement.isJsonArray()) {
List<Object> list = GsonUtil.fromJson(jsonElement, List.class);
List<Object> list = gson.fromJson(jsonElement, List.class);
result = new HashMap<>();
result.put(defaultKey, list);
} else {
Expand All @@ -72,7 +73,7 @@ public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs)
if (value instanceof String) {
parameters.put(key, (String)value);
} else {
parameters.put(key, GsonUtil.toJson(value));
parameters.put(key, gson.toJson(value));
}
return null;
});
Expand All @@ -82,8 +83,4 @@ public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs)
}
return parameters;
}

public static String xContentBuilderToString(XContentBuilder builder) {
return BytesReference.bytes(builder).utf8ToString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash;
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.GsonUtil;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
Expand All @@ -40,6 +39,7 @@

import static org.apache.commons.text.StringEscapeUtils.escapeJson;
import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;
import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction;
Expand Down Expand Up @@ -87,9 +87,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
throw new IllegalArgumentException("no predict action found");
}
String preProcessFunction = predictAction.get().getPreProcessFunction();
if (preProcessFunction == null) {
throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input.");
}
preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction;
if (MLPreProcessFunction.contains(preProcessFunction)) {
Map<String, Object> buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs);
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build();
Expand All @@ -102,7 +100,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
if (processedInput.isEmpty()) {
throw new IllegalArgumentException("Wrong input");
}
Map<String, Object> map = GsonUtil.fromJson(processedInput.get(), Map.class);
Map<String, Object> map = gson.fromJson(processedInput.get(), Map.class);
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build();
}
}
Expand All @@ -116,7 +114,7 @@ private static Map<String, String> convertScriptStringToJsonString(Map<String, O
if (parametersMap.get(key) instanceof String) {
parameterStringMap.put(key, (String) parametersMap.get(key));
} else {
parameterStringMap.put(key, GsonUtil.toJson(parametersMap.get(key)));
parameterStringMap.put(key, gson.toJson(parametersMap.get(key)));
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableMap;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptService;
import org.opensearch.script.ScriptType;
Expand All @@ -29,7 +30,7 @@ public static List<ModelTensor> executeBuildInPostProcessFunction(List<List<Floa
}

public static Optional<String> executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) {
Map<String, Object> result = org.opensearch.ml.common.utils.StringUtils.fromJson(resultJson, "result");
Map<String, Object> result = StringUtils.fromJson(resultJson, "result");
if (postProcessFunction != null) {
return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.GsonUtil;
import org.opensearch.script.ScriptService;

import java.io.IOException;
Expand All @@ -37,6 +36,7 @@

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.utils.StringUtils.gson;

public class ConnectorUtilsTest {

Expand All @@ -60,8 +60,6 @@ public void processInput_NullInput() {

@Test
public void processInput_TextDocsInputDataSet_NoPreprocessFunction() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input.");
TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build();

Expand Down Expand Up @@ -126,7 +124,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec
@Test
public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() {
List<String> input = Collections.singletonList("test_value");
String inputJson = GsonUtil.toJson(input);
String inputJson = gson.toJson(input);
processInput_TextDocsInputDataSet_PreprocessFunction(
"{\"input\": \"${parameters.input}\"}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "texts");
}
Expand All @@ -136,7 +134,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
List<String> input = new ArrayList<>();
input.add("test_value1");
input.add("test_value2");
String inputJson = GsonUtil.toJson(input);
String inputJson = gson.toJson(input);
processInput_TextDocsInputDataSet_PreprocessFunction(
"{\"input\": ${parameters.input}}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "input");
}
Expand Down Expand Up @@ -166,24 +164,6 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio
Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response"));
}

@Test
public void processOutput_noPostProcessFunction_nonJsonResponse() throws IOException {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "value1");
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build();
ModelTensors tensors = ConnectorUtils.processOutput("test response", connector, scriptService, parameters);
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Assert.assertEquals(1, tensors.getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("test response", tensors.getMlModelTensors().get(0).getDataAsMap().get("response"));
}

@Test
public void processOutput_PostprocessFunction() throws IOException {
String postprocessResult = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[1536],\"data\":[-0.014555434, -2.135904E-4, 0.0035105038]}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import java.io.IOException;
import java.util.Arrays;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -94,19 +97,25 @@ public void executePredict_RemoteInferenceInput() throws IOException {
}

@Test
public void executePredict_TextDocsInput_NoPreprocessFunction() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input.");
public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void test_executePreprocessFunction() {
@Test
public void test_executeBuildInPostProcessFunction() {
List<List<Float>> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f));
List<ModelTensor> modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.NEURAL_SEARCH_EMBEDDING));
List<ModelTensor> modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING));
assertNotNull(modelTensors);
assertEquals(2, modelTensors.size());
}
Expand Down

0 comments on commit cb83022

Please sign in to comment.