Skip to content

Commit

Permalink
add more UT for model/task rest actions, remove support for JDK8 (ope…
Browse files Browse the repository at this point in the history
…nsearch-project#230)

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored Mar 17, 2022
1 parent fb9ac75 commit d6f9e96
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
Build-ml:
strategy:
matrix:
java: [8, 11, 14]
java: [11, 14]

name: Build and Test MLCommons Plugin
runs-on: ubuntu-latest
Expand Down
8 changes: 3 additions & 5 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,8 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.task.MLPredictTaskRunner',
'org.opensearch.ml.rest.RestMLPredictionAction',
'org.opensearch.ml.rest.AbstractMLSearchAction*',
'org.opensearch.ml.rest.RestMLDeleteTaskAction', //0.5
'org.opensearch.ml.rest.RestMLGetModelAction', //0.5
'org.opensearch.ml.rest.RestMLExecuteAction', //0.3
'org.opensearch.ml.rest.RestMLDeleteModelAction', //0.5
'org.opensearch.ml.rest.RestMLTrainAndPredictAction', //0.3
'org.opensearch.ml.rest.RestMLGetTaskAction' //0.5
'org.opensearch.ml.rest.RestMLTrainAndPredictAction' //0.3
]

jacocoTestCoverageVerification {
Expand Down Expand Up @@ -333,3 +329,5 @@ tasks.withType(licenseHeaders.class) {
checkstyle {
toolVersion = '8.29'
}
sourceCompatibility = JavaVersion.VERSION_1_9
targetCompatibility = JavaVersion.VERSION_1_9
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,64 @@

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.times;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

public class RestMLDeleteModelActionTests extends OpenSearchTestCase {

private RestMLDeleteModelAction restMLDeleteModelAction;

NodeClient client;
private ThreadPool threadPool;

@Mock
RestChannel channel;

@Before
public void setup() {
restMLDeleteModelAction = new RestMLDeleteModelAction();

threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
return null;
}).when(client).execute(eq(MLModelDeleteAction.INSTANCE), any(), any());

}

@Override
public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
client.close();
}

public void testConstructor() {
Expand All @@ -41,4 +84,21 @@ public void testRoutes() {
assertEquals(RestRequest.Method.DELETE, route.getMethod());
assertEquals("/_plugins/_ml/models/{model_id}", route.getPath());
}

public void test_PrepareRequest() throws Exception {
RestRequest request = getRestRequest();
restMLDeleteModelAction.handleRequest(request, channel, client);

ArgumentCaptor<MLModelDeleteRequest> argumentCaptor = ArgumentCaptor.forClass(MLModelDeleteRequest.class);
verify(client, times(1)).execute(eq(MLModelDeleteAction.INSTANCE), argumentCaptor.capture(), any());
String taskId = argumentCaptor.getValue().getModelId();
assertEquals(taskId, "test_id");
}

private RestRequest getRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_MODEL_ID, "test_id");
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,63 @@

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.times;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

public class RestMLDeleteTaskActionTests extends OpenSearchTestCase {

private RestMLDeleteTaskAction restMLDeleteTaskAction;

NodeClient client;
private ThreadPool threadPool;

@Mock
RestChannel channel;

@Before
public void setup() {
restMLDeleteTaskAction = new RestMLDeleteTaskAction();

threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());
}

@Override
public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
client.close();
}

public void testConstructor() {
Expand All @@ -41,4 +83,21 @@ public void testRoutes() {
assertEquals(RestRequest.Method.DELETE, route.getMethod());
assertEquals("/_plugins/_ml/tasks/{task_id}", route.getPath());
}

public void test_PrepareRequest() throws Exception {
RestRequest request = getRestRequest();
restMLDeleteTaskAction.handleRequest(request, channel, client);

ArgumentCaptor<MLTaskDeleteRequest> argumentCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class);
verify(client, times(1)).execute(eq(MLTaskDeleteAction.INSTANCE), argumentCaptor.capture(), any());
String taskId = argumentCaptor.getValue().getTaskId();
assertEquals(taskId, "test_id");
}

private RestRequest getRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_TASK_ID, "test_id");
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,68 @@

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.times;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.opensearch.action.ActionListener;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

public class RestMLGetModelActionTests extends OpenSearchTestCase {
@Rule
public ExpectedException thrown = ExpectedException.none();

private RestMLGetModelAction restMLGetModelAction;

NodeClient client;
private ThreadPool threadPool;

@Mock
RestChannel channel;

@Before
public void setup() {
restMLGetModelAction = new RestMLGetModelAction();

threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));

doAnswer(invocation -> {
ActionListener<MLModelGetResponse> actionListener = invocation.getArgument(2);
return null;
}).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any());

}

@Override
public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
client.close();
}

public void testConstructor() {
Expand All @@ -45,4 +88,21 @@ public void testRoutes() {
assertEquals(RestRequest.Method.GET, route.getMethod());
assertEquals("/_plugins/_ml/models/{model_id}", route.getPath());
}

public void test_PrepareRequest() throws Exception {
RestRequest request = getRestRequest();
restMLGetModelAction.handleRequest(request, channel, client);

ArgumentCaptor<MLModelGetRequest> argumentCaptor = ArgumentCaptor.forClass(MLModelGetRequest.class);
verify(client, times(1)).execute(eq(MLModelGetAction.INSTANCE), argumentCaptor.capture(), any());
String taskId = argumentCaptor.getValue().getModelId();
assertEquals(taskId, "test_id");
}

private RestRequest getRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_MODEL_ID, "test_id");
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,61 @@

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.times;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TASK_ID;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.opensearch.action.ActionListener;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.task.*;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

public class RestMLGetTaskActionTests extends OpenSearchTestCase {

private RestMLGetTaskAction restMLGetTaskAction;

NodeClient client;
private ThreadPool threadPool;

@Mock
RestChannel channel;

@Before
public void setup() {
restMLGetTaskAction = new RestMLGetTaskAction();

threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool");
client = spy(new NodeClient(Settings.EMPTY, threadPool));

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());
}

@Override
public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
client.close();
}

public void testConstructor() {
Expand All @@ -41,4 +81,21 @@ public void testRoutes() {
assertEquals(RestRequest.Method.GET, route.getMethod());
assertEquals("/_plugins/_ml/tasks/{task_id}", route.getPath());
}

public void test_PrepareRequest() throws Exception {
RestRequest request = getRestRequest();
restMLGetTaskAction.handleRequest(request, channel, client);

ArgumentCaptor<MLTaskGetRequest> argumentCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class);
verify(client, times(1)).execute(eq(MLTaskGetAction.INSTANCE), argumentCaptor.capture(), any());
String taskId = argumentCaptor.getValue().getTaskId();
assertEquals(taskId, "test_id");
}

private RestRequest getRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_TASK_ID, "test_id");
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();
return request;
}
}
Loading

0 comments on commit d6f9e96

Please sign in to comment.