-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added initial implementation of create connector
Signed-off-by: Owais Kazi <[email protected]>
- Loading branch information
1 parent
9784f41
commit 4e7670f
Showing
4 changed files
with
252 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.client.MLClient; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.connector.ConnectorAction; | ||
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; | ||
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; | ||
|
||
import java.io.IOException; | ||
import java.security.AccessController; | ||
import java.security.PrivilegedActionException; | ||
import java.security.PrivilegedExceptionAction; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Stream; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.*; | ||
|
||
public class CreateConnectorStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); | ||
|
||
private Client client; | ||
|
||
static final String NAME = "create_connector"; | ||
|
||
/** | ||
* Instantiate this class | ||
* @param client client to instantiate MLClient | ||
*/ | ||
public CreateConnectorStep(Client client) { | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) throws IOException { | ||
CompletableFuture<WorkflowData> createConnectorFuture = new CompletableFuture<>(); | ||
|
||
MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); | ||
|
||
ActionListener<MLCreateConnectorResponse> actionListener = new ActionListener<>() { | ||
|
||
@Override | ||
public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { | ||
logger.info("Created connector successfully"); | ||
createConnectorFuture.complete( | ||
new WorkflowData(Map.ofEntries(Map.entry("connector-id", mlCreateConnectorResponse.getConnectorId()))) | ||
); | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
logger.error("Failed to create connector"); | ||
createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); | ||
} | ||
}; | ||
|
||
String name = null; | ||
String description = null; | ||
String version = null; | ||
String protocol = null; | ||
Map<String, String> parameters = new HashMap<>(); | ||
Map<String, String> credentials = new HashMap<>(); | ||
List<ConnectorAction> actions = null; | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, Object> content = workflowData.getContent(); | ||
|
||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case NAME_FIELD: | ||
name = (String) content.get(NAME_FIELD); | ||
break; | ||
case DESCRIPTION: | ||
description = (String) content.get(DESCRIPTION); | ||
break; | ||
case VERSION_FIELD: | ||
version = (String) content.get(VERSION_FIELD); | ||
break; | ||
case PROTOCOL_FIELD: | ||
protocol = (String) content.get(PROTOCOL_FIELD); | ||
case PARAMETERS_FIELD: | ||
parameters = getParameterMap((Map<String, String>) content.get(PARAMETERS_FIELD)); | ||
case CREDENTIALS_FIELD: | ||
credentials = (Map<String, String>) content.get(CREDENTIALS_FIELD); | ||
case ACTIONS_FIELD: | ||
actions = (List<ConnectorAction>) content.get(ACTIONS_FIELD); | ||
} | ||
|
||
} | ||
} | ||
|
||
if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { | ||
MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder() | ||
.name(name) | ||
.description(description) | ||
.version(version) | ||
.protocol(protocol) | ||
.parameters(parameters) | ||
.credential(credentials) | ||
.actions(actions) | ||
.build(); | ||
|
||
machineLearningNodeClient.createConnector(mlInput, actionListener); | ||
} | ||
|
||
return createConnectorFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
|
||
private static Map<String, String> getParameterMap(Map<String, String> params) { | ||
|
||
Map<String, String> parameters = new HashMap<>(); | ||
for (String key : params.keySet()) { | ||
String value = params.get(key); | ||
try { | ||
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> { | ||
parameters.put(key, value); | ||
return null; | ||
}); | ||
} catch (PrivilegedActionException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
return parameters; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.opensearch.client.Client; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; | ||
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import org.mockito.Answers; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Mock; | ||
import org.mockito.MockitoAnnotations; | ||
|
||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.doAnswer; | ||
import static org.mockito.Mockito.verify; | ||
|
||
public class CreateConnectorStepTests extends OpenSearchTestCase { | ||
private WorkflowData inputData = WorkflowData.EMPTY; | ||
|
||
@Mock(answer = Answers.RETURNS_DEEP_STUBS) | ||
private Client client; | ||
|
||
@Mock | ||
ActionListener<MLCreateConnectorResponse> registerModelActionListener; | ||
|
||
@Mock | ||
MachineLearningNodeClient machineLearningNodeClient; | ||
|
||
@Override | ||
public void setUp() throws Exception { | ||
super.setUp(); | ||
|
||
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); | ||
Map<String, String> credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); | ||
|
||
MockitoAnnotations.openMocks(this); | ||
|
||
inputData = new WorkflowData( | ||
Map.ofEntries( | ||
Map.entry("name", "test"), | ||
Map.entry("description", "description"), | ||
Map.entry("version", "1"), | ||
Map.entry("protocol", "test"), | ||
Map.entry("params", params), | ||
Map.entry("credentials", credentials), | ||
Map.entry("actions", List.of("actions")) | ||
) | ||
); | ||
|
||
} | ||
|
||
public void testCreateConnector() throws IOException { | ||
|
||
String connectorId = "connect"; | ||
CreateConnectorStep createConnectorStep = new CreateConnectorStep(client); | ||
|
||
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2); | ||
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); | ||
actionListener.onResponse(output); | ||
return null; | ||
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); | ||
|
||
CompletableFuture<WorkflowData> future = createConnectorStep.execute(List.of(inputData)); | ||
|
||
verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); | ||
|
||
} | ||
|
||
} |