From d09374c3feb277a514d641dbc7bf91dce19ec303 Mon Sep 17 00:00:00 2001 From: Rithin Pullela Date: Tue, 24 Dec 2024 11:55:19 -0800 Subject: [PATCH] Add application_type to ConversationMeta; update tests (#3282) Modify getMemory(Conversation) to return the application_type parameter. Include application_type in the ConversationMeta data model. Update existing tests to validate the new parameter. Signed-off-by: rithin-pullela-aws --- .../common/conversation/ConversationMeta.java | 22 +++++++++++++++---- .../conversation/ConversationMetaTests.java | 12 +++++++--- .../GetConversationResponseTests.java | 6 ++--- .../GetConversationTransportActionTests.java | 2 +- .../GetConversationsResponseTests.java | 6 ++--- .../GetConversationsTransportActionTests.java | 10 ++++----- ...earchConversationalMemoryHandlerTests.java | 2 +- 7 files changed, 40 insertions(+), 20 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java index 21ed608654..5d847fbcd7 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -51,6 +51,8 @@ public class ConversationMeta implements Writeable, ToXContentObject { @Getter private String user; @Getter + private String applicationType; + @Getter private Map additionalInfos; /** @@ -74,8 +76,9 @@ public static ConversationMeta fromMap(String id, Map docFields) Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD)); String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD); String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD); + String applicationType = (String) docFields.get(ConversationalIndexConstants.APPLICATION_TYPE_FIELD); Map additionalInfos = (Map) docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD); - return new ConversationMeta(id, created, updated, name, user, additionalInfos); + return new ConversationMeta(id, created, updated, name, user, applicationType, additionalInfos); } /** @@ -91,13 +94,14 @@ public static ConversationMeta fromStream(StreamInput in) throws IOException { Instant updated = in.readInstant(); String name = in.readString(); String user = in.readOptionalString(); + String applicationType = in.readOptionalString(); Map additionalInfos = null; if (in.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { if (in.readBoolean()) { additionalInfos = in.readMap(StreamInput::readString, StreamInput::readString); } } - return new ConversationMeta(id, created, updated, name, user, additionalInfos); + return new ConversationMeta(id, created, updated, name, user, applicationType, additionalInfos); } @Override @@ -107,6 +111,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(updatedTime); out.writeString(name); out.writeOptionalString(user); + out.writeOptionalString(applicationType); if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { if (additionalInfos == null) { out.writeBoolean(false); @@ -129,6 +134,10 @@ public String toString() { + updatedTime.toString() + ", user=" + user + + ", applicationType=" + + applicationType + + ", additionalInfos=" + + additionalInfos + "}"; } @@ -142,7 +151,10 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para if (this.user != null) { builder.field(ConversationalIndexConstants.USER_FIELD, this.user); } - if (this.additionalInfos != null) { + if (this.applicationType != null && !this.applicationType.trim().isEmpty()) { + builder.field(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, this.applicationType); + } + if (this.additionalInfos != null && !additionalInfos.isEmpty()) { builder.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, this.additionalInfos); } builder.endObject(); @@ -159,7 +171,9 @@ public boolean equals(Object other) { && Objects.equals(this.user, otherConversation.user) && Objects.equals(this.createdTime, otherConversation.createdTime) && Objects.equals(this.updatedTime, otherConversation.updatedTime) - && Objects.equals(this.name, otherConversation.name); + && Objects.equals(this.name, otherConversation.name) + && Objects.equals(this.applicationType, otherConversation.applicationType) + && Objects.equals(this.additionalInfos, otherConversation.additionalInfos); } } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java index aaa52ffcff..7666ab3bf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java @@ -30,7 +30,7 @@ public class ConversationMetaTests { @Before public void setUp() { time = Instant.now(); - conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", null); + conversationMeta = new ConversationMeta("test_id", time, time, "test_name", "admin", "conversational-search", null); } @Test @@ -41,6 +41,7 @@ public void test_fromSearchHit() throws IOException { content.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, time); content.field(ConversationalIndexConstants.META_NAME_FIELD, "meta name"); content.field(ConversationalIndexConstants.USER_FIELD, "admin"); + content.field(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, "conversational-search"); content.field(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, Map.of("test_key", "test_value")); content.endObject(); @@ -51,6 +52,7 @@ public void test_fromSearchHit() throws IOException { assertEquals(conversationMeta.getId(), "cId"); assertEquals(conversationMeta.getName(), "meta name"); assertEquals(conversationMeta.getUser(), "admin"); + assertEquals(conversationMeta.getApplicationType(), "conversational-search"); assertEquals(conversationMeta.getAdditionalInfos().get("test_key"), "test_value"); } @@ -83,6 +85,7 @@ public void test_fromStream() throws IOException { assertEquals(meta.getId(), conversationMeta.getId()); assertEquals(meta.getName(), conversationMeta.getName()); assertEquals(meta.getUser(), conversationMeta.getUser()); + assertEquals(meta.getApplicationType(), conversationMeta.getApplicationType()); } @Test @@ -93,6 +96,7 @@ public void test_ToXContent() throws IOException { Instant.ofEpochMilli(123), "test meta", "admin", + "neural-search", null ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -100,7 +104,7 @@ public void test_ToXContent() throws IOException { String content = TestHelper.xContentBuilderToString(builder); assertEquals( content, - "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}" + "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\",\"application_type\":\"neural-search\"}" ); } @@ -112,10 +116,11 @@ public void test_toString() { Instant.ofEpochMilli(123), "test meta", "admin", + "conversational-search", null ); assertEquals( - "{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", + "{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin, applicationType=conversational-search, additionalInfos=null}", conversationMeta.toString() ); } @@ -128,6 +133,7 @@ public void test_equal() { Instant.ofEpochMilli(123), "test meta", "admin", + "conversational-search", null ); assertEquals(meta.equals(conversationMeta), false); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java index 0b39d546f8..08a285ec90 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -38,7 +38,7 @@ public class GetConversationResponseTests extends OpenSearchTestCase { public void testGetConversationResponseStreaming() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null, null); GetConversationResponse response = new GetConversationResponse(convo); assert (response.getConversation().equals(convo)); @@ -51,7 +51,7 @@ public void testGetConversationResponseStreaming() throws IOException { } public void testToXContent() throws IOException { - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null, null); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -68,7 +68,7 @@ public void testToXContent() throws IOException { public void testToXContent_withAdditionalInfo() throws IOException { Map additionalInfos = Map.of("key1", "value1"); - ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, additionalInfos); + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), Instant.now(), "name", null, null, additionalInfos); GetConversationResponse response = new GetConversationResponse(convo); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java index 558ecd9b65..aa85a8507d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -107,7 +107,7 @@ public void setup() throws IOException { } public void testGetConversation() { - ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null, null); + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), Instant.now(), "name", null, null, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(result); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java index b28ed26d0f..71982e6a76 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java @@ -46,9 +46,9 @@ public class GetConversationsResponseTests extends OpenSearchTestCase { public void setup() { conversations = List .of( - new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0", null), - new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0", null), - new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2", null) + new ConversationMeta("0", Instant.now(), Instant.now(), "name0", "user0", null, null), + new ConversationMeta("1", Instant.now(), Instant.now(), "name1", "user0", null, null), + new ConversationMeta("2", Instant.now(), Instant.now(), "name2", "user2", null, null) ); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java index a866167d37..257c74f1bb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java @@ -114,8 +114,8 @@ public void testGetConversations() { log.info("testing get conversations transport"); List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null), - new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null, null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null, null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); @@ -132,9 +132,9 @@ public void testGetConversations() { public void testPagination() { List testResult = List .of( - new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null), - new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null), - new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null, null) + new ConversationMeta("testcid1", Instant.now(), Instant.now(), "", null, null, null), + new ConversationMeta("testcid2", Instant.now(), Instant.now(), "testname", null, null, null), + new ConversationMeta("testcid3", Instant.now(), Instant.now(), "testname", null, null, null) ); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index 903be08338..fc63811d2c 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -315,7 +315,7 @@ public void testSearchInteractions_Future() { } public void testGetAConversation_Future() { - ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null, null); + ConversationMeta response = new ConversationMeta("cid", Instant.now(), Instant.now(), "boring name", null, null, null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(response);