Skip to content

Commit

Permalink
Add more unit test for update model API
Browse files Browse the repository at this point in the history
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
b4sjoo committed Oct 3, 2023
1 parent 68921ae commit b9e45a7
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ private void updateUndeployedModel(
if (searchHits.length == 0) {
updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, updateRequest, actionListener, context);
} else {
log.error("Models is deployed, please undeploy the models first!");
actionListener.onFailure(new MLValidationException("Models is deployed, please undeploy the models first!"));
log.error("ML Model " + modelId + " is deployed, please undeploy the models first!");
actionListener
.onFailure(new MLValidationException("ML Model " + modelId + " is deployed, please undeploy the models first!"));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand Down Expand Up @@ -152,6 +153,7 @@ public void setup() throws IOException {
updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED);
}

@Test
public void test_execute_connectorAccessControl_success() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -175,6 +177,7 @@ public void test_execute_connectorAccessControl_success() {
verify(actionListener).onResponse(updateResponse);
}

@Test
public void test_execute_connectorAccessControl_NoPermission() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -191,6 +194,7 @@ public void test_execute_connectorAccessControl_NoPermission() {
);
}

@Test
public void test_execute_connectorAccessControl_AccessError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -204,6 +208,7 @@ public void test_execute_connectorAccessControl_AccessError() {
assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage());
}

@Test
public void test_execute_connectorAccessControl_Exception() {
doThrow(new RuntimeException("exception in access control"))
.when(connectorAccessControlHelper)
Expand All @@ -215,6 +220,7 @@ public void test_execute_connectorAccessControl_Exception() {
assertEquals("exception in access control", argumentCaptor.getValue().getMessage());
}

@Test
public void test_execute_UpdateWrongStatus() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -239,6 +245,7 @@ public void test_execute_UpdateWrongStatus() {
verify(actionListener).onResponse(updateResponse);
}

@Test
public void test_execute_UpdateException() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -264,6 +271,7 @@ public void test_execute_UpdateException() {
assertEquals("update document failure", argumentCaptor.getValue().getMessage());
}

@Test
public void test_execute_SearchResponseNotEmpty() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -283,6 +291,7 @@ public void test_execute_SearchResponseNotEmpty() {
assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage());
}

@Test
public void test_execute_SearchResponseError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand All @@ -302,6 +311,7 @@ public void test_execute_SearchResponseError() {
assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage());
}

@Test
public void test_execute_SearchIndexNotFoundError() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.get.GetResult;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
Expand All @@ -59,6 +60,7 @@
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -181,14 +183,14 @@ public void setup() throws IOException {
SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1);
searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);

threadContext = new ThreadContext(settings);
Expand Down Expand Up @@ -231,6 +233,114 @@ public void testUpdateLocalModelSuccess() throws IOException {
verify(actionListener).onResponse(updateResponse);
}

@Test
public void testUpdateLocalModelSuccessWithSearchIndexNotFoundError() throws IOException {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(true);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new IndexNotFoundException("Index not found!"));
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING);
GetResponse getResponse = prepareGetResponse(localModel);
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));

transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener);
verify(actionListener).onResponse(updateResponse);
}

@Test
public void testUpdateLocalModelWithSearchResponseNotEmpty() throws IOException {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(true);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(noneEmptySearchResponse());
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING);
GetResponse getResponse = prepareGetResponse(localModel);
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));

transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("ML Model test_model_id is deployed, please undeploy the models first!", argumentCaptor.getValue().getMessage());
}

@Test
public void testUpdateLocalModelWithSearchResponseOtherException() throws IOException {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onResponse(true);
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener
.onFailure(
new RuntimeException(
"Any other Exception occurred during running SearchResponseListener. Please check log for more details."
)
);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING);
GetResponse getResponse = prepareGetResponse(localModel);
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(getResponse);
return null;
}).when(client).get(any(GetRequest.class), isA(ActionListener.class));

transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Any other Exception occurred during running SearchResponseListener. Please check log for more details.",
argumentCaptor.getValue().getMessage()
);
}

@Test
public void testUpdateRequestDocIOException() throws IOException {
doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput();
Expand Down Expand Up @@ -934,4 +1044,26 @@ private GetResponse prepareGetResponse(MLModel mlModel) throws IOException {
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
return new GetResponse(getResult);
}

private SearchResponse noneEmptySearchResponse() throws IOException {
String modelContent =
"{\"model_id\":\"test-model_id\",\"description\":\"description\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":"
+ "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\""
+ "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}";
SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { model }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);

return searchResponse;
}
}

0 comments on commit b9e45a7

Please sign in to comment.