diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 93afbb52a3..2bffc21b01 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId); builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime); - builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); - builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); - builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); - builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if (input != null && !input.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 9ef58dd394..480998a0e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -122,15 +122,17 @@ public void test_ToXContent() throws IOException { .builder() .conversationId("conversation id") .origin("amazon bedrock") + .promptTemplate(" ") .parentInteractionId("parant id") .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .response("sample response") .traceNum(1) .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); assertEquals( - "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", + "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent ); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index 991ddde2a7..e64658054c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -137,12 +137,28 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest) } try (XContentParser parser = restRequest.contentParser()) { Map body = parser.map(); + String name = null; + String applicationType = null; + Map additionalInfo = null; + + for (String key : body.keySet()) { + switch (key) { + case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD: + name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD); + break; + case APPLICATION_TYPE_FIELD: + applicationType = (String) body.get(APPLICATION_TYPE_FIELD); + break; + case META_ADDITIONAL_INFO_FIELD: + additionalInfo = (Map) body.get(META_ADDITIONAL_INFO_FIELD); + break; + default: + parser.skipChildren(); + break; + } + } if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) { - return new CreateConversationRequest( - (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD), - body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD), - body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map) body.get(META_ADDITIONAL_INFO_FIELD) - ); + return new CreateConversationRequest(name, applicationType, additionalInfo); } else { return new CreateConversationRequest(); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index fe4a05bc0c..3927312d9c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -171,6 +171,16 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro } } + boolean allFieldsEmpty = (input == null || input.trim().isEmpty()) + && (prompt == null || prompt.trim().isEmpty()) + && (response == null || response.trim().isEmpty()) + && (origin == null || origin.trim().isEmpty()) + && (addinf == null || addinf.isEmpty()); + if (allFieldsEmpty) { + throw new IllegalArgumentException( + "At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info" + ); + } return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 5e128e4d6f..8b18092a2a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -139,24 +140,24 @@ public void createConversation( ) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); Instant now = Instant.now(); - IndexRequest request = Requests - .indexRequest(META_INDEX_NAME) - .source( - ConversationalIndexConstants.META_CREATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_UPDATED_TIME_FIELD, - now, - ConversationalIndexConstants.META_NAME_FIELD, - name, - ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName(), - ConversationalIndexConstants.APPLICATION_TYPE_FIELD, - applicationType, - ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, - additionalInfos == null ? Map.of() : additionalInfos - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now); + sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now); + if (name != null && !name.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name); + } + if (userStr != null && !userStr.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName()); + } + if (applicationType != null && !applicationType.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType); + } + if (additionalInfos != null && !additionalInfos.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos); + } + IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { @@ -210,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener li return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); - String userstr = getUserStrFromThreadContext(); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = getUserStrFromThreadContext(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { @@ -308,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onResponse(true); return; } - String userstr = getUserStrFromThreadContext(); + String userStr = getUserStrFromThreadContext(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -318,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener listener) throw new ResourceNotFoundException("Memory [" + conversationId + "] not found"); } // If security is off - User doesn't exist - you have permission - if (userstr == null || User.parse(userstr) == null) { + if (userStr == null || User.parse(userStr) == null) { internalListener.onResponse(true); return; } ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); - String user = User.parse(userstr).getName(); + String user = User.parse(userStr).getName(); // If you're not the owner of this conversation, you do not have permission if (!user.equals(conversation.getUser())) { internalListener.onResponse(false); @@ -353,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); @@ -432,12 +433,12 @@ public void getConversation(String conversationId, ActionListener listener) { * @param origin the origin of the response for this interaction * @param additionalInfo additional information used for constructing the LLM prompt * @param timestamp when this interaction happened - * @param parintid the parent interactionId of this interaction + * @param parentId the parent interactionId of this interaction * @param traceNumber the trace number for a parent interaction * @param listener gets the id of the newly created interaction record */ @@ -149,40 +150,40 @@ public void createInteraction( Map additionalInfo, Instant timestamp, ActionListener listener, - String parintid, + String parentId, Integer traceNumber ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { - IndexRequest request = Requests - .indexRequest(INTERACTIONS_INDEX_NAME) - .source( - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - origin, - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - conversationId, - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - input, - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - promptTemplate, - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - response, - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - additionalInfo, - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp, - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - parintid, - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - traceNumber - ); + Map sourceMap = new HashMap<>(); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp); + sourceMap.put(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentId); + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNumber); + + if (input != null && !input.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + } + if (promptTemplate != null && !promptTemplate.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + } + if (response != null && !response.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + } + if (origin != null && !origin.trim().isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + } + if (additionalInfo != null && !additionalInfo.isEmpty()) { + sourceMap.put(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } + IndexRequest request = Requests.indexRequest(INTERACTIONS_INDEX_NAME).source(sourceMap); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { @@ -272,11 +273,11 @@ public void getInteractions(String conversationId, int from, int maxResults, Act if (access) { innerGetInteractions(conversationId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -361,13 +362,13 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList if (access) { innerGetTraces(interactionId, from, maxResults, listener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -482,8 +483,8 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { @@ -549,11 +550,11 @@ public void searchInteractions(String conversationId, SearchRequest request, Act listener.onFailure(e); } } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to memory " + conversationId, RestStatus.UNAUTHORIZED @@ -628,13 +629,13 @@ public void updateInteraction(String interactionId, UpdateRequest updateRequest, if (access) { innerUpdateInteraction(updateRequest, internalListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); + : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED @@ -671,11 +672,11 @@ private void checkInteractionPermission(String interactionId, Interaction intera internalListener.onResponse(interaction); log.info("Successfully get the message : {}", interactionId); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName(); throw new OpenSearchStatusException( "User [" + user + "] does not have access to message " + interactionId, RestStatus.UNAUTHORIZED diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java index 0f2dd2b5ce..28b529d360 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -132,4 +132,5 @@ public void testRestRequest_WithAdditionalInfo() throws IOException { Assert.assertEquals("value1", request.getAdditionalInfos().get("key1")); Assert.assertEquals(123, request.getAdditionalInfos().get("key2")); } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index 8068a85dfb..5f274fec82 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.junit.Before; @@ -153,4 +154,31 @@ public void testFromRestRequest_Trace() throws IOException { assert (request.getParentIid().equals("parentId")); assert (request.getTraceNumber().equals(1)); } + + public void testFromRestRequest_WithAllFieldsEmpty_Fails() throws IOException { + Map params = new HashMap<>(); + + params.put(ActionConstants.INPUT_FIELD, ""); + params.put(ActionConstants.PROMPT_TEMPLATE_FIELD, null); + params.put(ActionConstants.AI_RESPONSE_FIELD, " "); + params.put(ActionConstants.RESPONSE_ORIGIN_FIELD, null); + params.put(ActionConstants.ADDITIONAL_INFO_FIELD, Collections.emptyMap()); + + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.MEMORY_ID, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + IllegalArgumentException exception = assertThrows( + "Expected IllegalArgumentException due to all fields empty", + IllegalArgumentException.class, + () -> CreateInteractionRequest.fromRestRequest(rrequest) + ); + + assertEquals( + exception.getMessage(), + "At least one of the following parameters must be non-empty: input, prompt_template, response, origin, additional_info" + ); + + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 5baefa358d..8f6057e57e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -564,8 +564,8 @@ public void testCanGetAConversationById() { assert (cid2.result().equals(get2.result().getId())); assert (get1.result().getName().equals("convo1")); assert (get2.result().getName().equals("convo2")); - Assert.assertTrue(convo2.getAdditionalInfos().isEmpty()); - Assert.assertTrue(get1.result().getAdditionalInfos().isEmpty()); + Assert.assertTrue(convo2.getAdditionalInfos() == null); + Assert.assertTrue(get1.result().getAdditionalInfos() == null); cdl.countDown(); }, e -> { cdl.countDown(); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index a27653a4e2..ccb7fd112f 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -134,10 +134,10 @@ private void blanketGrantAccess() { } private void setupUser(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 042a4a3a91..f18aec2e33 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -156,7 +156,7 @@ private void setupGrantAccess() { } private void setupDenyAccess(String user) { - String userstr = user == null ? "" : user + "||"; + String userStr = user == null ? "" : user + "||"; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -164,7 +164,7 @@ private void setupDenyAccess(String user) { }).when(conversationMetaIndex).checkAccess(anyString(), any()); doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java index ec7a805c9e..c084b47eec 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/MLMemoryManager.java @@ -150,11 +150,11 @@ public void getFinalInteractions(String conversationId, int lastNInteraction, Ac if (access) { innerGetFinalInteractions(conversationId, lastNInteraction, actionListener); } else { - String userstr = client + String userStr = client .threadPool() .getThreadContext() .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? "" : User.parse(userstr).getName(); + String user = User.parse(userStr) == null ? "" : User.parse(userStr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } }, e -> { actionListener.onFailure(e); }); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java index 68355d9a68..a3b7bd76fa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/MLMemoryManagerTests.java @@ -249,7 +249,7 @@ public void testGetInteractions_SearchFails_ThenFail() { @Test public void testGetInteractions_NoAccessNoUser_ThenFail() { doReturn(true).when(metadata).hasIndex(anyString()); - String userstr = ""; + String userStr = ""; doAnswer(invocation -> { ActionListener al = invocation.getArgument(1); al.onResponse(false); @@ -258,7 +258,7 @@ public void testGetInteractions_NoAccessNoUser_ThenFail() { doAnswer(invocation -> { ThreadContext tc = new ThreadContext(Settings.EMPTY); - tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userStr); return tc; }).when(threadPool).getThreadContext(); mlMemoryManager.getFinalInteractions("cid", 10, interactionListActionListener);