diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 76dc55e7e3..cf308f1d8d 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -30,7 +30,8 @@ public enum FunctionName { SPARSE_TOKENIZE, TEXT_SIMILARITY, QUESTION_ANSWERING, - AGENT; + AGENT, + CONNECTOR; public static FunctionName from(String value) { try { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index fadab3ef9a..90837425c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -70,7 +70,7 @@ public abstract class AbstractConnector implements Connector { @Setter protected ConnectorClientConfig connectorClientConfig; - protected Map createPredictDecryptedHeaders(Map headers) { + protected Map createDecryptedHeaders(Map headers) { if (headers == null) { return null; } @@ -116,9 +116,9 @@ public void parseResponse(T response, List modelTensors, boolea } @Override - public Optional findPredictAction() { + public Optional findAction(String action) { if (actions != null) { - return actions.stream().filter(a -> a.getActionType() == ConnectorAction.ActionType.PREDICT).findFirst(); + return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); } return Optional.empty(); } @@ -131,12 +131,12 @@ public void removeCredential() { } @Override - public String getPredictEndpoint(Map parameters) { - Optional predictAction = findPredictAction(); - if (!predictAction.isPresent()) { + public String getActionEndpoint(String action, Map parameters) { + Optional actionEndpoint = findAction(action); + if (!actionEndpoint.isPresent()) { return null; } - String predictEndpoint = predictAction.get().getUrl(); + String predictEndpoint = actionEndpoint.get().getUrl(); if (parameters != null && parameters.size() > 0) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); predictEndpoint = substitutor.replace(predictEndpoint); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index e74a453dc9..12f8ca0eba 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -56,18 +56,18 @@ public interface Connector extends ToXContentObject, Writeable { ConnectorClientConfig getConnectorClientConfig(); - String getPredictEndpoint(Map parameters); + String getActionEndpoint(String action, Map parameters); - String getPredictHttpMethod(); + String getActionHttpMethod(String action); - T createPredictPayload(Map parameters); + T createPayload(String action, Map parameters); - void decrypt(Function function); + void decrypt(String action, Function function); void encrypt(Function function); Connector cloneConnector(); - Optional findPredictAction(); + Optional findAction(String action); void removeCredential(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index ae43c10867..e424914b4f 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { } public enum ActionType { - PREDICT + PREDICT, + EXECUTE } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 5bb00560a2..fc01ffad38 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -30,6 +30,7 @@ import java.util.regex.Pattern; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @@ -307,10 +308,10 @@ public void update(MLCreateConnectorInput updateContent, Function T createPredictPayload(Map parameters) { - Optional predictAction = findPredictAction(); - if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) { - String payload = predictAction.get().getRequestBody(); + public T createPayload(String action, Map parameters) { + Optional connectorAction = findAction(action); + if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) { + String payload = connectorAction.get().getRequestBody(); payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); @@ -348,15 +349,15 @@ private List findStringParametersWithNullDefaultValue(String input) { } @Override - public void decrypt(Function function) { + public void decrypt(String action, Function function) { Map decrypted = new HashMap<>(); for (String key : credential.keySet()) { decrypted.put(key, function.apply(credential.get(key))); } this.decryptedCredential = decrypted; - Optional predictAction = findPredictAction(); - Map headers = predictAction.isPresent() ? predictAction.get().getHeaders() : null; - this.decryptedHeaders = createPredictDecryptedHeaders(headers); + Optional connectorAction = findAction(action); + Map headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null; + this.decryptedHeaders = createDecryptedHeaders(headers); } @Override @@ -378,8 +379,9 @@ public void encrypt(Function function) { } } - public String getPredictHttpMethod() { - return findPredictAction().get().getMethod(); + @Override + public String getActionHttpMethod(String action) { + return findAction(action).get().getMethod(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java new file mode 100644 index 0000000000..02e1c59cb4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorAction.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.opensearch.action.ActionType; +import org.opensearch.ml.common.transport.MLTaskResponse; + +public class MLExecuteConnectorAction extends ActionType { + public static final MLExecuteConnectorAction INSTANCE = new MLExecuteConnectorAction(); + public static final String NAME = "cluster:admin/opensearch/ml/connectors/execute"; + + private MLExecuteConnectorAction() { + super(NAME, MLTaskResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java new file mode 100644 index 0000000000..ab7ffa9c9f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +@FieldDefaults(level = AccessLevel.PRIVATE) +@ToString +public class MLExecuteConnectorRequest extends MLTaskRequest { + + String connectorId; + MLInput mlInput; + + @Builder + public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, boolean dispatchTask) { + super(dispatchTask); + this.mlInput = mlInput; + this.connectorId = connectorId; + } + + public MLExecuteConnectorRequest(String connectorId, MLInput mlInput) { + this(connectorId, mlInput, true); + } + + public MLExecuteConnectorRequest(StreamInput in) throws IOException { + super(in); + this.connectorId = in.readString(); + this.mlInput = new MLInput(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.connectorId); + this.mlInput.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.mlInput == null) { + exception = addValidationError("ML input can't be null", exception); + } else if (this.mlInput.getInputDataset() == null) { + exception = addValidationError("input data can't be null", exception); + } + + return exception; + } + + + public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLExecuteConnectorRequest) { + return (MLExecuteConnectorRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLExecuteConnectorRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e); + } + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index a242c213ea..36a964cef1 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -31,6 +31,7 @@ import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; @@ -110,7 +111,7 @@ public void constructor_NoPredictAction() { Assert.assertNotNull(connector); connector.encrypt(encryptFunction); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals(null, connector.getSessionToken()); @@ -149,13 +150,13 @@ public void constructor() { AwsConnector connector = createAwsConnector(parameters, credential, url); connector.encrypt(encryptFunction); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); Assert.assertEquals("test_service", connector.getServiceName()); Assert.assertEquals("us-west-2", connector.getRegion()); - Assert.assertEquals("https://test.com/model1", connector.getPredictEndpoint(parameters)); + Assert.assertEquals("https://test.com/model1", connector.getActionEndpoint(PREDICT.name(), parameters)); } @Test @@ -170,13 +171,13 @@ public void constructor_NoParameter() { String url = "https://test.com"; AwsConnector connector = createAwsConnector(null, credential, url); connector.encrypt(encryptFunction); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); Assert.assertEquals("decrypted: ENCRYPTED: TEST_SERVICE", connector.getServiceName()); Assert.assertEquals("decrypted: ENCRYPTED: US-WEST-2", connector.getRegion()); - Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null)); + Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null)); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 4f1df76da2..c25f9653c3 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -30,9 +30,10 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Optional; import java.util.function.Function; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; + public class HttpConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -119,7 +120,7 @@ public void cloneConnector() { @Test public void decrypt() { HttpConnector connector = createHttpConnector(); - connector.decrypt(decryptFunction); + connector.decrypt(PREDICT.name(), decryptFunction); Map decryptedCredential = connector.getDecryptedCredential(); Assert.assertEquals(1, decryptedCredential.size()); Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key")); @@ -148,42 +149,42 @@ public void encrypted() { } @Test - public void getPredictEndpoint() { + public void getActionEndpoint() { HttpConnector connector = createHttpConnector(); - Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null)); + Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null)); } @Test - public void getPredictHttpMethod() { + public void getActionHttpMethod() { HttpConnector connector = createHttpConnector(); - Assert.assertEquals("POST", connector.getPredictHttpMethod()); + Assert.assertEquals("POST", connector.getActionHttpMethod(PREDICT.name())); } @Test - public void createPredictPayload_Invalid() { + public void createPayload_Invalid() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Some parameter placeholder not filled in payload: input"); HttpConnector connector = createHttpConnector(); - String predictPayload = connector.createPredictPayload(null); + String predictPayload = connector.createPayload(PREDICT.name(), null); connector.validatePayload(predictPayload); } @Test - public void createPredictPayload_InvalidJson() { + public void createPayload_InvalidJson() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Invalid payload: {\"input\": ${parameters.input} }"); String requestBody = "{\"input\": ${parameters.input} }"; HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - String predictPayload = connector.createPredictPayload(null); + String predictPayload = connector.createPayload(PREDICT.name(), null); connector.validatePayload(predictPayload); } @Test - public void createPredictPayload() { + public void createPayload() { HttpConnector connector = createHttpConnector(); Map parameters = new HashMap<>(); parameters.put("input", "test input value"); - String predictPayload = connector.createPredictPayload(parameters); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); connector.validatePayload(predictPayload); Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java new file mode 100644 index 0000000000..f95b236259 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class MLExecuteConnectorRequestTests { + private MLExecuteConnectorRequest mlExecuteConnectorRequest; + private MLInput mlInput; + private String connectorId; + + @Before + public void setUp(){ + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build(); + connectorId = "test_connector"; + mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build(); + mlExecuteConnectorRequest = MLExecuteConnectorRequest.builder().connectorId(connectorId).mlInput(mlInput).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + mlExecuteConnectorRequest.writeTo(output); + MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput()); + assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType()); + assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input")); + } + + @Test + public void validateSuccess() { + assertNull(mlExecuteConnectorRequest.validate()); + } + + @Test + public void testConstructor() { + MLExecuteConnectorRequest executeConnectorRequest = new MLExecuteConnectorRequest(connectorId, mlInput); + assertTrue(executeConnectorRequest.isDispatchTask()); + } + + @Test + public void validateWithNullMLInputException() { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder() + .build(); + ActionRequestValidationException exception = executeConnectorRequest.validate(); + assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLInputDataSetException() { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput()) + .build(); + ActionRequestValidationException exception = executeConnectorRequest.validate(); + assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequestWithMLExecuteConnectorRequestSuccess() { + assertSame(MLExecuteConnectorRequest.fromActionRequest(mlExecuteConnectorRequest), mlExecuteConnectorRequest); + } + + @Test + public void fromActionRequestWithNonMLExecuteConnectorRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlExecuteConnectorRequest.writeTo(out); + } + }; + MLExecuteConnectorRequest result = MLExecuteConnectorRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlExecuteConnectorRequest); + assertEquals(mlExecuteConnectorRequest.getConnectorId(), result.getConnectorId()); + assertEquals(mlExecuteConnectorRequest.getMlInput().getFunctionName(), result.getMlInput().getFunctionName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLExecuteConnectorRequest.fromActionRequest(actionRequest); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 035b6a6d8d..2ebc7ce563 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -77,7 +77,8 @@ public Logger getLogger() { @SuppressWarnings("removal") @Override - public void invokeRemoteModel( + public void invokeRemoteService( + String action, MLInput mlInput, Map parameters, String payload, @@ -85,22 +86,30 @@ public void invokeRemoteModel( ActionListener> actionListener ) { try { - SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); + SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(signRequest(request)) .requestContentPublisher(new SimpleHttpContentPublisher(request)) .responseHandler( - new MLSdkAsyncHttpResponseHandler(executionContext, actionListener, parameters, connector, scriptService, mlGuard) + new MLSdkAsyncHttpResponseHandler( + executionContext, + actionListener, + parameters, + connector, + scriptService, + mlGuard, + action + ) ) .build(); AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); } catch (RuntimeException exception) { - log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); + log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception); actionListener.onFailure(exception); } catch (Throwable e) { - log.error("Failed to execute predict in aws connector", e); - actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); + log.error("Failed to execute {} in aws connector", action, e); + actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e)); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index a6181e1b2f..cad0278a6d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -63,6 +63,7 @@ public class ConnectorUtils { } public static RemoteInferenceInputDataSet processInput( + String action, MLInput mlInput, Connector connector, Map parameters, @@ -71,22 +72,23 @@ public static RemoteInferenceInputDataSet processInput( if (mlInput == null) { throw new IllegalArgumentException("Input is null"); } - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + Optional connectorAction = connector.findAction(action); + if (connectorAction.isEmpty()) { + throw new IllegalArgumentException("no " + action + " action found"); } - RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService); + RemoteInferenceInputDataSet inputData = processMLInput(action, mlInput, connector, parameters, scriptService); escapeRemoteInferenceInputData(inputData); return inputData; } private static RemoteInferenceInputDataSet processMLInput( + String action, MLInput mlInput, Connector connector, Map parameters, ScriptService scriptService ) { - String preProcessFunction = getPreprocessFunction(mlInput, connector); + String preProcessFunction = getPreprocessFunction(action, mlInput, connector); if (preProcessFunction == null) { if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); @@ -168,9 +170,9 @@ public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet in } } - private static String getPreprocessFunction(MLInput mlInput, Connector connector) { - Optional predictAction = connector.findPredictAction(); - String preProcessFunction = predictAction.get().getPreProcessFunction(); + private static String getPreprocessFunction(String action, MLInput mlInput, Connector connector) { + Optional connectorAction = connector.findAction(action); + String preProcessFunction = connectorAction.get().getPreProcessFunction(); if (preProcessFunction != null) { return preProcessFunction; } @@ -181,6 +183,7 @@ private static String getPreprocessFunction(MLInput mlInput, Connector connector } public static ModelTensors processOutput( + String action, String modelResponse, Connector connector, ScriptService scriptService, @@ -194,12 +197,11 @@ public static ModelTensors processOutput( throw new IllegalArgumentException("guardrails triggered for LLM output"); } List modelTensors = new ArrayList<>(); - Optional predictAction = connector.findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + Optional connectorAction = connector.findAction(action); + if (connectorAction.isEmpty()) { + throw new IllegalArgumentException("no " + action + " action found"); } - ConnectorAction connectorAction = predictAction.get(); - String postProcessFunction = connectorAction.getPostProcessFunction(); + String postProcessFunction = connectorAction.get().getPostProcessFunction(); postProcessFunction = fillProcessFunctionParameter(parameters, postProcessFunction); String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); @@ -263,6 +265,7 @@ public static SdkHttpFullRequest signRequest( } public static SdkHttpFullRequest buildSdkRequest( + String action, Connector connector, Map parameters, String payload, @@ -279,7 +282,7 @@ public static SdkHttpFullRequest buildSdkRequest( log.error("Content length is 0. Aborting request to remote model"); throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } - String endpoint = connector.getPredictEndpoint(parameters); + String endpoint = connector.getActionEndpoint(action, parameters); SdkHttpFullRequest.Builder builder = SdkHttpFullRequest .builder() .method(method) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index be15740bdb..ee29f67a43 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -81,7 +81,8 @@ public Logger getLogger() { @SuppressWarnings("removal") @Override - public void invokeRemoteModel( + public void invokeRemoteService( + String action, MLInput mlInput, Map parameters, String payload, @@ -90,15 +91,15 @@ public void invokeRemoteModel( ) { try { SdkHttpFullRequest request; - switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { + switch (connector.getActionHttpMethod(action).toUpperCase(Locale.ROOT)) { case "POST": log.debug("original payload to remote model: " + payload); - validateHttpClientParameters(parameters); - request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); + validateHttpClientParameters(action, parameters); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - validateHttpClientParameters(parameters); - request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET); + validateHttpClientParameters(action, parameters); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); break; default: throw new IllegalArgumentException("unsupported http method"); @@ -108,7 +109,15 @@ public void invokeRemoteModel( .request(request) .requestContentPublisher(new SimpleHttpContentPublisher(request)) .responseHandler( - new MLSdkAsyncHttpResponseHandler(executionContext, actionListener, parameters, connector, scriptService, mlGuard) + new MLSdkAsyncHttpResponseHandler( + executionContext, + actionListener, + parameters, + connector, + scriptService, + mlGuard, + action + ) ) .build(); AccessController.doPrivileged((PrivilegedExceptionAction>) () -> httpClient.execute(executeRequest)); @@ -121,8 +130,8 @@ public void invokeRemoteModel( } } - private void validateHttpClientParameters(Map parameters) throws Exception { - String endpoint = connector.getPredictEndpoint(parameters); + private void validateHttpClientParameters(String action, Map parameters) throws Exception { + String endpoint = connector.getActionEndpoint(action, parameters); URL url = new URL(endpoint); String protocol = url.getProtocol(); String host = url.getHost(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index b289a76157..6ea03058f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -55,6 +55,8 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle private final Connector connector; + private final String action; + private final ScriptService scriptService; private final MLGuard mlGuard; @@ -68,7 +70,8 @@ public MLSdkAsyncHttpResponseHandler( Map parameters, Connector connector, ScriptService scriptService, - MLGuard mlGuard + MLGuard mlGuard, + String action ) { this.executionContext = executionContext; this.actionListener = actionListener; @@ -76,6 +79,7 @@ public MLSdkAsyncHttpResponseHandler( this.connector = connector; this.scriptService = scriptService; this.mlGuard = mlGuard; + this.action = action; } @Override @@ -184,12 +188,12 @@ private void response() { } try { - ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard); + ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard); tensors.setStatusCode(statusCode); actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors)); } catch (Exception e) { log.error("Failed to process response body: {}", body, e); - actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e)); + actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e)); } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index f52532bcd1..11e43cef85 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -50,7 +50,7 @@ public interface RemoteConnectorExecutor { public String RETRY_EXECUTOR = "opensearch_ml_predict_remote"; - default void executePredict(MLInput mlInput, ActionListener actionListener) { + default void executeAction(String action, MLInput mlInput, ActionListener actionListener) { ActionListener>> tensorActionListener = ActionListener.wrap(r -> { // Only all sub-requests success will call logics here ModelTensors[] modelTensors = new ModelTensors[r.size()]; @@ -61,7 +61,7 @@ default void executePredict(MLInput mlInput, ActionListener acti try { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - Tuple calculatedChunkSize = calculateChunkSize(textDocsInputDataSet); + Tuple calculatedChunkSize = calculateChunkSize(action, textDocsInputDataSet); GroupedActionListener> groupedActionListener = new GroupedActionListener<>( tensorActionListener, calculatedChunkSize.v1() @@ -72,7 +72,8 @@ default void executePredict(MLInput mlInput, ActionListener acti List textDocs = textDocsInputDataSet .getDocs() .subList(processedDocs, Math.min(processedDocs + calculatedChunkSize.v2(), textDocsInputDataSet.getDocs().size())); - preparePayloadAndInvokeRemoteModel( + preparePayloadAndInvoke( + action, MLInput .builder() .algorithm(FunctionName.TEXT_EMBEDDING) @@ -83,7 +84,7 @@ default void executePredict(MLInput mlInput, ActionListener acti ); } } else { - preparePayloadAndInvokeRemoteModel(mlInput, new ExecutionContext(0), new GroupedActionListener<>(tensorActionListener, 1)); + preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), new GroupedActionListener<>(tensorActionListener, 1)); } } catch (Exception e) { actionListener.onFailure(e); @@ -95,12 +96,12 @@ default void executePredict(MLInput mlInput, ActionListener acti * @param textDocsInputDataSet * @return Tuple of chunk size and step size. */ - private Tuple calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) { + private Tuple calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) { int textDocsLength = textDocsInputDataSet.getDocs().size(); Map parameters = getConnector().getParameters(); if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) { int stepSize = Integer.parseInt(parameters.get("input_docs_processed_step_size")); - // We need to check the parameter on runtime as parameter can be passed into predict request + // We need to check the parameter on runtime as parameter can be passed into action request if (stepSize <= 0) { throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer."); } else { @@ -111,11 +112,11 @@ private Tuple calculateChunkSize(TextDocsInputDataSet textDocs return Tuple.tuple(textDocsLength / stepSize + 1, stepSize); } } else { - Optional predictAction = getConnector().findPredictAction(); - if (predictAction.isEmpty()) { - throw new IllegalArgumentException("no predict action found"); + Optional connectorAction = getConnector().findAction(action); + if (connectorAction.isEmpty()) { + throw new IllegalArgumentException("no " + action + " action found"); } - String preProcessFunction = predictAction.get().getPreProcessFunction(); + String preProcessFunction = connectorAction.get().getPreProcessFunction(); if (preProcessFunction != null && !MLPreProcessFunction.contains(preProcessFunction)) { // user defined preprocess script, this case, the chunk size is always equals to text docs length. return Tuple.tuple(textDocsLength, 1); @@ -155,7 +156,8 @@ default void setUserRateLimiterMap(Map userRateLimiterMap) default void setMlGuard(MLGuard mlGuard) {} - default void preparePayloadAndInvokeRemoteModel( + default void preparePayloadAndInvoke( + String action, MLInput mlInput, ExecutionContext executionContext, ActionListener> actionListener @@ -173,13 +175,13 @@ default void preparePayloadAndInvokeRemoteModel( inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); } parameters.putAll(inputParameters); - RemoteInferenceInputDataSet inputData = processInput(mlInput, connector, parameters, getScriptService()); + RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService()); if (inputData.getParameters() != null) { parameters.putAll(inputData.getParameters()); } // override again to always prioritize the input parameter parameters.putAll(inputParameters); - String payload = connector.createPredictPayload(parameters); + String payload = connector.createPayload(action, parameters); connector.validatePayload(payload); String userStr = getClient() .threadPool() @@ -201,9 +203,9 @@ && getUserRateLimiterMap().get(user.getName()) != null throw new IllegalArgumentException("guardrails triggered for user input"); } if (getConnectorClientConfig().getMaxRetryTimes() != 0) { - invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + invokeRemoteServiceWithRetry(action, mlInput, parameters, payload, executionContext, actionListener); } else { - invokeRemoteModel(mlInput, parameters, payload, executionContext, actionListener); + invokeRemoteService(action, mlInput, parameters, payload, executionContext, actionListener); } } } @@ -230,7 +232,8 @@ default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClien } } - default void invokeRemoteModelWithRetry( + default void invokeRemoteServiceWithRetry( + String action, MLInput mlInput, Map parameters, String payload, @@ -252,7 +255,7 @@ default void invokeRemoteModelWithRetry( public void tryAction(ActionListener> listener) { // the listener here is RetryingListener // If the request success, or can not retry, will call delegate listener - invokeRemoteModel(mlInput, parameters, payload, executionContext, listener); + invokeRemoteService(action, mlInput, parameters, payload, executionContext, listener); } @Override @@ -272,7 +275,8 @@ public boolean shouldRetry(Exception e) { invokeRemoteModelAction.run(); }; - void invokeRemoteModel( + void invokeRemoteService( + String action, MLInput mlInput, Map parameters, String payload, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 5828395641..c8685c010e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -5,6 +5,8 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; + import java.util.Map; import org.opensearch.client.Client; @@ -66,7 +68,7 @@ public void asyncPredict(MLInput mlInput, ActionListener actionL return; } try { - connectorExecutor.executePredict(mlInput, actionListener); + connectorExecutor.executeAction(PREDICT.name(), mlInput, actionListener); } catch (RuntimeException e) { log.error("Failed to call remote model.", e); actionListener.onFailure(e); @@ -90,7 +92,7 @@ public boolean isModelReady() { public void initModel(MLModel model, Map params, Encryptor encryptor) { try { Connector connector = model.getConnector().cloneConnector(); - connector.decrypt((credential) -> encryptor.decrypt(credential)); + connector.decrypt(PREDICT.name(), (credential) -> encryptor.decrypt(credential)); this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java new file mode 100644 index 0000000000..cb8b231ebf --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports running connector. + */ +@Log4j2 +@ToolAnnotation(ConnectorTool.TYPE) +public class ConnectorTool implements Tool { + public static final String TYPE = "ConnectorTool"; + public static final String CONNECTOR_ID = "connector_id"; + public static final String CONNECTOR_ACTION = "connector_action"; + + @Setter + @Getter + private String name = ConnectorTool.TYPE; + @Getter + @Setter + private String description = Factory.DEFAULT_DESCRIPTION; + @Getter + private String version; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + private Client client; + private String connectorId; + + public ConnectorTool(Client client, String connectorId) { + this.client = client; + if (connectorId == null) { + throw new IllegalArgumentException("connector_id can't be null"); + } + this.connectorId = connectorId; + + outputParser = new Parser() { + @Override + public Object parse(Object o) { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + ActionRequest request = new MLExecuteConnectorRequest(connectorId, mlInput); + + client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) r.getOutput(); + modelTensorOutput.getMlModelOutputs(); + if (outputParser == null) { + listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); + } else { + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + } + }, e -> { + log.error("Failed to run model " + connectorId, e); + listener.onFailure(e); + })); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + return true; + } + + public static class Factory implements Tool.Factory { + public static final String TYPE = "ConnectorTool"; + public static final String DEFAULT_DESCRIPTION = "This tool will invoke external service."; + private Client client; + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ConnectorTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + @Override + public ConnectorTool create(Map map) { + return new ConnectorTool(client, (String) map.get(CONNECTOR_ID)); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index bf13c9f68c..cb192e83f9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -15,6 +16,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; @@ -114,7 +116,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { exceptionRule.expectMessage("Missing credential"); ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -132,7 +134,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { public void executePredict_RemoteInferenceInput_EmptyIpAddress() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http:///mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -150,7 +152,7 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -159,7 +161,12 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + executor + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof NullPointerException; @@ -170,7 +177,7 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { public void executePredict_TextDocsInferenceInput() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -188,7 +195,7 @@ public void executePredict_TextDocsInferenceInput() { .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -198,14 +205,18 @@ public void executePredict_TextDocsInferenceInput() { MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test public void executePredict_TextDocsInferenceInput_withStepSize() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -225,7 +236,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -235,21 +246,29 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); MLInputDataset inputDataSet1 = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), + actionListener + ); Mockito.verify(actionListener, times(0)).onFailure(any()); - Mockito.verify(executor, times(3)).preparePayloadAndInvokeRemoteModel(any(), any(), any()); + Mockito.verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any()); } @Test public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResults() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -269,7 +288,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -277,17 +296,21 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); Integer idx = Integer.parseInt(doc.substring(doc.length() - 1)); actionListener.onResponse(new Tuple<>(3 - idx, new ModelTensors(modelTensors.subList(3 - idx, 4 - idx)))); return null; - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); Mockito.verify(actionListener, times(1)).onResponse(responseCaptor.capture()); @@ -305,7 +328,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_returnOrderedResu public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_thenFail() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -325,7 +348,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -333,8 +356,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); if (doc.endsWith("1")) { actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST)); @@ -342,11 +365,15 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t actionListener.onResponse(new Tuple<>(0, new ModelTensors(modelTensors.subList(0, 1)))); } return null; - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); @@ -358,7 +385,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_partiallyFailed_t public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleFailures() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -378,7 +405,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF .actions(Arrays.asList(predictAction)) .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -386,8 +413,8 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); doAnswer(invocation -> { - MLInput mlInput = invocation.getArgument(0); - ActionListener> actionListener = invocation.getArgument(4); + MLInput mlInput = invocation.getArgument(1); + ActionListener> actionListener = invocation.getArgument(5); String doc = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().get(0); if (!doc.endsWith("1")) { actionListener.onFailure(new OpenSearchStatusException("test failure", RestStatus.BAD_REQUEST)); @@ -395,11 +422,15 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF actionListener.onResponse(new Tuple<>(0, new ModelTensors(modelTensors.subList(0, 1)))); } return null; - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); @@ -414,7 +445,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize_failWithMultipleF public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -431,7 +462,7 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor0 = new AwsConnectorExecutor(connector); Field httpClientField = AwsConnectorExecutor.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); @@ -444,7 +475,12 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( when(threadPool.getThreadContext()).thenReturn(threadContext); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); + executor + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof NullPointerException; @@ -454,7 +490,7 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException( public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArgumentException() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -472,7 +508,7 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -482,7 +518,11 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof IllegalArgumentException; @@ -492,7 +532,7 @@ public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArg public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictionAction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -509,7 +549,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio .parameters(parameters) .credential(credential) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -519,18 +559,22 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_emptyPredictio MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); assert exceptionArgumentCaptor.getValue() instanceof IllegalArgumentException; - assert "no predict action found".equals(exceptionArgumentCaptor.getValue().getMessage()); + assert "no PREDICT action found".equals(exceptionArgumentCaptor.getValue().getMessage()); } @Test public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPreProcessFunction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -550,7 +594,7 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre .credential(credential) .actions(Arrays.asList(predictAction)) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -561,14 +605,18 @@ public void executePredict_TextDocsInferenceInput_withoutStepSize_userDefinedPre MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); } @Test - public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { + public void executePredict_whenRetryEnabled_thenInvokeRemoteServiceWithRetry() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("{\"input\": ${parameters.input}}") @@ -590,7 +638,7 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(connectorClientConfig) .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); + connector.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); @@ -603,10 +651,14 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); - Mockito.verify(executor, times(0)).invokeRemoteModelWithRetry(any(), any(), any(), any(), any()); - Mockito.verify(executor, times(1)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(0)).invokeRemoteServiceWithRetry(any(), any(), any(), any(), any(), any()); + Mockito.verify(executor, times(1)).invokeRemoteService(any(), any(), any(), any(), any(), any()); // execute with retry enabled ConnectorClientConfig connectorClientConfig2 = new ConnectorClientConfig(10, 10, 10, 1, 1, 1, RetryBackoffPolicy.CONSTANT); @@ -620,12 +672,16 @@ public void executePredict_whenRetryEnabled_thenInvokeRemoteModelWithRetry() { .actions(Arrays.asList(predictAction)) .connectorClientConfig(connectorClientConfig2) .build(); - connector2.decrypt((c) -> encryptor.decrypt(c)); + connector2.decrypt(PREDICT.name(), (c) -> encryptor.decrypt(c)); executor.initialize(connector2); executor - .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + .executeAction( + PREDICT.name(), + MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), + actionListener + ); - Mockito.verify(executor, times(1)).invokeRemoteModelWithRetry(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(1)).invokeRemoteServiceWithRetry(any(), any(), any(), any(), any(), any()); Mockito.verify(actionListener, times(0)).onFailure(any()); } @@ -659,7 +715,7 @@ public void testGetRetryBackoffPolicy() { } @Test - public void invokeRemoteModelWithRetry_whenRetryableException_thenRetryUntilSuccess() { + public void invokeRemoteServiceWithRetry_whenRetryableException_thenRetryUntilSuccess() { MLInput mlInput = mock(MLInput.class); Map parameters = Map.of(); String payload = ""; @@ -674,7 +730,7 @@ public void invokeRemoteModelWithRetry_whenRetryableException_thenRetryUntilSucc @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 10 invocation, then success if (countOfInvocation++ < 10) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -683,7 +739,7 @@ public Void answer(InvocationOnMock invocation) { } return null; } - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); when(executor.getConnectorClientConfig()).thenReturn(connectorClientConfig); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); @@ -699,14 +755,14 @@ public Void answer(InvocationOnMock invocation) { return null; }).when(executorService).execute(any()); - executor.invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + executor.invokeRemoteServiceWithRetry(PREDICT.name(), mlInput, parameters, payload, executionContext, actionListener); Mockito.verify(actionListener, times(0)).onFailure(any()); Mockito.verify(actionListener, times(1)).onResponse(any()); - Mockito.verify(executor, times(11)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(11)).invokeRemoteService(any(), any(), any(), any(), any(), any()); } @Test - public void invokeRemoteModelWithRetry_whenRetryExceedMaxRetryTimes_thenCallOnFailure() { + public void invokeRemoteServiceWithRetry_whenRetryExceedMaxRetryTimes_thenCallOnFailure() { MLInput mlInput = mock(MLInput.class); Map parameters = Map.of(); String payload = ""; @@ -721,7 +777,7 @@ public void invokeRemoteModelWithRetry_whenRetryExceedMaxRetryTimes_thenCallOnFa @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 10 invocation, then success if (countOfInvocation++ < 10) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -730,7 +786,7 @@ public Void answer(InvocationOnMock invocation) { } return null; } - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); when(executor.getConnectorClientConfig()).thenReturn(connectorClientConfig); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); @@ -746,14 +802,14 @@ public Void answer(InvocationOnMock invocation) { return null; }).when(executorService).execute(any()); - executor.invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + executor.invokeRemoteServiceWithRetry(PREDICT.name(), mlInput, parameters, payload, executionContext, actionListener); Mockito.verify(actionListener, times(1)).onFailure(any()); Mockito.verify(actionListener, times(0)).onResponse(any()); - Mockito.verify(executor, times(6)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(6)).invokeRemoteService(any(), any(), any(), any(), any(), any()); } @Test - public void invokeRemoteModelWithRetry_whenNonRetryableException_thenCallOnFailure() { + public void invokeRemoteServiceWithRetry_whenNonRetryableException_thenCallOnFailure() { MLInput mlInput = mock(MLInput.class); Map parameters = Map.of(); String payload = ""; @@ -768,7 +824,7 @@ public void invokeRemoteModelWithRetry_whenNonRetryableException_thenCallOnFailu @Override public Void answer(InvocationOnMock invocation) { - ActionListener> actionListener = invocation.getArgument(4); + ActionListener> actionListener = invocation.getArgument(5); // fail the first 2 invocation with retryable exception, then fail with non-retryable exception if (countOfInvocation++ < 2) { actionListener.onFailure(new RemoteConnectorThrottlingException("test failure retryable", RestStatus.BAD_REQUEST)); @@ -777,7 +833,7 @@ public Void answer(InvocationOnMock invocation) { } return null; } - }).when(executor).invokeRemoteModel(any(), any(), any(), any(), any()); + }).when(executor).invokeRemoteService(any(), any(), any(), any(), any(), any()); when(executor.getConnectorClientConfig()).thenReturn(connectorClientConfig); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); @@ -795,10 +851,10 @@ public Void answer(InvocationOnMock invocation) { ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); - executor.invokeRemoteModelWithRetry(mlInput, parameters, payload, executionContext, actionListener); + executor.invokeRemoteServiceWithRetry(PREDICT.name(), mlInput, parameters, payload, executionContext, actionListener); Mockito.verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); Mockito.verify(actionListener, times(0)).onResponse(any()); - Mockito.verify(executor, times(3)).invokeRemoteModel(any(), any(), any(), any(), any()); + Mockito.verify(executor, times(3)).invokeRemoteService(any(), any(), any(), any(), any(), any()); assert exceptionArgumentCaptor.getValue() instanceof OpenSearchStatusException; assertEquals("test failure", exceptionArgumentCaptor.getValue().getMessage()); assertEquals("test failure retryable", exceptionArgumentCaptor.getValue().getSuppressed()[0].getMessage()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index cb7f8a4fe8..31b0f5e420 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -7,6 +7,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -56,7 +57,7 @@ public void setUp() { public void processInput_NullInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Input is null"); - ConnectorUtils.processInput(null, null, new HashMap<>(), null); + ConnectorUtils.processInput(PREDICT.name(), null, null, new HashMap<>(), null); } @Test @@ -66,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -78,7 +79,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); } @Test @@ -120,7 +121,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -132,7 +133,7 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); + ConnectorUtils.processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertEquals(expectedInput, ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters().get("input")); } @@ -168,14 +169,14 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() public void processOutput_NullResponse() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model response is null"); - ConnectorUtils.processOutput(null, null, null, null, null); + ConnectorUtils.processOutput(PREDICT.name(), null, null, null, null, null); } @Test public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -192,7 +193,8 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size()); @@ -206,7 +208,7 @@ public void processOutput_PostprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -224,7 +226,8 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; - ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null); + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName()); Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap()); @@ -246,7 +249,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody(requestBody) @@ -263,7 +266,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction( .actions(Arrays.asList(predictAction)) .build(); RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils - .processInput(mlInput, connector, new HashMap<>(), scriptService); + .processInput(PREDICT.name(), mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index f0efd2efc2..8f920ffeba 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import java.lang.reflect.Field; import java.util.Arrays; @@ -47,10 +48,10 @@ public void setUp() { } @Test - public void invokeRemoteModel_WrongHttpMethod() { + public void invokeRemoteService_WrongHttpMethod() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("wrong_method") .url("http://openai.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -63,17 +64,17 @@ public void invokeRemoteModel_WrongHttpMethod() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(null, null, null, null, actionListener); + executor.invokeRemoteService(PREDICT.name(), null, null, null, null, actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); assertEquals("unsupported http method", captor.getValue().getMessage()); } @Test - public void invokeRemoteModel_invalidIpAddress() { + public void invokeRemoteService_invalidIpAddress() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://127.0.0.1/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -87,7 +88,14 @@ public void invokeRemoteModel_invalidIpAddress() { .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor - .invokeRemoteModel(createMLInput(), new HashMap<>(), "{\"input\": \"hello world\"}", new ExecutionContext(0), actionListener); + .invokeRemoteService( + PREDICT.name(), + createMLInput(), + new HashMap<>(), + "{\"input\": \"hello world\"}", + new ExecutionContext(0), + actionListener + ); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof IllegalArgumentException; @@ -95,10 +103,10 @@ public void invokeRemoteModel_invalidIpAddress() { } @Test - public void invokeRemoteModel_Empty_payload() { + public void invokeRemoteService_Empty_payload() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("") @@ -111,7 +119,7 @@ public void invokeRemoteModel_Empty_payload() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); + executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); Mockito.verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof IllegalArgumentException; @@ -119,10 +127,10 @@ public void invokeRemoteModel_Empty_payload() { } @Test - public void invokeRemoteModel_get_request() { + public void invokeRemoteService_get_request() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("GET") .url("http://openai.com/mock") .requestBody("") @@ -135,14 +143,14 @@ public void invokeRemoteModel_get_request() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); + executor.invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), null, new ExecutionContext(0), actionListener); } @Test - public void invokeRemoteModel_post_request() { + public void invokeRemoteService_post_request() { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("hello world") @@ -155,14 +163,15 @@ public void invokeRemoteModel_post_request() { .actions(Arrays.asList(predictAction)) .build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); + executor + .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); } @Test - public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { + public void invokeRemoteService_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://openai.com/mock") .requestBody("hello world") @@ -178,7 +187,8 @@ public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFie Field httpClientField = HttpJsonConnectorExecutor.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); httpClientField.set(executor, null); - executor.invokeRemoteModel(createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); + executor + .invokeRemoteService(PREDICT.name(), createMLInput(), new HashMap<>(), "hello world", new ExecutionContext(0), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue() instanceof NullPointerException; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index 11990e36d7..f6c9b76071 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler.AMZ_ERROR_HEADER; import java.nio.ByteBuffer; @@ -59,6 +60,7 @@ public class MLSdkAsyncHttpResponseHandlerTest { private SdkHttpFullResponse sdkHttpResponse; @Mock private ScriptService scriptService; + private String action; private MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler; @@ -70,7 +72,7 @@ public void setup() { when(sdkHttpResponse.statusCode()).thenReturn(HttpStatusCode.OK); ConnectorAction predictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .postProcessFunction(MLPostProcessFunction.BEDROCK_EMBEDDING) .url("http://test.com/mock") @@ -86,7 +88,7 @@ public void setup() { ConnectorAction noProcessFunctionPredictAction = ConnectorAction .builder() - .actionType(ConnectorAction.ActionType.PREDICT) + .actionType(PREDICT) .method("POST") .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") @@ -98,13 +100,15 @@ public void setup() { .protocol("http") .actions(Arrays.asList(noProcessFunctionPredictAction)) .build(); + action = PREDICT.name(); mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler( executionContext, actionListener, parameters, connector, scriptService, - null + null, + action ); responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber(); headersMap = Map.of(AMZ_ERROR_HEADER, Arrays.asList("ThrottlingException:request throttled!")); @@ -171,7 +175,8 @@ public void test_OnStream_without_postProcessFunction() { parameters, noProcessFunctionConnector, scriptService, - null + null, + action ); noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream); @@ -261,7 +266,8 @@ public void test_onComplete_failed() { parameters, connector, scriptService, - null + null, + action ); SdkHttpFullResponse sdkHttpResponse = mock(SdkHttpFullResponse.class); @@ -357,7 +363,8 @@ public void test_onComplete_throttle_exception_onFailure() { parameters, connector, scriptService, - null + null, + action ); SdkHttpFullResponse sdkHttpResponse = mock(SdkHttpFullResponse.class); @@ -397,7 +404,8 @@ public void test_onComplete_processOutputFail_onFailure() { parameters, testConnector, scriptService, - null + null, + action ); mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); @@ -414,6 +422,6 @@ public void test_onComplete_processOutputFail_onFailure() { ArgumentCaptor captor = ArgumentCaptor.forClass(MLException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("Fail to execute predict in aws connector"); + assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector"); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java new file mode 100644 index 0000000000..05b3426f4d --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ConnectorToolTests.java @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +public class ConnectorToolTests { + + @Mock + private Client client; + private Map otherParams; + + @Mock + private Parser mockOutputParser; + + @Mock + private ActionListener listener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + ConnectorTool.Factory.getInstance().init(client); + + otherParams = Map.of("other", "[\"bar\"]"); + } + + @Test + public void testConnectorTool_NullConnectorId() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> ConnectorTool.Factory.getInstance().create(Map.of("test1", "value1")) + ); + MatcherAssert.assertThat(exception.getMessage(), containsString("connector_id can't be null")); + } + + @Test + public void testConnectorTool_DefaultOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test_connector")); + tool.run(null, ActionListener.wrap(r -> { assertEquals("response 1", r); }, e -> { throw new RuntimeException("Test failed"); })); + } + + @Test + public void testConnectorTool_NullOutputParser() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "response 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLExecuteConnectorAction.INSTANCE), any(), any()); + + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test_connector")); + tool.setOutputParser(null); + + tool.run(null, ActionListener.wrap(r -> { + List response = (List) r; + assertEquals(1, response.size()); + assertEquals(1, ((ModelTensors) response.get(0)).getMlModelTensors().size()); + ModelTensor modelTensor1 = ((ModelTensors) response.get(0)).getMlModelTensors().get(0); + assertEquals(2, modelTensor1.getDataAsMap().size()); + assertEquals("response 1", modelTensor1.getDataAsMap().get("response")); + assertEquals("action1", modelTensor1.getDataAsMap().get("action")); + }, e -> { throw new RuntimeException("Test failed"); })); + } + + @Test + public void testConnectorTool_NotNullParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertTrue(tool.validate(Map.of("key1", "value1"))); + } + + @Test + public void testConnectorTool_NullParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertFalse(tool.validate(Map.of())); + } + + @Test + public void testConnectorTool_EmptyParameters() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertFalse(tool.validate(null)); + } + + @Test + public void testConnectorTool_GetType() { + ConnectorTool.Factory.getInstance().init(client); + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertEquals("ConnectorTool", tool.getType()); + } + + @Test + public void testRunWithError() { + // Mocking the client.execute to simulate an error + String errorMessage = "Test Exception"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).execute(any(), any(), any()); + + // Running the test + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + tool.setOutputParser(mockOutputParser); + tool.run(otherParams, listener); + + // Verifying that onFailure was called + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); + } + + @Test + public void testTool() { + Tool tool = ConnectorTool.Factory.getInstance().create(Map.of("connector_id", "test1")); + assertEquals(ConnectorTool.TYPE, tool.getName()); + assertEquals(ConnectorTool.TYPE, tool.getType()); + assertNull(tool.getVersion()); + assertTrue(tool.validate(otherParams)); + assertEquals(ConnectorTool.Factory.DEFAULT_DESCRIPTION, tool.getDescription()); + assertEquals(ConnectorTool.Factory.DEFAULT_DESCRIPTION, ConnectorTool.Factory.getInstance().getDefaultDescription()); + assertEquals(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance().getDefaultType()); + assertNull(ConnectorTool.Factory.getInstance().getDefaultVersion()); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java new file mode 100644 index 0000000000..497e6768d8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; + +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ExecuteConnectorTransportAction extends HandledTransportAction { + + Client client; + ClusterService clusterService; + ScriptService scriptService; + NamedXContentRegistry xContentRegistry; + + ConnectorAccessControlHelper connectorAccessControlHelper; + EncryptorImpl encryptor; + + @Inject + public ExecuteConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + ConnectorAccessControlHelper connectorAccessControlHelper, + EncryptorImpl encryptor + ) { + super(MLExecuteConnectorAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); + this.client = client; + this.clusterService = clusterService; + this.scriptService = scriptService; + this.xContentRegistry = xContentRegistry; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.encryptor = encryptor; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); + String connectorId = executeConnectorRequest.getConnectorId(); + String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); + + if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener listener = ActionListener.wrap(connector -> { + if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { + connector.decrypt(connectorAction, (credential) -> encryptor.decrypt(credential)); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader + .initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor + .executeAction(connectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { + actionListener.onResponse(taskResponse); + }, e -> { actionListener.onFailure(e); })); + } + }, e -> { + log.error("Failed to get connector " + connectorId, e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.runBefore(listener, threadContext::restore)); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + connectorId)); + } + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java index b50f935774..9337653fc4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteTaskAction.java @@ -42,8 +42,8 @@ public TransportExecuteTaskAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - MLExecuteTaskRequest mlPredictionTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); - FunctionName functionName = mlPredictionTaskRequest.getFunctionName(); - mlExecuteTaskRunner.run(functionName, mlPredictionTaskRequest, transportService, listener); + MLExecuteTaskRequest mlExecuteTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); + FunctionName functionName = mlExecuteTaskRequest.getFunctionName(); + mlExecuteTaskRunner.run(functionName, mlExecuteTaskRequest, transportService, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 6c119d46d2..bde53795a3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; @@ -300,7 +301,9 @@ private void validateInternalConnector(MLRegisterModelInput registerModelInput) log.error("You must provide connector content when creating a remote model without providing connector id!"); throw new IllegalArgumentException("You must provide connector content when creating a remote model without connector id!"); } - if (registerModelInput.getConnector().getPredictEndpoint(registerModelInput.getConnector().getParameters()) == null) { + if (registerModelInput + .getConnector() + .getActionEndpoint(PREDICT.name(), registerModelInput.getConnector().getParameters()) == null) { log.error("Connector endpoint is required when creating a remote model without connector id!"); throw new IllegalArgumentException("Connector endpoint is required when creating a remote model without connector id!"); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 89b812b613..e9a79236b1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -44,6 +44,7 @@ import org.opensearch.ml.action.agents.TransportSearchAgentAction; import org.opensearch.ml.action.config.GetConfigTransportAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; +import org.opensearch.ml.action.connector.ExecuteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; @@ -118,6 +119,7 @@ import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; @@ -168,6 +170,7 @@ import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.engine.tools.AgentTool; import org.opensearch.ml.engine.tools.CatIndexTool; +import org.opensearch.ml.engine.tools.ConnectorTool; import org.opensearch.ml.engine.tools.IndexMappingTool; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.engine.tools.SearchIndexTool; @@ -398,6 +401,7 @@ public MachineLearningPlugin(Settings settings) { new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class), new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class), new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class), + new ActionHandler<>(MLExecuteConnectorAction.INSTANCE, ExecuteConnectorTransportAction.class), new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class), new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class), @@ -579,6 +583,7 @@ public Collection createComponents( IndexMappingTool.Factory.getInstance().init(client); SearchIndexTool.Factory.getInstance().init(client, xContentRegistry); VisualizationsTool.Factory.getInstance().init(client); + ConnectorTool.Factory.getInstance().init(client); toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance()); toolFactories.put(AgentTool.TYPE, AgentTool.Factory.getInstance()); @@ -586,6 +591,7 @@ public Collection createComponents( toolFactories.put(IndexMappingTool.TYPE, IndexMappingTool.Factory.getInstance()); toolFactories.put(SearchIndexTool.TYPE, SearchIndexTool.Factory.getInstance()); toolFactories.put(VisualizationsTool.TYPE, VisualizationsTool.Factory.getInstance()); + toolFactories.put(ConnectorTool.TYPE, ConnectorTool.Factory.getInstance()); if (externalToolFactories != null) { toolFactories.putAll(externalToolFactories); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java new file mode 100644 index 0000000000..719f168ead --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.action.connector; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.script.ScriptService; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class ExecuteConnectorTransportActionTests extends OpenSearchTestCase { + + private ExecuteConnectorTransportAction action; + + @Mock + private Client client; + + @Mock + ActionListener actionListener; + @Mock + private ClusterService clusterService; + @Mock + private TransportService transportService; + @Mock + private ActionFilters actionFilters; + @Mock + private ScriptService scriptService; + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + private Metadata metaData; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + @Mock + private MLExecuteConnectorRequest request; + @Mock + private EncryptorImpl encryptor; + @Mock + private HttpConnector connector; + @Mock + private Task task; + @Mock + ThreadPool threadPool; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + ClusterState testState = new ClusterState( + new ClusterName("clusterName"), + 123l, + "111111", + metaData, + null, + null, + null, + Map.of(), + 0, + false + ); + when(clusterService.state()).thenReturn(testState); + + when(request.getConnectorId()).thenReturn("test_connector_id"); + + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + action = new ExecuteConnectorTransportAction( + transportService, + actionFilters, + client, + clusterService, + scriptService, + xContentRegistry, + connectorAccessControlHelper, + encryptor + ); + } + + public void testExecute_NoConnectorIndex() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("Can't find connector test_connector_id")); + } + + public void testExecute_FailedToGetConnector() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(metaData.hasIndex(anyString())).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("test failure")); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assertTrue(argCaptor.getValue().getMessage().contains("test failure")); + } + + public void testExecute_NullMLInput() { + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(metaData.hasIndex(anyString())).thenReturn(true); + when(connector.getProtocol()).thenReturn(ConnectorProtocols.HTTP); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); + + action.doExecute(task, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index d30ef15a5a..d936f199a2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -522,7 +522,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class); when(input.getConnector()).thenReturn(connector); - when(connector.getPredictEndpoint(any(Map.class))).thenReturn("https://api.openai.com"); + when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn("https://api.openai.com"); MLCreateConnectorResponse mlCreateConnectorResponse = mock(MLCreateConnectorResponse.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -556,7 +556,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi when(request.getRegisterModelInput()).thenReturn(input); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class); - when(connector.getPredictEndpoint(any(Map.class))).thenReturn(null); + when(connector.getActionEndpoint(anyString(), any(Map.class))).thenReturn(null); when(input.getConnector()).thenReturn(connector); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index f5da656d06..cf1f87e09e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -961,6 +961,13 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt assertTrue(taskDone.get()); } + public String registerConnector(String createConnectorInput) throws IOException, InterruptedException { + Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + return connectorId; + } + public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException, InterruptedException { Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java new file mode 100644 index 0000000000..4ae9653d60 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; + +import org.apache.hc.core5.http.ParseException; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +public class RestConnectorToolIT extends RestBaseAgentToolsIT { + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private String bedrockClaudeConnectorId; + private String bedrockClaudeConnectorIdForPredict; + + @Before + public void setUp() throws Exception { + super.setUp(); + Thread.sleep(20000); + this.bedrockClaudeConnectorId = createBedrockClaudeConnector("execute"); + this.bedrockClaudeConnectorIdForPredict = createBedrockClaudeConnector("predict"); + } + + private String createBedrockClaudeConnector(String action) throws IOException, InterruptedException { + String bedrockClaudeConnectorEntity = "{\n" + + " \"name\": \"BedRock Claude instant-v1 Connector \",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"" + + action + + "\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + return registerConnector(bedrockClaudeConnectorEntity); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + deleteExternalIndices(); + } + + public void testConnectorToolInFlowAgent_WrongAction() throws IOException, ParseException { + String registerAgentRequestBody = "{\n" + + " \"name\": \"Test agent with connector tool\",\n" + + " \"type\": \"flow\",\n" + + " \"description\": \"This is a demo agent for connector tool\",\n" + + " \"app_type\": \"test1\",\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"ConnectorTool\",\n" + + " \"name\": \"bedrock_model\",\n" + + " \"parameters\": {\n" + + " \"connector_id\": \"" + + bedrockClaudeConnectorIdForPredict + + "\",\n" + + " \"connector_action\": \"predict\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, agentInput)); + MatcherAssert.assertThat(exception.getMessage(), containsString("no EXECUTE action found")); + } + + public void testConnectorToolInFlowAgent() throws IOException, ParseException { + String registerAgentRequestBody = "{\n" + + " \"name\": \"Test agent with connector tool\",\n" + + " \"type\": \"flow\",\n" + + " \"description\": \"This is a demo agent for connector tool\",\n" + + " \"app_type\": \"test1\",\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"ConnectorTool\",\n" + + " \"name\": \"bedrock_model\",\n" + + " \"parameters\": {\n" + + " \"connector_id\": \"" + + bedrockClaudeConnectorId + + "\",\n" + + " \"connector_action\": \"execute\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; + String result = executeAgent(agentId, agentInput); + assertNotNull(result); + } + +}