Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…rch-project#3353)

Backporting these two PRs together becasue auto backporting of 3260 failed becasue of usage of Text Blocks, 3329 is to revert the usage of Text Blocks.

Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
  • Loading branch information
akolarkunnu authored Jan 9, 2025
1 parent 225ca40 commit 01d1349
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ public ConnectorAction(
String postProcessFunction
) {
if (actionType == null) {
throw new IllegalArgumentException("action type can't null");
throw new IllegalArgumentException("action type can't be null");
}
if (url == null) {
throw new IllegalArgumentException("url can't null");
throw new IllegalArgumentException("url can't be null");
}
if (method == null) {
throw new IllegalArgumentException("method can't null");
throw new IllegalArgumentException("method can't be null");
}
this.actionType = actionType;
this.method = method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public MLCreateConnectorInput(
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
if (credential == null || credential.isEmpty()) {
throw new IllegalArgumentException("Connector credential is null or empty list");
}
}
this.name = name;
this.description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.ml.common.connector;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -27,122 +26,114 @@
import org.opensearch.search.SearchModule;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

// Shared test data for the class
private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT;
private static final String TEST_METHOD_POST = "post";
private static final String TEST_METHOD_HTTP = "http";
private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}";
private static final String URL = "https://test.com";

@Test
public void constructor_NullActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("action type can't null");
ConnectorAction.ActionType actionType = null;
String method = "post";
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("action type can't be null", exception.getMessage());

}

@Test
public void constructor_NullUrl() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("url can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = null;
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("url can't be null", exception.getMessage());
}

@Test
public void constructor_NullMethod() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("method can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = null;
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("method can't be null", exception.getMessage());
}

@Test
public void writeTo_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void writeTo() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void toXContent_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content);
assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"}",
content
);
}

@Test
public void toXContent() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert
.assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
}

@Test
Expand All @@ -160,24 +151,23 @@ public void parse() throws IOException {
);
parser.nextToken();
ConnectorAction action = ConnectorAction.parse(parser);
Assert.assertEquals("http", action.getMethod());
Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
Assert.assertEquals("https://test.com", action.getUrl());
Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody());
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
assertEquals(TEST_METHOD_HTTP, action.getMethod());
assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
assertEquals(URL, action.getUrl());
assertEquals(TEST_REQUEST_BODY, action.getRequestBody());
assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); });
assertEquals("Wrong Action Type of badAction", exception.getMessage());
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Loading

0 comments on commit 01d1349

Please sign in to comment.