From 2e40ed6c6ef990fae69ec4584b9f683f71ffd4ad Mon Sep 17 00:00:00 2001 From: Rithin Pullela Date: Tue, 24 Dec 2024 12:26:04 -0800 Subject: [PATCH] Enhance Message and Memory API Validation and storage (#3283) * Enchance Message and Memory API Validation and storage Throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Skip saving empty fields in interactions and conversations to optimize storage usage. Modify GET requests for interactions and conversations to return only non-null fields. Throw an exception if all fields in a create interaction call are empty or null. Add unit tests to cover the above cases. Signed-off-by: rithin-pullela-aws * Update unit test to check for null instead of empty map Signed-off-by: rithin-pullela-aws * Refactored userstr to Camel Case Signed-off-by: rithin-pullela-aws * Addressing comments Used assertThrows and added promptTemplate with empty string in test_ToXContent to ensure well rounded testing of expected functionality Signed-off-by: rithin-pullela-aws * Undo: throw an error when an unknown field is provided in CreateConversation or CreateInteraction. Signed-off-by: rithin-pullela-aws --------- Signed-off-by: rithin-pullela-aws --- .../ml/common/conversation/Interaction.java | 16 +++- .../common/conversation/InteractionTests.java | 4 +- .../CreateConversationRequest.java | 26 ++++-- .../CreateInteractionRequest.java | 10 +++ .../memory/index/ConversationMetaIndex.java | 67 +++++++-------- .../ml/memory/index/InteractionsIndex.java | 81 ++++++++++--------- .../CreateConversationRequestTests.java | 1 + .../CreateInteractionRequestTests.java | 28 +++++++ .../index/ConversationMetaIndexITTests.java | 4 +- .../index/ConversationMetaIndexTests.java | 4 +- .../memory/index/InteractionsIndexTests.java | 4 +- .../ml/engine/memory/MLMemoryManager.java | 4 +- .../engine/memory/MLMemoryManagerTests.java | 4 +- 13 files changed, 160 insertions(+), 93 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 93afbb52a3..2bffc21b01 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId); builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime); - builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); - builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); - builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); - builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if (input != null && !input.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 9ef58dd394..480998a0e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -122,15 +122,17 @@ public void test_ToXContent() throws IOException { .builder() .conversationId("conversation id") .origin("amazon bedrock") + .promptTemplate(" ") .parentInteractionId("parant id") .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .response("sample response") .traceNum(1) .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); assertEquals( - "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", + "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent ); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 991ddde2a7..e64658054c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -137,12 +137,28 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } try (XContentParser parser = restRequest.contentParser()) { Map body = parser.map(); + String name = null; + String applicationType = null; + Map additionalInfo = null; + + for (String key : body.keySet()) { + switch (key) { + case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD: + name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD); + break; + case APPLICATION_TYPE_FIELD: + applicationType = (String) body.get(APPLICATION_TYPE_FIELD); + break; + case META_ADDITIONAL_INFO_FIELD: + additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); + break; + default: + parser.skipChildren(); + break; + } + } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { - return new CreateConversationRequest( - (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), - body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) - ); + return new CreateConversationRequest(name, applicationType, additionalInfo); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index fe4a05bc0c..3927312d9c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -171,6 +171,16 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro } } + boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) + && (prompt == null || prompt.trim().isEmpty()) + && (response == null || response.trim().isEmpty()) + && (origin == null || origin.trim().isEmpty()) + && (addinf == null || addinf.isEmpty()); + if (allFieldsEmpty) { + throw new IllegalArgumentException( + "At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info" + ); + } return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 5e128e4d6f..8b18092a2a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -139,24 +140,24 @@ public void createConversation( ) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); Instant now = Instant.now(); - IndexRequest request = Requests - .indexRequest(META_INDEX_NAME) - .source( - ConversationalIndexConstants.META_CREATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_UPDATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_NAME_FIELD, - name, - ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName(), - ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType, - ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, - additionalInfos == null ? Map.of() : additionalInfos - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now); + sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now); + if (name != null && !name.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name); + } + if (userStr != null && !userStr.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName()); + } + if (applicationType != null && !applicationType.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType); + } + if (additionalInfos != null && !additionalInfos.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos); + } + IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { @@ -210,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener li return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); - String userstr = getUserStrFromThreadContext(); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = getUserStrFromThreadContext(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { @@ -308,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onResponse(true); return; } - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -318,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener listener) throw new ResourceNotFoundException("Memory [" + conversationId + "] not found"); } // If security is off - User doesn't exist - you have permission - if (userstr == null || User.parse(userstr) == null) { + if (userStr == null || User.parse(userStr) == null) { internalListener.onResponse(true); return; } ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); - String user = User.parse(userstr).getName(); + String user = User.parse(userStr).getName(); // If you're not the owner of this conversation, you do not have permission if (!user.equals(conversation.getUser())) { internalListener.onResponse(false); @@ -353,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -432,12 +433,12 @@ public void getConversation(String conversationId, ActionListener listener) { * @param origin the origin of the response for this interaction * @param additionalInfo additional information used for constructing the LLM prompt * @param timestamp when this interaction happened - * @param parintid the parent interactionId of this interaction + * @param parentId the parent interactionId of this interaction * @param traceNumber the trace number for a parent interaction * @param listener gets the id of the newly created interaction record */ @@ -149,40 +150,40 @@ public void createInteraction( Map additionalInfo, Instant timestamp, ActionListener listener, - String parintid, + String parentId, Integer traceNumber ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { - IndexRequest request = Requests - .indexRequest(INTERACTIONS_INDEX_NAME) - .source( - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - origin, - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - conversationId, - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - input, - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - promptTemplate, - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - response, - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - additionalInfo, - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp, - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - parintid, - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - traceNumber - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); + + if (input != null && !input.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } + if (additionalInfo != null && !additionalInfo.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } + IndexRequest request = Requests.indexRequest(INTERACTIONS_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { @@ -272,11 +273,11 @@ public void getInteractions(String conversationId, int from, int maxResults, Act if (access) { innerGetInteractions(conversationId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -361,13 +362,13 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList if (access) { innerGetTraces(interactionId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -482,8 +483,8 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { @@ -549,11 +550,11 @@ public void searchInteractions(String conversationId, SearchRequest request, Act listener.onFailure(e); } } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -628,13 +629,13 @@ public void updateInteraction(String interactionId, UpdateRequest updateRequest, if (access) { innerUpdateInteraction(updateRequest, internalListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -671,11 +672,11 @@ private void checkInteractionPermission(String interactionId, Interaction intera internalListener.onResponse(interaction); log.info("Successfully get the message : {}", interactionId); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 0f2dd2b5ce..28b529d360 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -132,4 +132,5 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 8068a85dfb..5f274fec82 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -153,4 +154,31 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getParentIid().equals("parentId")); assert (request.getTraceNumber().equals(1)); } + + public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { + Map params = new HashMap<>(); + + params.put(ActionConstants.INPUT_FIELD, ""); + params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null); + params.put(ActionConstants.AI_RESPONSE_FIELD, " "); + params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null); + params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap()); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + IllegalArgumentException exception = assertThrows( + "Expected IllegalArgumentException due to all fields empty", + IllegalArgumentException.class, + () -> CreateInteractionRequest.fromRestRequest(rrequest) + ); + + assertEquals( + exception.getMessage(), + "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" + ); + + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 5baefa358d..8f6057e57e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -564,8 +564,8 @@ public void testCanGetAConversationById() { assert (cid2.result().equals(get2.result().getId())); assert (get1.result().getName().equals("convo1")); assert (get2.result().getName().equals("convo2")); - Assert.assertTrue(convo2.getAdditionalInfos().isEmpty()); - Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty()); + Assert.assertTrue(convo2.getAdditionalInfos() == null); + Assert.assertTrue(get1.result().getAdditionalInfos() == null); cdl.countDown(); }, e -> { cdl.countDown(); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index a27653a4e2..ccb7fd112f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -134,10 +134,10 @@ private void blanketGrantAccess() { } private void setupUser(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 042a4a3a91..f18aec2e33 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -156,7 +156,7 @@ private void setupGrantAccess() { } private void setupDenyAccess(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -164,7 +164,7 @@ private void setupDenyAccess(String user) { }).when(conversationMetaIndex).checkAccess(anyString(), any()); doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index ec7a805c9e..c084b47eec 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -150,11 +150,11 @@ public void getFinalInteractions(String conversationId, int lastNInteraction, Ac if (access) { innerGetFinalInteractions(conversationId, lastNInteraction, actionListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? "" : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? "" : User.parse(userStr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } }, e -> { actionListener.onFailure(e); }); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index 68355d9a68..a3b7bd76fa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -249,7 +249,7 @@ public void testGetInteractions_SearchFails_ThenFail() { @Test public void testGetInteractions_NoAccessNoUser_ThenFail() { doReturn(true).when(metadata).hasIndex(anyString()); - String userstr = ""; + String userStr = ""; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -258,7 +258,7 @@ public void testGetInteractions_NoAccessNoUser_ThenFail() { doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener);