Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new fields in the memory and refactor transport actions #1619

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ public class ActionConstants {
public final static String PROMPT_TEMPLATE_FIELD = "prompt_template";
/** name of metadata field in all requests */
public final static String ADDITIONAL_INFO_FIELD = "additional_info";
/** name of metadata field in all requests */
public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id";
/** name of metadata field in all requests */
public final static String TRACE_NUMBER_FIELD = "trace_number";
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class ConversationalIndexConstants {
public final static String META_NAME_FIELD = "name";
/** Name of the owning user field in all indices */
public final static String USER_FIELD = "user";
/** Name of the application that created this conversation */
public final static String APPLICATION_TYPE_FIELD = "application_type";
/** Mappings for the conversational metadata index */
public final static String META_MAPPING = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -47,6 +49,9 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ USER_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ APPLICATION_TYPE_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ " }\n"
+ "}";
Expand All @@ -69,6 +74,10 @@ public class ConversationalIndexConstants {
public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info";
/** Name of the interaction field for the timestamp */
public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time";
/** Name of the interaction id */
public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id";
/** The trace number of an interaction */
public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number";
/** Mappings for the interactions index */
public final static String INTERACTIONS_MAPPINGS = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -95,7 +104,13 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_ADDITIONAL_INFO_FIELD
+ "\": {\"type\": \"text\"}\n"
+ "\": {\"type\": \"flat_object\"},\n"
+ " \""
+ PARENT_INTERACTIONS_ID_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_TRACE_NUMBER_FIELD
+ "\": {\"type\": \"long\"}\n"
+ " }\n"
+ "}";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -54,7 +55,7 @@ public class Interaction implements Writeable, ToXContentObject {
@Getter
private String origin;
@Getter
private String additionalInfo;
private Map<String, String> additionalInfo;

/**
* Creates an Interaction object from a map of fields in the OS index
Expand All @@ -69,7 +70,7 @@ public static Interaction fromMap(String id, Map<String, Object> fields) {
String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD);
String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
Map<String,String> additionalInfo = (Map<String,String>) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
}

Expand Down Expand Up @@ -97,7 +98,10 @@ public static Interaction fromStream(StreamInput in) throws IOException {
String promptTemplate = in.readString();
String response = in.readString();
String origin = in.readString();
String additionalInfo = in.readOptionalString();
Map<String, String> additionalInfo = new HashMap<>();
if (in.readBoolean()) {
additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
}

Expand All @@ -111,7 +115,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(promptTemplate);
out.writeString(response);
out.writeString(origin);
out.writeOptionalString(additionalInfo);
if (additionalInfo != null) {
out.writeBoolean(true);
out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.ml.memory;

import java.util.List;
import java.util.Map;

import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -49,6 +50,14 @@ public interface ConversationalMemoryHandler {
*/
public void createConversation(String name, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
Expand All @@ -72,10 +81,34 @@ public void createInteraction(
String promptTemplate,
String response,
String origin,
String additionalInfo,
Map<String, String> additionalInfo,
ActionListener<String> listener
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
* @param input the human input for the interaction
* @param promptTemplate the prompt template used for this interaction
* @param response the Gen AI response for this interaction
* @param origin the name of the GenAI agent in this interaction
* @param additionalInfo additional information used in constructing the LLM prompt
* @param interactionId the parent interactionId of this interaction
* @param traceNumber the trace number for a parent interaction
* @param listener gets the ID of the new interaction
*/
public void createInteraction(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo,
ActionListener<String> listener,
String interactionId,
Integer traceNumber
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
Expand All @@ -92,7 +125,7 @@ public ActionFuture<String> createInteraction(
String promptTemplate,
String response,
String origin,
String additionalInfo
Map<String, String> additionalInfo
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
public class CreateConversationRequest extends ActionRequest {
@Getter
private String name = null;
@Getter
private String applicationType = null;

/**
* Constructor
Expand All @@ -44,6 +46,7 @@ public class CreateConversationRequest extends ActionRequest {
public CreateConversationRequest(StreamInput in) throws IOException {
super(in);
this.name = in.readOptionalString();
this.applicationType = in.readOptionalString();
}

/**
Expand All @@ -55,6 +58,16 @@ public CreateConversationRequest(String name) {
this.name = name;
}

/**
* Constructor
* @param name name of the conversation
*/
public CreateConversationRequest(String name, String applicationType) {
super();
this.name = name;
this.applicationType = applicationType;
}

/**
* Constructor
* name will be null
Expand All @@ -65,6 +78,7 @@ public CreateConversationRequest() {}
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(name);
out.writeOptionalString(applicationType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
return;
}
String name = request.getName();
String applicationType = request.getApplicationType();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> {
Expand All @@ -92,7 +93,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
if (name == null) {
cmHandler.createConversation(al);
} else {
cmHandler.createConversation(name, al);
cmHandler.createConversation(name, applicationType, al);
}
} catch (Exception e) {
log.error("Failed to create new conversation with name " + request.getName(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

Expand All @@ -48,7 +52,27 @@ public class CreateInteractionRequest extends ActionRequest {
@Getter
private String origin;
@Getter
private String additionalInfo;
private Map<String, String> additionalInfo;
@Getter
private String parent_interaction_id;
@Getter
private Integer trace_number;

public CreateInteractionRequest(
String conversationId,
String input,
String promptTemplate,
String response,
String origin,
Map<String, String> additionalInfo
) {
this.conversationId = conversationId;
this.input = input;
this.promptTemplate = promptTemplate;
this.response = response;
this.origin = origin;
this.additionalInfo = additionalInfo;
}

/**
* Constructor
Expand All @@ -62,7 +86,11 @@ public CreateInteractionRequest(StreamInput in) throws IOException {
this.promptTemplate = in.readString();
this.response = in.readString();
this.origin = in.readOptionalString();
this.additionalInfo = in.readOptionalString();
if (in.readBoolean()) {
this.additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
this.parent_interaction_id = in.readOptionalString();
this.trace_number = in.readOptionalInt();
}

@Override
Expand All @@ -73,7 +101,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(promptTemplate);
out.writeString(response);
out.writeOptionalString(origin);
out.writeOptionalString(additionalInfo);
if (additionalInfo != null) {
out.writeBoolean(true);
out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(parent_interaction_id);
out.writeOptionalInt(trace_number);
}

@Override
Expand All @@ -92,14 +127,52 @@ public ActionRequestValidationException validate() {
* @throws IOException if something goes wrong reading from request
*/
public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException {
Map<String, String> body = request.contentParser().mapStrings();
String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD);
String inp = body.get(ActionConstants.INPUT_FIELD);
String prmpt = body.get(ActionConstants.PROMPT_TEMPLATE_FIELD);
String rsp = body.get(ActionConstants.AI_RESPONSE_FIELD);
String ogn = body.get(ActionConstants.RESPONSE_ORIGIN_FIELD);
String addinf = body.get(ActionConstants.ADDITIONAL_INFO_FIELD);
return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, addinf);
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);

String input = null;
String prompt = null;
String rep = null;
String origin = null;
Map<String, String> addinf = new HashMap<>();
String parintid = null;
Integer tracenum = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case ActionConstants.INPUT_FIELD:
input = parser.text();
break;
case ActionConstants.PROMPT_TEMPLATE_FIELD:
prompt = parser.text();
break;
case ActionConstants.AI_RESPONSE_FIELD:
rep = parser.text();
break;
case ActionConstants.RESPONSE_ORIGIN_FIELD:
origin = parser.text();
break;
case ActionConstants.ADDITIONAL_INFO_FIELD:
addinf = getParameterMap(parser.map());
break;
case ActionConstants.PARENT_INTERACTION_ID_FIELD:
parintid = parser.text();
break;
case ActionConstants.TRACE_NUMBER_FIELD:
tracenum = parser.intValue(false);
break;
default:
parser.skipChildren();
break;
}
}

return new CreateInteractionRequest(cid, input, prompt, rep, origin, addinf, parintid, tracenum);
}

}
Loading
Loading