From 9c705032fa4f5706874a15417eadd112be848824 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Tue, 22 Aug 2023 15:01:17 -0700 Subject: [PATCH] more ut test coverage Signed-off-by: Xun Zhang --- .../MLUpdateConnectorRequestTests.java | 128 +++++++++++ .../UpdateConnectorTransportAction.java | 14 +- .../ml/rest/RestMLUpdateConnectorAction.java | 6 +- .../TransportUpdateConnectorActionTests.java | 214 ++++++++++++++++++ .../ml/rest/RestMLDeployModelActionTests.java | 11 - .../RestMLUpdateConnectorActionTests.java | 181 +++++++++++++++ 6 files changed, 529 insertions(+), 25 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java new file mode 100644 index 0000000000..e017009983 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.when; + +public class MLUpdateConnectorRequestTests { + private String connectorId; + private Map updateContent; + private MLUpdateConnectorRequest mlUpdateConnectorRequest; + + @Mock + XContentParser parser; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + this.connectorId = "test-connector_id"; + this.updateContent = Map.of("description", "new description"); + mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + .connectorId(connectorId) + .updateContent(updateContent) + .build(); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlUpdateConnectorRequest.writeTo(bytesStreamOutput); + MLUpdateConnectorRequest parsedUpdateRequest = new MLUpdateConnectorRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(connectorId, parsedUpdateRequest.getConnectorId()); + assertEquals(updateContent, parsedUpdateRequest.getUpdateContent()); + } + + @Test + public void validate_Success() { + assertNull(mlUpdateConnectorRequest.validate()); + } + + @Test + public void validate_Exception_NullConnectorId() { + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); + Exception exception = updateConnectorRequest.validate(); + + assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); + } + + @Test + public void parse_success() throws IOException { + RestRequest.Method method = RestRequest.Method.POST; + final Map updatefields = Map.of("version", "new version", "description", "new description"); + when(parser.map()).thenReturn(updatefields); + + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); + assertEquals(updateConnectorRequest.getConnectorId(), connectorId); + assertEquals(updateConnectorRequest.getUpdateContent(), updatefields); + } + + @Test + public void fromActionRequest_Success() { + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + .connectorId(connectorId) + .updateContent(updateContent) + .build(); + assertSame(MLUpdateConnectorRequest.fromActionRequest(mlUpdateConnectorRequest), mlUpdateConnectorRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + .connectorId(connectorId) + .updateContent(updateContent) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlUpdateConnectorRequest.writeTo(out); + } + }; + MLUpdateConnectorRequest request = MLUpdateConnectorRequest.fromActionRequest(actionRequest); + assertNotSame(request, mlUpdateConnectorRequest); + assertEquals(mlUpdateConnectorRequest.getConnectorId(), request.getConnectorId()); + assertEquals(mlUpdateConnectorRequest.getUpdateContent(), request.getUpdateContent()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUpdateConnectorRequest.fromActionRequest(actionRequest); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 866c91bdc6..21df3d31ba 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -17,7 +17,6 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; @@ -32,7 +31,6 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { Client client; - NamedXContentRegistry xContentRegistry; ConnectorAccessControlHelper connectorAccessControlHelper; @@ -41,12 +39,10 @@ public UpdateConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, - NamedXContentRegistry xContentRegistry, ConnectorAccessControlHelper connectorAccessControlHelper ) { super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); this.client = client; - this.xContentRegistry = xContentRegistry; this.connectorAccessControlHelper = connectorAccessControlHelper; } @@ -69,11 +65,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.error("You don't have permission to update the connector for connector id: " + connectorId, exception); + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); listener.onFailure(exception); })); } catch (Exception e) { - log.error("Failed to update ML connector " + connectorId, e); + log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e); listener.onFailure(e); } } @@ -85,14 +81,14 @@ private ActionListener getUpdateResponseListener( ) { return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - log.info("Connector id:{} failed update", connectorId); + log.info("Failed to update the connector with ID: {}", connectorId); actionListener.onResponse(updateResponse); return; } - log.info("Completed Update Connector Request, connector id:{} updated", connectorId); + log.info("Successfully updated the connector with ID: {}", connectorId); actionListener.onResponse(updateResponse); }, exception -> { - log.error("Failed to update ML connector: " + connectorId, exception); + log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception); actionListener.onFailure(exception); }), context::restore); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index 11fe8bc920..a74ed27ecc 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -15,7 +15,6 @@ import java.util.List; import java.util.Locale; -import org.apache.logging.log4j.util.Strings; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; @@ -66,13 +65,10 @@ private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOExcept } if (!request.hasContent()) { - throw new IOException("Update Connector request has empty body"); + throw new IOException("Failed to update connector: Request body is empty"); } String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); - if (Strings.isBlank(connectorId)) { - throw new IOException("Update Connector request has no connector Id"); - } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java new file mode 100644 index 0000000000..7024666715 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java @@ -0,0 +1,214 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { + + private UpdateConnectorTransportAction transportUpdateConnectorAction; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private Task task; + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private ClusterService clusterService; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLUpdateConnectorRequest updateRequest; + + @Mock + private UpdateResponse updateResponse; + + @Mock + ActionListener actionListener; + + ThreadContext threadContext; + + private Settings settings; + + private ShardId shardId; + + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList + .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + settings = Settings + .builder() + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, + ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED + ); + + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + threadContext = new ThreadContext(settings); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + String connector_id = "test_connector_id"; + Map updateContent = Map.of("version", "2", "description", "updated description"); + when(updateRequest.getConnectorId()).thenReturn(connector_id); + when(updateRequest.getUpdateContent()).thenReturn(updateContent); + + transportUpdateConnectorAction = new UpdateConnectorTransportAction( + transportService, + actionFilters, + client, + connectorAccessControlHelper + ); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + } + + public void test_execute_connectorAccessControl_success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_connectorAccessControl_NoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to update the connector, connector id: test_connector_id", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_execute_connectorAccessControl_AccessError() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Connector Access Control Error")); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_connectorAccessControl_Exception() { + doThrow(new RuntimeException("exception in access control")) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("exception in access control", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_UpdateWrongStatus() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_UpdateException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("update document failure")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("update document failure", argumentCaptor.getValue().getMessage()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java index eb9caba417..eff2f2d69f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java @@ -8,8 +8,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; -import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.common.utils.StringUtils.gson; import java.util.*; @@ -17,7 +15,6 @@ import org.junit.Ignore; import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; @@ -37,7 +34,6 @@ import org.opensearch.threadpool.ThreadPool; import com.google.gson.Gson; -import com.google.gson.JsonParser; public class RestMLDeployModelActionTests extends OpenSearchTestCase { @@ -133,14 +129,7 @@ private RestRequest getRestRequest() { .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, "12222"); - updateRequest.doc(model); - UpdateRequest updateRequest1 = new UpdateRequest(ML_CONNECTOR_INDEX, "12222"); - updateRequest.doc(gson.fromJson(JsonParser.parseString(requestContent), Map.class)); - - System.out.println(updateRequest); - System.out.println(updateRequest1); return request; } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java new file mode 100644 index 0000000000..814402fb66 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMLUpdateConnectorActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLUpdateConnectorAction restMLUpdateConnectorAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateConnectorAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLUpdateConnectorAction updateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + assertNotNull(updateConnectorAction); + } + + public void testGetName() { + String actionName = restMLUpdateConnectorAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_connector_action", actionName); + } + + public void testRoutes() { + List routes = restMLUpdateConnectorAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/_update/{connector_id}", route.getPath()); + } + + public void testUpdateConnectorRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateConnectorRequest.class); + verify(client, times(1)).execute(eq(MLUpdateConnectorAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateConnectorRequest updateConnectorRequest = argumentCaptor.getValue(); + assertEquals("test_connectorId", updateConnectorRequest.getConnectorId()); + assertEquals("This is test description", updateConnectorRequest.getUpdateContent().get("description")); + assertEquals("2", updateConnectorRequest.getUpdateContent().get("version")); + } + + public void testUpdateConnectorRequestWithEmptyContent() throws Exception { + exceptionRule.expect(IOException.class); + exceptionRule.expectMessage("Failed to update connector: Request body is empty"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + } + + public void testUpdateConnectorRequestWithNullConnectorId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain connector_id"); + RestRequest request = getRestRequestWithNullConnectorId(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + } + + public void testPrepareRequestFeatureDisabled() throws Exception { + exceptionRule.expect(IllegalStateException.class); + exceptionRule.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of("version", "2", "description", "This is test description"); + String requestContent = new Gson().toJson(updateContent).toString(); + Map params = new HashMap<>(); + params.put("connector_id", "test_connectorId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = new HashMap<>(); + params.put("connector_id", "test_connectorId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullConnectorId() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of("version", "2", "description", "This is test description"); + String requestContent = new Gson().toJson(updateContent).toString(); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + +}