From 093fed8331c4da2f44ab727f4e51b8efbebf33d7 Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Sun, 15 Dec 2024 23:17:12 +0900 Subject: [PATCH 1/5] update error handling to throw exception when post processing function recieve empty result from a model. Signed-off-by: tkykenmt --- ...erank-m3-v2_model_deployed_on_Sagemaker.md | 126 ++++++++++++++++-- 1 file changed, 114 insertions(+), 12 deletions(-) diff --git a/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md b/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md index 238f62300e..278a96e80d 100644 --- a/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md +++ b/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md @@ -59,10 +59,38 @@ result = predictor.predict(data={ ] }) -print(json.dumps(sorted(result, key=lambda x: x['index']), indent=2)) +print(json.dumps(result, indent=2)) ``` -The reranking results are as follows: +The reranking result is ordering by the highest score first: +``` +[ + { + "index": 2, + "score": 0.92879725 + }, + { + "index": 0, + "score": 0.013636836 + }, + { + "index": 1, + "score": 0.000593021 + }, + { + "index": 3, + "score": 0.00012148176 + } +] +``` + +You can sort the result by index number. + +```python +print(json.dumps(result, indent=2)) +``` + +The results are as follows: ``` [ @@ -121,9 +149,46 @@ POST /_plugins/_ml/connectors/_create "headers": { "content-type": "application/json" }, + "pre_process_function": """ + def query_text = params.query_text; + def text_docs = params.text_docs; + def textDocsBuilder = new StringBuilder('['); + for (int i=0; i 0) { + throw new IllegalArgumentException("Post process function input is empty."); + } + def outputs = params.result; + def scores = new Double[outputs.length]; + for (int i=0; i 0) { + throw new IllegalArgumentException("Post process function input is empty."); + } + def outputs = params.result; + def scores = new Double[outputs.length]; + for (int i=0; i Date: Wed, 18 Dec 2024 03:12:02 +0530 Subject: [PATCH 2/5] [Enhancement] Enhance validation for create connector API (#3260) This PR addresses the first part of this enhancement "Validate if connector payload has all the required fields. If not provided, throw the illegal argument exception". Validation of fields description, parameters, credential, and request_body are missing. That validations are added in this fix. Added new test cases correspong to these validations and fixed all failing test cases because of these new validations. Partially Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --- .../ml/common/connector/ConnectorAction.java | 6 +- .../connector/MLCreateConnectorInput.java | 3 + .../common/connector/ConnectorActionTest.java | 136 +++++----- .../MLCreateConnectorInputTests.java | 237 +++++++++++------- .../TransportCreateConnectorActionTests.java | 2 + 5 files changed, 215 insertions(+), 169 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 4a7555d69b..835c6a6c47 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -55,13 +55,13 @@ public ConnectorAction( String postProcessFunction ) { if (actionType == null) { - throw new IllegalArgumentException("action type can't null"); + throw new IllegalArgumentException("action type can't be null"); } if (url == null) { - throw new IllegalArgumentException("url can't null"); + throw new IllegalArgumentException("url can't be null"); } if (method == null) { - throw new IllegalArgumentException("method can't null"); + throw new IllegalArgumentException("method can't be null"); } this.actionType = actionType; this.method = method; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 697f27494f..7029fccb7e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -93,6 +93,9 @@ public MLCreateConnectorInput( if (protocol == null) { throw new IllegalArgumentException("Connector protocol is null"); } + if (credential == null || credential.isEmpty()) { + throw new IllegalArgumentException("Connector credential is null or empty list"); + } } this.name = name; this.description = description; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index 1539b9b432..ed1981e1be 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; import java.io.IOException; @@ -12,10 +14,7 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Assert; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; @@ -27,107 +26,100 @@ import org.opensearch.search.SearchModule; public class ConnectorActionTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + + // Shared test data for the class + private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT; + private static final String TEST_METHOD_POST = "post"; + private static final String TEST_METHOD_HTTP = "http"; + private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}"; + private static final String URL = "https://test.com"; @Test public void constructor_NullActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("action type can't null"); - ConnectorAction.ActionType actionType = null; - String method = "post"; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("action type can't be null", exception.getMessage()); + } @Test public void constructor_NullUrl() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("url can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = null; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("url can't be null", exception.getMessage()); } @Test public void constructor_NullMethod() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("method can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = null; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows( + IllegalArgumentException.class, + () -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null) + ); + assertEquals("method can't be null", exception.getMessage()); } @Test public void writeTo_NullValue() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test public void writeTo() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("key1", "value1"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( - actionType, - method, - url, + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, headers, - requestBody, + TEST_REQUEST_BODY, preProcessFunction, postProcessFunction ); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test public void toXContent_NullValue() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com",\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\ + """; + assertEquals(expctedContent, content); } @Test public void toXContent() throws IOException { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "http"; - String url = "https://test.com"; Map headers = new HashMap<>(); headers.put("key1", "value1"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( - actionType, - method, - url, + TEST_ACTION_TYPE, + TEST_METHOD_HTTP, + URL, headers, - requestBody, + TEST_REQUEST_BODY, preProcessFunction, postProcessFunction ); @@ -135,22 +127,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert - .assertEquals( - "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}", - content - ); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}\ + """; + assertEquals(expctedContent, content); } @Test public void parse() throws IOException { - String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}"; + String jsonStr = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}"\ + """; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -160,24 +153,23 @@ public void parse() throws IOException { ); parser.nextToken(); ConnectorAction action = ConnectorAction.parse(parser); - Assert.assertEquals("http", action.getMethod()); - Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); - Assert.assertEquals("https://test.com", action.getUrl()); - Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); - Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); - Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); + assertEquals(TEST_METHOD_HTTP, action.getMethod()); + assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); + assertEquals(URL, action.getUrl()); + assertEquals(TEST_REQUEST_BODY, action.getRequestBody()); + assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); + assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); } @Test public void test_wrongActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Action Type"); - ConnectorAction.ActionType.from("badAction"); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); }); + assertEquals("Wrong Action Type of badAction", exception.getMessage()); } @Test public void test_invalidActionInModelPrediction() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); - Assert.assertEquals(isValidActionInModelPrediction(actionType), false); + assertEquals(isValidActionInModelPrediction(actionType), false); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 28e597e186..ab08907b84 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -19,9 +20,7 @@ import java.util.function.Consumer; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -46,20 +45,29 @@ public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; - @Rule - public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," - + "\"connection_timeout\":10000,\"read_timeout\":10000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + private static final String TEST_CONNECTOR_NAME = "test_connector_name"; + private static final String TEST_CONNECTOR_DESCRIPTION = "this is a test connector"; + private static final String TEST_CONNECTOR_VERSION = "1"; + private static final String TEST_CONNECTOR_PROTOCOL = "http"; + private static final String TEST_PARAM_KEY = "input"; + private static final String TEST_PARAM_VALUE = "test input value"; + private static final String TEST_CREDENTIAL_KEY = "key"; + private static final String TEST_CREDENTIAL_VALUE = "test_key_value"; + private static final String TEST_ROLE1 = "role1"; + private static final String TEST_ROLE2 = "role2"; + private final String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1","protocol":"http",\ + "parameters":{"input":"test input value"},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,\ + "access_mode":"PUBLIC","client_config":{"max_connection":20,\ + "connection_timeout":10000,"read_timeout":10000,\ + "retry_backoff_millis":10,"retry_timeout_seconds":10,"max_retry_times":-1,"retry_backoff_policy":"constant"}}\ + """; @Before public void setUp() { @@ -84,15 +92,15 @@ public void setUp() { mlCreateConnectorInput = MLCreateConnectorInput .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .actions(List.of(action)) .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) .addAllBackendRoles(false) .connectorClientConfig(connectorClientConfig) .build(); @@ -102,59 +110,102 @@ public void setUp() { @Test public void constructorMLCreateConnectorInput_NullName() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput - .builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(null) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector name is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullVersion() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(null) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector version is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(null) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector protocol is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(null) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of()) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); } @Test @@ -178,7 +229,7 @@ public void testToXContent_NullFields() throws Exception { @Test public void testParse() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(20, parsedInput.getConnectorClientConfig().getMaxConnections().intValue()); assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeout().intValue()); assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeout().intValue()); @@ -187,18 +238,17 @@ public void testParse() throws Exception { @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1",\ + "protocol":"http","parameters":{"input":["test input value"]},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,"access_mode":"PUBLIC"};\ + """; testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); assertEquals("[\"test input value\"]", parsedInput.getParameters().get("input")); }); @@ -222,9 +272,10 @@ public void readInputStream_Success() throws IOException { public void readInputStream_SuccessWithNullFields() throws IOException { MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput .builder() - .name("test_connector_name") - .version("1") - .protocol("http") + .name(TEST_CONNECTOR_NAME) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); @@ -238,15 +289,15 @@ public void testBuilder_NullActions_ShouldNotThrowException() { // Actions can be null for a connector without any specific actions defined. MLCreateConnectorInput input = MLCreateConnectorInput .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol(TEST_CONNECTOR_PROTOCOL) + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) .actions(null) // Setting actions to null .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) .addAllBackendRoles(false) .build(); @@ -258,10 +309,8 @@ public void testParse_MissingNameField_ShouldThrowException() throws IOException String jsonMissingName = "{\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"}"; XContentParser parser = createParser(jsonMissingName); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - - MLCreateConnectorInput.parse(parser); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); }); + assertEquals("Connector name is null", exception.getMessage()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index e16400bc56..1a0cd7716f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -450,12 +450,14 @@ public void test_execute_URL_notMatchingExpression_exception() { .build() ); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput .builder() .name(randomAlphaOfLength(5)) .description(randomAlphaOfLength(10)) .version("1") .protocol(ConnectorProtocols.HTTP) + .credential(credential) .actions(actions) .build(); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); From 7334b5527a18280bef87ad39fd71b68a890356f4 Mon Sep 17 00:00:00 2001 From: Rithin Pullela Date: Tue, 24 Dec 2024 11:55:19 -0800 Subject: [PATCH 3/5] Add application_type to ConversationMeta; update tests (#3282) Modify getMemory(Conversation) to return the application_type parameter. Include application_type in the ConversationMeta data model. Update existing tests to validate the new parameter. Signed-off-by: rithin-pullela-aws --- .../common/conversation/ConversationMeta.java | 22 +++++++++++++++---- .../conversation/ConversationMetaTests.java | 12 +++++++--- .../GetConversationResponseTests.java | 6 ++--- .../GetConversationTransportActionTests.java | 2 +- .../GetConversationsResponseTests.java | 6 ++--- .../GetConversationsTransportActionTests.java | 10 ++++----- ...earchConversationalMemoryHandlerTests.java | 2 +- 7 files changed, 40 insertions(+), 20 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java index 21ed608654..5d847fbcd7 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -51,6 +51,8 @@ public class ConversationMeta implements Writeable, ToXContentObject { @Getter private String user; @Getter + private String applicationType; + @Getter private Map additionalInfos; /** @@ -74,8 +76,9 @@ public static ConversationMeta fromMap(String id, Map docFields) Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD)); String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD); String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD); + String applicationType = (String) docFields.get(ConversationalIndexConstants.APPLICATION_TYPE_FIELD); Map additionalInfos = (Map) docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD); - return new ConversationMeta(id, created, updated, name, user, additionalInfos); + return new ConversationMeta(id, created, updated, name, user, applicationType, additionalInfos); } /** @@ -91,13 +94,14 @@ public static ConversationMeta fromStream(StreamInput in) throws IOException { Instant updated = in.readInstant(); String name = in.readString(); String user = in.readOptionalString(); + String applicationType = in.readOptionalString(); Map additionalInfos = null; if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { if (in.readBoolean()) { additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString); } } - return new ConversationMeta(id, created, updated, name, user, additionalInfos); + return new ConversationMeta(id, created, updated, name, user, applicationType, additionalInfos); } @Override @@ -107,6 +111,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(updatedTime); out.writeString(name); out.writeOptionalString(user); + out.writeOptionalString(applicationType); if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { if (additionalInfos == null) { out.writeBoolean(false); @@ -129,6 +134,10 @@ public String toString() { + updatedTime.toString() + ", user=" + user + + ", applicationType=" + + applicationType + + ", additionalInfos=" + + additionalInfos + "}"; } @@ -142,7 +151,10 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para if (this.user != null) { builder.field(ConversationalIndexConstants.USER_FIELD, this.user); } - if (this.additionalInfos != null) { + if (this.applicationType != null && !this.applicationType.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, this.applicationType); + } + if (this.additionalInfos != null && !additionalInfos.isEmpty()) { builder.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, this.additionalInfos); } builder.endObject(); @@ -159,7 +171,9 @@ public boolean equals(Object other) { && Objects.equals(this.user, otherConversation.user) && Objects.equals(this.createdTime, otherConversation.createdTime) && Objects.equals(this.updatedTime, otherConversation.updatedTime) - && Objects.equals(this.name, otherConversation.name); + && Objects.equals(this.name, otherConversation.name) + && Objects.equals(this.applicationType, otherConversation.applicationType) + && Objects.equals(this.additionalInfos, otherConversation.additionalInfos); } } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java index aaa52ffcff..7666ab3bf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java @@ -30,7 +30,7 @@ public class ConversationMetaTests { @Before public void setUp() { time = Instant.now(); - conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", null); + conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", "conversational-search", null); } @Test @@ -41,6 +41,7 @@ public void test_fromSearchHit() throws IOException { content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time); content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name"); content.field(ConversationalIndexConstants.USER_FIELD, "admin"); + content.field(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, "conversational-search"); content.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, Map.of("test_key", "test_value")); content.endObject(); @@ -51,6 +52,7 @@ public void test_fromSearchHit() throws IOException { assertEquals(conversationMeta.getId(), "cId"); assertEquals(conversationMeta.getName(), "meta name"); assertEquals(conversationMeta.getUser(), "admin"); + assertEquals(conversationMeta.getApplicationType(), "conversational-search"); assertEquals(conversationMeta.getAdditionalInfos().get("test_key"), "test_value"); } @@ -83,6 +85,7 @@ public void test_fromStream() throws IOException { assertEquals(meta.getId(), conversationMeta.getId()); assertEquals(meta.getName(), conversationMeta.getName()); assertEquals(meta.getUser(), conversationMeta.getUser()); + assertEquals(meta.getApplicationType(), conversationMeta.getApplicationType()); } @Test @@ -93,6 +96,7 @@ public void test_ToXContent() throws IOException { Instant.ofEpochMilli(123), "test meta", "admin", + "neural-search", null ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -100,7 +104,7 @@ public void test_ToXContent() throws IOException { String content = TestHelper.xContentBuilderToString(builder); assertEquals( content, - "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}" + "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\",\"application_type\":\"neural-search\"}" ); } @@ -112,10 +116,11 @@ public void test_toString() { Instant.ofEpochMilli(123), "test meta", "admin", + "conversational-search", null ); assertEquals( - "{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", + "{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin, applicationType=conversational-search, additionalInfos=null}", conversationMeta.toString() ); } @@ -128,6 +133,7 @@ public void test_equal() { Instant.ofEpochMilli(123), "test meta", "admin", + "conversational-search", null ); assertEquals(meta.equals(conversationMeta), false); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java index 0b39d546f8..08a285ec90 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -38,7 +38,7 @@ public class GetConversationResponseTests extends OpenSearchTestCase { public void testGetConversationResponseStreaming() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null, null); GetConversationResponse response = new GetConversationResponse(convo); assert (response.getConversation().equals(convo)); @@ -51,7 +51,7 @@ public void testGetConversationResponseStreaming() throws IOException { } public void testToXContent() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null, null); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -68,7 +68,7 @@ public void testToXContent() throws IOException { public void testToXContent_withAdditionalInfo() throws IOException { Map additionalInfos = Map.of("key1", "value1"); - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, additionalInfos); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null, additionalInfos); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java index 558ecd9b65..aa85a8507d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -107,7 +107,7 @@ public void setup() throws IOException { } public void testGetConversation() { - ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null, null); + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null, null, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(result); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java index b28ed26d0f..71982e6a76 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java @@ -46,9 +46,9 @@ public class GetConversationsResponseTests extends OpenSearchTestCase { public void setup() { conversations = List .of( - new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0", null), - new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0", null), - new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2", null) + new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0", null, null), + new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0", null, null), + new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2", null, null) ); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java index a866167d37..257c74f1bb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java @@ -114,8 +114,8 @@ public void testGetConversations() { log.info("testing get conversations transport"); List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null), - new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null, null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null, null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); @@ -132,9 +132,9 @@ public void testGetConversations() { public void testPagination() { List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null), - new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null), - new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null, null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null, null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null, null), + new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null, null, null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index 903be08338..fc63811d2c 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -315,7 +315,7 @@ public void testSearchInteractions_Future() { } public void testGetAConversation_Future() { - ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null, null); + ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null, null, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(response); From 2e40ed6c6ef990fae69ec4584b9f683f71ffd4ad Mon Sep 17 00:00:00 2001 From: Rithin Pullela Date: Tue, 24 Dec 2024 12:26:04 -0800 Subject: [PATCH 4/5] Enhance Message and Memory API Validation and storage (#3283) * Enchance Message and Memory API Validation and storage Throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Skip saving empty fields in interactions and conversations to optimize storage usage. Modify GET requests for interactions and conversations to return only non-null fields. Throw an exception if all fields in a create interaction call are empty or null. Add unit tests to cover the above cases. Signed-off-by: rithin-pullela-aws * Update unit test to check for null instead of empty map Signed-off-by: rithin-pullela-aws * Refactored userstr to Camel Case Signed-off-by: rithin-pullela-aws * Addressing comments Used assertThrows and added promptTemplate with empty string in test_ToXContent to ensure well rounded testing of expected functionality Signed-off-by: rithin-pullela-aws * Undo: throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Signed-off-by: rithin-pullela-aws --------- Signed-off-by: rithin-pullela-aws --- .../ml/common/conversation/Interaction.java | 16 +++- .../common/conversation/InteractionTests.java | 4 +- .../CreateConversationRequest.java | 26 ++++-- .../CreateInteractionRequest.java | 10 +++ .../memory/index/ConversationMetaIndex.java | 67 +++++++-------- .../ml/memory/index/InteractionsIndex.java | 81 ++++++++++--------- .../CreateConversationRequestTests.java | 1 + .../CreateInteractionRequestTests.java | 28 +++++++ .../index/ConversationMetaIndexITTests.java | 4 +- .../index/ConversationMetaIndexTests.java | 4 +- .../memory/index/InteractionsIndexTests.java | 4 +- .../ml/engine/memory/MLMemoryManager.java | 4 +- .../engine/memory/MLMemoryManagerTests.java | 4 +- 13 files changed, 160 insertions(+), 93 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 93afbb52a3..2bffc21b01 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId); builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime); - builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); - builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); - builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); - builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if (input != null && !input.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 9ef58dd394..480998a0e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -122,15 +122,17 @@ public void test_ToXContent() throws IOException { .builder() .conversationId("conversation id") .origin("amazon bedrock") + .promptTemplate(" ") .parentInteractionId("parant id") .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .response("sample response") .traceNum(1) .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); assertEquals( - "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", + "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent ); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 991ddde2a7..e64658054c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -137,12 +137,28 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } try (XContentParser parser = restRequest.contentParser()) { Map body = parser.map(); + String name = null; + String applicationType = null; + Map additionalInfo = null; + + for (String key : body.keySet()) { + switch (key) { + case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD: + name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD); + break; + case APPLICATION_TYPE_FIELD: + applicationType = (String) body.get(APPLICATION_TYPE_FIELD); + break; + case META_ADDITIONAL_INFO_FIELD: + additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); + break; + default: + parser.skipChildren(); + break; + } + } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { - return new CreateConversationRequest( - (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), - body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) - ); + return new CreateConversationRequest(name, applicationType, additionalInfo); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index fe4a05bc0c..3927312d9c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -171,6 +171,16 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro } } + boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) + && (prompt == null || prompt.trim().isEmpty()) + && (response == null || response.trim().isEmpty()) + && (origin == null || origin.trim().isEmpty()) + && (addinf == null || addinf.isEmpty()); + if (allFieldsEmpty) { + throw new IllegalArgumentException( + "At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info" + ); + } return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 5e128e4d6f..8b18092a2a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -139,24 +140,24 @@ public void createConversation( ) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); Instant now = Instant.now(); - IndexRequest request = Requests - .indexRequest(META_INDEX_NAME) - .source( - ConversationalIndexConstants.META_CREATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_UPDATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_NAME_FIELD, - name, - ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName(), - ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType, - ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, - additionalInfos == null ? Map.of() : additionalInfos - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now); + sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now); + if (name != null && !name.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name); + } + if (userStr != null && !userStr.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName()); + } + if (applicationType != null && !applicationType.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType); + } + if (additionalInfos != null && !additionalInfos.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos); + } + IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { @@ -210,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener li return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); - String userstr = getUserStrFromThreadContext(); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = getUserStrFromThreadContext(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { @@ -308,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onResponse(true); return; } - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -318,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener listener) throw new ResourceNotFoundException("Memory [" + conversationId + "] not found"); } // If security is off - User doesn't exist - you have permission - if (userstr == null || User.parse(userstr) == null) { + if (userStr == null || User.parse(userStr) == null) { internalListener.onResponse(true); return; } ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); - String user = User.parse(userstr).getName(); + String user = User.parse(userStr).getName(); // If you're not the owner of this conversation, you do not have permission if (!user.equals(conversation.getUser())) { internalListener.onResponse(false); @@ -353,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -432,12 +433,12 @@ public void getConversation(String conversationId, ActionListener listener) { * @param origin the origin of the response for this interaction * @param additionalInfo additional information used for constructing the LLM prompt * @param timestamp when this interaction happened - * @param parintid the parent interactionId of this interaction + * @param parentId the parent interactionId of this interaction * @param traceNumber the trace number for a parent interaction * @param listener gets the id of the newly created interaction record */ @@ -149,40 +150,40 @@ public void createInteraction( Map additionalInfo, Instant timestamp, ActionListener listener, - String parintid, + String parentId, Integer traceNumber ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { - IndexRequest request = Requests - .indexRequest(INTERACTIONS_INDEX_NAME) - .source( - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - origin, - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - conversationId, - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - input, - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - promptTemplate, - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - response, - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - additionalInfo, - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp, - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - parintid, - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - traceNumber - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); + + if (input != null && !input.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } + if (additionalInfo != null && !additionalInfo.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } + IndexRequest request = Requests.indexRequest(INTERACTIONS_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { @@ -272,11 +273,11 @@ public void getInteractions(String conversationId, int from, int maxResults, Act if (access) { innerGetInteractions(conversationId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -361,13 +362,13 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList if (access) { innerGetTraces(interactionId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -482,8 +483,8 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { @@ -549,11 +550,11 @@ public void searchInteractions(String conversationId, SearchRequest request, Act listener.onFailure(e); } } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -628,13 +629,13 @@ public void updateInteraction(String interactionId, UpdateRequest updateRequest, if (access) { innerUpdateInteraction(updateRequest, internalListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -671,11 +672,11 @@ private void checkInteractionPermission(String interactionId, Interaction intera internalListener.onResponse(interaction); log.info("Successfully get the message : {}", interactionId); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 0f2dd2b5ce..28b529d360 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -132,4 +132,5 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 8068a85dfb..5f274fec82 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -153,4 +154,31 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getParentIid().equals("parentId")); assert (request.getTraceNumber().equals(1)); } + + public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { + Map params = new HashMap<>(); + + params.put(ActionConstants.INPUT_FIELD, ""); + params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null); + params.put(ActionConstants.AI_RESPONSE_FIELD, " "); + params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null); + params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap()); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + IllegalArgumentException exception = assertThrows( + "Expected IllegalArgumentException due to all fields empty", + IllegalArgumentException.class, + () -> CreateInteractionRequest.fromRestRequest(rrequest) + ); + + assertEquals( + exception.getMessage(), + "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" + ); + + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 5baefa358d..8f6057e57e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -564,8 +564,8 @@ public void testCanGetAConversationById() { assert (cid2.result().equals(get2.result().getId())); assert (get1.result().getName().equals("convo1")); assert (get2.result().getName().equals("convo2")); - Assert.assertTrue(convo2.getAdditionalInfos().isEmpty()); - Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty()); + Assert.assertTrue(convo2.getAdditionalInfos() == null); + Assert.assertTrue(get1.result().getAdditionalInfos() == null); cdl.countDown(); }, e -> { cdl.countDown(); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index a27653a4e2..ccb7fd112f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -134,10 +134,10 @@ private void blanketGrantAccess() { } private void setupUser(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 042a4a3a91..f18aec2e33 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -156,7 +156,7 @@ private void setupGrantAccess() { } private void setupDenyAccess(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -164,7 +164,7 @@ private void setupDenyAccess(String user) { }).when(conversationMetaIndex).checkAccess(anyString(), any()); doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index ec7a805c9e..c084b47eec 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -150,11 +150,11 @@ public void getFinalInteractions(String conversationId, int lastNInteraction, Ac if (access) { innerGetFinalInteractions(conversationId, lastNInteraction, actionListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? "" : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? "" : User.parse(userStr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } }, e -> { actionListener.onFailure(e); }); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index 68355d9a68..a3b7bd76fa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -249,7 +249,7 @@ public void testGetInteractions_SearchFails_ThenFail() { @Test public void testGetInteractions_NoAccessNoUser_ThenFail() { doReturn(true).when(metadata).hasIndex(anyString()); - String userstr = ""; + String userStr = ""; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -258,7 +258,7 @@ public void testGetInteractions_NoAccessNoUser_ThenFail() { doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); From 3f47a1f8caf4d2d3e2163380ca67f336a6b68952 Mon Sep 17 00:00:00 2001 From: tkykenmt Date: Wed, 25 Dec 2024 15:48:06 +0900 Subject: [PATCH 5/5] fixup! update error handling to throw exception when post processing function recieve empty result from a model. Signed-off-by: tkykenmt --- ...erank-m3-v2_model_deployed_on_Sagemaker.md | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md b/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md index 278a96e80d..3a9a62c9d9 100644 --- a/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md +++ b/docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md @@ -154,20 +154,25 @@ POST /_plugins/_ml/connectors/_create def text_docs = params.text_docs; def textDocsBuilder = new StringBuilder('['); for (int i=0; i 0) { + if (params.result == null || params.result.length == 0) { throw new IllegalArgumentException("Post process function input is empty."); } def outputs = params.result; @@ -178,8 +183,8 @@ POST /_plugins/_ml/connectors/_create } def resultBuilder = new StringBuilder('['); for (int i=0; i 0) { + if (params.result == null || params.result.length == 0) { throw new IllegalArgumentException("Post process function input is empty."); } def outputs = params.result; @@ -246,8 +256,8 @@ POST /_plugins/_ml/connectors/_create } def resultBuilder = new StringBuilder('['); for (int i=0; i