diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index b4d4431b90..e7f3344639 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -11,7 +11,7 @@ jobs: Build-ml: strategy: matrix: - java: [8, 11, 14] + java: [11, 14] name: Build and Test MLCommons Plugin runs-on: ubuntu-latest diff --git a/plugin/build.gradle b/plugin/build.gradle index 8370c08223..17fb568dc4 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -214,12 +214,8 @@ List jacocoExclusions = [ 'org.opensearch.ml.task.MLPredictTaskRunner', 'org.opensearch.ml.rest.RestMLPredictionAction', 'org.opensearch.ml.rest.AbstractMLSearchAction*', - 'org.opensearch.ml.rest.RestMLDeleteTaskAction', //0.5 - 'org.opensearch.ml.rest.RestMLGetModelAction', //0.5 'org.opensearch.ml.rest.RestMLExecuteAction', //0.3 - 'org.opensearch.ml.rest.RestMLDeleteModelAction', //0.5 - 'org.opensearch.ml.rest.RestMLTrainAndPredictAction', //0.3 - 'org.opensearch.ml.rest.RestMLGetTaskAction' //0.5 + 'org.opensearch.ml.rest.RestMLTrainAndPredictAction' //0.3 ] jacocoTestCoverageVerification { @@ -333,3 +329,5 @@ tasks.withType(licenseHeaders.class) { checkstyle { toolVersion = '8.29' } +sourceCompatibility = JavaVersion.VERSION_1_9 +targetCompatibility = JavaVersion.VERSION_1_9 diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java index 15e69048e7..0a5c00a838 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java @@ -5,21 +5,64 @@ package org.opensearch.ml.rest; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.model.MLModelDeleteAction; +import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; public class RestMLDeleteModelActionTests extends OpenSearchTestCase { private RestMLDeleteModelAction restMLDeleteModelAction; + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + @Before public void setup() { restMLDeleteModelAction = new RestMLDeleteModelAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLModelDeleteAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); } public void testConstructor() { @@ -41,4 +84,21 @@ public void testRoutes() { assertEquals(RestRequest.Method.DELETE, route.getMethod()); assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLDeleteModelAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelDeleteRequest.class); + verify(client, times(1)).execute(eq(MLModelDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getModelId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteTaskActionTests.java index f2c3f9a19e..6fb77a6410 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteTaskActionTests.java @@ -5,21 +5,63 @@ package org.opensearch.ml.rest; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.task.MLTaskDeleteAction; +import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest; +import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; public class RestMLDeleteTaskActionTests extends OpenSearchTestCase { private RestMLDeleteTaskAction restMLDeleteTaskAction; + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + @Before public void setup() { restMLDeleteTaskAction = new RestMLDeleteTaskAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); } public void testConstructor() { @@ -41,4 +83,21 @@ public void testRoutes() { assertEquals(RestRequest.Method.DELETE, route.getMethod()); assertEquals("/_plugins/_ml/tasks/{task_id}", route.getPath()); } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLDeleteTaskAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class); + verify(client, times(1)).execute(eq(MLTaskDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getTaskId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TASK_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java index a784e814c6..a431c72c04 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java @@ -5,15 +5,36 @@ package org.opensearch.ml.rest; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.model.MLModelGetAction; +import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; public class RestMLGetModelActionTests extends OpenSearchTestCase { @Rule @@ -21,9 +42,31 @@ public class RestMLGetModelActionTests extends OpenSearchTestCase { private RestMLGetModelAction restMLGetModelAction; + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + @Before public void setup() { restMLGetModelAction = new RestMLGetModelAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any()); + + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); } public void testConstructor() { @@ -45,4 +88,21 @@ public void testRoutes() { assertEquals(RestRequest.Method.GET, route.getMethod()); assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetModelAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGetRequest.class); + verify(client, times(1)).execute(eq(MLModelGetAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getModelId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetTaskActionTests.java index 229615ee3a..f79694102d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetTaskActionTests.java @@ -5,21 +5,61 @@ package org.opensearch.ml.rest; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.times; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID; + +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.task.*; +import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; public class RestMLGetTaskActionTests extends OpenSearchTestCase { private RestMLGetTaskAction restMLGetTaskAction; + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + @Before public void setup() { restMLGetTaskAction = new RestMLGetTaskAction(); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); } public void testConstructor() { @@ -41,4 +81,21 @@ public void testRoutes() { assertEquals(RestRequest.Method.GET, route.getMethod()); assertEquals("/_plugins/_ml/tasks/{task_id}", route.getPath()); } + + public void test_PrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLGetTaskAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class); + verify(client, times(1)).execute(eq(MLTaskGetAction.INSTANCE), argumentCaptor.capture(), any()); + String taskId = argumentCaptor.getValue().getTaskId(); + assertEquals(taskId, "test_id"); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TASK_ID, "test_id"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + return request; + } } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 71515b610d..6f209d0042 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -10,6 +10,7 @@ import static org.mockito.Mockito.*; import static org.mockito.Mockito.spy; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -17,6 +18,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; @@ -25,8 +27,14 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; import org.opensearch.index.get.GetResult; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; @@ -95,9 +103,10 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { String indexName = "testIndex"; String errorMessage = "test error"; GetResponse getResponse; + MLInput mlInputWithDataFrame; @Before - public void setup() { + public void setup() throws IOException { MockitoAnnotations.openMocks(this); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); @@ -134,7 +143,7 @@ public void setup() { MLInputDataset dataFrameInputDataSet = new DataFrameInputDataset(dataFrame); BatchRCFParams batchRCFParams = BatchRCFParams.builder().build(); - MLInput mlInputWithDataFrame = MLInput + mlInputWithDataFrame = MLInput .builder() .algorithm(FunctionName.BATCH_RCF) .parameters(batchRCFParams) @@ -156,20 +165,27 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "myuser|role1,role2|myTenant"); when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + MLModel mlModel = MLModel + .builder() + .user(new User()) + .version(111) + .name("test") + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .build(); + XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult(indexName, "type", "111", 111l, 111l, 111l, true, null, null, null); + GetResult getResult = new GetResult(indexName, "type", "111", 111l, 111l, 111l, true, bytesReference, null, null); getResponse = new GetResponse(getResult); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - when(threadPool.getThreadContext()).thenReturn(threadContext); } public void testExecuteTask_OnLocalNode() { - setupMocks(true, false); + setupMocks(true, false, false, false); taskRunner.executeTask(requestWithDataFrame, transportService, listener); verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); @@ -180,7 +196,7 @@ public void testExecuteTask_OnLocalNode() { } public void testExecuteTask_OnLocalNode_QueryInput() { - setupMocks(true, false); + setupMocks(true, false, false, false); taskRunner.executeTask(requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); @@ -191,7 +207,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() { } public void testExecuteTask_OnLocalNode_QueryInput_Failure() { - setupMocks(true, true); + setupMocks(true, true, false, false); taskRunner.executeTask(requestWithQuery, transportService, listener); verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); @@ -201,12 +217,54 @@ public void testExecuteTask_OnLocalNode_QueryInput_Failure() { } public void testExecuteTask_OnRemoteNode() { - setupMocks(false, false); + setupMocks(false, false, false, false); taskRunner.executeTask(requestWithDataFrame, transportService, listener); verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any()); } - private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput) { + public void testExecuteTask_OnLocalNode_GetModelFail() { + setupMocks(true, false, true, false); + + taskRunner.executeTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client).get(any(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); + } + + public void testExecuteTask_OnLocalNode_NullModelIdException() { + setupMocks(true, false, false, false); + requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build(); + + taskRunner.executeTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client, never()).get(any(), any()); + verify(mlTaskManager).remove(anyString()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("ModelId is invalid", argumentCaptor.getValue().getMessage()); + } + + public void testExecuteTask_OnLocalNode_NullGetResponse() { + setupMocks(true, false, false, true); + + taskRunner.executeTask(requestWithDataFrame, transportService, listener); + verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any()); + verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset()); + verify(mlTaskManager).add(any(MLTask.class)); + verify(client).get(any(), any()); + verify(mlTaskManager).remove(anyString()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("No model found, please check the modelId.", argumentCaptor.getValue().getMessage()); + } + + private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput, boolean failedToGetModel, boolean nullGetResponse) { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(0); if (runOnLocalNode) { @@ -230,5 +288,23 @@ private void setupMocks(boolean runOnLocalNode, boolean failedToParseQueryInput) return null; }).when(mlInputDatasetHandler).parseSearchQueryInput(any(), any()); } + + if (nullGetResponse) { + getResponse = null; + } + + if (failedToGetModel) { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).get(any(), any()); + } else { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + } } }