diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java new file mode 100644 index 0000000000..de7caed4dd --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agents; + +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; + +import java.time.Instant; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.ml.task.MLTaskDispatcher; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class TransportRegisterAgentAction extends HandledTransportAction { + TransportService transportService; + ModelHelper modelHelper; + MLIndicesHandler mlIndicesHandler; + MLModelManager mlModelManager; + MLTaskManager mlTaskManager; + ClusterService clusterService; + ThreadPool threadPool; + Client client; + DiscoveryNodeHelper nodeFilter; + MLTaskDispatcher mlTaskDispatcher; + MLStats mlStats; + ModelAccessControlHelper modelAccessControlHelper; + ConnectorAccessControlHelper connectorAccessControlHelper; + MLModelGroupManager mlModelGroupManager; + + @Inject + public TransportRegisterAgentAction( + TransportService transportService, + ActionFilters actionFilters, + ModelHelper modelHelper, + MLIndicesHandler mlIndicesHandler, + MLModelManager mlModelManager, + MLTaskManager mlTaskManager, + ClusterService clusterService, + Settings settings, + ThreadPool threadPool, + Client client, + DiscoveryNodeHelper nodeFilter, + MLTaskDispatcher mlTaskDispatcher, + MLStats mlStats, + ModelAccessControlHelper modelAccessControlHelper, + ConnectorAccessControlHelper connectorAccessControlHelper, + MLModelGroupManager mlModelGroupManager + ) { + super(MLRegisterAgentAction.NAME, transportService, actionFilters, MLRegisterAgentRequest::new); + this.transportService = transportService; + this.modelHelper = modelHelper; + this.mlIndicesHandler = mlIndicesHandler; + this.mlModelManager = mlModelManager; + this.mlTaskManager = mlTaskManager; + this.clusterService = clusterService; + this.threadPool = threadPool; + this.client = client; + this.nodeFilter = nodeFilter; + this.mlTaskDispatcher = mlTaskDispatcher; + this.mlStats = mlStats; + this.modelAccessControlHelper = modelAccessControlHelper; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelGroupManager = mlModelGroupManager; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + User user = RestActionUtils.getUserContext(client);// TODO: check access + MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest(request); + MLAgent mlAgent = registerAgentRequest.getMlAgent(); + registerAgent(mlAgent, listener); + } + + private void registerAgent(MLAgent agent, ActionListener listener) { + Instant now = Instant.now(); + MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build(); + mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> { + if (result) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + IndexRequest indexRequest = new IndexRequest(ML_AGENT_INDEX); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlAgent.toXContent(builder, ToXContent.EMPTY_PARAMS); + indexRequest.source(builder); + client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> { + listener.onResponse(new MLRegisterAgentResponse(r.getId())); + }, e -> { + log.error("Failed to index ML agent", e); + listener.onFailure(e); + }), context::restore)); + } catch (Exception e) { + log.error("Failed to index ML agent", e); + listener.onFailure(e); + } + } else { + log.error("Failed to create ML agent index"); + } + }, e -> { + log.error("Failed to create ML agent index", e); + listener.onFailure(e); + })); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java new file mode 100644 index 0000000000..068fbc0b72 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLRegisterAgentAction extends BaseRestHandler { + private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action"; + + /** + * Constructor + */ + public RestMLRegisterAgentAction() {} + + @Override + public String getName() { + return ML_REGISTER_AGENT_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/_register", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLRegisterAgentRequest registerAgentRequest = getRequest(request); + return channel -> client.execute(MLRegisterAgentAction.INSTANCE, registerAgentRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLTrainingTaskRequest from a RestRequest + * + * @param request RestRequest + * @return MLTrainingTaskRequest + */ + @VisibleForTesting + MLRegisterAgentRequest getRequest(RestRequest request) throws IOException { + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + return new MLRegisterAgentRequest(mlAgent); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterAgentActionTests.java new file mode 100644 index 0000000000..f1a6958bb6 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterAgentActionTests.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.common.agent.LLMSpec.MODEL_ID_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.AGENT_NAME_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.AGENT_TYPE_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.CREATED_TIME_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.DESCRIPTION_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.LAST_UPDATED_TIME_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.LLM_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.MEMORY_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.MEMORY_ID_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.PARAMETERS_FIELD; +import static org.opensearch.ml.common.agent.MLAgent.TOOLS_FIELD; +import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_TYPE_FIELD; +import static org.opensearch.ml.common.agent.MLMemorySpec.SESSION_ID_FIELD; +import static org.opensearch.ml.common.agent.MLMemorySpec.WINDOW_SIZE_FIELD; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; +import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class RestMLRegisterAgentActionTests extends OpenSearchTestCase { + + Gson gson; + Instant time; + + @Before + public void setup() { + gson = new Gson(); + time = Instant.ofEpochMilli(123); + } + + public void test_GetName_Routes() { + RestMLRegisterAgentAction action = new RestMLRegisterAgentAction(); + assert (action.getName().equals("ml_register_agent_action")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.POST, "/_plugins/_ml/agents/_register"))); + } + + public void testPrepareRequest() throws Exception { + RestMLRegisterAgentAction action = new RestMLRegisterAgentAction(); + final Map llmSpec = Map.of(MODEL_ID_FIELD, "id", PARAMETERS_FIELD, new HashMap<>()); + final Map memorySpec = Map.of(MEMORY_TYPE_FIELD, "conversation", SESSION_ID_FIELD, "sid", WINDOW_SIZE_FIELD, 2); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray( + gson + .toJson( + Map + .of( + AGENT_NAME_FIELD, + "agent-name", + AGENT_TYPE_FIELD, + "agent-type", + DESCRIPTION_FIELD, + "description", + LLM_FIELD, + llmSpec, + TOOLS_FIELD, + new ArrayList<>(), + PARAMETERS_FIELD, + new HashMap<>(), + MEMORY_FIELD, + memorySpec, + MEMORY_ID_FIELD, + "memory_id", + CREATED_TIME_FIELD, + time.getEpochSecond(), + LAST_UPDATED_TIME_FIELD, + time.getEpochSecond() + ) + ) + ), + MediaTypeRegistry.JSON + ) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentRequest.class); + verify(client, times(1)).execute(eq(MLRegisterAgentAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().getMlAgent().getName().equals("agent-name")); + assert (argumentCaptor.getValue().getMlAgent().getType().equals("agent-type")); + assert (argumentCaptor.getValue().getMlAgent().getDescription().equals("description")); + assert (argumentCaptor.getValue().getMlAgent().getTools().equals(new ArrayList<>())); + assert (argumentCaptor.getValue().getMlAgent().getLlm().getModelId().equals("id")); + assert (argumentCaptor.getValue().getMlAgent().getParameters().equals(new HashMap<>())); + assert (argumentCaptor.getValue().getMlAgent().getMemory().getType().equals("conversation")); + } +}