Skip to content

Commit

Permalink
Add connector tool (#2512)
Browse files Browse the repository at this point in the history
* expose connector action parameter

Signed-off-by: Yaliang Wu <[email protected]>

* add connector tool

Signed-off-by: Yaliang Wu <[email protected]>

* fix ut

Signed-off-by: Yaliang Wu <[email protected]>

* fix it

Signed-off-by: Yaliang Wu <[email protected]>

* fix flaky test

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Jun 6, 2024
1 parent 8331fe6 commit a0272f2
Show file tree
Hide file tree
Showing 30 changed files with 1,279 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public enum FunctionName {
SPARSE_TOKENIZE,
TEXT_SIMILARITY,
QUESTION_ANSWERING,
AGENT;
AGENT,
CONNECTOR;

public static FunctionName from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class AbstractConnector implements Connector {
@Setter
protected ConnectorClientConfig connectorClientConfig;

protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
protected Map<String, String> createDecryptedHeaders(Map<String, String> headers) {
if (headers == null) {
return null;
}
Expand Down Expand Up @@ -116,9 +116,9 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
}

@Override
public Optional<ConnectorAction> findPredictAction() {
public Optional<ConnectorAction> findAction(String action) {
if (actions != null) {
return actions.stream().filter(a -> a.getActionType() == ConnectorAction.ActionType.PREDICT).findFirst();
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
}
return Optional.empty();
}
Expand All @@ -131,12 +131,12 @@ public void removeCredential() {
}

@Override
public String getPredictEndpoint(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (!predictAction.isPresent()) {
public String getActionEndpoint(String action, Map<String, String> parameters) {
Optional<ConnectorAction> actionEndpoint = findAction(action);
if (!actionEndpoint.isPresent()) {
return null;
}
String predictEndpoint = predictAction.get().getUrl();
String predictEndpoint = actionEndpoint.get().getUrl();
if (parameters != null && parameters.size() > 0) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
predictEndpoint = substitutor.replace(predictEndpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,18 @@ public interface Connector extends ToXContentObject, Writeable {

ConnectorClientConfig getConnectorClientConfig();

String getPredictEndpoint(Map<String, String> parameters);
String getActionEndpoint(String action, Map<String, String> parameters);

String getPredictHttpMethod();
String getActionHttpMethod(String action);

<T> T createPredictPayload(Map<String, String> parameters);
<T> T createPayload(String action, Map<String, String> parameters);

void decrypt(Function<String, String> function);
void decrypt(String action, Function<String, String> function);
void encrypt(Function<String, String> function);

Connector cloneConnector();

Optional<ConnectorAction> findPredictAction();
Optional<ConnectorAction> findAction(String action);

void removeCredential();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
}

public enum ActionType {
PREDICT
PREDICT,
EXECUTE
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.regex.Pattern;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP;
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
Expand Down Expand Up @@ -307,10 +308,10 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
}

@Override
public <T> T createPredictPayload(Map<String, String> parameters) {
Optional<ConnectorAction> predictAction = findPredictAction();
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
String payload = predictAction.get().getRequestBody();
public <T> T createPayload(String action, Map<String, String> parameters) {
Optional<ConnectorAction> connectorAction = findAction(action);
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);
Expand Down Expand Up @@ -348,15 +349,15 @@ private List<String> findStringParametersWithNullDefaultValue(String input) {
}

@Override
public void decrypt(Function<String, String> function) {
public void decrypt(String action, Function<String, String> function) {
Map<String, String> decrypted = new HashMap<>();
for (String key : credential.keySet()) {
decrypted.put(key, function.apply(credential.get(key)));
}
this.decryptedCredential = decrypted;
Optional<ConnectorAction> predictAction = findPredictAction();
Map<String, String> headers = predictAction.isPresent() ? predictAction.get().getHeaders() : null;
this.decryptedHeaders = createPredictDecryptedHeaders(headers);
Optional<ConnectorAction> connectorAction = findAction(action);
Map<String, String> headers = connectorAction.isPresent() ? connectorAction.get().getHeaders() : null;
this.decryptedHeaders = createDecryptedHeaders(headers);
}

@Override
Expand All @@ -378,8 +379,9 @@ public void encrypt(Function<String, String> function) {
}
}

public String getPredictHttpMethod() {
return findPredictAction().get().getMethod();
@Override
public String getActionHttpMethod(String action) {
return findAction(action).get().getMethod();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.opensearch.action.ActionType;
import org.opensearch.ml.common.transport.MLTaskResponse;

public class MLExecuteConnectorAction extends ActionType<MLTaskResponse> {
public static final MLExecuteConnectorAction INSTANCE = new MLExecuteConnectorAction();
public static final String NAME = "cluster:admin/opensearch/ml/connectors/execute";

private MLExecuteConnectorAction() {
super(NAME, MLTaskResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(level = AccessLevel.PRIVATE)
@ToString
public class MLExecuteConnectorRequest extends MLTaskRequest {

String connectorId;
MLInput mlInput;

@Builder
public MLExecuteConnectorRequest(String connectorId, MLInput mlInput, boolean dispatchTask) {
super(dispatchTask);
this.mlInput = mlInput;
this.connectorId = connectorId;
}

public MLExecuteConnectorRequest(String connectorId, MLInput mlInput) {
this(connectorId, mlInput, true);
}

public MLExecuteConnectorRequest(StreamInput in) throws IOException {
super(in);
this.connectorId = in.readString();
this.mlInput = new MLInput(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.connectorId);
this.mlInput.writeTo(out);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (this.mlInput == null) {
exception = addValidationError("ML input can't be null", exception);
} else if (this.mlInput.getInputDataset() == null) {
exception = addValidationError("input data can't be null", exception);
}

return exception;
}


public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLExecuteConnectorRequest) {
return (MLExecuteConnectorRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLExecuteConnectorRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;

Expand Down Expand Up @@ -110,7 +111,7 @@ public void constructor_NoPredictAction() {
Assert.assertNotNull(connector);

connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals(null, connector.getSessionToken());
Expand Down Expand Up @@ -149,13 +150,13 @@ public void constructor() {

AwsConnector connector = createAwsConnector(parameters, credential, url);
connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
Assert.assertEquals("test_service", connector.getServiceName());
Assert.assertEquals("us-west-2", connector.getRegion());
Assert.assertEquals("https://test.com/model1", connector.getPredictEndpoint(parameters));
Assert.assertEquals("https://test.com/model1", connector.getActionEndpoint(PREDICT.name(), parameters));
}

@Test
Expand All @@ -170,13 +171,13 @@ public void constructor_NoParameter() {
String url = "https://test.com";
AwsConnector connector = createAwsConnector(null, credential, url);
connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SERVICE", connector.getServiceName());
Assert.assertEquals("decrypted: ENCRYPTED: US-WEST-2", connector.getRegion());
Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null));
Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;

public class HttpConnectorTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -119,7 +120,7 @@ public void cloneConnector() {
@Test
public void decrypt() {
HttpConnector connector = createHttpConnector();
connector.decrypt(decryptFunction);
connector.decrypt(PREDICT.name(), decryptFunction);
Map<String, String> decryptedCredential = connector.getDecryptedCredential();
Assert.assertEquals(1, decryptedCredential.size());
Assert.assertEquals("decrypted: TEST_KEY_VALUE", decryptedCredential.get("key"));
Expand Down Expand Up @@ -148,42 +149,42 @@ public void encrypted() {
}

@Test
public void getPredictEndpoint() {
public void getActionEndpoint() {
HttpConnector connector = createHttpConnector();
Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null));
Assert.assertEquals("https://test.com", connector.getActionEndpoint(PREDICT.name(), null));
}

@Test
public void getPredictHttpMethod() {
public void getActionHttpMethod() {
HttpConnector connector = createHttpConnector();
Assert.assertEquals("POST", connector.getPredictHttpMethod());
Assert.assertEquals("POST", connector.getActionHttpMethod(PREDICT.name()));
}

@Test
public void createPredictPayload_Invalid() {
public void createPayload_Invalid() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Some parameter placeholder not filled in payload: input");
HttpConnector connector = createHttpConnector();
String predictPayload = connector.createPredictPayload(null);
String predictPayload = connector.createPayload(PREDICT.name(), null);
connector.validatePayload(predictPayload);
}

@Test
public void createPredictPayload_InvalidJson() {
public void createPayload_InvalidJson() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Invalid payload: {\"input\": ${parameters.input} }");
String requestBody = "{\"input\": ${parameters.input} }";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
String predictPayload = connector.createPredictPayload(null);
String predictPayload = connector.createPayload(PREDICT.name(), null);
connector.validatePayload(predictPayload);
}

@Test
public void createPredictPayload() {
public void createPayload() {
HttpConnector connector = createHttpConnector();
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "test input value");
String predictPayload = connector.createPredictPayload(parameters);
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload);
}
Expand Down
Loading

0 comments on commit a0272f2

Please sign in to comment.