From 032579bd7e4679db4cb86253e419cc5a35d8c801 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Sat, 23 Sep 2023 00:07:35 +0000 Subject: [PATCH] Add more unit tests on Update model API Signed-off-by: Sicheng Song --- .../models/TransportUpdateModelAction.java | 7 +- .../TransportUpdateModelActionTests.java | 464 +++++++++++++++++- .../ml/rest/RestMLUpdateModelActionTests.java | 5 + 3 files changed, 455 insertions(+), 21 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java index 992c5a7c44..18b52b0577 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java @@ -101,7 +101,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { @@ -119,7 +120,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); @@ -130,29 +163,424 @@ public void testUpdateModel_Success() throws IOException { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - GetResponse getResponse = prepareMLModel(); + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(getResponse); return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, mlUpdateModelRequest, actionListener); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); verify(actionListener).onResponse(updateResponse); } - public GetResponse prepareMLModel() throws IOException { - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .description("test_description") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); + @Test + public void testUpdateLocalModelWithoutRelinkModelGroupSuccess() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithLocalInformationSuccess() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + GetResponse getResponse = prepareGetResponse(remoteModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationSuccess() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + GetResponse getResponse = prepareGetResponse(remoteModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateRemoteModelWithNoStandAloneConnectorFound() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); + GetResponse getResponse = prepareGetResponse(remoteModelWithInternalConnector); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "This remote does not have a connector_id field, maybe it uses an internal connector.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + GetResponse getResponse = prepareGetResponse(remoteModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to update the connector, connector id: updated_test_connector_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelAccessControlNoPermission() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this model, model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithRelinkModelGroupModelAccessControlNoPermission() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User Doesn't have privilege to re-link this model to the target model group due to no access to the target model group with model group ID updated_test_model_group_id", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateModelWithModelNotFound() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to find model to update with the provided model id: test_model_id", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelWithFunctionNameFieldNotFound() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + GetResponse getResponse = prepareGetResponse(mlModelWithNullFunctionName); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("FUNCTION_NAME_FIELD not found for this model, model ID test_model_id", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateLocalModelWithRemoteInformation() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Trying to update the connector or connector_id field on a local model", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateLocalModelWithUnsupportedFunction() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModelWithUnsupportedFunction = prepareUnsupportedMLModel(FunctionName.KMEANS); + GetResponse getResponse = prepareGetResponse(localModelWithUnsupportedFunction); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "User doesn't have privilege to perform this operation on this function category: KMEANS", + argumentCaptor.getValue().getMessage() + ); + } + + private MLModel prepareMLModel(FunctionName functionName) throws IllegalArgumentException { + MLModel mlModel; + switch (functionName) { + case TEXT_EMBEDDING: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.TEXT_EMBEDDING) + .build(); + return mlModel; + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connectorId("test_connector_id") + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.TEXT_EMBEDDING and FunctionName.REMOTE"); + } + } + + private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws IllegalArgumentException { + MLModel mlModel; + switch (unsupportedCase) { + case REMOTE: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) + .build(); + return mlModel; + case KMEANS: + mlModel = MLModel + .builder() + .name("test_name") + .modelId("test_model_id") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.KMEANS) + .build(); + return mlModel; + default: + throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS"); + } + } + + private GetResponse prepareGetResponse(MLModel mlModel) throws IOException { XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); - return getResponse; + return new GetResponse(getResult); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java index 8651852b98..d8f164b3f2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -1,3 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + package org.opensearch.ml.rest; import static org.mockito.ArgumentMatchers.any;