Skip to content

Commit

Permalink
Add Delete Agent to MLClient (opensearch-project#1731)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Dec 4, 2023
1 parent d3f222d commit 8bda65a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,17 @@ default ActionFuture<MLRegisterAgentResponse> registerAgent(MLAgent mlAgent) {
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
*/
void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener);

/**
* Delete agent
* @param agentId The id of the agent to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, actionFuture);
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
Expand Down Expand Up @@ -282,6 +284,14 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
client.execute(MLRegisterAgentAction.INSTANCE, mlRegisterAgentRequest, getMLRegisterAgentResponseActionListener(listener));
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
listener.onResponse(registerAgentResponse);
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}
};
}

Expand Down Expand Up @@ -405,4 +410,9 @@ public void testRegisterAgent() {
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}

@Test
public void deleteAgent() {
assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
Expand Down Expand Up @@ -174,6 +176,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<MLRegisterAgentResponse> registerAgentResponseActionListener;

@Mock
ActionListener<DeleteResponse> deleteAgentActionListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -768,6 +773,28 @@ public void testRegisterAgent() {
assertEquals(agentId, (argumentCaptor.getValue()).getAgentId());
}

@Test
public void deleteAgent() {

String agentId = "agentId";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, agentId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLAgentDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);

machineLearningNodeClient.deleteAgent(agentId, deleteAgentActionListener);

verify(client).execute(eq(MLAgentDeleteAction.INSTANCE), isA(MLAgentDeleteRequest.class), any());
verify(deleteAgentActionListener).onResponse(argumentCaptor.capture());
assertEquals(agentId, (argumentCaptor.getValue()).getId());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down

0 comments on commit 8bda65a

Please sign in to comment.