Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] [Enhancement] Enhance validation for create connector API #3287

Open
wants to merge 1 commit into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ public ConnectorAction(
String postProcessFunction
) {
if (actionType == null) {
throw new IllegalArgumentException("action type can't null");
throw new IllegalArgumentException("action type can't be null");
}
if (url == null) {
throw new IllegalArgumentException("url can't null");
throw new IllegalArgumentException("url can't be null");
}
if (method == null) {
throw new IllegalArgumentException("method can't null");
throw new IllegalArgumentException("method can't be null");
}
this.actionType = actionType;
this.method = method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public MLCreateConnectorInput(
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
if (credential == null || credential.isEmpty()) {
throw new IllegalArgumentException("Connector credential is null or empty list");
}
}
this.name = name;
this.description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.ml.common.connector;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -27,130 +26,124 @@
import org.opensearch.search.SearchModule;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

// Shared test data for the class
private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT;
private static final String TEST_METHOD_POST = "post";
private static final String TEST_METHOD_HTTP = "http";
private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}";
private static final String URL = "https://test.com";

@Test
public void constructor_NullActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("action type can't null");
ConnectorAction.ActionType actionType = null;
String method = "post";
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("action type can't be null", exception.getMessage());

}

@Test
public void constructor_NullUrl() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("url can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = null;
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("url can't be null", exception.getMessage());
}

@Test
public void constructor_NullMethod() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("method can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = null;
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("method can't be null", exception.getMessage());
}

@Test
public void writeTo_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void writeTo() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void toXContent_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content);
String expctedContent = """
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't support in Java 17. In main branch, we use JAVA 21 which support this string. So for you I would suggest to do:

  1. Fix main branch code with not using """
  2. And then backport both of the commits together in 2.x branch.

Copy link
Contributor

@akolarkunnu akolarkunnu Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Text blocks supports from JDK15 onwards - https://openjdk.org/jeps/378
I think real issue is, 2.x branch's build.gradle still points to JDK11 - https://github.com/opensearch-project/ml-commons/blob/2.x/build.gradle#L64 . Is it suppose to be JDK17 ?
According to developer guide, it's JDK 17 https://github.com/opensearch-project/ml-commons/blob/2.x/DEVELOPER_GUIDE.md#install-java

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

D:\a\ml-commons\ml-commons\common\src\test\java\org\opensearch\ml\common\connector\ConnectorActionTest.java:103: error: text blocks are not supported in -source 11
For more on this, please refer to https://docs.gradle.org/8.11.1/userguide/command_line_interface.html#sec:command_line_warnings in the Gradle documentation.
          String expctedContent = """
31 actionable tasks: 31 executed

My bad, I meant 11. For now let's change the text block.

Upgrading that to 17 requires more discussion and plugin level campaign. As other plugins are also in the same verison: flow framework, knn, neural-search, skills

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, raised a new PR to revert Text Block changes from test cases - #3329
Please review.

Copy link
Collaborator

@ylwu-amzn ylwu-amzn Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks , PR #3329 merged. Can you fix this PR to let it pass CI?

{"action_type":"PREDICT","method":"http","url":"https://test.com",\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\
""";
assertEquals(expctedContent, content);
}

@Test
public void toXContent() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert
.assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
String expctedContent = """
{"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\
"pre_process_function":"connector.pre_process.openai.embedding",\
"post_process_function":"connector.post_process.openai.embedding"}\
""";
assertEquals(expctedContent, content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}";
String jsonStr = """
{"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\
"pre_process_function":"connector.pre_process.openai.embedding",\
"post_process_function":"connector.post_process.openai.embedding"}"\
""";
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
Expand All @@ -160,24 +153,23 @@ public void parse() throws IOException {
);
parser.nextToken();
ConnectorAction action = ConnectorAction.parse(parser);
Assert.assertEquals("http", action.getMethod());
Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
Assert.assertEquals("https://test.com", action.getUrl());
Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody());
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
assertEquals(TEST_METHOD_HTTP, action.getMethod());
assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
assertEquals(URL, action.getUrl());
assertEquals(TEST_REQUEST_BODY, action.getRequestBody());
assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); });
assertEquals("Wrong Action Type of badAction", exception.getMessage());
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Loading
Loading