Skip to content

Commit

Permalink
register agent rest and transport actions
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Dec 20, 2023
1 parent 9ba7c60 commit f623726
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.agents;

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;

import java.time.Instant;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelGroupManager;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class TransportRegisterAgentAction extends HandledTransportAction<ActionRequest, MLRegisterAgentResponse> {
TransportService transportService;
ModelHelper modelHelper;
MLIndicesHandler mlIndicesHandler;
MLModelManager mlModelManager;
MLTaskManager mlTaskManager;
ClusterService clusterService;
ThreadPool threadPool;
Client client;
DiscoveryNodeHelper nodeFilter;
MLTaskDispatcher mlTaskDispatcher;
MLStats mlStats;
ModelAccessControlHelper modelAccessControlHelper;
ConnectorAccessControlHelper connectorAccessControlHelper;
MLModelGroupManager mlModelGroupManager;

@Inject
public TransportRegisterAgentAction(
TransportService transportService,
ActionFilters actionFilters,
ModelHelper modelHelper,
MLIndicesHandler mlIndicesHandler,
MLModelManager mlModelManager,
MLTaskManager mlTaskManager,
ClusterService clusterService,
Settings settings,
ThreadPool threadPool,
Client client,
DiscoveryNodeHelper nodeFilter,
MLTaskDispatcher mlTaskDispatcher,
MLStats mlStats,
ModelAccessControlHelper modelAccessControlHelper,
ConnectorAccessControlHelper connectorAccessControlHelper,
MLModelGroupManager mlModelGroupManager
) {
super(MLRegisterAgentAction.NAME, transportService, actionFilters, MLRegisterAgentRequest::new);
this.transportService = transportService;
this.modelHelper = modelHelper;
this.mlIndicesHandler = mlIndicesHandler;
this.mlModelManager = mlModelManager;
this.mlTaskManager = mlTaskManager;
this.clusterService = clusterService;
this.threadPool = threadPool;
this.client = client;
this.nodeFilter = nodeFilter;
this.mlTaskDispatcher = mlTaskDispatcher;
this.mlStats = mlStats;
this.modelAccessControlHelper = modelAccessControlHelper;
this.connectorAccessControlHelper = connectorAccessControlHelper;
this.mlModelGroupManager = mlModelGroupManager;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterAgentResponse> listener) {
User user = RestActionUtils.getUserContext(client);// TODO: check access
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest(request);
MLAgent mlAgent = registerAgentRequest.getMlAgent();
registerAgent(mlAgent, listener);
}

private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse> listener) {
Instant now = Instant.now();
MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build();
mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> {
if (result) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
IndexRequest indexRequest = new IndexRequest(ML_AGENT_INDEX);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlAgent.toXContent(builder, ToXContent.EMPTY_PARAMS);
indexRequest.source(builder);
client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
listener.onResponse(new MLRegisterAgentResponse(r.getId()));
}, e -> {
log.error("Failed to index ML agent", e);
listener.onFailure(e);
}), context::restore));
} catch (Exception e) {
log.error("Failed to index ML agent", e);
listener.onFailure(e);
}
} else {
log.error("Failed to create ML agent index");
}
}, e -> {
log.error("Failed to create ML agent index", e);
listener.onFailure(e);
}));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;

import java.io.IOException;
import java.util.List;
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;

public class RestMLRegisterAgentAction extends BaseRestHandler {
private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action";

/**
* Constructor
*/
public RestMLRegisterAgentAction() {}

@Override
public String getName() {
return ML_REGISTER_AGENT_ACTION;
}

@Override
public List<Route> routes() {
return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/_register", ML_BASE_URI)));
}

@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLRegisterAgentRequest registerAgentRequest = getRequest(request);
return channel -> client.execute(MLRegisterAgentAction.INSTANCE, registerAgentRequest, new RestToXContentListener<>(channel));
}

/**
* Creates a MLTrainingTaskRequest from a RestRequest
*
* @param request RestRequest
* @return MLTrainingTaskRequest
*/
@VisibleForTesting
MLRegisterAgentRequest getRequest(RestRequest request) throws IOException {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLAgent mlAgent = MLAgent.parse(parser);
return new MLRegisterAgentRequest(mlAgent);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.agent.LLMSpec.MODEL_ID_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.AGENT_NAME_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.AGENT_TYPE_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.CREATED_TIME_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.DESCRIPTION_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.LAST_UPDATED_TIME_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.LLM_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.MEMORY_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.MEMORY_ID_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.PARAMETERS_FIELD;
import static org.opensearch.ml.common.agent.MLAgent.TOOLS_FIELD;
import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_TYPE_FIELD;
import static org.opensearch.ml.common.agent.MLMemorySpec.SESSION_ID_FIELD;
import static org.opensearch.ml.common.agent.MLMemorySpec.WINDOW_SIZE_FIELD;

import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;

import com.google.gson.Gson;

public class RestMLRegisterAgentActionTests extends OpenSearchTestCase {

Gson gson;
Instant time;

@Before
public void setup() {
gson = new Gson();
time = Instant.ofEpochMilli(123);
}

public void test_GetName_Routes() {
RestMLRegisterAgentAction action = new RestMLRegisterAgentAction();
assert (action.getName().equals("ml_register_agent_action"));
List<RestHandler.Route> routes = action.routes();
assert (routes.size() == 1);
assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.POST, "/_plugins/_ml/agents/_register")));
}

public void testPrepareRequest() throws Exception {
RestMLRegisterAgentAction action = new RestMLRegisterAgentAction();
final Map<String, Object> llmSpec = Map.of(MODEL_ID_FIELD, "id", PARAMETERS_FIELD, new HashMap<>());
final Map<String, Object> memorySpec = Map.of(MEMORY_TYPE_FIELD, "conversation", SESSION_ID_FIELD, "sid", WINDOW_SIZE_FIELD, 2);
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(
new BytesArray(
gson
.toJson(
Map
.of(
AGENT_NAME_FIELD,
"agent-name",
AGENT_TYPE_FIELD,
"agent-type",
DESCRIPTION_FIELD,
"description",
LLM_FIELD,
llmSpec,
TOOLS_FIELD,
new ArrayList<>(),
PARAMETERS_FIELD,
new HashMap<>(),
MEMORY_FIELD,
memorySpec,
MEMORY_ID_FIELD,
"memory_id",
CREATED_TIME_FIELD,
time.getEpochSecond(),
LAST_UPDATED_TIME_FIELD,
time.getEpochSecond()
)
)
),
MediaTypeRegistry.JSON
)
.build();

NodeClient client = mock(NodeClient.class);
RestChannel channel = mock(RestChannel.class);
action.handleRequest(request, channel, client);

ArgumentCaptor<MLRegisterAgentRequest> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentRequest.class);
verify(client, times(1)).execute(eq(MLRegisterAgentAction.INSTANCE), argumentCaptor.capture(), any());
assert (argumentCaptor.getValue().getMlAgent().getName().equals("agent-name"));
assert (argumentCaptor.getValue().getMlAgent().getType().equals("agent-type"));
assert (argumentCaptor.getValue().getMlAgent().getDescription().equals("description"));
assert (argumentCaptor.getValue().getMlAgent().getTools().equals(new ArrayList<>()));
assert (argumentCaptor.getValue().getMlAgent().getLlm().getModelId().equals("id"));
assert (argumentCaptor.getValue().getMlAgent().getParameters().equals(new HashMap<>()));
assert (argumentCaptor.getValue().getMlAgent().getMemory().getType().equals("conversation"));
}
}

0 comments on commit f623726

Please sign in to comment.