forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
register agent rest and transport actions (opensearch-project#1795)
* register agent rest and transport actions Signed-off-by: Xun Zhang <[email protected]> * add the register agent action into the ml plugin Signed-off-by: Xun Zhang <[email protected]> --------- Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
9ba7c60
commit b200122
Showing
5 changed files
with
466 additions
and
0 deletions.
There are no files selected for viewing
91 changes: 91 additions & 0 deletions
91
plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.agents; | ||
|
||
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; | ||
|
||
import java.time.Instant; | ||
|
||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.ActionRequest; | ||
import org.opensearch.action.index.IndexRequest; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.action.support.HandledTransportAction; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.inject.Inject; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.common.xcontent.XContentType; | ||
import org.opensearch.commons.authuser.User; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.xcontent.ToXContent; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; | ||
import org.opensearch.ml.engine.indices.MLIndicesHandler; | ||
import org.opensearch.ml.utils.RestActionUtils; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.transport.TransportService; | ||
|
||
import lombok.extern.log4j.Log4j2; | ||
|
||
@Log4j2 | ||
public class TransportRegisterAgentAction extends HandledTransportAction<ActionRequest, MLRegisterAgentResponse> { | ||
MLIndicesHandler mlIndicesHandler; | ||
Client client; | ||
|
||
@Inject | ||
public TransportRegisterAgentAction( | ||
TransportService transportService, | ||
ActionFilters actionFilters, | ||
Client client, | ||
MLIndicesHandler mlIndicesHandler | ||
) { | ||
super(MLRegisterAgentAction.NAME, transportService, actionFilters, MLRegisterAgentRequest::new); | ||
this.client = client; | ||
this.mlIndicesHandler = mlIndicesHandler; | ||
} | ||
|
||
@Override | ||
protected void doExecute(Task task, ActionRequest request, ActionListener<MLRegisterAgentResponse> listener) { | ||
User user = RestActionUtils.getUserContext(client);// TODO: check access | ||
MLRegisterAgentRequest registerAgentRequest = MLRegisterAgentRequest.fromActionRequest(request); | ||
MLAgent mlAgent = registerAgentRequest.getMlAgent(); | ||
registerAgent(mlAgent, listener); | ||
} | ||
|
||
private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse> listener) { | ||
Instant now = Instant.now(); | ||
MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).build(); | ||
mlIndicesHandler.initMLAgentIndex(ActionListener.wrap(result -> { | ||
if (result) { | ||
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { | ||
IndexRequest indexRequest = new IndexRequest(ML_AGENT_INDEX); | ||
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); | ||
mlAgent.toXContent(builder, ToXContent.EMPTY_PARAMS); | ||
indexRequest.source(builder); | ||
client.index(indexRequest, ActionListener.runBefore(ActionListener.wrap(r -> { | ||
listener.onResponse(new MLRegisterAgentResponse(r.getId())); | ||
}, e -> { | ||
log.error("Failed to index ML agent", e); | ||
listener.onFailure(e); | ||
}), context::restore)); | ||
} catch (Exception e) { | ||
log.error("Failed to index ML agent", e); | ||
listener.onFailure(e); | ||
} | ||
} else { | ||
log.error("Failed to create ML agent index"); | ||
listener.onFailure(new OpenSearchException("Failed to create ML agent index")); | ||
} | ||
}, e -> { | ||
log.error("Failed to create ML agent index", e); | ||
listener.onFailure(e); | ||
})); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterAgentAction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Locale; | ||
|
||
import org.opensearch.client.node.NodeClient; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; | ||
import org.opensearch.rest.BaseRestHandler; | ||
import org.opensearch.rest.RestRequest; | ||
import org.opensearch.rest.action.RestToXContentListener; | ||
|
||
import com.google.common.annotations.VisibleForTesting; | ||
import com.google.common.collect.ImmutableList; | ||
|
||
public class RestMLRegisterAgentAction extends BaseRestHandler { | ||
private static final String ML_REGISTER_AGENT_ACTION = "ml_register_agent_action"; | ||
|
||
/** | ||
* Constructor | ||
*/ | ||
public RestMLRegisterAgentAction() {} | ||
|
||
@Override | ||
public String getName() { | ||
return ML_REGISTER_AGENT_ACTION; | ||
} | ||
|
||
@Override | ||
public List<Route> routes() { | ||
return ImmutableList.of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/_register", ML_BASE_URI))); | ||
} | ||
|
||
@Override | ||
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { | ||
MLRegisterAgentRequest registerAgentRequest = getRequest(request); | ||
return channel -> client.execute(MLRegisterAgentAction.INSTANCE, registerAgentRequest, new RestToXContentListener<>(channel)); | ||
} | ||
|
||
/** | ||
* Creates a MLTrainingTaskRequest from a RestRequest | ||
* | ||
* @param request RestRequest | ||
* @return MLTrainingTaskRequest | ||
*/ | ||
@VisibleForTesting | ||
MLRegisterAgentRequest getRequest(RestRequest request) throws IOException { | ||
XContentParser parser = request.contentParser(); | ||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); | ||
MLAgent mlAgent = MLAgent.parse(parser); | ||
return new MLRegisterAgentRequest(mlAgent); | ||
} | ||
} |
184 changes: 184 additions & 0 deletions
184
plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.action.agents; | ||
|
||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.doAnswer; | ||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.verify; | ||
import static org.mockito.Mockito.when; | ||
|
||
import java.io.IOException; | ||
import java.util.HashMap; | ||
|
||
import org.junit.Before; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Mock; | ||
import org.mockito.MockitoAnnotations; | ||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.index.IndexResponse; | ||
import org.opensearch.action.support.ActionFilters; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.commons.ConfigConstants; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.index.shard.ShardId; | ||
import org.opensearch.ml.common.agent.LLMSpec; | ||
import org.opensearch.ml.common.agent.MLAgent; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; | ||
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; | ||
import org.opensearch.ml.engine.indices.MLIndicesHandler; | ||
import org.opensearch.tasks.Task; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
import org.opensearch.threadpool.ThreadPool; | ||
import org.opensearch.transport.TransportService; | ||
|
||
public class RegisterAgentTransportActionTests extends OpenSearchTestCase { | ||
|
||
@Mock | ||
private Client client; | ||
|
||
@Mock | ||
private MLIndicesHandler mlIndicesHandler; | ||
|
||
@Mock | ||
private ActionFilters actionFilters; | ||
|
||
@Mock | ||
private TransportService transportService; | ||
|
||
@Mock | ||
private Task task; | ||
|
||
@Mock | ||
private ActionListener<MLRegisterAgentResponse> actionListener; | ||
|
||
@Mock | ||
private ThreadPool threadPool; | ||
|
||
private TransportRegisterAgentAction transportRegisterAgentAction; | ||
private Settings settings; | ||
private ThreadContext threadContext; | ||
|
||
@Before | ||
public void setup() throws IOException { | ||
MockitoAnnotations.openMocks(this); | ||
settings = Settings.builder().build(); | ||
threadContext = new ThreadContext(settings); | ||
|
||
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); | ||
|
||
when(client.threadPool()).thenReturn(threadPool); | ||
when(threadPool.getThreadContext()).thenReturn(threadContext); | ||
transportRegisterAgentAction = new TransportRegisterAgentAction(transportService, actionFilters, client, mlIndicesHandler); | ||
} | ||
|
||
public void test_execute_registerAgent_success() { | ||
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); | ||
MLAgent mlAgent = MLAgent | ||
.builder() | ||
.name("agent") | ||
.type("some type") | ||
.description("description") | ||
.llm(new LLMSpec("model_id", new HashMap<>())) | ||
.build(); | ||
when(request.getMlAgent()).thenReturn(mlAgent); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<Boolean> listener = invocation.getArgument(0); | ||
listener.onResponse(true); | ||
return null; | ||
}).when(mlIndicesHandler).initMLAgentIndex(any()); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<IndexResponse> al = invocation.getArgument(1); | ||
IndexResponse indexResponse = new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true); | ||
al.onResponse(indexResponse); | ||
return null; | ||
}).when(client).index(any(), any()); | ||
|
||
transportRegisterAgentAction.doExecute(task, request, actionListener); | ||
ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class); | ||
verify(actionListener).onResponse(argumentCaptor.capture()); | ||
} | ||
|
||
public void test_execute_registerAgent_AgentIndexNotInitialized() { | ||
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); | ||
MLAgent mlAgent = MLAgent | ||
.builder() | ||
.name("agent") | ||
.type("some type") | ||
.description("description") | ||
.llm(new LLMSpec("model_id", new HashMap<>())) | ||
.build(); | ||
when(request.getMlAgent()).thenReturn(mlAgent); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<Boolean> listener = invocation.getArgument(0); | ||
listener.onResponse(false); | ||
return null; | ||
}).when(mlIndicesHandler).initMLAgentIndex(any()); | ||
|
||
transportRegisterAgentAction.doExecute(task, request, actionListener); | ||
ArgumentCaptor<OpenSearchException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchException.class); | ||
verify(actionListener).onFailure(argumentCaptor.capture()); | ||
assertEquals("Failed to create ML agent index", argumentCaptor.getValue().getMessage()); | ||
} | ||
|
||
public void test_execute_registerAgent_IndexFailure() { | ||
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); | ||
MLAgent mlAgent = MLAgent | ||
.builder() | ||
.name("agent") | ||
.type("some type") | ||
.description("description") | ||
.llm(new LLMSpec("model_id", new HashMap<>())) | ||
.build(); | ||
when(request.getMlAgent()).thenReturn(mlAgent); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<Boolean> listener = invocation.getArgument(0); | ||
listener.onResponse(true); | ||
return null; | ||
}).when(mlIndicesHandler).initMLAgentIndex(any()); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<IndexResponse> al = invocation.getArgument(1); | ||
al.onFailure(new RuntimeException("index failure")); | ||
return null; | ||
}).when(client).index(any(), any()); | ||
|
||
transportRegisterAgentAction.doExecute(task, request, actionListener); | ||
ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); | ||
verify(actionListener).onFailure(argumentCaptor.capture()); | ||
|
||
assertEquals("index failure", argumentCaptor.getValue().getMessage()); | ||
} | ||
|
||
public void test_execute_registerAgent_InitAgentIndexFailure() { | ||
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class); | ||
MLAgent mlAgent = MLAgent | ||
.builder() | ||
.name("agent") | ||
.type("some type") | ||
.description("description") | ||
.llm(new LLMSpec("model_id", new HashMap<>())) | ||
.build(); | ||
when(request.getMlAgent()).thenReturn(mlAgent); | ||
|
||
doAnswer(invocation -> { | ||
ActionListener<Boolean> listener = invocation.getArgument(0); | ||
listener.onFailure(new RuntimeException("agent index initialization failed")); | ||
return null; | ||
}).when(mlIndicesHandler).initMLAgentIndex(any()); | ||
|
||
transportRegisterAgentAction.doExecute(task, request, actionListener); | ||
ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); | ||
verify(actionListener).onFailure(argumentCaptor.capture()); | ||
assertEquals("agent index initialization failed", argumentCaptor.getValue().getMessage()); | ||
} | ||
} |
Oops, something went wrong.