From c0ac08483cf95d99f6aa2831bfb0c344c259e6b8 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Mon, 16 Dec 2024 00:27:41 -0800 Subject: [PATCH 1/5] Enchance Message and Memory API Validation and storage Throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Skip saving empty fields in interactions and conversations to optimize storage usage. Modify GET requests for interactions and conversations to return only non-null fields. Throw an exception if all fields in a create interaction call are empty or null. Add unit tests to cover the above cases. Signed-off-by: rithin-pullela-aws --- .../ml/common/conversation/Interaction.java | 16 +++-- .../common/conversation/InteractionTests.java | 3 +- .../CreateConversationRequest.java | 25 ++++++-- .../CreateInteractionRequest.java | 13 +++- .../memory/index/ConversationMetaIndex.java | 33 +++++----- .../ml/memory/index/InteractionsIndex.java | 45 +++++++------- .../CreateConversationRequestTests.java | 21 +++++++ .../CreateInteractionRequestTests.java | 60 +++++++++++++++++++ 8 files changed, 166 insertions(+), 50 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 93afbb52a3..2bffc21b01 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId); builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime); - builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); - builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); - builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); - builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if (input != null && !input.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 9ef58dd394..2e4dfd5259 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -124,13 +124,14 @@ public void test_ToXContent() throws IOException { .origin("amazon bedrock") .parentInteractionId("parant id") .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .response("sample response") .traceNum(1) .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); assertEquals( - "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", + "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent ); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 991ddde2a7..c4cc2f7448 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -137,12 +137,27 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } try (XContentParser parser = restRequest.contentParser()) { Map body = parser.map(); + String name = null; + String applicationType = null; + Map additionalInfo = null; + + for (String key : body.keySet()) { + switch (key) { + case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD: + name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD); + break; + case APPLICATION_TYPE_FIELD: + applicationType = (String) body.get(APPLICATION_TYPE_FIELD); + break; + case META_ADDITIONAL_INFO_FIELD: + additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); + break; + default: + throw new IllegalArgumentException("Invalid field [" + key + "] found in request body"); + } + } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { - return new CreateConversationRequest( - (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), - body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) - ); + return new CreateConversationRequest(name, applicationType, additionalInfo); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index fe4a05bc0c..69c9aa9b44 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -166,11 +166,20 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro tracenum = parser.intValue(false); break; default: - parser.skipChildren(); - break; + throw new IllegalArgumentException("Invalid field [" + fieldName + "] found in request body"); } } + boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) + && (prompt == null || prompt.trim().isEmpty()) + && (response == null || response.trim().isEmpty()) + && (origin == null || origin.trim().isEmpty()) + && (addinf == null || addinf.isEmpty()); + if (allFieldsEmpty) { + throw new IllegalArgumentException( + "At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info" + ); + } return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 5e128e4d6f..bffbcbbef9 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -141,22 +142,22 @@ public void createConversation( if (indexExists) { String userstr = getUserStrFromThreadContext(); Instant now = Instant.now(); - IndexRequest request = Requests - .indexRequest(META_INDEX_NAME) - .source( - ConversationalIndexConstants.META_CREATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_UPDATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_NAME_FIELD, - name, - ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName(), - ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType, - ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, - additionalInfos == null ? Map.of() : additionalInfos - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now); + sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now); + if (name != null && !name.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name); + } + if (userstr != null && !userstr.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName()); + } + if (applicationType != null && !applicationType.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType); + } + if (additionalInfos != null && !additionalInfos.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos); + } + IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index e8cf4bd7ae..b6fceb0cf1 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -161,28 +162,28 @@ public void createInteraction( if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { - IndexRequest request = Requests - .indexRequest(INTERACTIONS_INDEX_NAME) - .source( - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - origin, - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - conversationId, - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - input, - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - promptTemplate, - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - response, - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - additionalInfo, - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp, - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - parintid, - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - traceNumber - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parintid); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); + + if (input != null && !input.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } + if (additionalInfo != null && !additionalInfo.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } + IndexRequest request = Requests.indexRequest(INTERACTIONS_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 0f2dd2b5ce..1cade08b7a 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -132,4 +133,24 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } + + public void testRestRequest_WithUnknownFields_Fails() throws IOException { + String name = "test-name"; + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, "unknown_field", "some value"))), + MediaTypeRegistry.JSON + ) + .build(); + + try { + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + fail("Expected IllegalArgumentException due to unknown field"); + } catch (OpenSearchParseException e) { + assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); + } catch (Exception e) { + fail("Expected OpenSearchParseException due to unknown field, got " + e.getClass().getName()); + } + + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 8068a85dfb..828d7edc53 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -153,4 +154,63 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getParentIid().equals("parentId")); assert (request.getTraceNumber().equals(1)); } + + public void testRestRequest_WithUnknownFields_Fails() throws IOException { + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.ADDITIONAL_INFO_FIELD, + Collections.singletonMap("metadata", "some meta"), + "unknown_field", + "some value" + ); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + try { + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + fail("Expected IllegalArgumentException due to unknown field"); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); + } catch (Exception e) { + fail("Expected IllegalArgumentException due to unknown field, got " + e.getClass().getName()); + } + } + + public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { + Map params = new HashMap<>(); + + params.put(ActionConstants.INPUT_FIELD, ""); + params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null); + params.put(ActionConstants.AI_RESPONSE_FIELD, " "); + params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null); + params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap()); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + try { + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + fail("Expected IllegalArgumentException due to all fields empty"); + } catch (IllegalArgumentException e) { + assertEquals( + e.getMessage(), + "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" + ); + } catch (Exception e) { + fail("Expected IllegalArgumentException due to all fields empty, got " + e.getClass().getName()); + } + } } From c8b314ce1a5f261a8108ae811c759c51677518c6 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Mon, 16 Dec 2024 15:57:45 -0800 Subject: [PATCH 2/5] Update unit test to check for null instead of empty map Signed-off-by: rithin-pullela-aws --- .../ml/memory/index/ConversationMetaIndexITTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 5baefa358d..8f6057e57e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -564,8 +564,8 @@ public void testCanGetAConversationById() { assert (cid2.result().equals(get2.result().getId())); assert (get1.result().getName().equals("convo1")); assert (get2.result().getName().equals("convo2")); - Assert.assertTrue(convo2.getAdditionalInfos().isEmpty()); - Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty()); + Assert.assertTrue(convo2.getAdditionalInfos() == null); + Assert.assertTrue(get1.result().getAdditionalInfos() == null); cdl.countDown(); }, e -> { cdl.countDown(); From 4e91ea662f0497eb4f3cd67815cf76c013e25bb2 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Tue, 17 Dec 2024 16:04:23 -0800 Subject: [PATCH 3/5] Refactored userstr to Camel Case Signed-off-by: rithin-pullela-aws --- .../memory/index/ConversationMetaIndex.java | 38 +++++++++---------- .../ml/memory/index/InteractionsIndex.java | 32 ++++++++-------- .../index/ConversationMetaIndexTests.java | 4 +- .../memory/index/InteractionsIndexTests.java | 4 +- .../ml/engine/memory/MLMemoryManager.java | 4 +- .../engine/memory/MLMemoryManagerTests.java | 4 +- 6 files changed, 43 insertions(+), 43 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index bffbcbbef9..8b18092a2a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -140,7 +140,7 @@ public void createConversation( ) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); Instant now = Instant.now(); Map sourceMap = new HashMap<>(); sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now); @@ -148,8 +148,8 @@ public void createConversation( if (name != null && !name.trim().isEmpty()) { sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name); } - if (userstr != null && !userstr.trim().isEmpty()) { - sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName()); + if (userStr != null && !userStr.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName()); } if (applicationType != null && !applicationType.trim().isEmpty()) { sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType); @@ -211,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener li return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); - String userstr = getUserStrFromThreadContext(); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = getUserStrFromThreadContext(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { @@ -309,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onResponse(true); return; } - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -319,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener listener) throw new ResourceNotFoundException("Memory [" + conversationId + "] not found"); } // If security is off - User doesn't exist - you have permission - if (userstr == null || User.parse(userstr) == null) { + if (userStr == null || User.parse(userStr) == null) { internalListener.onResponse(true); return; } ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); - String user = User.parse(userstr).getName(); + String user = User.parse(userStr).getName(); // If you're not the owner of this conversation, you do not have permission if (!user.equals(conversation.getUser())) { internalListener.onResponse(false); @@ -354,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -433,12 +433,12 @@ public void getConversation(String conversationId, ActionListener { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { @@ -273,11 +273,11 @@ public void getInteractions(String conversationId, int from, int maxResults, Act if (access) { innerGetInteractions(conversationId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -362,13 +362,13 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList if (access) { innerGetTraces(interactionId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -483,8 +483,8 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { @@ -550,11 +550,11 @@ public void searchInteractions(String conversationId, SearchRequest request, Act listener.onFailure(e); } } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -629,13 +629,13 @@ public void updateInteraction(String interactionId, UpdateRequest updateRequest, if (access) { innerUpdateInteraction(updateRequest, internalListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -672,11 +672,11 @@ private void checkInteractionPermission(String interactionId, Interaction intera internalListener.onResponse(interaction); log.info("Successfully get the message : {}", interactionId); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index a27653a4e2..ccb7fd112f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -134,10 +134,10 @@ private void blanketGrantAccess() { } private void setupUser(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 042a4a3a91..f18aec2e33 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -156,7 +156,7 @@ private void setupGrantAccess() { } private void setupDenyAccess(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -164,7 +164,7 @@ private void setupDenyAccess(String user) { }).when(conversationMetaIndex).checkAccess(anyString(), any()); doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index ec7a805c9e..c084b47eec 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -150,11 +150,11 @@ public void getFinalInteractions(String conversationId, int lastNInteraction, Ac if (access) { innerGetFinalInteractions(conversationId, lastNInteraction, actionListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? "" : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? "" : User.parse(userStr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } }, e -> { actionListener.onFailure(e); }); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index 68355d9a68..a3b7bd76fa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -249,7 +249,7 @@ public void testGetInteractions_SearchFails_ThenFail() { @Test public void testGetInteractions_NoAccessNoUser_ThenFail() { doReturn(true).when(metadata).hasIndex(anyString()); - String userstr = ""; + String userStr = ""; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -258,7 +258,7 @@ public void testGetInteractions_NoAccessNoUser_ThenFail() { doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener); From 41325c4499887a191a1a7dc6c733a5cc20bc766f Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Tue, 17 Dec 2024 18:48:33 -0800 Subject: [PATCH 4/5] Addressing comments Used assertThrows and added promptTemplate with empty string in test_ToXContent to ensure well rounded testing of expected functionality Signed-off-by: rithin-pullela-aws --- .../common/conversation/InteractionTests.java | 1 + .../CreateConversationRequestTests.java | 14 +++---- .../CreateInteractionRequestTests.java | 37 +++++++++---------- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 2e4dfd5259..480998a0e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -122,6 +122,7 @@ public void test_ToXContent() throws IOException { .builder() .conversationId("conversation id") .origin("amazon bedrock") + .promptTemplate(" ") .parentInteractionId("parant id") .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) .response("sample response") diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 1cade08b7a..08ad91d277 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -143,14 +143,12 @@ public void testRestRequest_WithUnknownFields_Fails() throws IOException { ) .build(); - try { - CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); - fail("Expected IllegalArgumentException due to unknown field"); - } catch (OpenSearchParseException e) { - assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); - } catch (Exception e) { - fail("Expected OpenSearchParseException due to unknown field, got " + e.getClass().getName()); - } + OpenSearchParseException exception = assertThrows( + "Expected OpenSearchParseException due to unknown field", + OpenSearchParseException.class, + () -> CreateConversationRequest.fromRestRequest(req) + ); + assertEquals(exception.getMessage(), "Invalid field [unknown_field] found in request body"); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 828d7edc53..3acb6b6182 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -177,14 +177,13 @@ public void testRestRequest_WithUnknownFields_Fails() throws IOException { .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) .build(); - try { - CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); - fail("Expected IllegalArgumentException due to unknown field"); - } catch (IllegalArgumentException e) { - assertEquals(e.getMessage(), "Invalid field [unknown_field] found in request body"); - } catch (Exception e) { - fail("Expected IllegalArgumentException due to unknown field, got " + e.getClass().getName()); - } + IllegalArgumentException exception = assertThrows( + "Expected IllegalArgumentException due to unknown field", + IllegalArgumentException.class, + () -> CreateInteractionRequest.fromRestRequest(rrequest) + ); + + assertEquals(exception.getMessage(), "Invalid field [unknown_field] found in request body"); } public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { @@ -201,16 +200,16 @@ public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) .build(); - try { - CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); - fail("Expected IllegalArgumentException due to all fields empty"); - } catch (IllegalArgumentException e) { - assertEquals( - e.getMessage(), - "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" - ); - } catch (Exception e) { - fail("Expected IllegalArgumentException due to all fields empty, got " + e.getClass().getName()); - } + IllegalArgumentException exception = assertThrows( + "Expected IllegalArgumentException due to all fields empty", + IllegalArgumentException.class, + () -> CreateInteractionRequest.fromRestRequest(rrequest) + ); + + assertEquals( + exception.getMessage(), + "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" + ); + } } From e577f6140030a2f891c56d086744c7b9a10a0284 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Thu, 19 Dec 2024 14:06:49 -0800 Subject: [PATCH 5/5] Undo: throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Signed-off-by: rithin-pullela-aws --- .../CreateConversationRequest.java | 3 +- .../CreateInteractionRequest.java | 3 +- .../ml/memory/index/InteractionsIndex.java | 6 ++-- .../CreateConversationRequestTests.java | 18 ----------- .../CreateInteractionRequestTests.java | 31 ------------------- 5 files changed, 7 insertions(+), 54 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index c4cc2f7448..e64658054c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -153,7 +153,8 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); break; default: - throw new IllegalArgumentException("Invalid field [" + key + "] found in request body"); + parser.skipChildren(); + break; } } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index 69c9aa9b44..3927312d9c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -166,7 +166,8 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro tracenum = parser.intValue(false); break; default: - throw new IllegalArgumentException("Invalid field [" + fieldName + "] found in request body"); + parser.skipChildren(); + break; } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index e4310e820d..2c958f152f 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -137,7 +137,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { * @param origin the origin of the response for this interaction * @param additionalInfo additional information used for constructing the LLM prompt * @param timestamp when this interaction happened - * @param parintid the parent interactionId of this interaction + * @param parentId the parent interactionId of this interaction * @param traceNumber the trace number for a parent interaction * @param listener gets the id of the newly created interaction record */ @@ -150,7 +150,7 @@ public void createInteraction( Map additionalInfo, Instant timestamp, ActionListener listener, - String parintid, + String parentId, Integer traceNumber ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { @@ -165,7 +165,7 @@ public void createInteraction( Map sourceMap = new HashMap<>(); sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); - sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parintid); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentId); sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); if (input != null && !input.trim().isEmpty()) { diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 08ad91d277..28b529d360 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -27,7 +27,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; -import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; @@ -134,21 +133,4 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } - public void testRestRequest_WithUnknownFields_Fails() throws IOException { - String name = "test-name"; - RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withContent( - new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name, "unknown_field", "some value"))), - MediaTypeRegistry.JSON - ) - .build(); - - OpenSearchParseException exception = assertThrows( - "Expected OpenSearchParseException due to unknown field", - OpenSearchParseException.class, - () -> CreateConversationRequest.fromRestRequest(req) - ); - - assertEquals(exception.getMessage(), "Invalid field [unknown_field] found in request body"); - } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 3acb6b6182..5f274fec82 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -155,37 +155,6 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getTraceNumber().equals(1)); } - public void testRestRequest_WithUnknownFields_Fails() throws IOException { - Map params = Map - .of( - ActionConstants.INPUT_FIELD, - "input", - ActionConstants.PROMPT_TEMPLATE_FIELD, - "pt", - ActionConstants.AI_RESPONSE_FIELD, - "response", - ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin", - ActionConstants.ADDITIONAL_INFO_FIELD, - Collections.singletonMap("metadata", "some meta"), - "unknown_field", - "some value" - ); - - RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) - .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) - .build(); - - IllegalArgumentException exception = assertThrows( - "Expected IllegalArgumentException due to unknown field", - IllegalArgumentException.class, - () -> CreateInteractionRequest.fromRestRequest(rrequest) - ); - - assertEquals(exception.getMessage(), "Invalid field [unknown_field] found in request body"); - } - public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { Map params = new HashMap<>();