Skip to content

Commit

Permalink
Backport #3260 and #3329 (#3353)
Browse files Browse the repository at this point in the history
Backporting these two PRs together becasue auto backporting of 3260 failed becasue of usage of Text Blocks, 3329 is to revert the usage of Text Blocks.

Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
  • Loading branch information
akolarkunnu authored Jan 9, 2025
1 parent 225ca40 commit 01d1349
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ public ConnectorAction(
String postProcessFunction
) {
if (actionType == null) {
throw new IllegalArgumentException("action type can't null");
throw new IllegalArgumentException("action type can't be null");
}
if (url == null) {
throw new IllegalArgumentException("url can't null");
throw new IllegalArgumentException("url can't be null");
}
if (method == null) {
throw new IllegalArgumentException("method can't null");
throw new IllegalArgumentException("method can't be null");
}
this.actionType = actionType;
this.method = method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public MLCreateConnectorInput(
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
if (credential == null || credential.isEmpty()) {
throw new IllegalArgumentException("Connector credential is null or empty list");
}
}
this.name = name;
this.description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.ml.common.connector;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -27,122 +26,114 @@
import org.opensearch.search.SearchModule;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

// Shared test data for the class
private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT;
private static final String TEST_METHOD_POST = "post";
private static final String TEST_METHOD_HTTP = "http";
private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}";
private static final String URL = "https://test.com";

@Test
public void constructor_NullActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("action type can't null");
ConnectorAction.ActionType actionType = null;
String method = "post";
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("action type can't be null", exception.getMessage());

}

@Test
public void constructor_NullUrl() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("url can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = null;
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("url can't be null", exception.getMessage());
}

@Test
public void constructor_NullMethod() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("method can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = null;
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("method can't be null", exception.getMessage());
}

@Test
public void writeTo_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void writeTo() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void toXContent_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content);
assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}",
content
);
}

@Test
public void toXContent() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert
.assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
}

@Test
Expand All @@ -160,24 +151,23 @@ public void parse() throws IOException {
);
parser.nextToken();
ConnectorAction action = ConnectorAction.parse(parser);
Assert.assertEquals("http", action.getMethod());
Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
Assert.assertEquals("https://test.com", action.getUrl());
Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody());
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
assertEquals(TEST_METHOD_HTTP, action.getMethod());
assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
assertEquals(URL, action.getUrl());
assertEquals(TEST_REQUEST_BODY, action.getRequestBody());
assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); });
assertEquals("Wrong Action Type of badAction", exception.getMessage());
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Loading

0 comments on commit 01d1349

Please sign in to comment.