Skip to content

Commit

Permalink
init commit for adding additional info for memory metadata (#2750)
Browse files Browse the repository at this point in the history
create conversation support additional info



add test for search conversation



add bwc

Signed-off-by: Hailong Cui <[email protected]>
  • Loading branch information
Hailong-am authored Aug 5, 2024
1 parent 90f3f3d commit 920685d
Show file tree
Hide file tree
Showing 18 changed files with 289 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,5 @@ public class CommonValue {
public static final Version VERSION_2_13_0 = Version.fromString("2.13.0");
public static final Version VERSION_2_14_0 = Version.fromString("2.14.0");
public static final Version VERSION_2_16_0 = Version.fromString("2.16.0");
public static final Version VERSION_2_17_0 = Version.fromString("2.17.0");
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,26 @@
import java.util.Map;
import java.util.Objects;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.search.SearchHit;

import lombok.AllArgsConstructor;
import lombok.Getter;

import static org.opensearch.ml.common.CommonValue.VERSION_2_17_0;

/**
* Class for holding conversational metadata
*/
@AllArgsConstructor
public class ConversationMeta implements Writeable, ToXContentObject {

public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO = CommonValue.VERSION_2_17_0;
@Getter
private String id;
@Getter
Expand All @@ -49,6 +52,8 @@ public class ConversationMeta implements Writeable, ToXContentObject {
private String name;
@Getter
private String user;
@Getter
private Map<String, String> additionalInfos;

/**
* Creates a conversationMeta object from a SearchHit object
Expand All @@ -71,7 +76,8 @@ public static ConversationMeta fromMap(String id, Map<String, Object> docFields)
Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD));
String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD);
String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD);
return new ConversationMeta(id, created, updated, name, user);
Map<String, String> additionalInfos = (Map<String, String>)docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD);
return new ConversationMeta(id, created, updated, name, user, additionalInfos);
}

/**
Expand All @@ -87,7 +93,13 @@ public static ConversationMeta fromStream(StreamInput in) throws IOException {
Instant updated = in.readInstant();
String name = in.readString();
String user = in.readOptionalString();
return new ConversationMeta(id, created, updated, name, user);
Map<String, String> additionalInfos = null;
if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (in.readBoolean()) {
additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString);
}
}
return new ConversationMeta(id, created, updated, name, user, additionalInfos);
}

@Override
Expand All @@ -97,6 +109,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInstant(updatedTime);
out.writeString(name);
out.writeOptionalString(user);
if(out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (additionalInfos == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString);
}
}
}

@Override
Expand All @@ -119,6 +139,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
if(this.user != null) {
builder.field(ConversationalIndexConstants.USER_FIELD, this.user);
}
if (this.additionalInfos != null) {
builder.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, this.additionalInfos);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
*/
public class ConversationalIndexConstants {
/** Version of the meta index schema */
public final static Integer META_INDEX_SCHEMA_VERSION = 1;
public final static Integer META_INDEX_SCHEMA_VERSION = 2;
/** Name of the conversational metadata index */
public final static String META_INDEX_NAME = ".plugins-ml-memory-meta";
/** Name of the metadata field for initial timestamp */
Expand All @@ -37,6 +37,9 @@ public class ConversationalIndexConstants {
public final static String USER_FIELD = "user";
/** Name of the application that created this conversation */
public final static String APPLICATION_TYPE_FIELD = "application_type";
/** Name of the additional information for this memory */
public final static String META_ADDITIONAL_INFO_FIELD = "additional_info";

/** Mappings for the conversational metadata index */
public final static String META_MAPPING = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -57,7 +60,10 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ APPLICATION_TYPE_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ META_ADDITIONAL_INFO_FIELD
+ "\": {\"type\": \"flat_object\"}\n"
+ " }\n"
+ "}";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

public class ConversationMetaTests {
Expand All @@ -30,7 +31,7 @@ public class ConversationMetaTests {
@Before
public void setUp() {
time = Instant.now();
conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin");
conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", null);
}

@Test
Expand All @@ -41,6 +42,7 @@ public void test_fromSearchHit() throws IOException {
content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time);
content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name");
content.field(ConversationalIndexConstants.USER_FIELD, "admin");
content.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, Map.of("test_key", "test_value"));
content.endObject();

SearchHit[] hits = new SearchHit[1];
Expand All @@ -50,6 +52,7 @@ public void test_fromSearchHit() throws IOException {
assertEquals(conversationMeta.getId(), "cId");
assertEquals(conversationMeta.getName(), "meta name");
assertEquals(conversationMeta.getUser(), "admin");
assertEquals(conversationMeta.getAdditionalInfos().get("test_key"), "test_value");
}

@Test
Expand Down Expand Up @@ -85,7 +88,7 @@ public void test_fromStream() throws IOException {

@Test
public void test_ToXContent() throws IOException {
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null);
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
conversationMeta.toXContent(builder, EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Expand All @@ -94,13 +97,13 @@ public void test_ToXContent() throws IOException {

@Test
public void test_toString() {
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null);
assertEquals("{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", conversationMeta.toString());
}

@Test
public void test_equal() {
ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null);
assertEquals(meta.equals(conversationMeta), false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ public interface ConversationalMemoryHandler {
*/
public void createConversation(String name, String applicationType, ActionListener<String> listener);

/**
* Create a new conversation
* @param name the name of the new conversation
* @param applicationType the application that creates this conversation
* @param additionalInfos additional information associated with this conversation
* @param listener listener to wait for this op to finish, gets unique id of new conversation
*/
public void createConversation(
String name,
String applicationType,
Map<String, String> additionalInfos,
ActionListener<String> listener
);

/**
* Adds an interaction to the conversation indicated, updating the conversational metadata
* @param conversationId the conversation to add the interaction to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD;

import java.io.IOException;
import java.util.Map;

import org.opensearch.OpenSearchParseException;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

Expand All @@ -36,10 +40,14 @@
* Action Request for creating a conversation
*/
public class CreateConversationRequest extends ActionRequest {
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO = CommonValue.VERSION_2_17_0;

@Getter
private String name = null;
@Getter
private String applicationType = null;
@Getter
private Map<String, String> additionalInfos = null;

/**
* Constructor
Expand All @@ -50,6 +58,11 @@ public CreateConversationRequest(StreamInput in) throws IOException {
super(in);
this.name = in.readOptionalString();
this.applicationType = in.readOptionalString();
if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (in.readBoolean()) {
this.additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString);
}
}
}

/**
Expand All @@ -71,6 +84,19 @@ public CreateConversationRequest(String name, String applicationType) {
this.applicationType = applicationType;
}

/**
* Constructor
* @param name name of the conversation
* @param applicationType of the conversation
* @param additionalInfos information of the conversation
*/
public CreateConversationRequest(String name, String applicationType, Map<String, String> additionalInfos) {
super();
this.name = name;
this.applicationType = applicationType;
this.additionalInfos = additionalInfos;
}

/**
* Constructor
* name will be null
Expand All @@ -82,6 +108,14 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(name);
out.writeOptionalString(applicationType);
if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) {
if (additionalInfos == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(additionalInfos, StreamOutput::writeString, StreamOutput::writeString);
}
}
}

@Override
Expand All @@ -101,12 +135,13 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
if (!restRequest.hasContent()) {
return new CreateConversationRequest();
}
try {
Map<String, String> body = restRequest.contentParser().mapStrings();
if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
try (XContentParser parser = restRequest.contentParser()) {
Map<String, Object> body = parser.map();
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
return new CreateConversationRequest(
body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD)
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
);
} else {
return new CreateConversationRequest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE;

import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -79,6 +81,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
}
String name = request.getName();
String applicationType = request.getApplicationType();
Map<String, String> additionalInfos = request.getAdditionalInfos();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> {
Expand All @@ -89,7 +92,7 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
if (name == null) {
cmHandler.createConversation(al);
} else {
cmHandler.createConversation(name, applicationType, al);
cmHandler.createConversation(name, applicationType, additionalInfos, al);
}
} catch (Exception e) {
log.error("Failed to create new memory with name " + request.getName(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD;

import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -35,7 +36,7 @@ public class UpdateConversationRequest extends ActionRequest {
private String conversationId;
private Map<String, Object> updateContent;

private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD));
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD, META_ADDITIONAL_INFO_FIELD));

@Builder
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) {
Expand Down
Loading

0 comments on commit 920685d

Please sign in to comment.