Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature/agent_framework] Add Delete Connector Step #211

Merged
7 changes: 6 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Spotless requires JDK 17+
- uses: actions/setup-java@v4
with:
java-version: 17
distribution: temurin
- name: Spotless Check
run: ./gradlew spotlessCheck
build:
Expand Down Expand Up @@ -41,7 +46,7 @@ jobs:
- uses: actions/checkout@v4
- name: Build and Run Tests
run: |
./gradlew check
./gradlew check -x spotlessJava
- name: Upload Coverage Report
if: matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v3
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ dependencies {
configurations.all {
resolutionStrategy {
force("com.google.guava:guava:32.1.3-jre") // CVE for 31.1
force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // CVE for < 3.29.0
force("com.fasterxml.jackson.core:jackson-core:2.16.0") // Dependency Jar Hell
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;

/**
* Step to delete a connector for a remote model
*/
public class DeleteConnectorStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(DeleteConnectorStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "delete_connector";

/**
* Instantiate this class
* @param mlClient Machine Learning client to perform the deletion
*/
public DeleteConnectorStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
CompletableFuture<WorkflowData> deleteConnectorFuture = new CompletableFuture<>();

ActionListener<DeleteResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteConnectorFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to delete connector");
deleteConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String connectorId = null;

// Previous Node inputs defines which step the connector ID came from
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> CONNECTOR_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();
if (previousNode.isPresent()) {
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) {
connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString();
}
}

if (connectorId != null) {
mlClient.deleteConnector(connectorId, actionListener);
} else {
deleteConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST)
);
}

return deleteConnectorFuture;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
public class WorkflowStepFactory {

private final Map<String, WorkflowStep> stepMap = new HashMap<>();
private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;

/**
* Instantiate this class.
Expand All @@ -42,17 +41,6 @@ public WorkflowStepFactory(
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler);
}

private void populateMap(
Settings settings,
ClusterService clusterService,
Client client,
MachineLearningNodeClient mlClient,
FlowFrameworkIndicesHandler flowFrameworkIndicesHandler
) {
stepMap.put(NoOpStep.NAME, new NoOpStep());
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client));
Expand All @@ -61,6 +49,7 @@ private void populateMap(
stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient));
stepMap.put(GetMLTaskStep.NAME, new GetMLTaskStep(settings, clusterService, mlClient));
}
Expand All @@ -79,7 +68,7 @@ public WorkflowStep createStep(String type) {

/**
* Gets the step map
* @return the step map
* @return a read-only copy of the step map
*/
public Map<String, WorkflowStep> getStepMap() {
return Map.copyOf(this.stepMap);
Expand Down
8 changes: 8 additions & 0 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@
"connector_id"
]
},
"delete_connector": {
"inputs": [
"connector_id"
],
"outputs":[
"connector_id"
]
},
"register_local_model": {
"inputs":[
"name",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand Down Expand Up @@ -81,15 +80,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
String connectorId = "connect";
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(
inputData.getNodeId(),
Expand All @@ -98,8 +94,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
Collections.emptyMap()
);

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());
assertTrue(future.isDone());
assertEquals(connectorId, future.get().getContent().get("connector_id"));

Expand All @@ -108,14 +103,11 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr
public void testCreateConnectorFailure() throws IOException {
CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler);

@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<MLCreateConnectorResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());
}).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

CompletableFuture<WorkflowData> future = createConnectorStep.execute(
inputData.getNodeId(),
Expand All @@ -124,7 +116,7 @@ public void testCreateConnectorFailure() throws IOException {
Collections.emptyMap()
);

verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture());
verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class DeleteConnectorStepTests extends OpenSearchTestCase {
private WorkflowData inputData;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id");
}

public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException {

String connectorId = randomAlphaOfLength(5);
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

doAnswer(invocation -> {
String connectorIdArg = invocation.getArgument(0);
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, connectorIdArg, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).deleteConnector(any(String.class), any());

CompletableFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of("connector_id", connectorId), "workflowId", "nodeId")),
Map.of("step_1", "connector_id")
);
verify(machineLearningNodeClient).deleteConnector(any(String.class), any());

assertTrue(future.isDone());
assertEquals(connectorId, future.get().getContent().get("connector_id"));
}

public void testNoConnectorIdInOutput() throws IOException {
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

CompletableFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Required field connector_id is not provided", ex.getCause().getMessage());
}

public void testDeleteConnectorFailure() throws IOException {
DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient);

doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).deleteConnector(any(String.class), any());

CompletableFuture<WorkflowData> future = deleteConnectorStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of("connector_id", "test"), "workflowId", "nodeId")),
Map.of("step_1", "connector_id")
);

verify(machineLearningNodeClient).deleteConnector(any(String.class), any());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Failed to delete connector", ex.getCause().getMessage());
}
}
Loading