From 68dead2b0057f8f5832828d86c004a6eb5b05a86 Mon Sep 17 00:00:00 2001 From: Muneer Kolarkunnu <33829651+akolarkunnu@users.noreply.github.com> Date: Wed, 18 Dec 2024 03:12:02 +0530 Subject: [PATCH] [Enhancement] Enhance validation for create connector API (#3260) This PR addresses the first part of this enhancement "Validate if connector payload has all the required fields. If not provided, throw the illegal argument exception". Validation of fields description, parameters, credential, and request_body are missing. That validations are added in this fix. Added new test cases correspong to these validations and fixed all failing test cases because of these new validations. Partially Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu (cherry picked from commit 58903ba9410edd3f04451c3f7cd55462d79a4279) --- .../ml/common/connector/ConnectorAction.java | 6 +- .../connector/MLCreateConnectorInput.java | 3 + .../common/connector/ConnectorActionTest.java | 136 +++++----- .../MLCreateConnectorInputTests.java | 237 +++++++++++------- .../TransportCreateConnectorActionTests.java | 2 + 5 files changed, 215 insertions(+), 169 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 4a7555d69b..835c6a6c47 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -55,13 +55,13 @@ public ConnectorAction( String postProcessFunction ) { if (actionType == null) { - throw new IllegalArgumentException("action type can't null"); + throw new IllegalArgumentException("action type can't be null"); } if (url == null) { - throw new IllegalArgumentException("url can't null"); + throw new IllegalArgumentException("url can't be null"); } if (method == null) { - throw new IllegalArgumentException("method can't null"); + throw new IllegalArgumentException("method can't be null"); } this.actionType = actionType; this.method = method; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 697f27494f..7029fccb7e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -93,6 +93,9 @@ public MLCreateConnectorInput( if (protocol == null) { throw new IllegalArgumentException("Connector protocol is null"); } + if (credential == null || credential.isEmpty()) { + throw new IllegalArgumentException("Connector credential is null or empty list"); + } } this.name = name; this.description = description; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index 1539b9b432..ed1981e1be 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; import java.io.IOException; @@ -12,10 +14,7 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Assert; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; @@ -27,107 +26,100 @@ import org.opensearch.search.SearchModule; public class ConnectorActionTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + + // Shared test data for the class + private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT; + private static final String TEST_METHOD_POST = "post"; + private static final String TEST_METHOD_HTTP = "http"; + private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}"; + private static final String URL = "https://test.com"; @Test public void constructor_NullActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("action type can't null"); - ConnectorAction.ActionType actionType = null; - String method = "post"; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("action type can't be null", exception.getMessage()); + } @Test public void constructor_NullUrl() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("url can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = null; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("url can't be null", exception.getMessage()); } @Test public void constructor_NullMethod() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("method can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = null; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("method can't be null", exception.getMessage()); } @Test public void writeTo_NullValue() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test public void writeTo() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("key1", "value1"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( - actionType, - method, - url, + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, headers, - requestBody, + TEST_REQUEST_BODY, preProcessFunction, postProcessFunction ); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test public void toXContent_NullValue() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com",\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\ + """; + assertEquals(expctedContent, content); } @Test public void toXContent() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("key1", "value1"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( - actionType, - method, - url, + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, headers, - requestBody, + TEST_REQUEST_BODY, preProcessFunction, postProcessFunction ); @@ -135,22 +127,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert - .assertEquals( - "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}", - content - ); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}\ + """; + assertEquals(expctedContent, content); } @Test public void parse() throws IOException { - String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}"; + String jsonStr = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}"\ + """; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -160,24 +153,23 @@ public void parse() throws IOException { ); parser.nextToken(); ConnectorAction action = ConnectorAction.parse(parser); - Assert.assertEquals("http", action.getMethod()); - Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); - Assert.assertEquals("https://test.com", action.getUrl()); - Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); - Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); - Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); + assertEquals(TEST_METHOD_HTTP, action.getMethod()); + assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); + assertEquals(URL, action.getUrl()); + assertEquals(TEST_REQUEST_BODY, action.getRequestBody()); + assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); + assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); } @Test public void test_wrongActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Action Type"); - ConnectorAction.ActionType.from("badAction"); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); }); + assertEquals("Wrong Action Type of badAction", exception.getMessage()); } @Test public void test_invalidActionInModelPrediction() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); - Assert.assertEquals(isValidActionInModelPrediction(actionType), false); + assertEquals(isValidActionInModelPrediction(actionType), false); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 28e597e186..ab08907b84 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -19,9 +20,7 @@ import java.util.function.Consumer; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -46,20 +45,29 @@ public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; - @Rule - public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," - + "\"connection_timeout\":10000,\"read_timeout\":10000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + private static final String TEST_CONNECTOR_NAME = "test_connector_name"; + private static final String TEST_CONNECTOR_DESCRIPTION = "this is a test connector"; + private static final String TEST_CONNECTOR_VERSION = "1"; + private static final String TEST_CONNECTOR_PROTOCOL = "http"; + private static final String TEST_PARAM_KEY = "input"; + private static final String TEST_PARAM_VALUE = "test input value"; + private static final String TEST_CREDENTIAL_KEY = "key"; + private static final String TEST_CREDENTIAL_VALUE = "test_key_value"; + private static final String TEST_ROLE1 = "role1"; + private static final String TEST_ROLE2 = "role2"; + private final String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1","protocol":"http",\ + "parameters":{"input":"test input value"},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,\ + "access_mode":"PUBLIC","client_config":{"max_connection":20,\ + "connection_timeout":10000,"read_timeout":10000,\ + "retry_backoff_millis":10,"retry_timeout_seconds":10,"max_retry_times":-1,"retry_backoff_policy":"constant"}}\ + """; @Before public void setUp() { @@ -84,15 +92,15 @@ public void setUp() { mlCreateConnectorInput = MLCreateConnectorInput .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .actions(List.of(action)) .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) .addAllBackendRoles(false) .connectorClientConfig(connectorClientConfig) .build(); @@ -102,59 +110,102 @@ public void setUp() { @Test public void constructorMLCreateConnectorInput_NullName() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput - .builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(null) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector name is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullVersion() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(null) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector version is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(null) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector protocol is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(null) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of()) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); } @Test @@ -178,7 +229,7 @@ public void testToXContent_NullFields() throws Exception { @Test public void testParse() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(20, parsedInput.getConnectorClientConfig().getMaxConnections().intValue()); assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeout().intValue()); assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeout().intValue()); @@ -187,18 +238,17 @@ public void testParse() throws Exception { @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1",\ + "protocol":"http","parameters":{"input":["test input value"]},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,"access_mode":"PUBLIC"};\ + """; testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); assertEquals("[\"test input value\"]", parsedInput.getParameters().get("input")); }); @@ -222,9 +272,10 @@ public void readInputStream_Success() throws IOException { public void readInputStream_SuccessWithNullFields() throws IOException { MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput .builder() - .name("test_connector_name") - .version("1") - .protocol("http") + .name(TEST_CONNECTOR_NAME) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); @@ -238,15 +289,15 @@ public void testBuilder_NullActions_ShouldNotThrowException() { // Actions can be null for a connector without any specific actions defined. MLCreateConnectorInput input = MLCreateConnectorInput .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .actions(null) // Setting actions to null .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) .addAllBackendRoles(false) .build(); @@ -258,10 +309,8 @@ public void testParse_MissingNameField_ShouldThrowException() throws IOException String jsonMissingName = "{\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"}"; XContentParser parser = createParser(jsonMissingName); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - - MLCreateConnectorInput.parse(parser); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); }); + assertEquals("Connector name is null", exception.getMessage()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index e16400bc56..1a0cd7716f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -450,12 +450,14 @@ public void test_execute_URL_notMatchingExpression_exception() { .build() ); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput .builder() .name(randomAlphaOfLength(5)) .description(randomAlphaOfLength(10)) .version("1") .protocol(ConnectorProtocols.HTTP) + .credential(credential) .actions(actions) .build(); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);