forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
register agent rest and transport actions
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
9ba7c60
commit f623726
Showing
3 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
137 changes: 137 additions & 0 deletions
137
plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.agents; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; | ||
|
||
import java.time.Instant; | ||
|
||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.index.IndexRequest; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.commons.authuser.User; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.ToXContent; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.cluster.DiscoveryNodeHelper; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; | ||
import org.opensearch.ml.engine.ModelHelper; | ||
import org.opensearch.ml.engine.indices.MLIndicesHandler; | ||
import org.opensearch.ml.helper.ConnectorAccessControlHelper; | ||
import org.opensearch.ml.helper.ModelAccessControlHelper; | ||
import org.opensearch.ml.model.MLModelGroupManager; | ||
import org.opensearch.ml.model.MLModelManager; | ||
import org.opensearch.ml.stats.MLStats; | ||
import org.opensearch.ml.task.MLTaskDispatcher; | ||
import org.opensearch.ml.task.MLTaskManager; | ||
import org.opensearch.ml.utils.RestActionUtils; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.threadpool.ThreadPool; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class TransportRegisterAgentAction extends HandledTransportAction<ActionRequest, MLRegisterAgentResponse> { | ||
TransportService transportService; | ||
ModelHelper modelHelper; | ||
MLIndicesHandler mlIndicesHandler; | ||
MLModelManager mlModelManager; | ||
MLTaskManager mlTaskManager; | ||
ClusterService clusterService; | ||
ThreadPool threadPool; | ||
Client client; | ||
DiscoveryNodeHelper nodeFilter; | ||
MLTaskDispatcher mlTaskDispatcher; | ||
MLStats mlStats; | ||
ModelAccessControlHelper modelAccessControlHelper; | ||
ConnectorAccessControlHelper connectorAccessControlHelper; | ||
MLModelGroupManager mlModelGroupManager; | ||
|
||
@Inject | ||
public TransportRegisterAgentAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
ModelHelper modelHelper, | ||
MLIndicesHandler mlIndicesHandler, | ||
MLModelManager mlModelManager, | ||
MLTaskManager mlTaskManager, | ||
ClusterService clusterService, | ||
Settings settings, | ||
ThreadPool threadPool, | ||
Client client, | ||
DiscoveryNodeHelper nodeFilter, | ||
MLTaskDispatcher mlTaskDispatcher, | ||
MLStats mlStats, | ||
ModelAccessControlHelper modelAccessControlHelper, | ||
ConnectorAccessControlHelper connectorAccessControlHelper, | ||
MLModelGroupManager mlModelGroupManager | ||
) { | ||
super(MLRegisterAgentAction.NAME, transportService, actionFilters, MLRegisterAgentRequest::new); | ||
this.transportService = transportService; | ||
this.modelHelper = modelHelper; | ||
this.mlIndicesHandler = mlIndicesHandler; | ||
this.mlModelManager = mlModelManager; | ||
this.mlTaskManager = mlTaskManager; | ||
this.clusterService = clusterService; | ||
this.threadPool = threadPool; | ||
this.client = client; | ||
this.nodeFilter = nodeFilter; | ||
this.mlTaskDispatcher = mlTaskDispatcher; | ||
this.mlStats = mlStats; | ||
this.modelAccessControlHelper = modelAccessControlHelper; | ||
this.connectorAccessControlHelper = connectorAccessControlHelper; | ||
this.mlModelGroupManager = mlModelGroupManager; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterAgentResponse> listener) { | ||
User user = RestActionUtils.getUserContext(client);// TODO: check access | ||
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest(request); | ||
MLAgent mlAgent = registerAgentRequest.getMlAgent(); | ||
registerAgent(mlAgent, listener); | ||
} | ||
|
||
private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse> listener) { | ||
Instant now = Instant.now(); | ||
MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build(); | ||
mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> { | ||
if (result) { | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
IndexRequest indexRequest = new IndexRequest(ML_AGENT_INDEX); | ||
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); | ||
mlAgent.toXContent(builder, ToXContent.EMPTY_PARAMS); | ||
indexRequest.source(builder); | ||
client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> { | ||
listener.onResponse(new MLRegisterAgentResponse(r.getId())); | ||
}, e -> { | ||
log.error("Failed to index ML agent", e); | ||
listener.onFailure(e); | ||
}), context::restore)); | ||
} catch (Exception e) { | ||
log.error("Failed to index ML agent", e); | ||
listener.onFailure(e); | ||
} | ||
} else { | ||
log.error("Failed to create ML agent index"); | ||
} | ||
}, e -> { | ||
log.error("Failed to create ML agent index", e); | ||
listener.onFailure(e); | ||
})); | ||
} | ||
|
||
} |
64 changes: 64 additions & 0 deletions
64
plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Locale; | ||
|
||
import org.opensearch.client.node.NodeClient; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; | ||
import org.opensearch.rest.BaseRestHandler; | ||
import org.opensearch.rest.RestRequest; | ||
import org.opensearch.rest.action.RestToXContentListener; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
import com.google.common.collect.ImmutableList; | ||
|
||
public class RestMLRegisterAgentAction extends BaseRestHandler { | ||
private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action"; | ||
|
||
/** | ||
* Constructor | ||
*/ | ||
public RestMLRegisterAgentAction() {} | ||
|
||
@Override | ||
public String getName() { | ||
return ML_REGISTER_AGENT_ACTION; | ||
} | ||
|
||
@Override | ||
public List<Route> routes() { | ||
return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/_register", ML_BASE_URI))); | ||
} | ||
|
||
@Override | ||
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { | ||
MLRegisterAgentRequest registerAgentRequest = getRequest(request); | ||
return channel -> client.execute(MLRegisterAgentAction.INSTANCE, registerAgentRequest, new RestToXContentListener<>(channel)); | ||
} | ||
|
||
/** | ||
* Creates a MLTrainingTaskRequest from a RestRequest | ||
* | ||
* @param request RestRequest | ||
* @return MLTrainingTaskRequest | ||
*/ | ||
@VisibleForTesting | ||
MLRegisterAgentRequest getRequest(RestRequest request) throws IOException { | ||
XContentParser parser = request.contentParser(); | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
MLAgent mlAgent = MLAgent.parse(parser); | ||
return new MLRegisterAgentRequest(mlAgent); | ||
} | ||
} |
121 changes: 121 additions & 0 deletions
121
plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterAgentActionTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.ArgumentMatchers.eq; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.times; | ||
import static org.mockito.Mockito.verify; | ||
import static org.opensearch.ml.common.agent.LLMSpec.MODEL_ID_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.AGENT_NAME_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.AGENT_TYPE_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.CREATED_TIME_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.DESCRIPTION_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.LAST_UPDATED_TIME_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.LLM_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.MEMORY_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.MEMORY_ID_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.PARAMETERS_FIELD; | ||
import static org.opensearch.ml.common.agent.MLAgent.TOOLS_FIELD; | ||
import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_TYPE_FIELD; | ||
import static org.opensearch.ml.common.agent.MLMemorySpec.SESSION_ID_FIELD; | ||
import static org.opensearch.ml.common.agent.MLMemorySpec.WINDOW_SIZE_FIELD; | ||
|
||
import java.time.Instant; | ||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.mockito.ArgumentCaptor; | ||
import org.opensearch.client.node.NodeClient; | ||
import org.opensearch.core.common.bytes.BytesArray; | ||
import org.opensearch.core.xcontent.MediaTypeRegistry; | ||
import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; | ||
import org.opensearch.rest.RestChannel; | ||
import org.opensearch.rest.RestHandler; | ||
import org.opensearch.rest.RestRequest; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
import org.opensearch.test.rest.FakeRestRequest; | ||
|
||
import com.google.gson.Gson; | ||
|
||
public class RestMLRegisterAgentActionTests extends OpenSearchTestCase { | ||
|
||
Gson gson; | ||
Instant time; | ||
|
||
@Before | ||
public void setup() { | ||
gson = new Gson(); | ||
time = Instant.ofEpochMilli(123); | ||
} | ||
|
||
public void test_GetName_Routes() { | ||
RestMLRegisterAgentAction action = new RestMLRegisterAgentAction(); | ||
assert (action.getName().equals("ml_register_agent_action")); | ||
List<RestHandler.Route> routes = action.routes(); | ||
assert (routes.size() == 1); | ||
assert (routes.get(0).equals(new RestHandler.Route(RestRequest.Method.POST, "/_plugins/_ml/agents/_register"))); | ||
} | ||
|
||
public void testPrepareRequest() throws Exception { | ||
RestMLRegisterAgentAction action = new RestMLRegisterAgentAction(); | ||
final Map<String, Object> llmSpec = Map.of(MODEL_ID_FIELD, "id", PARAMETERS_FIELD, new HashMap<>()); | ||
final Map<String, Object> memorySpec = Map.of(MEMORY_TYPE_FIELD, "conversation", SESSION_ID_FIELD, "sid", WINDOW_SIZE_FIELD, 2); | ||
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) | ||
.withContent( | ||
new BytesArray( | ||
gson | ||
.toJson( | ||
Map | ||
.of( | ||
AGENT_NAME_FIELD, | ||
"agent-name", | ||
AGENT_TYPE_FIELD, | ||
"agent-type", | ||
DESCRIPTION_FIELD, | ||
"description", | ||
LLM_FIELD, | ||
llmSpec, | ||
TOOLS_FIELD, | ||
new ArrayList<>(), | ||
PARAMETERS_FIELD, | ||
new HashMap<>(), | ||
MEMORY_FIELD, | ||
memorySpec, | ||
MEMORY_ID_FIELD, | ||
"memory_id", | ||
CREATED_TIME_FIELD, | ||
time.getEpochSecond(), | ||
LAST_UPDATED_TIME_FIELD, | ||
time.getEpochSecond() | ||
) | ||
) | ||
), | ||
MediaTypeRegistry.JSON | ||
) | ||
.build(); | ||
|
||
NodeClient client = mock(NodeClient.class); | ||
RestChannel channel = mock(RestChannel.class); | ||
action.handleRequest(request, channel, client); | ||
|
||
ArgumentCaptor<MLRegisterAgentRequest> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentRequest.class); | ||
verify(client, times(1)).execute(eq(MLRegisterAgentAction.INSTANCE), argumentCaptor.capture(), any()); | ||
assert (argumentCaptor.getValue().getMlAgent().getName().equals("agent-name")); | ||
assert (argumentCaptor.getValue().getMlAgent().getType().equals("agent-type")); | ||
assert (argumentCaptor.getValue().getMlAgent().getDescription().equals("description")); | ||
assert (argumentCaptor.getValue().getMlAgent().getTools().equals(new ArrayList<>())); | ||
assert (argumentCaptor.getValue().getMlAgent().getLlm().getModelId().equals("id")); | ||
assert (argumentCaptor.getValue().getMlAgent().getParameters().equals(new HashMap<>())); | ||
assert (argumentCaptor.getValue().getMlAgent().getMemory().getType().equals("conversation")); | ||
} | ||
} |