From c0ac08483cf95d99f6aa2831bfb0c344c259e6b8 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Mon, 16 Dec 2024 00:27:41 -0800 Subject: [PATCH] Enchance Message and Memory API Validation and storage Throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Skip saving empty fields in interactions and conversations to optimize storage usage. Modify GET requests for interactions and conversations to return only non-null fields. Throw an exception if all fields in a create interaction call are empty or null. Add unit tests to cover the above cases. Signed-off-by: rithin-pullela-aws --- .../ml/common/conversation/Interaction.java | 16 +++-- .../common/conversation/InteractionTests.java | 3 +- .../CreateConversationRequest.java | 25 ++++++-- .../CreateInteractionRequest.java | 13 +++- .../memory/index/ConversationMetaIndex.java | 33 +++++----- .../ml/memory/index/InteractionsIndex.java | 45 +++++++------- .../CreateConversationRequestTests.java | 21 +++++++ .../CreateInteractionRequestTests.java | 60 +++++++++++++++++++ 8 files changed, 166 insertions(+), 50 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 93afbb52a3..2bffc21b01 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId); builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime); - builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); - builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); - builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); - builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if (input != null && !input.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 9ef58dd394..2e4dfd5259 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -124,13 +124,14 @@ public void test_ToXContent() throws IOException { .origin("amazon bedrock") .parentInteractionId("parant id") .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .response("sample response") .traceNum(1) .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); assertEquals( - "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", + "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent ); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 991ddde2a7..c4cc2f7448 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -137,12 +137,27 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } try (XContentParser parser = restRequest.contentParser()) { Map body = parser.map(); + String name = null; + String applicationType = null; + Map additionalInfo = null; + + for (String key : body.keySet()) { + switch (key) { + case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD: + name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD); + break; + case APPLICATION_TYPE_FIELD: + applicationType = (String) body.get(APPLICATION_TYPE_FIELD); + break; + case META_ADDITIONAL_INFO_FIELD: + additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); + break; + default: + throw new IllegalArgumentException("Invalid field [" + key + "] found in request body"); + } + } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { - return new CreateConversationRequest( - (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), - body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) - ); + return new CreateConversationRequest(name, applicationType, additionalInfo); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index fe4a05bc0c..69c9aa9b44 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -166,11 +166,20 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro tracenum = parser.intValue(false); break; default: - parser.skipChildren(); - break; + throw new IllegalArgumentException("Invalid field [" + fieldName + "] found in request body"); } } + boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) + && (prompt == null || prompt.trim().isEmpty()) + && (response == null || response.trim().isEmpty()) + && (origin == null || origin.trim().isEmpty()) + && (addinf == null || addinf.isEmpty()); + if (allFieldsEmpty) { + throw new IllegalArgumentException( + "At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info" + ); + } return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 5e128e4d6f..bffbcbbef9 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -141,22 +142,22 @@ public void createConversation( if (indexExists) { String userstr = getUserStrFromThreadContext(); Instant now = Instant.now(); - IndexRequest request = Requests - .indexRequest(META_INDEX_NAME) - .source( - ConversationalIndexConstants.META_CREATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_UPDATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_NAME_FIELD, - name, - ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName(), - ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType, - ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, - additionalInfos == null ? Map.of() : additionalInfos - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now); + sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now); + if (name != null && !name.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name); + } + if (userstr != null && !userstr.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName()); + } + if (applicationType != null && !applicationType.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType); + } + if (additionalInfos != null && !additionalInfos.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos); + } + IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index e8cf4bd7ae..b6fceb0cf1 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -161,28 +162,28 @@ public void createInteraction( if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { - IndexRequest request = Requests - .indexRequest(INTERACTIONS_INDEX_NAME) - .source( - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - origin, - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - conversationId, - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - input, - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - promptTemplate, - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - response, - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - additionalInfo, - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp, - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - parintid, - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - traceNumber - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parintid); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); + + if (input != null && !input.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } + if (additionalInfo != null && !additionalInfo.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } + IndexRequest request = Requests.indexRequest(INTERACTIONS_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 0f2dd2b5ce..1cade08b7a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -132,4 +133,24 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } + + public void testRestRequest_WithUnknownFields_Fails() throws IOException { + String name = "test-name"; + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, "unknown_field", "some value"))), + MediaTypeRegistry.JSON + ) + .build(); + + try { + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + fail("Expected IllegalArgumentException due to unknown field"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); + } catch (Exception e) { + fail("Expected OpenSearchParseException due to unknown field, got " + e.getClass().getName()); + } + + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 8068a85dfb..828d7edc53 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -153,4 +154,63 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getParentIid().equals("parentId")); assert (request.getTraceNumber().equals(1)); } + + public void testRestRequest_WithUnknownFields_Fails() throws IOException { + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.ADDITIONAL_INFO_FIELD, + Collections.singletonMap("metadata", "some meta"), + "unknown_field", + "some value" + ); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + try { + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + fail("Expected IllegalArgumentException due to unknown field"); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); + } catch (Exception e) { + fail("Expected IllegalArgumentException due to unknown field, got " + e.getClass().getName()); + } + } + + public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { + Map params = new HashMap<>(); + + params.put(ActionConstants.INPUT_FIELD, ""); + params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null); + params.put(ActionConstants.AI_RESPONSE_FIELD, " "); + params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null); + params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap()); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + try { + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + fail("Expected IllegalArgumentException due to all fields empty"); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" + ); + } catch (Exception e) { + fail("Expected IllegalArgumentException due to all fields empty, got " + e.getClass().getName()); + } + } }