diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index d742d05970..92c6b263d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -11,7 +11,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -49,11 +48,18 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); } - // This is to support some model which takes N text docs and embedding size is less than N-1. + // This is to support some model which takes N text docs and embedding size is less than N. // We need to tell executor what's the step size for each model run. Map parameters = getConnector().getParameters(); if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { - processedDocs += Integer.parseInt(parameters.get("input_docs_processed_step_size")); + int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); + // We need to check the parameter on runtime as parameter can be passed into predict request + if (stepSize <= 0) { + throw new IllegalArgumentException( + "Invalid parameter: input_docs_processed_step_size. It must be positive integer." + ); + } + processedDocs += stepSize; } else { processedDocs += Math.max(tensorCount, 1); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index ba1f434689..a4bc766aa2 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -6,13 +6,12 @@ package org.opensearch.ml.engine.algorithms.remote; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.http.HttpEntity; import org.apache.http.ProtocolVersion; @@ -44,9 +43,6 @@ import org.opensearch.script.ScriptService; import com.google.common.collect.ImmutableMap; -import java.util.Map; - - public class HttpJsonConnectorExecutorTest { @Rule @@ -197,15 +193,22 @@ public void executePredict_TextDocsInput() throws IOException { .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") - .build(); - HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + HttpConnector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); @@ -266,29 +269,49 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) - .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": ${parameters.input}}") - .build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); Map parameters = ImmutableMap.of("input_docs_processed_step_size", "2"); - HttpConnector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + HttpConnector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); // model takes 2 input docs, but only output 1 embedding - String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" - + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" - + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" - + " } ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" + " }\n" + "}"; + String modelResponse = "{\n" + + " \"object\": \"list\",\n" + + " \"data\": [\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 0,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + + " \"usage\": {\n" + + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + + " }\n" + + "}"; StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); @@ -296,10 +319,77 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE when(executor.getHttpClient()).thenReturn(httpClient); when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + Assert + .assertArrayEquals( + new Number[] { -0.014555434, -0.002135904, 0.0035105038 }, + modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData() + ); + } + + @Test + public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepSize() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); + String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; + String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); + + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) + .requestBody("{\"input\": ${parameters.input}}") + .build(); + // step size must be positive integer, here we set it as -1, should trigger IllegalArgumentException + Map parameters = ImmutableMap.of("input_docs_processed_step_size", "-1"); + HttpConnector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + executor.setScriptService(scriptService); + when(httpClient.execute(any())).thenReturn(response); + // model takes 2 input docs, but only output 1 embedding + String modelResponse = "{\n" + + " \"object\": \"list\",\n" + + " \"data\": [\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 0,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " } ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + + " \"usage\": {\n" + + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + + " }\n" + + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + HttpEntity entity = new StringEntity(modelResponse); + when(response.getEntity()).thenReturn(entity); + when(executor.getHttpClient()).thenReturn(httpClient); + when(executor.getConnector()).thenReturn(connector); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } }