Skip to content

Commit

Permalink
add new data fields in the memory layer and update tests (opensearch-…
Browse files Browse the repository at this point in the history
…project#1730)

* add new data fields in the memory layer and update tests

Signed-off-by: Xun Zhang <[email protected]>

* add more tests coverage and address comments

Signed-off-by: Xun Zhang <[email protected]>

* address comments and more tests

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored and austintlee committed Feb 29, 2024
1 parent 7451b31 commit 85b372d
Show file tree
Hide file tree
Showing 39 changed files with 1,564 additions and 293 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ public class ActionConstants {
public final static String PROMPT_TEMPLATE_FIELD = "prompt_template";
/** name of metadata field in all requests */
public final static String ADDITIONAL_INFO_FIELD = "additional_info";
/** name of metadata field in all requests */
public final static String PARENT_INTERACTION_ID_FIELD = "parent_interaction_id";
/** name of metadata field in all requests */
public final static String TRACE_NUMBER_FIELD = "trace_number";
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public class ConversationMeta implements Writeable, ToXContentObject {
@Getter
private Instant createdTime;
@Getter
private Instant updatedTime;
@Getter
private String name;
@Getter
private String user;
Expand All @@ -65,10 +67,11 @@ public static ConversationMeta fromSearchHit(SearchHit hit) {
* @return a new conversationMeta object representing the map
*/
public static ConversationMeta fromMap(String id, Map<String, Object> docFields) {
Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_FIELD));
Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_TIME_FIELD));
Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD));
String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD);
String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD);
return new ConversationMeta(id, created, name, user);
return new ConversationMeta(id, created, updated, name, user);
}

/**
Expand All @@ -81,38 +84,27 @@ public static ConversationMeta fromMap(String id, Map<String, Object> docFields)
public static ConversationMeta fromStream(StreamInput in) throws IOException {
String id = in.readString();
Instant created = in.readInstant();
Instant updated = in.readInstant();
String name = in.readString();
String user = in.readOptionalString();
return new ConversationMeta(id, created, name, user);
return new ConversationMeta(id, created, updated, name, user);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(id);
out.writeInstant(createdTime);
out.writeInstant(updatedTime);
out.writeString(name);
out.writeOptionalString(user);
}


/**
* Convert this conversationMeta object into an IndexRequest so it can be indexed
* @param index the index to send this conversation to. Should usually be .conversational-meta
* @return the IndexRequest for the client to send
*/
public IndexRequest toIndexRequest(String index) {
IndexRequest request = new IndexRequest(index);
return request.id(this.id).source(
ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime,
ConversationalIndexConstants.META_NAME_FIELD, this.name
);
}

@Override
public String toString() {
return "{id=" + id
+ ", name=" + name
+ ", created=" + createdTime.toString()
+ ", updated=" + updatedTime.toString()
+ ", user=" + user
+ "}";
}
Expand All @@ -121,7 +113,8 @@ public String toString() {
public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException {
builder.startObject();
builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.id);
builder.field(ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime);
builder.field(ConversationalIndexConstants.META_CREATED_TIME_FIELD, this.createdTime);
builder.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, this.updatedTime);
builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name);
if(this.user != null) {
builder.field(ConversationalIndexConstants.USER_FIELD, this.user);
Expand All @@ -137,9 +130,10 @@ public boolean equals(Object other) {
}
ConversationMeta otherConversation = (ConversationMeta) other;
return Objects.equals(this.id, otherConversation.id) &&
Objects.equals(this.user, otherConversation.user) &&
Objects.equals(this.createdTime, otherConversation.createdTime) &&
Objects.equals(this.name, otherConversation.name);
Objects.equals(this.user, otherConversation.user) &&
Objects.equals(this.createdTime, otherConversation.createdTime) &&
Objects.equals(this.updatedTime, otherConversation.updatedTime) &&
Objects.equals(this.name, otherConversation.name);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ public class ConversationalIndexConstants {
/** Name of the conversational metadata index */
public final static String META_INDEX_NAME = ".plugins-ml-conversation-meta";
/** Name of the metadata field for initial timestamp */
public final static String META_CREATED_FIELD = "create_time";
public final static String META_CREATED_TIME_FIELD = "create_time";
/** Name of the metadata field for updated timestamp */
public final static String META_UPDATED_TIME_FIELD = "updated_time";
/** Name of the metadata field for name of the conversation */
public final static String META_NAME_FIELD = "name";
/** Name of the owning user field in all indices */
public final static String USER_FIELD = "user";
/** Name of the application that created this conversation */
public final static String APPLICATION_TYPE_FIELD = "application_type";
/** Mappings for the conversational metadata index */
public final static String META_MAPPING = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -41,12 +45,18 @@ public class ConversationalIndexConstants {
+ " \"properties\": {\n"
+ " \""
+ META_NAME_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ "\": {\"type\": \"text\"},\n"
+ " \""
+ META_CREATED_FIELD
+ META_CREATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ META_UPDATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ USER_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ APPLICATION_TYPE_FIELD
+ "\": {\"type\": \"keyword\"}\n"
+ " }\n"
+ "}";
Expand All @@ -69,6 +79,10 @@ public class ConversationalIndexConstants {
public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info";
/** Name of the interaction field for the timestamp */
public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time";
/** Name of the interaction id */
public final static String PARENT_INTERACTIONS_ID_FIELD = "parent_interaction_id";
/** The trace number of an interaction */
public final static String INTERACTIONS_TRACE_NUMBER_FIELD = "trace_number";
/** Mappings for the interactions index */
public final static String INTERACTIONS_MAPPINGS = "{\n"
+ " \"_meta\": {\n"
Expand All @@ -95,7 +109,13 @@ public class ConversationalIndexConstants {
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_ADDITIONAL_INFO_FIELD
+ "\": {\"type\": \"text\"}\n"
+ "\": {\"type\": \"flat_object\"},\n"
+ " \""
+ PARENT_INTERACTIONS_ID_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ INTERACTIONS_TRACE_NUMBER_FIELD
+ "\": {\"type\": \"long\"}\n"
+ " }\n"
+ "}";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -54,7 +55,25 @@ public class Interaction implements Writeable, ToXContentObject {
@Getter
private String origin;
@Getter
private String additionalInfo;
private Map<String, String> additionalInfo;
@Getter
private String parentInteractionId;
@Getter
private Integer traceNum;

@Builder(toBuilder = true)
public Interaction(String id, Instant createTime, String conversationId, String input, String promptTemplate, String response, String origin, Map<String, String> additionalInfo) {
this.id = id;
this.createTime = createTime;
this.conversationId = conversationId;
this.input = input;
this.promptTemplate = promptTemplate;
this.response = response;
this.origin = origin;
this.additionalInfo = additionalInfo;
this.parentInteractionId = null;
this.traceNum = null;
}

/**
* Creates an Interaction object from a map of fields in the OS index
Expand All @@ -69,8 +88,10 @@ public static Interaction fromMap(String id, Map<String, Object> fields) {
String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD);
String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD);
String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD);
String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
Map<String,String> additionalInfo = (Map<String,String>) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD);
String parentInteractionId = (String) fields.getOrDefault(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, null);
Integer traceNum = (Integer) fields.getOrDefault(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, null);
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum);
}

/**
Expand All @@ -97,8 +118,13 @@ public static Interaction fromStream(StreamInput in) throws IOException {
String promptTemplate = in.readString();
String response = in.readString();
String origin = in.readString();
String additionalInfo = in.readOptionalString();
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo);
Map<String, String> additionalInfo = new HashMap<>();
if (in.readBoolean()) {
additionalInfo = in.readMap(s -> s.readString(), s -> s.readString());
}
String parentInteractionId = in.readOptionalString();
Integer traceNum = in.readOptionalInt();
return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum);
}


Expand All @@ -111,7 +137,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(promptTemplate);
out.writeString(response);
out.writeString(origin);
out.writeOptionalString(additionalInfo);
if (additionalInfo != null) {
out.writeBoolean(true);
out.writeMap(additionalInfo, StreamOutput::writeString, StreamOutput::writeString);
} else {
out.writeBoolean(false);
}
out.writeOptionalString(parentInteractionId);
out.writeOptionalInt(traceNum);
}

@Override
Expand All @@ -127,6 +160,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
if(additionalInfo != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
}
if (parentInteractionId != null) {
builder.field(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, parentInteractionId);
}
if (traceNum != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, traceNum);
}
builder.endObject();
return builder;
}
Expand All @@ -143,7 +182,12 @@ public boolean equals(Object other) {
((Interaction) other).response.equals(this.response) &&
((Interaction) other).origin.equals(this.origin) &&
( (((Interaction) other).additionalInfo == null && this.additionalInfo == null) ||
((Interaction) other).additionalInfo.equals(this.additionalInfo))
((Interaction) other).additionalInfo.equals(this.additionalInfo)) &&
( (((Interaction) other).parentInteractionId == null && this.parentInteractionId == null) ||
((Interaction) other).parentInteractionId.equals(this.parentInteractionId)) &&
( (((Interaction) other).traceNum == null && this.traceNum == null) ||
((Interaction) other).traceNum.equals(this.traceNum))

);
}

Expand All @@ -158,8 +202,9 @@ public String toString() {
+ ",promt_template=" + promptTemplate
+ ",response=" + response
+ ",additional_info=" + additionalInfo
+ ",parentInteractionId=" + parentInteractionId
+ ",traceNum=" + traceNum
+ "}";
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.conversation;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.TestHelper;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.time.Instant;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

public class ConversationMetaTests {

ConversationMeta conversationMeta;
Instant time;

@Before
public void setUp() {
time = Instant.now();
conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin");
}

@Test
public void test_fromSearchHit() throws IOException {
XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent());
content.startObject();
content.field(ConversationalIndexConstants.META_CREATED_TIME_FIELD, time);
content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time);
content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name");
content.field(ConversationalIndexConstants.USER_FIELD, "admin");
content.endObject();

SearchHit[] hits = new SearchHit[1];
hits[0] = new SearchHit(0, "cId", null, null).sourceRef(BytesReference.bytes(content));

ConversationMeta conversationMeta = ConversationMeta.fromSearchHit(hits[0]);
assertEquals(conversationMeta.getId(), "cId");
assertEquals(conversationMeta.getName(), "meta name");
assertEquals(conversationMeta.getUser(), "admin");
}

@Test
public void test_fromMap() {
Map<String, Object> params = Map
.of(
ConversationalIndexConstants.META_CREATED_TIME_FIELD,
time.toString(),
ConversationalIndexConstants.META_UPDATED_TIME_FIELD,
time.toString(),
ConversationalIndexConstants.META_NAME_FIELD,
"meta name",
ConversationalIndexConstants.USER_FIELD,
"admin"
);
ConversationMeta conversationMeta = ConversationMeta.fromMap("test-conversation-meta", params);
assertEquals(conversationMeta.getId(), "test-conversation-meta");
assertEquals(conversationMeta.getName(), "meta name");
assertEquals(conversationMeta.getUser(), "admin");
}

@Test
public void test_fromStream() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
conversationMeta.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ConversationMeta meta = ConversationMeta.fromStream(streamInput);
assertEquals(meta.getId(), conversationMeta.getId());
assertEquals(meta.getName(), conversationMeta.getName());
assertEquals(meta.getUser(), conversationMeta.getUser());
}

@Test
public void test_ToXContent() throws IOException {
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
conversationMeta.toXContent(builder, EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
assertEquals(content, "{\"conversation_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}");
}

@Test
public void test_toString() {
ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
assertEquals("{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", conversationMeta.toString());
}

@Test
public void test_equal() {
ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin");
assertEquals(meta.equals(conversationMeta), false);
}
}
Loading

0 comments on commit 85b372d

Please sign in to comment.