Skip to content

Commit

Permalink
re-add prompt temlplate and metadata fields at interaction level
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Aug 29, 2023
1 parent bd0ff99 commit 65dda47
Show file tree
Hide file tree
Showing 20 changed files with 347 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ public class ActionConstants {
public final static String AI_RESPONSE_FIELD = "response";
/** name of origin field in all requests */
public final static String RESPONSE_ORIGIN_FIELD = "origin";
/** name of prompt template field in all requests */
public final static String PROMPT_TEMPLATE_FIELD = "prompt_template";
/** name of metadata field in all requests */
public final static String METADATA_FIELD = "metadata";
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ public class ConversationalIndexConstants {
public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id";
/** Name of the interaction field for the human input */
public final static String INTERACTIONS_INPUT_FIELD = "input";
/** Name of the interaction field for the prompt template */
public final static String INTERACTIONS_PROMPT_TEMPLATE_FIELD = "prompt_template";
/** Name of the interaction field for the AI response */
public final static String INTERACTIONS_RESPONSE_FIELD = "response";
/** Name of the interaction field for the response's origin */
public final static String INTERACTIONS_ORIGIN_FIELD = "origin";
/** Name of the interaction field for additional metadata */
public final static String INTERACTIONS_METADATA_FIELD = "metadata";
/** Name of the interaction field for the timestamp */
public final static String INTERACTIONS_TIMESTAMP_FIELD = "timestamp";
/** Mappings for the interactions index */
Expand All @@ -79,12 +83,18 @@ public class ConversationalIndexConstants {
+ INTERACTIONS_INPUT_FIELD
+ "\": {\"type\": \"text\"},\n"
+ " \""
+ INTERACTIONS_PROMPT_TEMPLATE_FIELD
+ "\": {\"type\": \"text\"},\n"
+ " \""
+ INTERACTIONS_RESPONSE_FIELD
+ "\": {\"type\": \"text\"},\n"
+ " \""
+ INTERACTIONS_ORIGIN_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_METADATA_FIELD
+ "\": {\"type\": \"text\"},\n"
+ " \""
+ USER_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ " }\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ public class Interaction implements Writeable, ToXContentObject {
@Getter
private String input;
@Getter
private String promptTemplate;
@Getter
private String response;
@Getter
private String origin;
@Getter
private String metadata;

/**
* Creates an Interaction object from a map of fields in the OS index
Expand All @@ -62,9 +66,11 @@ public static Interaction fromMap(String id, Map<String, Object> fields) {
Instant timestamp = Instant.parse((String) fields.get(ConversationalIndexConstants.INTERACTIONS_TIMESTAMP_FIELD));
String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD);
String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD);
String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD);
String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
String agent = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
return new Interaction(id, timestamp, conversationId, input, response, agent);
String metadata = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_METADATA_FIELD);
return new Interaction(id, timestamp, conversationId, input, promptTemplate, response, agent, metadata);
}

/**
Expand All @@ -88,9 +94,11 @@ public static Interaction fromStream(StreamInput in) throws IOException {
Instant timestamp = in.readInstant();
String conversationId = in.readString();
String input = in.readString();
String promptTemplate = in.readString();
String response = in.readString();
String origin = in.readString();
return new Interaction(id, timestamp, conversationId, input, response, origin);
String metadata = in.readOptionalString();
return new Interaction(id, timestamp, conversationId, input, promptTemplate, response, origin, metadata);
}


Expand All @@ -100,8 +108,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInstant(timestamp);
out.writeString(conversationId);
out.writeString(input);
out.writeString(promptTemplate);
out.writeString(response);
out.writeString(origin);
out.writeOptionalString(metadata);
}

@Override
Expand All @@ -111,8 +121,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id);
builder.field(ConversationalIndexConstants.INTERACTIONS_TIMESTAMP_FIELD, timestamp);
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
if(metadata != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_METADATA_FIELD, metadata);
}
builder.endObject();
return builder;
}
Expand All @@ -125,8 +139,11 @@ public boolean equals(Object other) {
((Interaction) other).conversationId.equals(this.conversationId) &&
((Interaction) other).timestamp.equals(this.timestamp) &&
((Interaction) other).input.equals(this.input) &&
((Interaction) other).promptTemplate.equals(this.promptTemplate) &&
((Interaction) other).response.equals(this.response) &&
((Interaction) other).origin.equals(this.origin)
((Interaction) other).origin.equals(this.origin) &&
( (((Interaction) other).metadata == null && this.metadata == null) ||
((Interaction) other).metadata.equals(this.metadata))
);
}

Expand All @@ -138,7 +155,9 @@ public String toString() {
+ ",timestamp=" + timestamp
+ ",origin=" + origin
+ ",input=" + input
+ ",promt_template=" + promptTemplate
+ ",response=" + response
+ ",metadata=" + metadata
+ "}";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,40 @@ public interface ConversationalMemoryHandler {
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the Gen AI response for this interaction
* @param origin the name of the GenAI agent in this interaction
* @param metadata additional inofrmation used in constructing the LLM prompt
* @param listener gets the ID of the new interaction
*/
public void createInteraction(String conversationId, String input, String response, String origin, ActionListener<String> listener);
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
String metadata,
ActionListener<String> listener
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param promptTemplate the prompt template used in this interaction
* @param response the Gen AI response for this interaction
* @param origin the name of the GenAI agent in this interaction
* @param metadata arbitrary JSON string of extra stuff
* @return ActionFuture for the interactionId of the new interaction
*/
public ActionFuture<String> createInteraction(String conversationId, String input, String response, String origin);
public ActionFuture<String> createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
String metadata
);

/**
* Adds an interaction to the index, updating the associated Conversational Metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ public class CreateInteractionRequest extends ActionRequest {
@Getter
private String input;
@Getter
private String promptTemplate;
@Getter
private String response;
@Getter
private String origin;
@Getter
private String metadata;

/**
* Constructor
Expand All @@ -54,17 +58,21 @@ public CreateInteractionRequest(StreamInput in) throws IOException {
super(in);
this.conversationId = in.readString();
this.input = in.readString();
this.promptTemplate = in.readString();
this.response = in.readString();
this.origin = in.readOptionalString();
this.metadata = in.readOptionalString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(conversationId);
out.writeString(input);
out.writeString(promptTemplate);
out.writeString(response);
out.writeOptionalString(origin);
out.writeOptionalString(metadata);
}

@Override
Expand All @@ -85,9 +93,11 @@ public ActionRequestValidationException validate() {
public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException {
String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String inp = request.param(ActionConstants.INPUT_FIELD);
String prmpt = request.param(ActionConstants.PROMPT_TEMPLATE_FIELD);
String rsp = request.param(ActionConstants.AI_RESPONSE_FIELD);
String ogn = request.param(ActionConstants.RESPONSE_ORIGIN_FIELD);
return new CreateInteractionRequest(cid, inp, rsp, ogn);
String metadata = request.param(ActionConstants.METADATA_FIELD);
return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, metadata);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ protected void doExecute(Task task, CreateInteractionRequest request, ActionList
String inp = request.getInput();
String rsp = request.getResponse();
String ogn = request.getOrigin();
String prompt = request.getPromptTemplate();
String metadata = request.getMetadata();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateInteractionResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener
.wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> {
internalListener.onFailure(e);
});
cmHandler.createInteraction(cid, inp, rsp, ogn, al);
cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, metadata, al);
} catch (Exception e) {
log.error(e.toString());
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,20 @@ public void initInteractionsIndexIfAbsent(ActionListener<Boolean> listener) {
* Add an interaction to this index. Return the ID of the newly created interaction
* @param conversationId The id of the conversation this interaction belongs to
* @param input the user (human) input into this interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the GenAI response for this interaction
* @param origin the origin of the response for this interaction
* @param metadata additional information used for constructing the LLM prompt
* @param timestamp when this interaction happened
* @param listener gets the id of the newly created interaction record
*/
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
String metadata,
Instant timestamp,
ActionListener<String> listener
) {
Expand All @@ -138,8 +142,12 @@ public void createInteraction(
conversationId,
ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD,
input,
ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD,
promptTemplate,
ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD,
response,
ConversationalIndexConstants.INTERACTIONS_METADATA_FIELD,
metadata,
ConversationalIndexConstants.INTERACTIONS_TIMESTAMP_FIELD,
timestamp
);
Expand Down Expand Up @@ -177,12 +185,22 @@ public void createInteraction(
* Add an interaction to this index, timestamped now. Return the id of the newly created interaction
* @param conversationId The id of the converation this interaction belongs to
* @param input the user (human) input into this interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the GenAI response for this interaction
* @param origin the name of the GenAI agent this interaction belongs to
* @param metadata additional information used to construct the LLM prompt
* @param listener gets the id of the newly created interaction record
*/
public void createInteraction(String conversationId, String input, String response, String origin, ActionListener<String> listener) {
createInteraction(conversationId, input, response, origin, Instant.now(), listener);
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
String metadata,
ActionListener<String> listener
) {
createInteraction(conversationId, input, promptTemplate, response, origin, metadata, Instant.now(), listener);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,45 @@ public ActionFuture<String> createConversation(String name) {
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the Gen AI response for this interaction
* @param origin the name of the GenAI agent in this interaction
* @param metadata additional inofrmation used in constructing the LLM prompt
* @param listener gets the ID of the new interaction
*/
public void createInteraction(String conversationId, String input, String response, String origin, ActionListener<String> listener) {
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
String metadata,
ActionListener<String> listener
) {
Instant time = Instant.now();
interactionsIndex.createInteraction(conversationId, input, response, origin, time, listener);
interactionsIndex.createInteraction(conversationId, input, promptTemplate, response, origin, metadata, time, listener);
}

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param prompt the prompt template used in this interaction
* @param promptTemplate the prompt template used in this interaction
* @param response the Gen AI response for this interaction
* @param agent the name of the GenAI agent in this interaction
* @param origin the name of the GenAI agent in this interaction
* @param metadata arbitrary JSON string of extra stuff
* @return ActionFuture for the interactionId of the new interaction
*/
public ActionFuture<String> createInteraction(String conversationId, String input, String response, String origin) {
public ActionFuture<String> createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
String metadata
) {
PlainActionFuture<String> fut = PlainActionFuture.newFuture();
createInteraction(conversationId, input, response, origin, fut);
createInteraction(conversationId, input, promptTemplate, response, origin, metadata, fut);
return fut;
}

Expand All @@ -136,8 +153,10 @@ public void createInteraction(InteractionBuilder builder, ActionListener<String>
.createInteraction(
interaction.getConversationId(),
interaction.getInput(),
interaction.getPromptTemplate(),
interaction.getResponse(),
interaction.getOrigin(),
interaction.getMetadata(),
interaction.getTimestamp(),
listener
);
Expand Down
Loading

0 comments on commit 65dda47

Please sign in to comment.