From c4729fd427448604053e84faaa2a3009ba954cc4 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Fri, 10 Nov 2023 22:04:55 -0800 Subject: [PATCH] add new fields in the memory and refactor transport actions Signed-off-by: Xun Zhang --- .../common/conversation/ActionConstants.java | 4 + .../ConversationalIndexConstants.java | 17 +- .../ml/common/conversation/Interaction.java | 17 +- .../memory/ConversationalMemoryHandler.java | 37 ++- .../CreateConversationRequest.java | 14 + .../CreateConversationTransportAction.java | 3 +- .../CreateInteractionRequest.java | 93 ++++++- .../CreateInteractionTransportAction.java | 12 +- .../memory/index/ConversationMetaIndex.java | 18 +- .../ml/memory/index/InteractionsIndex.java | 43 +++- ...OpenSearchConversationalMemoryHandler.java | 54 +++- .../ConversationalMemoryHandlerITTests.java | 242 +++++++++++++----- .../CreateInteractionRequestTests.java | 23 +- ...CreateInteractionTransportActionTests.java | 12 +- .../GetInteractionsResponseTests.java | 34 ++- .../GetInteractionsTransportActionTests.java | 5 +- .../index/InteractionsIndexITTests.java | 119 ++++++--- .../memory/index/InteractionsIndexTests.java | 30 ++- ...earchConversationalMemoryHandlerTests.java | 20 +- .../opensearch/ml/memory/MLMemoryManager.java | 13 + .../GenerativeQAResponseProcessor.java | 2 +- .../client/ConversationalMemoryClient.java | 3 +- .../ConversationalMemoryClientTests.java | 4 +- 23 files changed, 654 insertions(+), 165 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/memory/MLMemoryManager.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 5bb8334bc1..7484165fe8 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -48,6 +48,10 @@ public class ActionConstants { public final static String PROMPT_TEMPLATE_FIELD = "prompt_template"; /** name of metadata field in all requests */ public final static String ADDITIONAL_INFO_FIELD = "additional_info"; + /** name of metadata field in all requests */ + public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id"; + /** name of metadata field in all requests */ + public final static String TRACE_NUMBER_FIELD = "trace_number"; /** name of success field in all requests */ public final static String SUCCESS_FIELD = "success"; diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index c8e652265b..3650bae8c7 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -33,6 +33,8 @@ public class ConversationalIndexConstants { public final static String META_NAME_FIELD = "name"; /** Name of the owning user field in all indices */ public final static String USER_FIELD = "user"; + /** Name of the application that created this conversation */ + public final static String APPLICATION_TYPE_FIELD = "application_type"; /** Mappings for the conversational metadata index */ public final static String META_MAPPING = "{\n" + " \"_meta\": {\n" @@ -47,6 +49,9 @@ public class ConversationalIndexConstants { + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" + USER_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + APPLICATION_TYPE_FIELD + "\": {\"type\": \"keyword\"}\n" + " }\n" + "}"; @@ -69,6 +74,10 @@ public class ConversationalIndexConstants { public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info"; /** Name of the interaction field for the timestamp */ public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time"; + /** Name of the interaction id */ + public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id"; + /** The trace number of an interaction */ + public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number"; /** Mappings for the interactions index */ public final static String INTERACTIONS_MAPPINGS = "{\n" + " \"_meta\": {\n" @@ -95,7 +104,13 @@ public class ConversationalIndexConstants { + "\": {\"type\": \"keyword\"},\n" + " \"" + INTERACTIONS_ADDITIONAL_INFO_FIELD - + "\": {\"type\": \"text\"}\n" + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + PARENT_INTERACTIONS_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_TRACE_NUMBER_FIELD + + "\": {\"type\": \"long\"}\n" + " }\n" + "}"; diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 9b6ec636bd..5e5878da85 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.HashMap; import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; @@ -54,7 +55,7 @@ public class Interaction implements Writeable, ToXContentObject { @Getter private String origin; @Getter - private String additionalInfo; + private Map additionalInfo; /** * Creates an Interaction object from a map of fields in the OS index @@ -69,7 +70,7 @@ public static Interaction fromMap(String id, Map fields) { String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD); String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); - String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); + Map additionalInfo = (Map) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); } @@ -97,7 +98,10 @@ public static Interaction fromStream(StreamInput in) throws IOException { String promptTemplate = in.readString(); String response = in.readString(); String origin = in.readString(); - String additionalInfo = in.readOptionalString(); + Map additionalInfo = new HashMap<>(); + if (in.readBoolean()) { + additionalInfo = in.readMap(s -> s.readString(), s -> s.readString()); + } return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); } @@ -111,7 +115,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(promptTemplate); out.writeString(response); out.writeString(origin); - out.writeOptionalString(additionalInfo); + if (additionalInfo != null) { + out.writeBoolean(true); + out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } } @Override diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index 18d23eff0d..d874195c9b 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -18,6 +18,7 @@ package org.opensearch.ml.memory; import java.util.List; +import java.util.Map; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; @@ -49,6 +50,14 @@ public interface ConversationalMemoryHandler { */ public void createConversation(String name, ActionListener listener); + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(String name, String applicationType, ActionListener listener); + /** * Create a new conversation * @param name the name of the new conversation @@ -72,10 +81,34 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, ActionListener listener ); + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param interactionId the parent interactionId of this interaction + * @param traceNumber the trace number for a parent interaction + * @param listener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + ActionListener listener, + String interactionId, + Integer traceNumber + ); + /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to @@ -92,7 +125,7 @@ public ActionFuture createInteraction( String promptTemplate, String response, String origin, - String additionalInfo + Map additionalInfo ); /** diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java index e0a03f13eb..741e743c7e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -35,6 +35,8 @@ public class CreateConversationRequest extends ActionRequest { @Getter private String name = null; + @Getter + private String applicationType = null; /** * Constructor @@ -44,6 +46,7 @@ public class CreateConversationRequest extends ActionRequest { public CreateConversationRequest(StreamInput in) throws IOException { super(in); this.name = in.readOptionalString(); + this.applicationType = in.readOptionalString(); } /** @@ -55,6 +58,16 @@ public CreateConversationRequest(String name) { this.name = name; } + /** + * Constructor + * @param name name of the conversation + */ + public CreateConversationRequest(String name, String applicationType) { + super(); + this.name = name; + this.applicationType = applicationType; + } + /** * Constructor * name will be null @@ -65,6 +78,7 @@ public CreateConversationRequest() {} public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(name); + out.writeOptionalString(applicationType); } @Override diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java index f6856b7c66..c9c26c6e20 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java @@ -82,6 +82,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis return; } String name = request.getName(); + String applicationType = request.getApplicationType(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> { @@ -92,7 +93,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis if (name == null) { cmHandler.createConversation(al); } else { - cmHandler.createConversation(name, al); + cmHandler.createConversation(name, applicationType, al); } } catch (Exception e) { log.error("Failed to create new conversation with name " + request.getName(), e); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java index 52344b3792..2b274d8026 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -18,14 +18,18 @@ package org.opensearch.ml.memory.action.conversation; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.rest.RestRequest; @@ -48,7 +52,27 @@ public class CreateInteractionRequest extends ActionRequest { @Getter private String origin; @Getter - private String additionalInfo; + private Map additionalInfo; + @Getter + private String parent_interaction_id; + @Getter + private Integer trace_number; + + public CreateInteractionRequest( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo + ) { + this.conversationId = conversationId; + this.input = input; + this.promptTemplate = promptTemplate; + this.response = response; + this.origin = origin; + this.additionalInfo = additionalInfo; + } /** * Constructor @@ -62,7 +86,11 @@ public CreateInteractionRequest(StreamInput in) throws IOException { this.promptTemplate = in.readString(); this.response = in.readString(); this.origin = in.readOptionalString(); - this.additionalInfo = in.readOptionalString(); + if (in.readBoolean()) { + this.additionalInfo = in.readMap(s -> s.readString(), s -> s.readString()); + } + this.parent_interaction_id = in.readOptionalString(); + this.trace_number = in.readOptionalInt(); } @Override @@ -73,7 +101,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(promptTemplate); out.writeString(response); out.writeOptionalString(origin); - out.writeOptionalString(additionalInfo); + if (additionalInfo != null) { + out.writeBoolean(true); + out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(parent_interaction_id); + out.writeOptionalInt(trace_number); } @Override @@ -92,14 +127,52 @@ public ActionRequestValidationException validate() { * @throws IOException if something goes wrong reading from request */ public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException { - Map body = request.contentParser().mapStrings(); String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); - String inp = body.get(ActionConstants.INPUT_FIELD); - String prmpt = body.get(ActionConstants.PROMPT_TEMPLATE_FIELD); - String rsp = body.get(ActionConstants.AI_RESPONSE_FIELD); - String ogn = body.get(ActionConstants.RESPONSE_ORIGIN_FIELD); - String addinf = body.get(ActionConstants.ADDITIONAL_INFO_FIELD); - return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, addinf); + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + String input = null; + String prompt = null; + String rep = null; + String origin = null; + Map addinf = new HashMap<>(); + String parintid = null; + Integer tracenum = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ActionConstants.INPUT_FIELD: + input = parser.text(); + break; + case ActionConstants.PROMPT_TEMPLATE_FIELD: + prompt = parser.text(); + break; + case ActionConstants.AI_RESPONSE_FIELD: + rep = parser.text(); + break; + case ActionConstants.RESPONSE_ORIGIN_FIELD: + origin = parser.text(); + break; + case ActionConstants.ADDITIONAL_INFO_FIELD: + addinf = getParameterMap(parser.map()); + break; + case ActionConstants.PARENT_INTERACTION_ID_FIELD: + parintid = parser.text(); + break; + case ActionConstants.TRACE_NUMBER_FIELD: + tracenum = parser.intValue(false); + break; + default: + parser.skipChildren(); + break; + } + } + + return new CreateInteractionRequest(cid, input, prompt, rep, origin, addinf, parintid, tracenum); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java index 2273cc32e8..0f73b0059e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.action.conversation; +import java.util.Map; + import org.opensearch.OpenSearchException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -86,14 +88,20 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList String rsp = request.getResponse(); String ogn = request.getOrigin(); String prompt = request.getPromptTemplate(); - String additionalInfo = request.getAdditionalInfo(); + Map additionalInfo = request.getAdditionalInfo(); + String parintid = request.getParent_interaction_id(); + Integer traceNumber = request.getTrace_number(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener .wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> { internalListener.onFailure(e); }); - cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al); + if (parintid == null || traceNumber == null) { + cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al); + } else { + cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al, parintid, traceNumber); + } } catch (Exception e) { log.error("Failed to create interaction for conversation " + cid, e); actionListener.onFailure(e); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index e36c296066..e9319018d2 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -112,9 +112,10 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) /** * Adds a new conversation with the specified name to the index * @param name user-specified name of the conversation to be added + * @param applicationType the application type that creates this conversation * @param listener listener to wait for this to finish */ - public void createConversation(String name, ActionListener listener) { + public void createConversation(String name, String applicationType, ActionListener listener) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { String userstr = client @@ -129,7 +130,9 @@ public void createConversation(String name, ActionListener listener) { ConversationalIndexConstants.META_NAME_FIELD, name, ConversationalIndexConstants.USER_FIELD, - userstr == null ? null : User.parse(userstr).getName() + userstr == null ? null : User.parse(userstr).getName(), + ConversationalIndexConstants.APPLICATION_TYPE_FIELD, + applicationType ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -159,7 +162,16 @@ public void createConversation(String name, ActionListener listener) { * @param listener listener to wait for this to finish */ public void createConversation(ActionListener listener) { - createConversation("", listener); + createConversation("", "", listener); + } + + /** + * Adds a new conversation named "" + * @param name user-specified name of the conversation to be added + * @param listener listener to wait for this to finish + */ + public void createConversation(String name, ActionListener listener) { + createConversation(name, "", listener); } /** diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 5a2a6d10a9..9acdada7fd 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -21,6 +21,7 @@ import java.time.Instant; import java.util.LinkedList; import java.util.List; +import java.util.Map; import org.opensearch.OpenSearchSecurityException; import org.opensearch.OpenSearchWrapperException; @@ -116,6 +117,8 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { * @param origin the origin of the response for this interaction * @param additionalInfo additional information used for constructing the LLM prompt * @param timestamp when this interaction happened + * @param parintid the parent interactionId of this interaction + * @param traceNumber the trace number for a parent interaction * @param listener gets the id of the newly created interaction record */ public void createInteraction( @@ -124,9 +127,11 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, Instant timestamp, - ActionListener listener + ActionListener listener, + String parintid, + Integer traceNumber ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { String userstr = client @@ -153,7 +158,11 @@ public void createInteraction( ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo, ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - timestamp + timestamp, + ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, + parintid, + ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, + traceNumber ); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); @@ -178,6 +187,30 @@ public void createInteraction( }, e -> { listener.onFailure(e); })); } + /** + * Add an interaction to this index. Return the ID of the newly created interaction + * @param conversationId The id of the conversation this interaction belongs to + * @param input the user (human) input into this interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the GenAI response for this interaction + * @param origin the origin of the response for this interaction + * @param additionalInfo additional information used for constructing the LLM prompt + * @param timestamp when this interaction happened + * @param listener gets the id of the newly created interaction record + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + Instant timestamp, + ActionListener listener + ) { + createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, timestamp, listener, "", null); + } + /** * Add an interaction to this index, timestamped now. Return the id of the newly created interaction * @param conversationId The id of the converation this interaction belongs to @@ -194,10 +227,10 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, ActionListener listener ) { - createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, Instant.now(), listener); + createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, Instant.now(), listener, "", null); } /** diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index e89e9765c0..ba650df4d2 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -19,6 +19,7 @@ import java.time.Instant; import java.util.List; +import java.util.Map; import org.opensearch.action.StepListener; import org.opensearch.action.support.PlainActionFuture; @@ -86,6 +87,16 @@ public void createConversation(String name, ActionListener listener) { conversationMetaIndex.createConversation(name, listener); } + /** + * Create a new conversation + * @param name the name of the new conversation + * @param applicationType the application that creates this conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(String name, String applicationType, ActionListener listener) { + conversationMetaIndex.createConversation(name, applicationType, listener); + } + /** * Create a new conversation * @param name the name of the new conversation @@ -113,13 +124,52 @@ public void createInteraction( String promptTemplate, String response, String origin, - String additionalInfo, + Map additionalInfo, ActionListener listener ) { Instant time = Instant.now(); interactionsIndex.createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, time, listener); } + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param interactionId the parent interactionId of this interaction + * @param traceNumber the trace number for a parent interaction + * @param listener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo, + ActionListener listener, + String interactionId, + Integer traceNumber + ) { + Instant time = Instant.now(); + interactionsIndex + .createInteraction( + conversationId, + input, + promptTemplate, + response, + origin, + additionalInfo, + time, + listener, + interactionId, + traceNumber + ); + } + /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to @@ -136,7 +186,7 @@ public ActionFuture createInteraction( String promptTemplate, String response, String origin, - String additionalInfo + Map additionalInfo ) { PlainActionFuture fut = PlainActionFuture.newFuture(); createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, fut); diff --git a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java index 6ee0d4cc31..d8cde2ee53 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java @@ -17,6 +17,7 @@ */ package org.opensearch.ml.memory; +import java.util.Collections; import java.util.List; import java.util.Stack; import java.util.concurrent.CountDownLatch; @@ -110,7 +111,16 @@ public void testCanAddNewInteractionsToConversation() { StepListener iid1Listener = new StepListener<>(); cidListener.whenComplete(cid -> { - cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -118,7 +128,16 @@ public void testCanAddNewInteractionsToConversation() { StepListener iid2Listener = new StepListener<>(); iid1Listener.whenComplete(iid -> { - cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); + cmHandler + .createInteraction( + cidListener.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -144,7 +163,16 @@ public void testCanGetInteractionsBackOut() { StepListener iid1Listener = new StepListener<>(); cidListener.whenComplete(cid -> { - cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -152,7 +180,16 @@ public void testCanGetInteractionsBackOut() { StepListener iid2Listener = new StepListener<>(); iid1Listener.whenComplete(iid -> { - cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); + cmHandler + .createInteraction( + cidListener.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2Listener + ); }, e -> { cdl.countDown(); assert (false); @@ -195,24 +232,38 @@ public void testCanDeleteConversations() { cmHandler.createConversation("test", cid1); StepListener iid1 = new StepListener<>(); - cid1 - .whenComplete( - cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1); }, - e -> { - cdl.countDown(); - assert (false); - } - ); + cid1.whenComplete(cid -> { + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1 + ); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener iid2 = new StepListener<>(); - iid1 - .whenComplete( - iid -> { cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid2); }, - e -> { - cdl.countDown(); - assert (false); - } - ); + iid1.whenComplete(iid -> { + cmHandler + .createInteraction( + cid1.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2 + ); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener cid2 = new StepListener<>(); iid2.whenComplete(iid -> { cmHandler.createConversation(cid2); }, e -> { @@ -221,14 +272,21 @@ public void testCanDeleteConversations() { }); StepListener iid3 = new StepListener<>(); - cid2 - .whenComplete( - cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid3); }, - e -> { - cdl.countDown(); - assert (false); - } - ); + cid2.whenComplete(cid -> { + cmHandler + .createInteraction( + cid, + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid3 + ); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener del = new StepListener<>(); iid3.whenComplete(iid -> { cmHandler.deleteConversation(cid1.result(), del); }, e -> { @@ -328,59 +386,102 @@ public void testDifferentUsers_DifferentConversations() { cid1.whenComplete(cid -> { cmHandler.createConversation("conversation2", cid2); }, onFail); - cid2 - .whenComplete( - cid -> { - cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid1); - }, - onFail - ); + cid2.whenComplete(cid -> { + cmHandler + .createInteraction( + cid1.result(), + "test input1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid1 + ); + }, onFail); - iid1 - .whenComplete( - iid -> { - cmHandler.createInteraction(cid1.result(), "test input2", "pt", "test response", "test origin", "meta", iid2); - }, - onFail - ); + iid1.whenComplete(iid -> { + cmHandler + .createInteraction( + cid1.result(), + "test input2", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid2 + ); + }, onFail); - iid2 - .whenComplete( - iid -> { - cmHandler.createInteraction(cid2.result(), "test input3", "pt", "test response", "test origin", "meta", iid3); - }, - onFail - ); + iid2.whenComplete(iid -> { + cmHandler + .createInteraction( + cid2.result(), + "test input3", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid3 + ); + }, onFail); iid3.whenComplete(iid -> { contextStack.push(setUser(user2)); cmHandler.createConversation("conversation3", cid3); }, onFail); - cid3 - .whenComplete( - cid -> { - cmHandler.createInteraction(cid3.result(), "test input4", "pt", "test response", "test origin", "meta", iid4); - }, - onFail - ); + cid3.whenComplete(cid -> { + cmHandler + .createInteraction( + cid3.result(), + "test input4", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid4 + ); + }, onFail); - iid4 - .whenComplete( - iid -> { - cmHandler.createInteraction(cid3.result(), "test input5", "pt", "test response", "test origin", "meta", iid5); - }, - onFail - ); + iid4.whenComplete(iid -> { + cmHandler + .createInteraction( + cid3.result(), + "test input5", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + iid5 + ); + }, onFail); iid5.whenComplete(iid -> { - cmHandler.createInteraction(cid1.result(), "test inputf1", "pt", "test response", "test origin", "meta", failiid1); + cmHandler + .createInteraction( + cid1.result(), + "test inputf1", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + failiid1 + ); }, onFail); failiid1.whenComplete(shouldHaveFailedAsString, e -> { if (e instanceof OpenSearchSecurityException && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { - cmHandler.createInteraction(cid1.result(), "test inputf2", "pt", "test response", "test origin", "meta", failiid2); + cmHandler + .createInteraction( + cid1.result(), + "test inputf2", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + failiid2 + ); } else { onFail.accept(e); } @@ -450,7 +551,16 @@ public void testDifferentUsers_DifferentConversations() { failInter3.whenComplete(shouldHaveFailedAsInterList, e -> { if (e instanceof OpenSearchSecurityException && e.getMessage().startsWith("User [" + user1 + "] does not have access to conversation ")) { - cmHandler.createInteraction(cid3.result(), "test inputf3", "pt", "test response", "test origin", "meta", failiid3); + cmHandler + .createInteraction( + cid3.result(), + "test inputf3", + "pt", + "test response", + "test origin", + Collections.singletonMap("meta", "some meta"), + failiid3 + ); } else { onFail.accept(e); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java index cf027aef79..e154d6b59d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -18,6 +18,7 @@ package org.opensearch.ml.memory.action.conversation; import java.io.IOException; +import java.util.Collections; import java.util.Map; import org.junit.Before; @@ -47,7 +48,16 @@ public void setup() { } public void testConstructorsAndStreaming() throws IOException { - CreateInteractionRequest request = new CreateInteractionRequest("cid", "input", "pt", "response", "origin", "metadata"); + CreateInteractionRequest request = new CreateInteractionRequest( + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "interaction_id", + 1 + ); assert (request.validate() == null); assert (request.getConversationId().equals("cid")); assert (request.getInput().equals("input")); @@ -67,7 +77,16 @@ public void testConstructorsAndStreaming() throws IOException { } public void testNullCID_thenFail() { - CreateInteractionRequest request = new CreateInteractionRequest(null, "input", "pt", "response", "origin", "metadata"); + CreateInteractionRequest request = new CreateInteractionRequest( + null, + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "interaction_id", + 1 + ); assert (request.validate() != null); assert (request.validate().validationErrors().size() == 1); assert (request.validate().validationErrors().get(0).equals("Interaction MUST belong to a conversation ID")); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java index 8321a0b65e..589ecdf570 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import java.util.Set; import org.junit.Before; @@ -91,7 +92,16 @@ public void setup() throws IOException { this.actionListener = al; this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); - this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata"); + this.request = new CreateInteractionRequest( + "test-cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta"), + "interaction_id", + 1 + ); Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); this.threadContext = new ThreadContext(settings); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java index bbd17b2603..a897da6f90 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.List; import org.apache.lucene.search.spell.LevenshteinDistance; @@ -45,9 +46,36 @@ public class GetInteractionsResponseTests extends OpenSearchTestCase { public void setup() { interactions = List .of( - new Interaction("id0", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata"), - new Interaction("id1", Instant.now(), "cid", "input", "pt", "response", "origin", "mteadata"), - new Interaction("id2", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata") + new Interaction( + "id0", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ), + new Interaction( + "id1", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ), + new Interaction( + "id2", + Instant.now(), + "cid", + "input", + "pt", + "response", + "origin", + Collections.singletonMap("metadata", "some meta") + ) ); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java index a7a245b680..7b4b62df15 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.List; import java.util.Set; @@ -118,7 +119,7 @@ public void testGetInteractions_noMorePages() { "pt", "test-response", "test-origin", - "metadata" + Collections.singletonMap("metadata", "some meta") ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); @@ -145,7 +146,7 @@ public void testGetInteractions_MorePages() { "pt", "test-response", "test-origin", - "metadata" + Collections.singletonMap("metadata", "some meta") ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index c23177bc2f..4d42a4314c 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -19,6 +19,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -97,7 +98,7 @@ public void testCanAddNewInteraction() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), new LatchedActionListener<>(ActionListener.wrap(id -> { ids[0] = id; }, e -> { @@ -114,7 +115,7 @@ public void testCanAddNewInteraction() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), new LatchedActionListener<>(ActionListener.wrap(id -> { ids[1] = id; }, e -> { @@ -138,7 +139,16 @@ public void testGetInteractions() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener id1Listener = new StepListener<>(); - index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); + index + .createInteraction( + conversation, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + id1Listener + ); StepListener id2Listener = new StepListener<>(); id1Listener.whenComplete(id -> { @@ -149,7 +159,7 @@ public void testGetInteractions() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), Instant.now().plus(3, ChronoUnit.MINUTES), id2Listener ); @@ -187,7 +197,16 @@ public void testGetInteractionPages() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener id1Listener = new StepListener<>(); - index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); + index + .createInteraction( + conversation, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + id1Listener + ); StepListener id2Listener = new StepListener<>(); id1Listener.whenComplete(id -> { @@ -198,7 +217,7 @@ public void testGetInteractionPages() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), Instant.now().plus(3, ChronoUnit.MINUTES), id2Listener ); @@ -217,7 +236,7 @@ public void testGetInteractionPages() { "pt", "test response", "test origin", - "metadata", + Collections.singletonMap("metadata", "some meta"), Instant.now().plus(4, ChronoUnit.MINUTES), id3Listener ); @@ -269,40 +288,70 @@ public void testDeleteConversation() { final String conversation2 = "conversation2"; CountDownLatch cdl = new CountDownLatch(1); StepListener iid1 = new StepListener<>(); - index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid1); + index + .createInteraction( + conversation1, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid1 + ); StepListener iid2 = new StepListener<>(); - iid1 - .whenComplete( - r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid2); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid1.whenComplete(r -> { + index + .createInteraction( + conversation1, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid2 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener iid3 = new StepListener<>(); - iid2 - .whenComplete( - r -> { index.createInteraction(conversation2, "test input", "pt", "test response", "test origin", "metadata", iid3); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid2.whenComplete(r -> { + index + .createInteraction( + conversation2, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid3 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener iid4 = new StepListener<>(); - iid3 - .whenComplete( - r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid4); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid3.whenComplete(r -> { + index + .createInteraction( + conversation1, + "test input", + "pt", + "test response", + "test origin", + Collections.singletonMap("metadata", "some meta"), + iid4 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener deleteListener = new StepListener<>(); iid4.whenComplete(r -> { index.deleteConversation(conversation1, deleteListener); }, e -> { diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 0e97c7e9f6..40d33b0cef 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.verify; import java.time.Instant; +import java.util.Collections; import java.util.List; import org.junit.Before; @@ -239,7 +240,8 @@ public void testCreate_NoIndex_ThenFail() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("no index to add conversation to")); @@ -257,7 +259,8 @@ public void testCreate_BadRestStatus_ThenFail() { }).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Failed to create interaction")); @@ -273,7 +276,8 @@ public void testCreate_InternalFailure_ThenFail() { }).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Failure")); @@ -285,7 +289,8 @@ public void testCreate_ClientFails_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Client Failure")); @@ -297,7 +302,8 @@ public void testCreate_NoAccessNoUser_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor @@ -312,7 +318,8 @@ public void testCreate_NoAccessWithUser_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation cid")); @@ -326,7 +333,8 @@ public void testCreate_CreateIndexFails_ThenFail() { }).when(interactionsIndex).initInteractionsIndexIfAbsent(any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + interactionsIndex + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta"), createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Fail in Index Creation")); @@ -414,10 +422,10 @@ public void testGetAll_BadMaxResults_ThenFail() { public void testGetAll_Recursion() { List interactions = List .of( - new Interaction("iid1", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), - new Interaction("iid2", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), - new Interaction("iid3", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), - new Interaction("iid4", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta") + new Interaction("iid1", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")), + new Interaction("iid2", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")), + new Interaction("iid3", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")), + new Interaction("iid4", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")) ); doAnswer(invocation -> { ActionListener> al = invocation.getArgument(3); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index e39513d2d8..b588aee068 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.util.Collections; import java.util.List; import org.junit.Before; @@ -79,10 +80,9 @@ public void testCreateInteraction_Future() { ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }) - .when(interactionsIndex) - .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); - ActionFuture result = cmHandler.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta"); + }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), any(), any(), any()); + ActionFuture result = cmHandler + .createInteraction("cid", "inp", "pt", "rsp", "ogn", Collections.singletonMap("meta", "some meta")); assert (result.actionGet(200).equals("iid")); } @@ -91,9 +91,7 @@ public void testCreateInteraction_FromBuilder_Success() { ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }) - .when(interactionsIndex) - .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), any(), any(), any()); InteractionBuilder builder = Interaction .builder() .conversationId("cid") @@ -101,7 +99,7 @@ public void testCreateInteraction_FromBuilder_Success() { .origin("origin") .response("rsp") .promptTemplate("pt") - .additionalInfo("meta"); + .additionalInfo(Collections.singletonMap("meta", "some meta")); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); cmHandler.createInteraction(builder, createInteractionListener); @@ -115,9 +113,7 @@ public void testCreateInteraction_FromBuilder_Future() { ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }) - .when(interactionsIndex) - .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), any(), any(), any()); InteractionBuilder builder = Interaction .builder() .origin("ogn") @@ -125,7 +121,7 @@ public void testCreateInteraction_FromBuilder_Future() { .input("inp") .response("rsp") .promptTemplate("pt") - .additionalInfo("meta"); + .additionalInfo(Collections.singletonMap("meta", "some meta")); ActionFuture result = cmHandler.createInteraction(builder); assert (result.actionGet(200).equals("iid")); } diff --git a/plugin/src/main/java/org/opensearch/ml/memory/MLMemoryManager.java b/plugin/src/main/java/org/opensearch/ml/memory/MLMemoryManager.java new file mode 100644 index 0000000000..e66ef6ad3b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/memory/MLMemoryManager.java @@ -0,0 +1,13 @@ +package org.opensearch.ml.memory; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Manager class for Memory related calls. It contains ML memory related operations like CRUD Conversation/Interaction etc. + */ +@Log4j2 +@AllArgsConstructor +public class MLMemoryManager { + +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 111437ab0f..09d9b0d896 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -171,7 +171,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp PromptUtil.getPromptTemplate(systemPrompt, userInstructions), answer, GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, - jsonArrayToString(searchResults) + Collections.singletonMap("metadata", jsonArrayToString(searchResults)) ); log.info("Created a new interaction: {} ({})", interactionId, getDuration(start)); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index a5b74dbfdc..c8b7b28c30 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -66,7 +67,7 @@ public String createInteraction( String promptTemplate, String response, String origin, - String additionalInfo + Map additionalInfo ) { Preconditions.checkNotNull(conversationId); Preconditions.checkNotNull(input); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java index 7241ba40ed..43c84d4aec 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java @@ -21,6 +21,7 @@ import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.UUID; import java.util.stream.IntStream; @@ -182,7 +183,8 @@ public void testCreateInteraction() { ActionFuture future = mock(ActionFuture.class); when(future.actionGet(anyLong())).thenReturn(res); when(client.execute(eq(CreateInteractionAction.INSTANCE), any())).thenReturn(future); - String actual = memoryClient.createInteraction("cid", "input", "prompt", "answer", "origin", "hits"); + String actual = memoryClient + .createInteraction("cid", "input", "prompt", "answer", "origin", Collections.singletonMap("metadata", "hits")); assertEquals(id, actual); } }