Skip to content

Commit

Permalink
register agent rest and transport actions (opensearch-project#1795)
Browse files Browse the repository at this point in the history
* register agent rest and transport actions

Signed-off-by: Xun Zhang <[email protected]>

* add the register agent action into the ml plugin

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Dec 22, 2023
1 parent 9ba7c60 commit b200122
Show file tree
Hide file tree
Showing 5 changed files with 466 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.agents;

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;

import java.time.Instant;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class TransportRegisterAgentAction extends HandledTransportAction<ActionRequest, MLRegisterAgentResponse> {
MLIndicesHandler mlIndicesHandler;
Client client;

@Inject
public TransportRegisterAgentAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
MLIndicesHandler mlIndicesHandler
) {
super(MLRegisterAgentAction.NAME, transportService, actionFilters, MLRegisterAgentRequest::new);
this.client = client;
this.mlIndicesHandler = mlIndicesHandler;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterAgentResponse> listener) {
User user = RestActionUtils.getUserContext(client);// TODO: check access
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest(request);
MLAgent mlAgent = registerAgentRequest.getMlAgent();
registerAgent(mlAgent, listener);
}

private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse> listener) {
Instant now = Instant.now();
MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build();
mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> {
if (result) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
IndexRequest indexRequest = new IndexRequest(ML_AGENT_INDEX);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlAgent.toXContent(builder, ToXContent.EMPTY_PARAMS);
indexRequest.source(builder);
client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
listener.onResponse(new MLRegisterAgentResponse(r.getId()));
}, e -> {
log.error("Failed to index ML agent", e);
listener.onFailure(e);
}), context::restore));
} catch (Exception e) {
log.error("Failed to index ML agent", e);
listener.onFailure(e);
}
} else {
log.error("Failed to create ML agent index");
listener.onFailure(new OpenSearchException("Failed to create ML agent index"));
}
}, e -> {
log.error("Failed to create ML agent index", e);
listener.onFailure(e);
}));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ml.action.agents.DeleteAgentTransportAction;
import org.opensearch.ml.action.agents.GetAgentTransportAction;
import org.opensearch.ml.action.agents.TransportRegisterAgentAction;
import org.opensearch.ml.action.connector.DeleteConnectorTransportAction;
import org.opensearch.ml.action.connector.GetConnectorTransportAction;
import org.opensearch.ml.action.connector.SearchConnectorTransportAction;
Expand Down Expand Up @@ -95,6 +96,7 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentGetAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorGetAction;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
Expand Down Expand Up @@ -181,6 +183,7 @@
import org.opensearch.ml.rest.RestMLGetTaskAction;
import org.opensearch.ml.rest.RestMLPredictionAction;
import org.opensearch.ml.rest.RestMLProfileAction;
import org.opensearch.ml.rest.RestMLRegisterAgentAction;
import org.opensearch.ml.rest.RestMLRegisterModelAction;
import org.opensearch.ml.rest.RestMLRegisterModelGroupAction;
import org.opensearch.ml.rest.RestMLRegisterModelMetaAction;
Expand Down Expand Up @@ -336,6 +339,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc
new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class),
new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class),
new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class),
new ActionHandler<>(MLRegisterAgentAction.INSTANCE, TransportRegisterAgentAction.class),
new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class),
new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class),
new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class),
Expand Down Expand Up @@ -576,6 +580,7 @@ public List<RestHandler> getRestHandlers(
settings,
mlFeatureEnabledSetting
);
RestMLRegisterAgentAction restMLRegisterAgentAction = new RestMLRegisterAgentAction();
RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction();
RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings);
Expand Down Expand Up @@ -621,6 +626,7 @@ public List<RestHandler> getRestHandlers(
restMLSearchTaskAction,
restMLProfileAction,
restMLRegisterModelAction,
restMLRegisterAgentAction,
restMLDeployModelAction,
restMLUndeployModelAction,
restMLRegisterModelMetaAction,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

public class RestMLRegisterAgentAction extends BaseRestHandler {
private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action";

/**
* Constructor
*/
public RestMLRegisterAgentAction() {}

@Override
public String getName() {
return ML_REGISTER_AGENT_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/_register", ML_BASE_URI)));
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLRegisterAgentRequest registerAgentRequest = getRequest(request);
return channel -> client.execute(MLRegisterAgentAction.INSTANCE, registerAgentRequest, new RestToXContentListener<>(channel));
}

/**
* Creates a MLTrainingTaskRequest from a RestRequest
*
* @param request RestRequest
* @return MLTrainingTaskRequest
*/
@VisibleForTesting
MLRegisterAgentRequest getRequest(RestRequest request) throws IOException {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLAgent mlAgent = MLAgent.parse(parser);
return new MLRegisterAgentRequest(mlAgent);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.agents;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.HashMap;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchException;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class RegisterAgentTransportActionTests extends OpenSearchTestCase {

@Mock
private Client client;

@Mock
private MLIndicesHandler mlIndicesHandler;

@Mock
private ActionFilters actionFilters;

@Mock
private TransportService transportService;

@Mock
private Task task;

@Mock
private ActionListener<MLRegisterAgentResponse> actionListener;

@Mock
private ThreadPool threadPool;

private TransportRegisterAgentAction transportRegisterAgentAction;
private Settings settings;
private ThreadContext threadContext;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
settings = Settings.builder().build();
threadContext = new ThreadContext(settings);

threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations");

when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
transportRegisterAgentAction = new TransportRegisterAgentAction(transportService, actionFilters, client, mlIndicesHandler);
}

public void test_execute_registerAgent_success() {
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
MLAgent mlAgent = MLAgent
.builder()
.name("agent")
.type("some type")
.description("description")
.llm(new LLMSpec("model_id", new HashMap<>()))
.build();
when(request.getMlAgent()).thenReturn(mlAgent);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLAgentIndex(any());

doAnswer(invocation -> {
ActionListener<IndexResponse> al = invocation.getArgument(1);
IndexResponse indexResponse = new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true);
al.onResponse(indexResponse);
return null;
}).when(client).index(any(), any());

transportRegisterAgentAction.doExecute(task, request, actionListener);
ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void test_execute_registerAgent_AgentIndexNotInitialized() {
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
MLAgent mlAgent = MLAgent
.builder()
.name("agent")
.type("some type")
.description("description")
.llm(new LLMSpec("model_id", new HashMap<>()))
.build();
when(request.getMlAgent()).thenReturn(mlAgent);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(false);
return null;
}).when(mlIndicesHandler).initMLAgentIndex(any());

transportRegisterAgentAction.doExecute(task, request, actionListener);
ArgumentCaptor<OpenSearchException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to create ML agent index", argumentCaptor.getValue().getMessage());
}

public void test_execute_registerAgent_IndexFailure() {
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
MLAgent mlAgent = MLAgent
.builder()
.name("agent")
.type("some type")
.description("description")
.llm(new LLMSpec("model_id", new HashMap<>()))
.build();
when(request.getMlAgent()).thenReturn(mlAgent);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLAgentIndex(any());

doAnswer(invocation -> {
ActionListener<IndexResponse> al = invocation.getArgument(1);
al.onFailure(new RuntimeException("index failure"));
return null;
}).when(client).index(any(), any());

transportRegisterAgentAction.doExecute(task, request, actionListener);
ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());

assertEquals("index failure", argumentCaptor.getValue().getMessage());
}

public void test_execute_registerAgent_InitAgentIndexFailure() {
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
MLAgent mlAgent = MLAgent
.builder()
.name("agent")
.type("some type")
.description("description")
.llm(new LLMSpec("model_id", new HashMap<>()))
.build();
when(request.getMlAgent()).thenReturn(mlAgent);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onFailure(new RuntimeException("agent index initialization failed"));
return null;
}).when(mlIndicesHandler).initMLAgentIndex(any());

transportRegisterAgentAction.doExecute(task, request, actionListener);
ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("agent index initialization failed", argumentCaptor.getValue().getMessage());
}
}
Loading

0 comments on commit b200122

Please sign in to comment.