diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 4a7555d69b..835c6a6c47 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -55,13 +55,13 @@ public ConnectorAction( String postProcessFunction ) { if (actionType == null) { - throw new IllegalArgumentException("action type can't null"); + throw new IllegalArgumentException("action type can't be null"); } if (url == null) { - throw new IllegalArgumentException("url can't null"); + throw new IllegalArgumentException("url can't be null"); } if (method == null) { - throw new IllegalArgumentException("method can't null"); + throw new IllegalArgumentException("method can't be null"); } this.actionType = actionType; this.method = method; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 697f27494f..7029fccb7e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -93,6 +93,9 @@ public MLCreateConnectorInput( if (protocol == null) { throw new IllegalArgumentException("Connector protocol is null"); } + if (credential == null || credential.isEmpty()) { + throw new IllegalArgumentException("Connector credential is null or empty list"); + } } this.name = name; this.description = description; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index 1539b9b432..ed1981e1be 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; import java.io.IOException; @@ -12,10 +14,7 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Assert; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; @@ -27,107 +26,100 @@ import org.opensearch.search.SearchModule; public class ConnectorActionTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + + // Shared test data for the class + private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT; + private static final String TEST_METHOD_POST = "post"; + private static final String TEST_METHOD_HTTP = "http"; + private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}"; + private static final String URL = "https://test.com"; @Test public void constructor_NullActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("action type can't null"); - ConnectorAction.ActionType actionType = null; - String method = "post"; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("action type can't be null", exception.getMessage()); + } @Test public void constructor_NullUrl() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("url can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = null; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("url can't be null", exception.getMessage()); } @Test public void constructor_NullMethod() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("method can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = null; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("method can't be null", exception.getMessage()); } @Test public void writeTo_NullValue() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test public void writeTo() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("key1", "value1"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( - actionType, - method, - url, + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, headers, - requestBody, + TEST_REQUEST_BODY, preProcessFunction, postProcessFunction ); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test public void toXContent_NullValue() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com",\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\ + """; + assertEquals(expctedContent, content); } @Test public void toXContent() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("key1", "value1"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( - actionType, - method, - url, + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, headers, - requestBody, + TEST_REQUEST_BODY, preProcessFunction, postProcessFunction ); @@ -135,22 +127,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert - .assertEquals( - "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}", - content - ); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}\ + """; + assertEquals(expctedContent, content); } @Test public void parse() throws IOException { - String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}"; + String jsonStr = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}"\ + """; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -160,24 +153,23 @@ public void parse() throws IOException { ); parser.nextToken(); ConnectorAction action = ConnectorAction.parse(parser); - Assert.assertEquals("http", action.getMethod()); - Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); - Assert.assertEquals("https://test.com", action.getUrl()); - Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); - Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); - Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); + assertEquals(TEST_METHOD_HTTP, action.getMethod()); + assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); + assertEquals(URL, action.getUrl()); + assertEquals(TEST_REQUEST_BODY, action.getRequestBody()); + assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); + assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); } @Test public void test_wrongActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Action Type"); - ConnectorAction.ActionType.from("badAction"); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); }); + assertEquals("Wrong Action Type of badAction", exception.getMessage()); } @Test public void test_invalidActionInModelPrediction() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); - Assert.assertEquals(isValidActionInModelPrediction(actionType), false); + assertEquals(isValidActionInModelPrediction(actionType), false); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 28e597e186..ab08907b84 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -19,9 +20,7 @@ import java.util.function.Consumer; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -46,20 +45,29 @@ public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; - @Rule - public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," - + "\"connection_timeout\":10000,\"read_timeout\":10000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + private static final String TEST_CONNECTOR_NAME = "test_connector_name"; + private static final String TEST_CONNECTOR_DESCRIPTION = "this is a test connector"; + private static final String TEST_CONNECTOR_VERSION = "1"; + private static final String TEST_CONNECTOR_PROTOCOL = "http"; + private static final String TEST_PARAM_KEY = "input"; + private static final String TEST_PARAM_VALUE = "test input value"; + private static final String TEST_CREDENTIAL_KEY = "key"; + private static final String TEST_CREDENTIAL_VALUE = "test_key_value"; + private static final String TEST_ROLE1 = "role1"; + private static final String TEST_ROLE2 = "role2"; + private final String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1","protocol":"http",\ + "parameters":{"input":"test input value"},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,\ + "access_mode":"PUBLIC","client_config":{"max_connection":20,\ + "connection_timeout":10000,"read_timeout":10000,\ + "retry_backoff_millis":10,"retry_timeout_seconds":10,"max_retry_times":-1,"retry_backoff_policy":"constant"}}\ + """; @Before public void setUp() { @@ -84,15 +92,15 @@ public void setUp() { mlCreateConnectorInput = MLCreateConnectorInput .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .actions(List.of(action)) .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) .addAllBackendRoles(false) .connectorClientConfig(connectorClientConfig) .build(); @@ -102,59 +110,102 @@ public void setUp() { @Test public void constructorMLCreateConnectorInput_NullName() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput - .builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(null) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector name is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullVersion() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(null) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector version is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(null) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector protocol is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(null) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of()) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); } @Test @@ -178,7 +229,7 @@ public void testToXContent_NullFields() throws Exception { @Test public void testParse() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(20, parsedInput.getConnectorClientConfig().getMaxConnections().intValue()); assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeout().intValue()); assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeout().intValue()); @@ -187,18 +238,17 @@ public void testParse() throws Exception { @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1",\ + "protocol":"http","parameters":{"input":["test input value"]},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,"access_mode":"PUBLIC"};\ + """; testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); assertEquals("[\"test input value\"]", parsedInput.getParameters().get("input")); }); @@ -222,9 +272,10 @@ public void readInputStream_Success() throws IOException { public void readInputStream_SuccessWithNullFields() throws IOException { MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput .builder() - .name("test_connector_name") - .version("1") - .protocol("http") + .name(TEST_CONNECTOR_NAME) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); @@ -238,15 +289,15 @@ public void testBuilder_NullActions_ShouldNotThrowException() { // Actions can be null for a connector without any specific actions defined. MLCreateConnectorInput input = MLCreateConnectorInput .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .actions(null) // Setting actions to null .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) .addAllBackendRoles(false) .build(); @@ -258,10 +309,8 @@ public void testParse_MissingNameField_ShouldThrowException() throws IOException String jsonMissingName = "{\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"}"; XContentParser parser = createParser(jsonMissingName); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - - MLCreateConnectorInput.parse(parser); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); }); + assertEquals("Connector name is null", exception.getMessage()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index e16400bc56..1a0cd7716f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -450,12 +450,14 @@ public void test_execute_URL_notMatchingExpression_exception() { .build() ); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput .builder() .name(randomAlphaOfLength(5)) .description(randomAlphaOfLength(10)) .version("1") .protocol(ConnectorProtocols.HTTP) + .credential(credential) .actions(actions) .build(); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);