diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java index 18b52b0577..06f2f6f3c2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/TransportUpdateModelAction.java @@ -223,25 +223,6 @@ private void updateRemoteOrTextEmbeddingModel( } } - private ActionListener getUpdateResponseListener( - String modelId, - ActionListener actionListener, - ThreadContext.StoredContext context - ) { - return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { - if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - log.info("Model id:{} failed update", modelId); - actionListener.onResponse(updateResponse); - return; - } - log.info("Completed Update Model Request, model id:{} updated", modelId); - actionListener.onResponse(updateResponse); - }, exception -> { - log.error("Failed to update ML model: " + modelId, exception); - actionListener.onFailure(exception); - }), context::restore); - } - private void updateModelWithOrWithoutRelinkModelGroup( String modelId, String relinkModelGroupId, @@ -271,4 +252,23 @@ private void updateModelWithOrWithoutRelinkModelGroup( client.update(updateRequest, getUpdateResponseListener(modelId, actionListener, context)); } } + + private ActionListener getUpdateResponseListener( + String modelId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Model id:{} failed update", modelId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Completed Update Model Request, model id:{} updated", modelId); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML model: " + modelId, exception); + actionListener.onFailure(exception); + }), context::restore); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/TransportUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/TransportUpdateModelActionTests.java index 5f2749e4d3..80cbc594b5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/TransportUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/TransportUpdateModelActionTests.java @@ -23,6 +23,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; @@ -35,13 +36,14 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; @@ -72,9 +74,6 @@ public class TransportUpdateModelActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - @Mock - UpdateResponse updateResponse; - @Mock GetResponse getResponse; @@ -84,9 +83,15 @@ public class TransportUpdateModelActionTests extends OpenSearchTestCase { @Mock ClusterService clusterService; + @Mock + private ModelAccessControlHelper modelAccessControlHelper; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); - + UpdateResponse updateResponse; TransportUpdateModelAction transportUpdateModelAction; MLUpdateModelRequest updateLocalModelRequest; MLUpdateModelInput updateLocalModelInput; @@ -94,11 +99,7 @@ public class TransportUpdateModelActionTests extends OpenSearchTestCase { MLUpdateModelInput updateRemoteModelInput; MLModel mlModelWithNullFunctionName; ThreadContext threadContext; - @Mock - private ModelAccessControlHelper modelAccessControlHelper; - - @Mock - private ConnectorAccessControlHelper connectorAccessControlHelper; + private ShardId shardId; @Before public void setup() throws IOException { @@ -123,13 +124,13 @@ public void setup() throws IOException { updateRemoteModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateRemoteModelInput).build(); mlModelWithNullFunctionName = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .modelGroupId("test_model_group_id") - .description("test_description") - .modelState(MLModelState.REGISTERED) - .build(); + .builder() + .name("test_name") + .modelId("test_model_id") + .modelGroupId("test_model_group_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .build(); Settings settings = Settings.builder().build(); @@ -147,6 +148,9 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); } @Test @@ -176,7 +180,34 @@ public void testUpdateLocalModelSuccess() throws IOException { } @Test - public void testUpdateLocalModelWithoutRelinkModelGroupSuccess() throws IOException { + public void testUpdateModelWithNoUpdateModelInput() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + MLUpdateModelRequest noUpdateModelInputUpdateModelRequest = MLUpdateModelRequest.builder().updateModelInput(null).build(); + transportUpdateModelAction.doExecute(task, noUpdateModelInputUpdateModelRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelWithoutRelinkModelGroupSuccess() throws IOException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); @@ -196,7 +227,7 @@ public void testUpdateLocalModelWithoutRelinkModelGroupSuccess() throws IOExcept listener.onResponse(getResponse); return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - + updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); verify(actionListener).onResponse(updateResponse); @@ -260,6 +291,64 @@ public void testUpdateRemoteModelWithRemoteInformationSuccess() throws IOExcepti verify(actionListener).onResponse(updateResponse); } + @Test + public void testGetUpdateResponseListenerWrongStatus() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateWrongResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + verify(actionListener).onResponse(updateWrongResponse); + } + + @Test + public void testGetUpdateResponseListenerOtherException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateRemoteModelWithNoStandAloneConnectorFound() throws IOException { doAnswer(invocation -> { @@ -334,6 +423,46 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl ); } + @Test + public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel remoteModel = prepareMLModel(FunctionName.REMOTE); + GetResponse getResponse = prepareGetResponse(remoteModel); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onFailure( + new RuntimeException("Any other connector access control Exception occurred. Please check log for more details.") + ); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateRemoteModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other connector access control Exception occurred. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateModelWithModelAccessControlNoPermission() throws IOException { doAnswer(invocation -> { @@ -366,6 +495,43 @@ public void testUpdateModelWithModelAccessControlNoPermission() throws IOExcepti ); } + @Test + public void testUpdateModelWithModelAccessControlOtherException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during update the model. Please check log for more details." + ) + ); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during update the model. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateModelWithRelinkModelGroupModelAccessControlNoPermission() throws IOException { doAnswer(invocation -> { @@ -406,6 +572,51 @@ public void testUpdateModelWithRelinkModelGroupModelAccessControlNoPermission() ); } + @Test + public void testUpdateModelWithRelinkModelGroupModelAccessControlOtherException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onFailure( + new RuntimeException( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details." + ) + ); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + MLModel localModel = prepareMLModel(FunctionName.TEXT_EMBEDDING); + GetResponse getResponse = prepareGetResponse(localModel); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Any other model access control Exception occurred during re-linking the model group. Please check log for more details.", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateModelWithModelNotFound() throws IOException { doAnswer(invocation -> { @@ -554,23 +765,23 @@ private MLModel prepareUnsupportedMLModel(FunctionName unsupportedCase) throws I switch (unsupportedCase) { case REMOTE: mlModel = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .description("test_description") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.REMOTE) - .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) - .build(); + .builder() + .name("test_name") + .modelId("test_model_id") + .description("test_description") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.REMOTE) + .connector(HttpConnector.builder().name("test_connector").protocol("http").build()) + .build(); return mlModel; case KMEANS: mlModel = MLModel - .builder() - .name("test_name") - .modelId("test_model_id") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.KMEANS) - .build(); + .builder() + .name("test_name") + .modelId("test_model_id") + .modelState(MLModelState.REGISTERED) + .algorithm(FunctionName.KMEANS) + .build(); return mlModel; default: throw new IllegalArgumentException("Please choose from FunctionName.REMOTE and FunctionName.KMEANS");