From 8780d1cbd28b0889a69b9b0f9ece2ff4d98104ef Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 26 Sep 2023 10:08:32 +0800 Subject: [PATCH] Rebase main code to feature branch (#1386) * Add Auto Release Workflow (#1306) * Add Auto Release Workflow Signed-off-by: Sicheng Song * Fix release note address Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * Bump aws-encryption-sdk-java to fix CVE-2023-33201 (#1309) Signed-off-by: Sicheng Song * Add release note for 2.10.0 release (#1312) * Add release note for 2.10.0 Signed-off-by: Sicheng Song * Add CVE fix Signed-off-by: Sicheng Song --------- Signed-off-by: Sicheng Song * fixing doc link (#1318) * fixing doc link Signed-off-by: Dhrubo Saha * fixing indentation Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha * Fix unassigned ml system shard replicas (#1315) (#1324) * Fix unassigned ml system shard replicas * Adjust auto replica settings to keep it consistent with AOS default setting * Update plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java * Modify exception handling * Modify exception messages * Add response check * Add response check and exception handling * Keep error message consistent * Keep error message consistent * Keep error message consistent --------- Signed-off-by: Sicheng Song Co-authored-by: Yaliang Wu * Adjust index replicas settings to keep consistent with AOS 2.9 (#1325) Signed-off-by: Sicheng Song * Make 2.10 release notes up to date (#1345) Signed-off-by: Sicheng Song * fix spelling (#1363) Signed-off-by: Kalyan * Add neural search default processor for non OpenAI/Cohere scenario (#1274) * Add neural search default pre/post process function support Signed-off-by: zane-neo * Fix UT failures Signed-off-by: zane-neo * Address PR comment to remove nonJson response case Signed-off-by: zane-neo * Fix low code coverage issue Signed-off-by: zane-neo * fix format issue Signed-off-by: zane-neo * Try to fix classNotFound issue in IT Signed-off-by: zane-neo * revert Try to fix classNotFound issue in IT Signed-off-by: zane-neo * Change gson dependency to compileOnly Signed-off-by: zane-neo * Change default pre/post process function name Signed-off-by: zane-neo * Address code review comments Signed-off-by: zane-neo * Make preprocess function to default Signed-off-by: zane-neo * Remove GsonUtil since there already a single instance in StringUtils Signed-off-by: zane-neo * Fix UT failures Signed-off-by: zane-neo * Address comments Signed-off-by: zane-neo * use import instead of fully qualified name Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --------- Signed-off-by: Sicheng Song Signed-off-by: Dhrubo Saha Signed-off-by: Kalyan Signed-off-by: zane-neo Co-authored-by: Sicheng Song Co-authored-by: Dhrubo Saha Co-authored-by: Yaliang Wu Co-authored-by: Kalyan --- .github/workflows/auto-release.yml | 28 ++++ README.md | 2 +- .../org/opensearch/ml/common/CommonValue.java | 10 +- .../ml/common/connector/Connector.java | 22 +-- .../connector/MLPostProcessFunction.java | 85 +++++------ .../connector/MLPreProcessFunction.java | 41 +++--- .../input/remote/RemoteInferenceMLInput.java | 5 - .../ml/common/utils/StringUtils.java | 1 + .../connector/MLPostProcessFunctionTest.java | 29 ++++ .../text_embedding_model_examples.md | 62 ++++---- docs/tutorials/remote_inference.md | 4 +- ml-algorithms/build.gradle | 2 +- .../org/opensearch/ml/engine/ModelHelper.java | 4 +- .../algorithms/remote/ConnectorUtils.java | 135 ++++++++++-------- .../remote/RemoteConnectorExecutor.java | 12 +- .../ml/engine/utils/ScriptUtils.java | 33 ++--- .../remote/AwsConnectorExecutorTest.java | 34 +++++ .../algorithms/remote/ConnectorUtilsTest.java | 32 +++-- .../remote/HttpJsonConnectorExecutorTest.java | 44 +++--- .../ml/engine/utils/ScriptUtilsTest.java | 59 ++++++++ .../ml/indices/MLIndicesHandler.java | 32 +++-- ...search-ml-common.release-notes-2.10.0.0.md | 49 +++++++ 22 files changed, 473 insertions(+), 252 deletions(-) create mode 100644 .github/workflows/auto-release.yml create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java create mode 100644 release-notes/opensearch-ml-common.release-notes-2.10.0.0.md diff --git a/.github/workflows/auto-release.yml b/.github/workflows/auto-release.yml new file mode 100644 index 0000000000..214283feeb --- /dev/null +++ b/.github/workflows/auto-release.yml @@ -0,0 +1,28 @@ +name: Releases + +on: + push: + tags: + - '*' + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: GitHub App token + id: github_app_token + uses: tibdex/github-app-token@v1.5.0 + with: + app_id: ${{ secrets.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + installation_id: 22958780 + - name: Get tag + id: tag + uses: dawidd6/action-get-tag@v1 + - uses: actions/checkout@v2 + - uses: ncipollo/release-action@v1 + with: + github_token: ${{ steps.github_app_token.outputs.token }} + bodyFile: release-notes/opensearch-ml-common.release-notes-${{steps.tag.outputs.tag}}.md \ No newline at end of file diff --git a/README.md b/README.md index 8e009d272f..2136b3fbed 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Machine Learning Commons for OpenSearch is a new solution that make it easy to d Until today, the challenge is significant to build a new machine learning feature inside OpenSearch. The reasons include: * **Disruption to OpenSearch Core features**. Machine learning is very computationally intensive. But currently there is no way to add dedicated computation resources in OpenSearch for machine learning jobs, hence these jobs have to share same resources with Core features, such as: indexing and searching. That might cause the latency increasing on search request, and cause circuit breaker exception on memory usage. To address this, we have to carefully distribute models and limit the data size to run the AD job. When more and more ML features are added into OpenSearch, it will become much harder to manage. -* **Lack of support for machine learning algorithms.** Customers need more algorighms within Opensearch, otherwise the data need be exported to outside of elasticsearch, such as s3 first to do the job, which will bring extra cost and latency. +* **Lack of support for machine learning algorithms.** Customers need more algorithms within Opensearch, otherwise the data need be exported to outside of elasticsearch, such as s3 first to do the job, which will bring extra cost and latency. * **Lack of resource management mechanism between multiple machine learning jobs.** It's hard to coordinate the resources between multi features. diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 16554933b5..dab60ce986 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -35,13 +35,13 @@ public class CommonValue { public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; - public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 1; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 6; + public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7; public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; - public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; - public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1; + public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER + "\": {\n" diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 81278ceb89..b3f9aafad8 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -5,6 +5,17 @@ package org.opensearch.ml.common.connector; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -20,17 +31,6 @@ import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 662db37341..9d9ba90171 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,61 +5,64 @@ package org.opensearch.ml.common.connector; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; public class MLPostProcessFunction { - private static Map POST_PROCESS_FUNCTIONS; public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; + public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; + + private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); + + private static final Map>, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); + + static { - POST_PROCESS_FUNCTIONS = new HashMap<>(); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, "\n def name = \"sentence_embedding\";\n" + - " def dataType = \"FLOAT32\";\n" + - " if (params.embeddings == null || params.embeddings.length == 0) {\n" + - " return null;\n" + - " }\n" + - " def embeddings = params.embeddings;\n" + - " StringBuilder builder = new StringBuilder(\"[\");\n" + - " for (int i=0; i>, List> buildModelTensorList() { + return embeddings -> { + List modelTensors = new ArrayList<>(); + if (embeddings == null) { + throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); + } + embeddings.forEach(embedding -> modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) + .build() + )); + return modelTensors; + }; } - public static boolean contains(String functionName) { - return POST_PROCESS_FUNCTIONS.containsKey(functionName); + public static String getResponseFilter(String postProcessFunction) { + return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static String get(String postProcessFunction) { + public static Function>, List> get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } + + public static boolean contains(String postProcessFunction) { + return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index fdbfb52d0f..0a41e17a9b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -6,44 +6,37 @@ package org.opensearch.ml.common.connector; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; public class MLPreProcessFunction { - private static Map PRE_PROCESS_FUNCTIONS; + private static final Map, Map>> PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; + public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; + + private static Function, Map> cohereTextEmbeddingPreProcess() { + return inputs -> Map.of("parameters", Map.of("texts", inputs)); + } + + private static Function, Map> openAiTextEmbeddingPreProcess() { + return inputs -> Map.of("parameters", Map.of("input", inputs)); + } + static { - PRE_PROCESS_FUNCTIONS = new HashMap<>(); - //TODO: change to java for openAI, embedding and Titan - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + - " builder.append(\"[\");\n" + - " for (int i=0; i< params.text_docs.length; i++) {\n" + - " builder.append(\"\\\"\");\n" + - " builder.append(params.text_docs[i]);\n" + - " builder.append(\"\\\"\");\n" + - " if (i < params.text_docs.length - 1) {\n" + - " builder.append(\",\")\n" + - " }\n" + - " }\n" + - " builder.append(\"]\");\n" + - " def parameters = \"{\" +\"\\\"prompt\\\":\" + builder + \"}\";\n" + - " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); - - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + - " builder.append(\"\\\"\");\n" + - " builder.append(params.text_docs[0]);\n" + - " builder.append(\"\\\"\");\n" + - " def parameters = \"{\" +\"\\\"input\\\":\" + builder + \"}\";\n" + - " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); } public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static String get(String postProcessFunction) { + public static Function, Map> get(String postProcessFunction) { return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index 5992e77a24..da4a9ad73d 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -14,14 +14,9 @@ import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; @org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) public class RemoteInferenceMLInput extends MLInput { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 968cda1575..edbd94b37f 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -24,6 +24,7 @@ public class StringUtils { public static final Gson gson; + static { gson = new Gson(); } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 346d5901a8..5d4c0c88d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -6,12 +6,21 @@ package org.opensearch.ml.common.connector; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; public class MLPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Test public void contains() { Assert.assertTrue(MLPostProcessFunction.contains(OPENAI_EMBEDDING)); @@ -23,4 +32,24 @@ public void get() { Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING)); Assert.assertNull(MLPostProcessFunction.get("wrong value")); } + + @Test + public void test_getResponseFilter() { + assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING); + assert null == MLPostProcessFunction.getResponseFilter("wrong value"); + } + + @Test + public void test_buildModelTensorList() { + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList()); + List> numbersList = new ArrayList<>(); + numbersList.add(Collections.singletonList(1.0f)); + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList)); + } + + @Test + public void test_buildModelTensorList_exception() { + exceptionRule.expect(IllegalArgumentException.class); + MLPostProcessFunction.buildModelTensorList().apply(null); + } } diff --git a/docs/model_serving_framework/text_embedding_model_examples.md b/docs/model_serving_framework/text_embedding_model_examples.md index ee04e11ae8..5c77f2a584 100644 --- a/docs/model_serving_framework/text_embedding_model_examples.md +++ b/docs/model_serving_framework/text_embedding_model_examples.md @@ -79,7 +79,7 @@ POST /_plugins/_ml/models/_register { "name": "huggingface/sentence-transformers/all-MiniLM-L12-v2", "version": "1.0.1", - "model_format": "TORCH_SCRIPT" + "model_format": "TORCH_SCRIPT", "model_group_id": "7IjOsYgBFp6IJxCceZ1-" } @@ -129,12 +129,12 @@ Now we can register model using URL upload: POST /_plugins/_ml/models/_register { "name": "sentence-transformers/all-MiniLM-L6-v2", - "version": "1.0.1", - "description": "This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.", - "model_task_type": "TEXT_EMBEDDING", - "model_format": "TORCH_SCRIPT", - "model_content_hash_value": "c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f", - "model_config": { + "version": "1.0.1", + "description": "This is a sentence-transformers model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.", + "model_task_type": "TEXT_EMBEDDING", + "model_format": "TORCH_SCRIPT", + "model_content_hash_value": "c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f", + "model_config": { "model_type": "bert", "embedding_dimension": 384, "framework_type": "sentence_transformers", @@ -303,18 +303,18 @@ The only difference is the uploading model input, for load/predict/profile/unloa # Sample request POST /_plugins/_ml/models/_upload { - "name": "sentence-transformers/all-MiniLM-L6-v2", - "version": "1.0.0", - "description": "test model", - "model_format": "TORCH_SCRIPT", - "model_config": { - "model_type": "bert", - "embedding_dimension": 384, - "framework_type": "huggingface_transformers", - "pooling_mode":"mean", - "normalize_result":"true" - }, - "url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_huggingface.zip?raw=true" + "name": "sentence-transformers/all-MiniLM-L6-v2", + "version": "1.0.0", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_config": { + "model_type": "bert", + "embedding_dimension": 384, + "framework_type": "huggingface_transformers", + "pooling_mode":"mean", + "normalize_result":"true" + }, + "url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_huggingface.zip?raw=true" } ``` @@ -330,17 +330,17 @@ The only difference is the uploading model input, for load/predict/profile/unloa # Sample request POST /_plugins/_ml/models/_upload { - "name": "sentence-transformers/all-MiniLM-L6-v2", - "version": "1.0.0", - "description": "test model", - "model_format": "ONNX", - "model_config": { - "model_type": "bert", - "embedding_dimension": 384, - "framework_type": "huggingface_transformers", - "pooling_mode":"mean", - "normalize_result":"true" - }, - "url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_onnx.zip?raw=true" + "name": "sentence-transformers/all-MiniLM-L6-v2", + "version": "1.0.0", + "description": "test model", + "model_format": "ONNX", + "model_config": { + "model_type": "bert", + "embedding_dimension": 384, + "framework_type": "huggingface_transformers", + "pooling_mode":"mean", + "normalize_result":"true" + }, + "url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_onnx.zip?raw=true" } ``` diff --git a/docs/tutorials/remote_inference.md b/docs/tutorials/remote_inference.md index 8cd653b3a9..8120c90cdc 100644 --- a/docs/tutorials/remote_inference.md +++ b/docs/tutorials/remote_inference.md @@ -54,7 +54,7 @@ That means if user have access to a model group, they have access to all model v # Connector -A connector defines protocol between ml-commons and external ML service. Read this [connector doc](https://opensearch.org/docs/latest/ml-commons-plugin/connectors/) for more details. +A connector defines protocol between ml-commons and external ML service. Read this [connector doc](https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/) for more details. A connector consists of 7 parts @@ -473,4 +473,4 @@ POST _plugins/_ml/models/AFeTb4kBJ1eYAeTMSVl0/_predict ] } } -``` \ No newline at end of file +``` diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 98fc9d5909..6108da8190 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -60,7 +60,7 @@ dependencies { implementation platform('software.amazon.awssdk:bom:2.20.19') implementation 'software.amazon.awssdk:auth' implementation 'software.amazon.awssdk:apache-client' - implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.0' + implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1' implementation 'com.jayway.jsonpath:json-path:2.8.0' implementation group: 'org.json', name: 'json', version: '20230227' } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 28c9dc20d5..b712b66975 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -7,7 +7,6 @@ import ai.djl.training.util.DownloadUtils; import ai.djl.training.util.ProgressBar; -import com.google.gson.Gson; import com.google.gson.stream.JsonReader; import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; @@ -32,6 +31,7 @@ import java.util.zip.ZipEntry; import java.util.zip.ZipFile; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks; @@ -48,11 +48,9 @@ public class ModelHelper { public static final String PYTORCH_ENGINE = "PyTorch"; public static final String ONNX_ENGINE = "OnnxRuntime"; private final MLEngine mlEngine; - private Gson gson; public ModelHelper(MLEngine mlEngine) { this.mlEngine = mlEngine; - gson = new Gson(); } public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener listener) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cd3038f49c..ac3f8a7eda 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -5,17 +5,19 @@ package org.opensearch.ml.engine.algorithms.remote; -import com.google.common.collect.ImmutableMap; import com.jayway.jsonpath.JsonPath; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.ScriptService; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -37,10 +39,12 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; -import static org.opensearch.ml.engine.utils.ScriptUtils.executePostprocessFunction; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; +import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; -import static org.opensearch.ml.engine.utils.ScriptUtils.gson; +@Log4j2 public class ConnectorUtils { private static final Aws4Signer signer; @@ -54,43 +58,7 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } RemoteInferenceInputDataSet inputData; if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - TextDocsInputDataSet inputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset(); - List docs = new ArrayList<>(inputDataSet.getDocs()); - Map params = ImmutableMap.of("text_docs", docs); - Optional predictAction = connector.findPredictAction(); - if (!predictAction.isPresent()) { - throw new IllegalArgumentException("no predict action found"); - } - String preProcessFunction = predictAction.get().getPreProcessFunction(); - if (preProcessFunction == null) { - throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input."); - } - if (preProcessFunction != null && preProcessFunction.contains("${parameters")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - preProcessFunction = substitutor.replace(preProcessFunction); - } - Optional processedResponse = executePreprocessFunction(scriptService, preProcessFunction, params); - if (!processedResponse.isPresent()) { - throw new IllegalArgumentException("Wrong input"); - } - Map map = gson.fromJson(processedResponse.get(), Map.class); - Map parametersMap = (Map) map.get("parameters"); - Map processedParameters = new HashMap<>(); - for (String key : parametersMap.keySet()) { - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - if (parametersMap.get(key) instanceof String) { - processedParameters.put(key, (String) parametersMap.get(key)); - } else { - processedParameters.put(key, gson.toJson(parametersMap.get(key))); - } - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } - } - inputData = RemoteInferenceInputDataSet.builder().parameters(processedParameters).build(); + inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService); } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset(); } else { @@ -98,20 +66,65 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); - inputData.getParameters().entrySet().forEach(entry -> { - if (entry.getValue() == null) { - newParameters.put(entry.getKey(), entry.getValue()); - } else if (StringUtils.isJson(entry.getValue())) { + inputData.getParameters().forEach((key, value) -> { + if (value == null) { + newParameters.put(key, null); + } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { // no need to escape if it's already valid json - newParameters.put(entry.getKey(), entry.getValue()); + newParameters.put(key, value); } else { - newParameters.put(entry.getKey(), escapeJson(entry.getValue())); + newParameters.put(key, escapeJson(value)); } }); inputData.setParameters(newParameters); } return inputData; } + private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map parameters, ScriptService scriptService) { + List docs = new ArrayList<>(inputDataSet.getDocs()); + Optional predictAction = connector.findPredictAction(); + if (predictAction.isEmpty()) { + throw new IllegalArgumentException("no predict action found"); + } + String preProcessFunction = predictAction.get().getPreProcessFunction(); + preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; + if (MLPreProcessFunction.contains(preProcessFunction)) { + Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); + } else { + if (preProcessFunction.contains("${parameters")) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + preProcessFunction = substitutor.replace(preProcessFunction); + } + Optional processedInput = executePreprocessFunction(scriptService, preProcessFunction, docs); + if (processedInput.isEmpty()) { + throw new IllegalArgumentException("Wrong input"); + } + Map map = gson.fromJson(processedInput.get(), Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } + } + + private static Map convertScriptStringToJsonString(Map processedInput) { + Map parameterStringMap = new HashMap<>(); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map parametersMap = (Map) processedInput.get("parameters"); + for (String key : parametersMap.keySet()) { + if (parametersMap.get(key) instanceof String) { + parameterStringMap.put(key, (String) parametersMap.get(key)); + } else { + parameterStringMap.put(key, gson.toJson(parametersMap.get(key))); + } + } + return null; + }); + } catch (PrivilegedActionException e) { + log.error("Error processing parameters", e); + throw new RuntimeException(e); + } + return parameterStringMap; + } public static ModelTensors processOutput(String modelResponse, Connector connector, ScriptService scriptService, Map parameters) throws IOException { if (modelResponse == null) { @@ -119,26 +132,36 @@ public static ModelTensors processOutput(String modelResponse, Connector connect } List modelTensors = new ArrayList<>(); Optional predictAction = connector.findPredictAction(); - if (!predictAction.isPresent()) { + if (predictAction.isEmpty()) { throw new IllegalArgumentException("no predict action found"); } - String postProcessFunction = predictAction.get().getPostProcessFunction(); + ConnectorAction connectorAction = predictAction.get(); + String postProcessFunction = connectorAction.getPostProcessFunction(); if (postProcessFunction != null && postProcessFunction.contains("${parameters")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); postProcessFunction = substitutor.replace(postProcessFunction); } - Optional processedResponse = executePostprocessFunction(scriptService, postProcessFunction, modelResponse); + String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); + if (MLPostProcessFunction.contains(postProcessFunction)) { + // in this case, we can use jsonpath to build a List> result from model response. + if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); + List> vectors = JsonPath.read(modelResponse, responseFilter); + List processedResponse = executeBuildInPostProcessFunction(vectors, MLPostProcessFunction.get(postProcessFunction)); + return ModelTensors.builder().mlModelTensors(processedResponse).build(); + } + + // execute user defined painless script. + Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - if (parameters.get(RESPONSE_FILTER_FIELD) == null) { - connector.parseResponse(response, modelTensors, postProcessFunction != null && processedResponse.isPresent()); + boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent(); + if (responseFilter == null) { + connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { Object filteredResponse = JsonPath.parse(response).read(parameters.get(RESPONSE_FILTER_FIELD)); - connector.parseResponse(filteredResponse, modelTensors, postProcessFunction != null && processedResponse.isPresent()); + connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor); } - - ModelTensors tensors = ModelTensors.builder().mlModelTensors(modelTensors).build(); - return tensors; + return ModelTensors.builder().mlModelTensors(modelTensors).build(); } public static SdkHttpFullRequest signRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String signingName, String region) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index c9b6e78873..8712f771c7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,14 +32,8 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList(textDocsInputDataSet.getDocs()); - for (int i = 0; i < textDocsInputDataSet.getDocs().size(); i++) { - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); - if (tensorOutputs.size() >= textDocsInputDataSet.getDocs().size()) { - break; - } - textDocs.remove(0); - } + List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); + preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); } @@ -65,7 +59,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List executePreprocessFunction(ScriptService scriptService, String preProcessFunction, List inputSentences) { + return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences))); } - public static Optional executePreprocessFunction(ScriptService scriptService, - String preProcessFunction, - Map params) { - if (MLPreProcessFunction.contains(preProcessFunction)) { - preProcessFunction = MLPreProcessFunction.get(preProcessFunction); - } - if (preProcessFunction != null) { - return Optional.ofNullable(executeScript(scriptService, preProcessFunction, params)); - } - return Optional.empty(); + public static List executeBuildInPostProcessFunction(List> vectors, Function>, List> function) { + return function.apply(vectors); } - public static Optional executePostprocessFunction(ScriptService scriptService, - String postProcessFunction, - String resultJson) { + public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); - if (MLPostProcessFunction.contains(postProcessFunction)) { - postProcessFunction = MLPostProcessFunction.get(postProcessFunction); - } if (postProcessFunction != null) { return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); } return Optional.empty(); } - public static String executeScript(ScriptService scriptService, String painlessScript, Map params) { Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index ecc143ea6f..5dbbf2090e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.Assert; import org.junit.Before; @@ -18,7 +19,9 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -136,4 +139,35 @@ public void executePredict_RemoteInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void executePredict_TextDocsInferenceInput() throws IOException { + String jsonString = "{\"key\":\"value\"}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 2a84e2fee1..8e046b151c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -27,12 +27,16 @@ import org.opensearch.script.ScriptService; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; public class ConnectorUtilsTest { @@ -56,8 +60,6 @@ public void processInput_NullInput() { @Test public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); @@ -121,18 +123,20 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() { + List input = Collections.singletonList("test_value"); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( - "{\"input\": \"${parameters.input}\"}", - "{\"parameters\": { \"input\": \"test_value\" } }", - "test_value"); + "{\"input\": \"${parameters.input}\"}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "texts"); } @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() { + List input = new ArrayList<>(); + input.add("test_value1"); + input.add("test_value2"); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( - "{\"input\": ${parameters.input}}", - "{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }", - "[\"test_value1\",\"test_value2\"]"); + "{\"input\": ${parameters.input}}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "input"); } @Test @@ -143,7 +147,7 @@ public void processOutput_NullResponse() throws IOException { } @Test - public void processOutput_NoPostprocessFunction() throws IOException { + public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") @@ -186,10 +190,8 @@ public void processOutput_PostprocessFunction() throws IOException { Assert.assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); } - private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, String preprocessResult, String expectedProcessedInput) { - when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); - - TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); + private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, List inputs, String expectedProcessedInput, String preProcessName, String resultKey) { + TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(inputs).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); ConnectorAction predictAction = ConnectorAction.builder() @@ -197,7 +199,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request .method("POST") .url("http://test.com/mock") .requestBody(requestBody) - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT) + .preProcessFunction(preProcessName) .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); @@ -205,6 +207,6 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); - Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get("input")); + Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 8d04603d2a..9caf621087 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -29,12 +29,15 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import java.io.IOException; import java.util.Arrays; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -94,32 +97,34 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); + public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")); } @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String postprocessResult1 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[1, 2, 3]}"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - String postprocessResult2 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[4, 5, 6]}"; when(scriptService.compile(any(), any())) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult2)); + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -127,21 +132,28 @@ public void executePredict_TextDocsInput() throws IOException { .url("http://test.com/mock") .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" + + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" + + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertArrayEquals(new Number[] {1, 2, 3}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); - Assert.assertArrayEquals(new Number[] {4, 5, 6}, modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java new file mode 100644 index 0000000000..6ca1401efd --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -0,0 +1,59 @@ +package org.opensearch.ml.engine.utils; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.script.ScriptService; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class ScriptUtilsTest { + + @Mock + ScriptService scriptService; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("test result")); + } + + @Test + public void test_executePreprocessFunction() { + Optional resultOpt = ScriptUtils.executePreprocessFunction(scriptService, "any function", Collections.singletonList("any input")); + assertEquals("test result", resultOpt.get()); + } + + @Test + public void test_executeBuildInPostProcessFunction() { + List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); + List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING)); + assertNotNull(modelTensors); + assertEquals(2, modelTensors.size()); + } + + @Test + public void test_executePostProcessFunction() { + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"test result\"}")); + Optional resultOpt = ScriptUtils.executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}"); + assertEquals("{\"result\": \"test result\"}", resultOpt.get()); + } + + @Test + public void test_executeScript() { + String result = ScriptUtils.executeScript(scriptService, "any function", Collections.singletonMap("key", "value")); + assertEquals("test result", result); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index 5d5a00cab8..d278fa6415 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -6,8 +6,6 @@ package org.opensearch.ml.indices; import static org.opensearch.ml.common.CommonValue.META; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.common.CommonValue.SCHEMA_VERSION_FIELD; import java.util.HashMap; @@ -17,6 +15,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; @@ -38,11 +37,13 @@ public class MLIndicesHandler { ClusterService clusterService; Client client; - + private static final Map indexSettings = Map.of("index.auto_expand_replicas", "0-1"); private static final Map indexMappingUpdated = new HashMap<>(); + static { - indexMappingUpdated.put(ML_MODEL_INDEX, new AtomicBoolean(false)); - indexMappingUpdated.put(ML_TASK_INDEX, new AtomicBoolean(false)); + for (MLIndex mlIndex : MLIndex.values()) { + indexMappingUpdated.put(mlIndex.getIndexName(), new AtomicBoolean(false)); + } } public void initModelGroupIndexIfAbsent(ActionListener listener) { @@ -83,7 +84,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) log.error("Failed to create index " + indexName, e); internalListener.onFailure(e); }); - CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping); + CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); client.admin().indices().create(request, actionListener); } else { log.debug("index:{} is already created", indexName); @@ -98,8 +99,23 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) new PutMappingRequest().indices(indexName).source(mapping, XContentType.JSON), ActionListener.wrap(response -> { if (response.isAcknowledged()) { - indexMappingUpdated.get(indexName).set(true); - internalListener.onResponse(true); + UpdateSettingsRequest updateSettingRequest = new UpdateSettingsRequest(); + updateSettingRequest.indices(indexName).settings(indexSettings); + client + .admin() + .indices() + .updateSettings(updateSettingRequest, ActionListener.wrap(updateResponse -> { + if (response.isAcknowledged()) { + indexMappingUpdated.get(indexName).set(true); + internalListener.onResponse(true); + } else { + internalListener + .onFailure(new MLException("Failed to update index setting for: " + indexName)); + } + }, exception -> { + log.error("Failed to update index setting for: " + indexName, exception); + internalListener.onFailure(exception); + })); } else { internalListener.onFailure(new MLException("Failed to update index: " + indexName)); } diff --git a/release-notes/opensearch-ml-common.release-notes-2.10.0.0.md b/release-notes/opensearch-ml-common.release-notes-2.10.0.0.md new file mode 100644 index 0000000000..2b4e522c75 --- /dev/null +++ b/release-notes/opensearch-ml-common.release-notes-2.10.0.0.md @@ -0,0 +1,49 @@ +## Version 2.10.0.0 Release Notes + +Compatible with OpenSearch 2.10.0 + + +### Experimental Features +* Conversations and Generative AI in OpenSearch ([#1150](https://github.com/opensearch-project/ml-commons/issues/1150)) + +### Enhancements +* Add feature flags for remote inference ([#1223](https://github.com/opensearch-project/ml-commons/pull/1223)) +* Add eligible node role settings ([#1197](https://github.com/opensearch-project/ml-commons/pull/1197)) +* Add more stats: connector count, connector/config index status ([#1180](https://github.com/opensearch-project/ml-commons/pull/1180)) + +### Infrastructure +* Updates demo certs used in integ tests ([#1291](https://github.com/opensearch-project/ml-commons/pull/1291)) +* Add Auto Release Workflow ([#1306](https://github.com/opensearch-project/ml-commons/pull/1306)) + +### Bug Fixes +* Fixing metrics ([#1194](https://github.com/opensearch-project/ml-commons/pull/1194)) +* Fix null pointer exception when input parameter is null. ([#1192](https://github.com/opensearch-project/ml-commons/pull/1192)) +* Fix admin with no backend role on AOS unable to create restricted model group ([#1188](https://github.com/opensearch-project/ml-commons/pull/1188)) +* Fix parameter parsing bug for create connector input ([#1185](https://github.com/opensearch-project/ml-commons/pull/1185)) +* Handle escaping string parameters explicitly ([#1174](https://github.com/opensearch-project/ml-commons/pull/1174)) +* Fix model count bug ([#1180](https://github.com/opensearch-project/ml-commons/pull/1180)) +* Fix core package name to address compilation errors ([#1157](https://github.com/opensearch-project/ml-commons/pull/1157)) +* Fix system index access bug ([#1320](https://github.com/opensearch-project/ml-commons/pull/1320)) +* Fix unassigned ml system shard replicas ([#1315](https://github.com/opensearch-project/ml-commons/pull/1315)) +* Adjust index replicas settings to keep consistent with AOS 2.9 ([#1325](https://github.com/opensearch-project/ml-commons/pull/1325)) +* Fix GetInteractions returned different results in security-enabled and -disabled settings ([#1334](https://github.com/opensearch-project/ml-commons/pull/1334)) + +### Documentation +* Updating cohere blueprint doc ([#1213](https://github.com/opensearch-project/ml-commons/pull/1213)) +* Fixing docs ([#1193](https://github.com/opensearch-project/ml-commons/pull/1193)) +* Add model auto redeploy tutorial ([#1175](https://github.com/opensearch-project/ml-commons/pull/1175)) +* Add remote inference tutorial ([#1158](https://github.com/opensearch-project/ml-commons/pull/1158)) +* Adding blueprint examples for remote inference ([#1155](https://github.com/opensearch-project/ml-commons/pull/1155)) +* Updating developer guide for CCI contributors ([#1049](https://github.com/opensearch-project/ml-commons/pull/1049)) + +### Maintenance +* Bump checkstyle version for CVE fix ([#1216](https://github.com/opensearch-project/ml-commons/pull/1216)) +* Correct imports for new location with regard to core refactoring ([#1206](https://github.com/opensearch-project/ml-commons/pull/1206)) +* Fix breaking change caused by opensearch core ([#1187](https://github.com/opensearch-project/ml-commons/pull/1187)) +* Bump OpenSearch snapshot version to 2.10 ([#1157](https://github.com/opensearch-project/ml-commons/pull/1157)) +* Bump aws-encryption-sdk-java to fix CVE-2023-33201 ([#1309](https://github.com/opensearch-project/ml-commons/pull/1309)) + +### Refactoring +* Renaming metrics ([#1224](https://github.com/opensearch-project/ml-commons/pull/1224)) +* Changing messaging for IllegalArgumentException on duplicate model groups ([#1294](https://github.com/opensearch-project/ml-commons/pull/1294)) +* Fixing some error message handeling ([#1222](https://github.com/opensearch-project/ml-commons/pull/1222))