From b200122128955b5238fcf688416b9434320bd657 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 21 Dec 2023 16:39:24 -0800 Subject: [PATCH] register agent rest and transport actions (#1795) * register agent rest and transport actions Signed-off-by: Xun Zhang * add the register agent action into the ml plugin Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../agents/TransportRegisterAgentAction.java | 91 +++++++++ .../ml/plugin/MachineLearningPlugin.java | 6 + .../ml/rest/RestMLRegisterAgentAction.java | 64 ++++++ .../RegisterAgentTransportActionTests.java | 184 ++++++++++++++++++ .../RegisterAgentTransportActionTests.java | 121 ++++++++++++ 5 files changed, 466 insertions(+) create mode 100644 plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java new file mode 100644 index 0000000000..fa87a0d1de --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; + +import java.time.Instant; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class TransportRegisterAgentAction extends HandledTransportAction { + MLIndicesHandler mlIndicesHandler; + Client client; + + @Inject + public TransportRegisterAgentAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + MLIndicesHandler mlIndicesHandler + ) { + super(MLRegisterAgentAction.NAME, transportService, actionFilters, MLRegisterAgentRequest::new); + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + User user = RestActionUtils.getUserContext(client);// TODO: check access + MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest(request); + MLAgent mlAgent = registerAgentRequest.getMlAgent(); + registerAgent(mlAgent, listener); + } + + private void registerAgent(MLAgent agent, ActionListener listener) { + Instant now = Instant.now(); + MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build(); + mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> { + if (result) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + IndexRequest indexRequest = new IndexRequest(ML_AGENT_INDEX); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlAgent.toXContent(builder, ToXContent.EMPTY_PARAMS); + indexRequest.source(builder); + client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + listener.onResponse(new MLRegisterAgentResponse(r.getId())); + }, e -> { + log.error("Failed to index ML agent", e); + listener.onFailure(e); + }), context::restore)); + } catch (Exception e) { + log.error("Failed to index ML agent", e); + listener.onFailure(e); + } + } else { + log.error("Failed to create ML agent index"); + listener.onFailure(new OpenSearchException("Failed to create ML agent index")); + } + }, e -> { + log.error("Failed to create ML agent index", e); + listener.onFailure(e); + })); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 5fb925e077..f5ef7708b3 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -38,6 +38,7 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.ml.action.agents.DeleteAgentTransportAction; import org.opensearch.ml.action.agents.GetAgentTransportAction; +import org.opensearch.ml.action.agents.TransportRegisterAgentAction; import org.opensearch.ml.action.connector.DeleteConnectorTransportAction; import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; @@ -95,6 +96,7 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; import org.opensearch.ml.common.transport.agent.MLAgentGetAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; @@ -181,6 +183,7 @@ import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLProfileAction; +import org.opensearch.ml.rest.RestMLRegisterAgentAction; import org.opensearch.ml.rest.RestMLRegisterModelAction; import org.opensearch.ml.rest.RestMLRegisterModelGroupAction; import org.opensearch.ml.rest.RestMLRegisterModelMetaAction; @@ -336,6 +339,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class), new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class), new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class), + new ActionHandler<>(MLRegisterAgentAction.INSTANCE, TransportRegisterAgentAction.class), new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), @@ -576,6 +580,7 @@ public List getRestHandlers( settings, mlFeatureEnabledSetting ); + RestMLRegisterAgentAction restMLRegisterAgentAction = new RestMLRegisterAgentAction(); RestMLDeployModelAction restMLDeployModelAction = new RestMLDeployModelAction(); RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); @@ -621,6 +626,7 @@ public List getRestHandlers( restMLSearchTaskAction, restMLProfileAction, restMLRegisterModelAction, + restMLRegisterAgentAction, restMLDeployModelAction, restMLUndeployModelAction, restMLRegisterModelMetaAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java new file mode 100644 index 0000000000..068fbc0b72 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLRegisterAgentAction extends BaseRestHandler { + private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action"; + + /** + * Constructor + */ + public RestMLRegisterAgentAction() {} + + @Override + public String getName() { + return ML_REGISTER_AGENT_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/_register", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLRegisterAgentRequest registerAgentRequest = getRequest(request); + return channel -> client.execute(MLRegisterAgentAction.INSTANCE, registerAgentRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLTrainingTaskRequest from a RestRequest + * + * @param request RestRequest + * @return MLTrainingTaskRequest + */ + @VisibleForTesting + MLRegisterAgentRequest getRequest(RestRequest request) throws IOException { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + return new MLRegisterAgentRequest(mlAgent); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java new file mode 100644 index 0000000000..592bca752d --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java @@ -0,0 +1,184 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.HashMap; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class RegisterAgentTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private ActionFilters actionFilters; + + @Mock + private TransportService transportService; + + @Mock + private Task task; + + @Mock + private ActionListener actionListener; + + @Mock + private ThreadPool threadPool; + + private TransportRegisterAgentAction transportRegisterAgentAction; + private Settings settings; + private ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + transportRegisterAgentAction = new TransportRegisterAgentAction(transportService, actionFilters, client, mlIndicesHandler); + } + + public void test_execute_registerAgent_success() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + MLAgent mlAgent = MLAgent + .builder() + .name("agent") + .type("some type") + .description("description") + .llm(new LLMSpec("model_id", new HashMap<>())) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + IndexResponse indexResponse = new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true); + al.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void test_execute_registerAgent_AgentIndexNotInitialized() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + MLAgent mlAgent = MLAgent + .builder() + .name("agent") + .type("some type") + .description("description") + .llm(new LLMSpec("model_id", new HashMap<>())) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(false); + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to create ML agent index", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_registerAgent_IndexFailure() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + MLAgent mlAgent = MLAgent + .builder() + .name("agent") + .type("some type") + .description("description") + .llm(new LLMSpec("model_id", new HashMap<>())) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new RuntimeException("index failure")); + return null; + }).when(client).index(any(), any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + + assertEquals("index failure", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_registerAgent_InitAgentIndexFailure() { + MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); + MLAgent mlAgent = MLAgent + .builder() + .name("agent") + .type("some type") + .description("description") + .llm(new LLMSpec("model_id", new HashMap<>())) + .build(); + when(request.getMlAgent()).thenReturn(mlAgent); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("agent index initialization failed")); + return null; + }).when(mlIndicesHandler).initMLAgentIndex(any()); + + transportRegisterAgentAction.doExecute(task, request, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("agent index initialization failed", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java new file mode 100644 index 0000000000..b8fb97fb85 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RegisterAgentTransportActionTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.agent.LLMSpec.MODEL_ID_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.AGENT_NAME_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.AGENT_TYPE_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.CREATED_TIME_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.DESCRIPTION_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.LAST_UPDATED_TIME_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.LLM_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.MEMORY_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.MEMORY_ID_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.PARAMETERS_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.TOOLS_FIELD; +import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_TYPE_FIELD; +import static org.opensearch.ml.common.agent.MLMemorySpec.SESSION_ID_FIELD; +import static org.opensearch.ml.common.agent.MLMemorySpec.WINDOW_SIZE_FIELD; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class RegisterAgentTransportActionTests extends OpenSearchTestCase { + + Gson gson; + Instant time; + + @Before + public void setup() { + gson = new Gson(); + time = Instant.ofEpochMilli(123); + } + + public void test_GetName_Routes() { + RestMLRegisterAgentAction action = new RestMLRegisterAgentAction(); + assert (action.getName().equals("ml_register_agent_action")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.POST, "/_plugins/_ml/agents/_register"))); + } + + public void testPrepareRequest() throws Exception { + RestMLRegisterAgentAction action = new RestMLRegisterAgentAction(); + final Map llmSpec = Map.of(MODEL_ID_FIELD, "id", PARAMETERS_FIELD, new HashMap<>()); + final Map memorySpec = Map.of(MEMORY_TYPE_FIELD, "conversation", SESSION_ID_FIELD, "sid", WINDOW_SIZE_FIELD, 2); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray( + gson + .toJson( + Map + .of( + AGENT_NAME_FIELD, + "agent-name", + AGENT_TYPE_FIELD, + "agent-type", + DESCRIPTION_FIELD, + "description", + LLM_FIELD, + llmSpec, + TOOLS_FIELD, + new ArrayList<>(), + PARAMETERS_FIELD, + new HashMap<>(), + MEMORY_FIELD, + memorySpec, + MEMORY_ID_FIELD, + "memory_id", + CREATED_TIME_FIELD, + time.getEpochSecond(), + LAST_UPDATED_TIME_FIELD, + time.getEpochSecond() + ) + ) + ), + MediaTypeRegistry.JSON + ) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentRequest.class); + verify(client, times(1)).execute(eq(MLRegisterAgentAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().getMlAgent().getName().equals("agent-name")); + assert (argumentCaptor.getValue().getMlAgent().getType().equals("agent-type")); + assert (argumentCaptor.getValue().getMlAgent().getDescription().equals("description")); + assert (argumentCaptor.getValue().getMlAgent().getTools().equals(new ArrayList<>())); + assert (argumentCaptor.getValue().getMlAgent().getLlm().getModelId().equals("id")); + assert (argumentCaptor.getValue().getMlAgent().getParameters().equals(new HashMap<>())); + assert (argumentCaptor.getValue().getMlAgent().getMemory().getType().equals("conversation")); + } +}