diff --git a/common/src/main/java/org/opensearch/ml/common/conversational/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversational/ActionConstants.java index 6a8f9873ce..be2e1b2d3b 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversational/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversational/ActionConstants.java @@ -44,6 +44,10 @@ public class ActionConstants { public final static String AI_RESPONSE_FIELD = "response"; /** name of origin field in all requests */ public final static String RESPONSE_ORIGIN_FIELD = "origin"; + /** name of prompt template field in all requests */ + public final static String PROMPT_TEMPLATE_FIELD = "prompt_template"; + /** name of metadata field in all requests */ + public final static String METADATA_FIELD = "metadata"; /** name of success field in all requests */ public final static String SUCCESS_FIELD = "success"; diff --git a/common/src/main/java/org/opensearch/ml/common/conversational/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversational/ConversationalIndexConstants.java index 0979d625cf..c28c150a22 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversational/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversational/ConversationalIndexConstants.java @@ -57,10 +57,14 @@ public class ConversationalIndexConstants { public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id"; /** Name of the interaction field for the human input */ public final static String INTERACTIONS_INPUT_FIELD = "input"; + /** Name of the interaction field for the prompt template */ + public final static String INTERACTIONS_PROMPT_TEMPLATE_FIELD = "prompt_template"; /** Name of the interaction field for the AI response */ public final static String INTERACTIONS_RESPONSE_FIELD = "response"; /** Name of the interaction field for the response's origin */ public final static String INTERACTIONS_ORIGIN_FIELD = "origin"; + /** Name of the interaction field for additional metadata */ + public final static String INTERACTIONS_METADATA_FIELD = "metadata"; /** Name of the interaction field for the timestamp */ public final static String INTERACTIONS_TIMESTAMP_FIELD = "timestamp"; /** Mappings for the interactions index */ @@ -79,12 +83,18 @@ public class ConversationalIndexConstants { + INTERACTIONS_INPUT_FIELD + "\": {\"type\": \"text\"},\n" + " \"" + + INTERACTIONS_PROMPT_TEMPLATE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + INTERACTIONS_RESPONSE_FIELD + "\": {\"type\": \"text\"},\n" + " \"" + INTERACTIONS_ORIGIN_FIELD + "\": {\"type\": \"keyword\"},\n" + " \"" + + INTERACTIONS_METADATA_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + USER_FIELD + "\": {\"type\": \"keyword\"}\n" + " }\n" diff --git a/common/src/main/java/org/opensearch/ml/common/conversational/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversational/Interaction.java index be5e6c094c..52131b5656 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversational/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversational/Interaction.java @@ -48,9 +48,13 @@ public class Interaction implements Writeable, ToXContentObject { @Getter private String input; @Getter + private String promptTemplate; + @Getter private String response; @Getter private String origin; + @Getter + private String metadata; /** * Creates an Interaction object from a map of fields in the OS index @@ -62,9 +66,11 @@ public static Interaction fromMap(String id, Map fields) { Instant timestamp = Instant.parse((String) fields.get(ConversationalIndexConstants.INTERACTIONS_TIMESTAMP_FIELD)); String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD); String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD); + String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD); String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); String agent = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); - return new Interaction(id, timestamp, conversationId, input, response, agent); + String metadata = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_METADATA_FIELD); + return new Interaction(id, timestamp, conversationId, input, promptTemplate, response, agent, metadata); } /** @@ -88,9 +94,11 @@ public static Interaction fromStream(StreamInput in) throws IOException { Instant timestamp = in.readInstant(); String conversationId = in.readString(); String input = in.readString(); + String promptTemplate = in.readString(); String response = in.readString(); String origin = in.readString(); - return new Interaction(id, timestamp, conversationId, input, response, origin); + String metadata = in.readOptionalString(); + return new Interaction(id, timestamp, conversationId, input, promptTemplate, response, origin, metadata); } @@ -100,8 +108,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(timestamp); out.writeString(conversationId); out.writeString(input); + out.writeString(promptTemplate); out.writeString(response); out.writeString(origin); + out.writeOptionalString(metadata); } @Override @@ -111,8 +121,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); builder.field(ConversationalIndexConstants.INTERACTIONS_TIMESTAMP_FIELD, timestamp); builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if(metadata != null) { + builder.field(ConversationalIndexConstants.INTERACTIONS_METADATA_FIELD, metadata); + } builder.endObject(); return builder; } @@ -125,8 +139,11 @@ public boolean equals(Object other) { ((Interaction) other).conversationId.equals(this.conversationId) && ((Interaction) other).timestamp.equals(this.timestamp) && ((Interaction) other).input.equals(this.input) && + ((Interaction) other).promptTemplate.equals(this.promptTemplate) && ((Interaction) other).response.equals(this.response) && - ((Interaction) other).origin.equals(this.origin) + ((Interaction) other).origin.equals(this.origin) && + ( (((Interaction) other).metadata == null && this.metadata == null) || + ((Interaction) other).metadata.equals(this.metadata)) ); } @@ -138,7 +155,9 @@ public String toString() { + ",timestamp=" + timestamp + ",origin=" + origin + ",input=" + input + + ",promt_template=" + promptTemplate + ",response=" + response + + ",metadata=" + metadata + "}"; } diff --git a/conversational-memory/src/main/java/org/opensearch/ml/conversational/ConversationalMemoryHandler.java b/conversational-memory/src/main/java/org/opensearch/ml/conversational/ConversationalMemoryHandler.java index 0719a857cb..d5b5606a6c 100644 --- a/conversational-memory/src/main/java/org/opensearch/ml/conversational/ConversationalMemoryHandler.java +++ b/conversational-memory/src/main/java/org/opensearch/ml/conversational/ConversationalMemoryHandler.java @@ -60,21 +60,40 @@ public interface ConversationalMemoryHandler { * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction * @param response the Gen AI response for this interaction * @param origin the name of the GenAI agent in this interaction + * @param metadata additional inofrmation used in constructing the LLM prompt * @param listener gets the ID of the new interaction */ - public void createInteraction(String conversationId, String input, String response, String origin, ActionListener listener); + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String metadata, + ActionListener listener + ); /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to * @param input the human input for the interaction + * @param promptTemplate the prompt template used in this interaction * @param response the Gen AI response for this interaction * @param origin the name of the GenAI agent in this interaction + * @param metadata arbitrary JSON string of extra stuff * @return ActionFuture for the interactionId of the new interaction */ - public ActionFuture createInteraction(String conversationId, String input, String response, String origin); + public ActionFuture createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String metadata + ); /** * Adds an interaction to the index, updating the associated Conversational Metadata diff --git a/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequest.java b/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequest.java index 92e0530553..4fe65db94e 100644 --- a/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequest.java +++ b/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequest.java @@ -41,9 +41,13 @@ public class CreateInteractionRequest extends ActionRequest { @Getter private String input; @Getter + private String promptTemplate; + @Getter private String response; @Getter private String origin; + @Getter + private String metadata; /** * Constructor @@ -54,8 +58,10 @@ public CreateInteractionRequest(StreamInput in) throws IOException { super(in); this.conversationId = in.readString(); this.input = in.readString(); + this.promptTemplate = in.readString(); this.response = in.readString(); this.origin = in.readOptionalString(); + this.metadata = in.readOptionalString(); } @Override @@ -63,8 +69,10 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(conversationId); out.writeString(input); + out.writeString(promptTemplate); out.writeString(response); out.writeOptionalString(origin); + out.writeOptionalString(metadata); } @Override @@ -85,9 +93,11 @@ public ActionRequestValidationException validate() { public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException { String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); String inp = request.param(ActionConstants.INPUT_FIELD); + String prmpt = request.param(ActionConstants.PROMPT_TEMPLATE_FIELD); String rsp = request.param(ActionConstants.AI_RESPONSE_FIELD); String ogn = request.param(ActionConstants.RESPONSE_ORIGIN_FIELD); - return new CreateInteractionRequest(cid, inp, rsp, ogn); + String metadata = request.param(ActionConstants.METADATA_FIELD); + return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, metadata); } } diff --git a/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportAction.java b/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportAction.java index e3d060d4d1..d3d37e3386 100644 --- a/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportAction.java +++ b/conversational-memory/src/main/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportAction.java @@ -64,13 +64,15 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList String inp = request.getInput(); String rsp = request.getResponse(); String ogn = request.getOrigin(); + String prompt = request.getPromptTemplate(); + String metadata = request.getMetadata(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); ActionListener al = ActionListener .wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> { internalListener.onFailure(e); }); - cmHandler.createInteraction(cid, inp, rsp, ogn, al); + cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, metadata, al); } catch (Exception e) { log.error(e.toString()); actionListener.onFailure(e); diff --git a/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/InteractionsIndex.java b/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/InteractionsIndex.java index 93dd6681ec..ba5bbb6412 100644 --- a/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/InteractionsIndex.java +++ b/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/InteractionsIndex.java @@ -112,16 +112,20 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { * Add an interaction to this index. Return the ID of the newly created interaction * @param conversationId The id of the conversation this interaction belongs to * @param input the user (human) input into this interaction + * @param promptTemplate the prompt template used for this interaction * @param response the GenAI response for this interaction * @param origin the origin of the response for this interaction + * @param metadata additional information used for constructing the LLM prompt * @param timestamp when this interaction happened * @param listener gets the id of the newly created interaction record */ public void createInteraction( String conversationId, String input, + String promptTemplate, String response, String origin, + String metadata, Instant timestamp, ActionListener listener ) { @@ -138,8 +142,12 @@ public void createInteraction( conversationId, ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input, + ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, + promptTemplate, ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response, + ConversationalIndexConstants.INTERACTIONS_METADATA_FIELD, + metadata, ConversationalIndexConstants.INTERACTIONS_TIMESTAMP_FIELD, timestamp ); @@ -177,12 +185,22 @@ public void createInteraction( * Add an interaction to this index, timestamped now. Return the id of the newly created interaction * @param conversationId The id of the converation this interaction belongs to * @param input the user (human) input into this interaction + * @param promptTemplate the prompt template used for this interaction * @param response the GenAI response for this interaction * @param origin the name of the GenAI agent this interaction belongs to + * @param metadata additional information used to construct the LLM prompt * @param listener gets the id of the newly created interaction record */ - public void createInteraction(String conversationId, String input, String response, String origin, ActionListener listener) { - createInteraction(conversationId, input, response, origin, Instant.now(), listener); + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String metadata, + ActionListener listener + ) { + createInteraction(conversationId, input, promptTemplate, response, origin, metadata, Instant.now(), listener); } /** diff --git a/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandler.java b/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandler.java index 0919b16878..f72506ec91 100644 --- a/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandler.java +++ b/conversational-memory/src/main/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandler.java @@ -99,28 +99,45 @@ public ActionFuture createConversation(String name) { * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction * @param response the Gen AI response for this interaction * @param origin the name of the GenAI agent in this interaction + * @param metadata additional inofrmation used in constructing the LLM prompt * @param listener gets the ID of the new interaction */ - public void createInteraction(String conversationId, String input, String response, String origin, ActionListener listener) { + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String metadata, + ActionListener listener + ) { Instant time = Instant.now(); - interactionsIndex.createInteraction(conversationId, input, response, origin, time, listener); + interactionsIndex.createInteraction(conversationId, input, promptTemplate, response, origin, metadata, time, listener); } /** * Adds an interaction to the conversation indicated, updating the conversational metadata * @param conversationId the conversation to add the interaction to * @param input the human input for the interaction - * @param prompt the prompt template used in this interaction + * @param promptTemplate the prompt template used in this interaction * @param response the Gen AI response for this interaction - * @param agent the name of the GenAI agent in this interaction + * @param origin the name of the GenAI agent in this interaction * @param metadata arbitrary JSON string of extra stuff * @return ActionFuture for the interactionId of the new interaction */ - public ActionFuture createInteraction(String conversationId, String input, String response, String origin) { + public ActionFuture createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String metadata + ) { PlainActionFuture fut = PlainActionFuture.newFuture(); - createInteraction(conversationId, input, response, origin, fut); + createInteraction(conversationId, input, promptTemplate, response, origin, metadata, fut); return fut; } @@ -136,8 +153,10 @@ public void createInteraction(InteractionBuilder builder, ActionListener .createInteraction( interaction.getConversationId(), interaction.getInput(), + interaction.getPromptTemplate(), interaction.getResponse(), interaction.getOrigin(), + interaction.getMetadata(), interaction.getTimestamp(), listener ); diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/ConversationalMemoryHandlerITTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/ConversationalMemoryHandlerITTests.java index 2db0f5cbf7..5ae9aaced1 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/ConversationalMemoryHandlerITTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/ConversationalMemoryHandlerITTests.java @@ -109,15 +109,16 @@ public void testCanAddNewInteractionsToConversation() { cmHandler.createConversation("test", cidListener); StepListener iid1Listener = new StepListener<>(); - cidListener - .whenComplete(cid -> { cmHandler.createInteraction(cid, "test input1", "test response", "test origin", iid1Listener); }, e -> { - cdl.countDown(); - assert (false); - }); + cidListener.whenComplete(cid -> { + cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener iid2Listener = new StepListener<>(); iid1Listener.whenComplete(iid -> { - cmHandler.createInteraction(cidListener.result(), "test input1", "test response", "test origin", iid2Listener); + cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); }, e -> { cdl.countDown(); assert (false); @@ -142,15 +143,16 @@ public void testCanGetInteractionsBackOut() { cmHandler.createConversation("test", cidListener); StepListener iid1Listener = new StepListener<>(); - cidListener - .whenComplete(cid -> { cmHandler.createInteraction(cid, "test input1", "test response", "test origin", iid1Listener); }, e -> { - cdl.countDown(); - assert (false); - }); + cidListener.whenComplete(cid -> { + cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + }, e -> { + cdl.countDown(); + assert (false); + }); StepListener iid2Listener = new StepListener<>(); iid1Listener.whenComplete(iid -> { - cmHandler.createInteraction(cidListener.result(), "test input1", "test response", "test origin", iid2Listener); + cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); }, e -> { cdl.countDown(); assert (false); @@ -193,15 +195,19 @@ public void testCanDeleteConversations() { cmHandler.createConversation("test", cid1); StepListener iid1 = new StepListener<>(); - cid1.whenComplete(cid -> { cmHandler.createInteraction(cid, "test input1", "test response", "test origin", iid1); }, e -> { - cdl.countDown(); - assert (false); - }); + cid1 + .whenComplete( + cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1); }, + e -> { + cdl.countDown(); + assert (false); + } + ); StepListener iid2 = new StepListener<>(); iid1 .whenComplete( - iid -> { cmHandler.createInteraction(cid1.result(), "test input1", "test response", "test origin", iid2); }, + iid -> { cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid2); }, e -> { cdl.countDown(); assert (false); @@ -215,10 +221,14 @@ public void testCanDeleteConversations() { }); StepListener iid3 = new StepListener<>(); - cid2.whenComplete(cid -> { cmHandler.createInteraction(cid, "test input1", "test response", "test origin", iid3); }, e -> { - cdl.countDown(); - assert (false); - }); + cid2 + .whenComplete( + cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid3); }, + e -> { + cdl.countDown(); + assert (false); + } + ); StepListener del = new StepListener<>(); iid3.whenComplete(iid -> { cmHandler.deleteConversation(cid1.result(), del); }, e -> { @@ -318,19 +328,25 @@ public void testDifferentUsers_DifferentConversations() { cid2 .whenComplete( - cid -> { cmHandler.createInteraction(cid1.result(), "test input1", "test response", "test origin", iid1); }, + cid -> { + cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid1); + }, onFail ); iid1 .whenComplete( - iid -> { cmHandler.createInteraction(cid1.result(), "test input2", "test response", "test origin", iid2); }, + iid -> { + cmHandler.createInteraction(cid1.result(), "test input2", "pt", "test response", "test origin", "meta", iid2); + }, onFail ); iid2 .whenComplete( - iid -> { cmHandler.createInteraction(cid2.result(), "test input3", "test response", "test origin", iid3); }, + iid -> { + cmHandler.createInteraction(cid2.result(), "test input3", "pt", "test response", "test origin", "meta", iid3); + }, onFail ); @@ -341,26 +357,28 @@ public void testDifferentUsers_DifferentConversations() { cid3 .whenComplete( - cid -> { cmHandler.createInteraction(cid3.result(), "test input4", "test response", "test origin", iid4); }, + cid -> { + cmHandler.createInteraction(cid3.result(), "test input4", "pt", "test response", "test origin", "meta", iid4); + }, onFail ); iid4 .whenComplete( - iid -> { cmHandler.createInteraction(cid3.result(), "test input5", "test response", "test origin", iid5); }, + iid -> { + cmHandler.createInteraction(cid3.result(), "test input5", "pt", "test response", "test origin", "meta", iid5); + }, onFail ); - iid5 - .whenComplete( - iid -> { cmHandler.createInteraction(cid1.result(), "test inputf1", "test response", "test origin", failiid1); }, - onFail - ); + iid5.whenComplete(iid -> { + cmHandler.createInteraction(cid1.result(), "test inputf1", "pt", "test response", "test origin", "meta", failiid1); + }, onFail); failiid1.whenComplete(shouldHaveFailedAsString, e -> { if (e instanceof OpenSearchSecurityException && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { - cmHandler.createInteraction(cid1.result(), "test inputf2", "test response", "test origin", failiid2); + cmHandler.createInteraction(cid1.result(), "test inputf2", "pt", "test response", "test origin", "meta", failiid2); } else { onFail.accept(e); } @@ -430,7 +448,7 @@ public void testDifferentUsers_DifferentConversations() { failInter3.whenComplete(shouldHaveFailedAsInterList, e -> { if (e instanceof OpenSearchSecurityException && e.getMessage().startsWith("User [" + user1 + "] does not have access to conversation ")) { - cmHandler.createInteraction(cid3.result(), "test inputf3", "test response", "test origin", failiid3); + cmHandler.createInteraction(cid3.result(), "test inputf3", "pt", "test response", "test origin", "meta", failiid3); } else { onFail.accept(e); } diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequestTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequestTests.java index dc38a7643c..2a47dd9511 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequestTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRequestTests.java @@ -35,7 +35,7 @@ public class CreateInteractionRequestTests extends OpenSearchTestCase { public void testConstructorsAndStreaming() throws IOException { - CreateInteractionRequest request = new CreateInteractionRequest("cid", "input", "response", "origin"); + CreateInteractionRequest request = new CreateInteractionRequest("cid", "input", "pt", "response", "origin", "metadata"); assert (request.validate() == null); assert (request.getConversationId().equals("cid")); assert (request.getInput().equals("input")); @@ -55,7 +55,7 @@ public void testConstructorsAndStreaming() throws IOException { } public void testNullCID_thenFail() { - CreateInteractionRequest request = new CreateInteractionRequest(null, "input", "response", "origin"); + CreateInteractionRequest request = new CreateInteractionRequest(null, "input", "pt", "response", "origin", "metadata"); assert (request.validate() != null); assert (request.validate().validationErrors().size() == 1); assert (request.validate().validationErrors().get(0).equals("Interaction MUST belong to a conversation ID")); @@ -68,17 +68,23 @@ public void testFromRestRequest() throws IOException { "cid", ActionConstants.INPUT_FIELD, "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.METADATA_FIELD, + "metadata" ); RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); assert (request.validate() == null); assert (request.getConversationId().equals("cid")); assert (request.getInput().equals("input")); + assert (request.getPromptTemplate().equals("pt")); assert (request.getResponse().equals("response")); assert (request.getOrigin().equals("origin")); + assert (request.getMetadata().equals("metadata")); } } diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRestActionTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRestActionTests.java index 7c1b652272..237b52e7e8 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRestActionTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionRestActionTests.java @@ -53,10 +53,14 @@ public void testPrepareRequest() throws Exception { "cid", ActionConstants.INPUT_FIELD, "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.METADATA_FIELD, + "metadata" ); CreateInteractionRestAction action = new CreateInteractionRestAction(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); @@ -70,7 +74,9 @@ public void testPrepareRequest() throws Exception { CreateInteractionRequest req = argCaptor.getValue(); assert (req.getConversationId().equals("cid")); assert (req.getInput().equals("input")); + assert (req.getPromptTemplate().equals("pt")); assert (req.getResponse().equals("response")); assert (req.getOrigin().equals("origin")); + assert (req.getMetadata().equals("metadata")); } } diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportActionTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportActionTests.java index daed865443..8c502097a3 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportActionTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/CreateInteractionTransportActionTests.java @@ -88,7 +88,7 @@ public void setup() throws IOException { this.actionListener = al; this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); - this.request = new CreateInteractionRequest("test-cid", "input", "response", "origin"); + this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata"); this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client)); Settings settings = Settings.builder().build(); @@ -100,10 +100,10 @@ public void setup() throws IOException { public void testCreateInteraction() { log.info("testing create interaction transport"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); + ActionListener listener = invocation.getArgument(6); listener.onResponse("testID"); return null; - }).when(cmHandler).createInteraction(any(), any(), any(), any(), any()); + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class); verify(actionListener).onResponse(argCaptor.capture()); @@ -113,10 +113,10 @@ public void testCreateInteraction() { public void testCreateInteractionFails_thenFail() { log.info("testing create interaction transport"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(4); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new Exception("Testing Failure")); return null; - }).when(cmHandler).createInteraction(any(), any(), any(), any(), any()); + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); @@ -125,7 +125,9 @@ public void testCreateInteractionFails_thenFail() { public void testDoExecuteFails_thenFail() { log.info("testing create interaction transport"); - doThrow(new RuntimeException("Failure in doExecute")).when(cmHandler).createInteraction(any(), any(), any(), any(), any()); + doThrow(new RuntimeException("Failure in doExecute")) + .when(cmHandler) + .createInteraction(any(), any(), any(), any(), any(), any(), any()); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argCaptor.capture()); diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsResponseTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsResponseTests.java index 74be5fa56a..02409ad543 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsResponseTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsResponseTests.java @@ -45,9 +45,9 @@ public class GetInteractionsResponseTests extends OpenSearchTestCase { public void setup() { interactions = List .of( - new Interaction("id0", Instant.now(), "cid", "input", "response", "origin"), - new Interaction("id1", Instant.now(), "cid", "input", "response", "origin"), - new Interaction("id2", Instant.now(), "cid", "input", "response", "origin") + new Interaction("id0", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata"), + new Interaction("id1", Instant.now(), "cid", "input", "pt", "response", "origin", "mteadata"), + new Interaction("id2", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata") ); } @@ -74,7 +74,7 @@ public void testToXContent_MoreTokens() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"timestamp\":\"" + interaction.getTimestamp() - + "\",\"input\":\"input\",\"response\":\"response\",\"origin\":\"origin\"}],\"next_token\":2}"; + + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"metadata\":\"metadata\"}],\"next_token\":2}"; log.info(result); log.info(expected); // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness @@ -91,7 +91,7 @@ public void testToXContent_NoMoreTokens() throws IOException { String result = BytesReference.bytes(builder).utf8ToString(); String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"timestamp\":\"" + interaction.getTimestamp() - + "\",\"input\":\"input\",\"response\":\"response\",\"origin\":\"origin\"}]}"; + + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"metadata\":\"metadata\"}]}"; log.info(result); log.info(expected); // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsTransportActionTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsTransportActionTests.java index 5b15c6b8a1..95877c8b92 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsTransportActionTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/action/memory/interaction/GetInteractionsTransportActionTests.java @@ -103,7 +103,16 @@ public void setup() throws IOException { public void testGetInteractions_noMorePages() { log.info("test get interactions transport"); - Interaction testInteraction = new Interaction("test-iid", Instant.now(), "test-cid", "test-input", "test-response", "test-origin"); + Interaction testInteraction = new Interaction( + "test-iid", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); listener.onResponse(List.of(testInteraction)); @@ -121,7 +130,16 @@ public void testGetInteractions_noMorePages() { public void testGetInteractions_MorePages() { log.info("test get interactions transport"); - Interaction testInteraction = new Interaction("test-iid", Instant.now(), "test-cid", "test-input", "test-response", "test-origin"); + Interaction testInteraction = new Interaction( + "test-iid", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(3); listener.onResponse(List.of(testInteraction)); diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexITTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexITTests.java index 7d2b0658c1..7550df6590 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexITTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexITTests.java @@ -91,22 +91,38 @@ public void testCanAddNewInteraction() { CountDownLatch cdl = new CountDownLatch(2); String[] ids = new String[2]; index - .createInteraction("test", "test input", "test response", "test origin", new LatchedActionListener<>(ActionListener.wrap(id -> { - ids[0] = id; - }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }), cdl)); + .createInteraction( + "test", + "test input", + "pt", + "test response", + "test origin", + "metadata", + new LatchedActionListener<>(ActionListener.wrap(id -> { + ids[0] = id; + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }), cdl) + ); index - .createInteraction("test", "test input", "test response", "test origin", new LatchedActionListener<>(ActionListener.wrap(id -> { - ids[1] = id; - }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }), cdl)); + .createInteraction( + "test", + "test input", + "pt", + "test response", + "test origin", + "metadata", + new LatchedActionListener<>(ActionListener.wrap(id -> { + ids[1] = id; + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }), cdl) + ); try { cdl.await(); } catch (InterruptedException e) { @@ -122,7 +138,7 @@ public void testGetInteractions() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener id1Listener = new StepListener<>(); - index.createInteraction(conversation, "test input", "test response", "test origin", id1Listener); + index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); StepListener id2Listener = new StepListener<>(); id1Listener.whenComplete(id -> { @@ -130,8 +146,10 @@ public void testGetInteractions() { .createInteraction( conversation, "test input", + "pt", "test response", "test origin", + "metadata", Instant.now().plus(3, ChronoUnit.MINUTES), id2Listener ); @@ -169,7 +187,7 @@ public void testGetInteractionPages() { final String conversation = "test-conversation"; CountDownLatch cdl = new CountDownLatch(1); StepListener id1Listener = new StepListener<>(); - index.createInteraction(conversation, "test input", "test response", "test origin", id1Listener); + index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); StepListener id2Listener = new StepListener<>(); id1Listener.whenComplete(id -> { @@ -177,8 +195,10 @@ public void testGetInteractionPages() { .createInteraction( conversation, "test input1", + "pt", "test response", "test origin", + "metadata", Instant.now().plus(3, ChronoUnit.MINUTES), id2Listener ); @@ -194,8 +214,10 @@ public void testGetInteractionPages() { .createInteraction( conversation, "test input2", + "pt", "test response", "test origin", + "metadata", Instant.now().plus(4, ChronoUnit.MINUTES), id3Listener ); @@ -247,28 +269,40 @@ public void testDeleteConversation() { final String conversation2 = "conversation2"; CountDownLatch cdl = new CountDownLatch(1); StepListener iid1 = new StepListener<>(); - index.createInteraction(conversation1, "test input", "test response", "test origin", iid1); + index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid1); StepListener iid2 = new StepListener<>(); - iid1.whenComplete(r -> { index.createInteraction(conversation1, "test input", "test response", "test origin", iid2); }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }); + iid1 + .whenComplete( + r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid2); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); StepListener iid3 = new StepListener<>(); - iid2.whenComplete(r -> { index.createInteraction(conversation2, "test input", "test response", "test origin", iid3); }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }); + iid2 + .whenComplete( + r -> { index.createInteraction(conversation2, "test input", "pt", "test response", "test origin", "metadata", iid3); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); StepListener iid4 = new StepListener<>(); - iid3.whenComplete(r -> { index.createInteraction(conversation1, "test input", "test response", "test origin", iid4); }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }); + iid3 + .whenComplete( + r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid4); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); StepListener deleteListener = new StepListener<>(); iid4.whenComplete(r -> { index.deleteConversation(conversation1, deleteListener); }, e -> { diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexTests.java index d090f925f1..8cb95b5f57 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/InteractionsIndexTests.java @@ -239,7 +239,7 @@ public void testCreate_NoIndex_ThenFail() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("no index to add conversation to")); @@ -257,7 +257,7 @@ public void testCreate_BadRestStatus_ThenFail() { }).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("failed to create conversation")); @@ -273,7 +273,7 @@ public void testCreate_InternalFailure_ThenFail() { }).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Failure")); @@ -285,7 +285,7 @@ public void testCreate_ClientFails_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Client Failure")); @@ -297,7 +297,7 @@ public void testCreate_NoAccessNoUser_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor @@ -312,7 +312,7 @@ public void testCreate_NoAccessWithUser_ThenFail() { doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation cid")); @@ -326,7 +326,7 @@ public void testCreate_CreateIndexFails_ThenFail() { }).when(interactionsIndex).initInteractionsIndexIfAbsent(any()); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); - interactionsIndex.createInteraction("cid", "inp", "rsp", "ogn", createInteractionListener); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Fail in Index Creation")); @@ -414,10 +414,10 @@ public void testGetAll_BadMaxResults_ThenFail() { public void testGetAll_Recursion() { List interactions = List .of( - new Interaction("iid1", Instant.now(), "cid", "inp", "rsp", "ogn"), - new Interaction("iid2", Instant.now(), "cid", "inp", "rsp", "ogn"), - new Interaction("iid3", Instant.now(), "cid", "inp", "rsp", "ogn"), - new Interaction("iid4", Instant.now(), "cid", "inp", "rsp", "ogn") + new Interaction("iid1", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), + new Interaction("iid2", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), + new Interaction("iid3", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), + new Interaction("iid4", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta") ); doAnswer(invocation -> { ActionListener> al = invocation.getArgument(3); diff --git a/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandlerTests.java b/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandlerTests.java index 385774ed81..21c9dafe9d 100644 --- a/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/conversational-memory/src/test/java/org/opensearch/ml/conversational/index/OpenSearchConversationalMemoryHandlerTests.java @@ -76,21 +76,32 @@ public void testCreateConversation_Named_FutureSucess() { public void testCreateInteraction_Future() { doAnswer(invocation -> { - ActionListener al = invocation.getArgument(5); + ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), any(), any()); - ActionFuture result = cmHandler.createInteraction("cid", "inp", "rsp", "ogn"); + }) + .when(interactionsIndex) + .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + ActionFuture result = cmHandler.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta"); assert (result.actionGet(200).equals("iid")); } public void testCreateInteraction_FromBuilder_Success() { doAnswer(invocation -> { - ActionListener al = invocation.getArgument(5); + ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), any(), any()); - InteractionBuilder builder = Interaction.builder().conversationId("cid").input("inp").origin("origin").response("rsp"); + }) + .when(interactionsIndex) + .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + InteractionBuilder builder = Interaction + .builder() + .conversationId("cid") + .input("inp") + .origin("origin") + .response("rsp") + .promptTemplate("pt") + .metadata("meta"); @SuppressWarnings("unchecked") ActionListener createInteractionListener = mock(ActionListener.class); cmHandler.createInteraction(builder, createInteractionListener); @@ -101,11 +112,20 @@ public void testCreateInteraction_FromBuilder_Success() { public void testCreateInteraction_FromBuilder_Future() { doAnswer(invocation -> { - ActionListener al = invocation.getArgument(5); + ActionListener al = invocation.getArgument(7); al.onResponse("iid"); return null; - }).when(interactionsIndex).createInteraction(anyString(), anyString(), anyString(), anyString(), any(), any()); - InteractionBuilder builder = Interaction.builder().origin("ogn").conversationId("cid").input("inp").response("rsp"); + }) + .when(interactionsIndex) + .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + InteractionBuilder builder = Interaction + .builder() + .origin("ogn") + .conversationId("cid") + .input("inp") + .response("rsp") + .promptTemplate("pt") + .metadata("meta"); ActionFuture result = cmHandler.createInteraction(builder); assert (result.actionGet(200).equals("iid")); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalCreateInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalCreateInteractionActionIT.java index a5a04d5642..b6c7812677 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalCreateInteractionActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalCreateInteractionActionIT.java @@ -45,7 +45,11 @@ public void testCreateInteraction() throws IOException { ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.METADATA_FIELD, + "some metadata" ); Response response = TestHelper .makeRequest(client(), "POST", ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", id), params, "", null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalDeleteConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalDeleteConversationActionIT.java index 5eb28adcae..1b204e1e5e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalDeleteConversationActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalDeleteConversationActionIT.java @@ -93,7 +93,11 @@ public void testDeleteConversation_WithInteractions() throws IOException { ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.METADATA_FIELD, + "some metadata" ); Response ciresponse = TestHelper .makeRequest( diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalGetInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalGetInteractionsActionIT.java index 239345b68d..7d53d98111 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalGetInteractionsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestConversationalGetInteractionsActionIT.java @@ -88,7 +88,11 @@ public void testGetInteractions_LastPage() throws IOException { ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.METADATA_FIELD, + "some metadata" ); Response response = TestHelper .makeRequest( @@ -139,7 +143,11 @@ public void testGetInteractions_MorePages() throws IOException { ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.METADATA_FIELD, + "some metadata" ); Response response = TestHelper .makeRequest( @@ -198,7 +206,11 @@ public void testGetInteractions_NextPage() throws IOException { ActionConstants.AI_RESPONSE_FIELD, "response", ActionConstants.RESPONSE_ORIGIN_FIELD, - "origin" + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.METADATA_FIELD, + "some metadata" ); Response response = TestHelper .makeRequest(