Skip to content

Commit

Permalink
Added initial implementation of create connector
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 19, 2023
1 parent 9784f41 commit 4e7670f
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 6 deletions.
14 changes: 12 additions & 2 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ private CommonValue() {}
public static final String MODEL_ID = "model_id";
/** Function Name field */
public static final String FUNCTION_NAME = "function_name";
/** Model Name field */
public static final String MODEL_NAME = "name";
/** Name field */
public static final String NAME_FIELD = "name";
/** Model Version field */
public static final String MODEL_VERSION = "model_version";
/** Model Group Id field */
Expand All @@ -62,4 +62,14 @@ private CommonValue() {}
public static final String MODEL_FORMAT = "model_format";
/** Model config field */
public static final String MODEL_CONFIG = "model_config";
/** Version field */
public static final String VERSION_FIELD = "version";
/** Connector protocol field */
public static final String PROTOCOL_FIELD = "protocol";
/** Connector parameters field */
public static final String PARAMETERS_FIELD = "parameters";
/** Connector credentials field */
public static final String CREDENTIALS_FIELD = "credentials";
/** Connector actions field */
public static final String ACTIONS_FIELD = "actions";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.*;

public class CreateConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class);

private Client client;

static final String NAME = "create_connector";

/**
* Instantiate this class
* @param client client to instantiate MLClient
*/
public CreateConnectorStep(Client client) {
this.client = client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException {
CompletableFuture<WorkflowData> createConnectorFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLCreateConnectorResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) {
logger.info("Created connector successfully");
createConnectorFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("connector-id", mlCreateConnectorResponse.getConnectorId())))
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to create connector");
createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR));
}
};

String name = null;
String description = null;
String version = null;
String protocol = null;
Map<String, String> parameters = new HashMap<>();
Map<String, String> credentials = new HashMap<>();
List<ConnectorAction> actions = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) content.get(NAME_FIELD);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case VERSION_FIELD:
version = (String) content.get(VERSION_FIELD);
break;
case PROTOCOL_FIELD:
protocol = (String) content.get(PROTOCOL_FIELD);
case PARAMETERS_FIELD:
parameters = getParameterMap((Map<String, String>) content.get(PARAMETERS_FIELD));
case CREDENTIALS_FIELD:
credentials = (Map<String, String>) content.get(CREDENTIALS_FIELD);
case ACTIONS_FIELD:
actions = (List<ConnectorAction>) content.get(ACTIONS_FIELD);
}

}
}

if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) {
MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder()
.name(name)
.description(description)
.version(version)
.protocol(protocol)
.parameters(parameters)
.credential(credentials)
.actions(actions)
.build();

machineLearningNodeClient.createConnector(mlInput, actionListener);
}

return createConnectorFuture;
}

@Override
public String getName() {
return NAME;
}

private static Map<String, String> getParameterMap(Map<String, String> params) {

Map<String, String> parameters = new HashMap<>();
for (String key : params.keySet()) {
String value = params.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
parameters.put(key, value);
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID;
import static org.opensearch.flowframework.common.CommonValue.MODEL_NAME;
import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;

/**
* Step to register a remote model
Expand Down Expand Up @@ -80,7 +80,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
@Override
public void onFailure(Exception e) {
logger.error("Failed to register model");
registerModelFuture.completeExceptionally(new IOException("Failed to register model "));
registerModelFuture.completeExceptionally(new IOException("Failed to register model"));
}
};

Expand All @@ -101,8 +101,8 @@ public void onFailure(Exception e) {
case FUNCTION_NAME:
functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
break;
case MODEL_NAME:
modelName = (String) content.get(MODEL_NAME);
case NAME_FIELD:
modelName = (String) content.get(NAME_FIELD);
break;
case MODEL_VERSION:
modelVersion = (String) content.get(MODEL_VERSION);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class CreateConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private Client client;

@Mock
ActionListener<MLCreateConnectorResponse> registerModelActionListener;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2"));

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "test"),
Map.entry("description", "description"),
Map.entry("version", "1"),
Map.entry("protocol", "test"),
Map.entry("params", params),
Map.entry("credentials", credentials),
Map.entry("actions", List.of("actions"))
)
);

}

public void testCreateConnector() throws IOException {

String connectorId = "connect";
CreateConnectorStep createConnectorStep = new CreateConnectorStep(client);

ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(List.of(inputData));

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

}

}

0 comments on commit 4e7670f

Please sign in to comment.