Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Message and Memory API Validation and storage #3283

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,18 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId);
builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id);
builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime);
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
if (input != null && !input.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input);
}
if (promptTemplate != null && !promptTemplate.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate);
}
if (response != null && !response.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response);
}
if (origin != null && !origin.trim().isEmpty()) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin);
}
if (additionalInfo != null) {
builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,17 @@ public void test_ToXContent() throws IOException {
.builder()
.conversationId("conversation id")
.origin("amazon bedrock")
.promptTemplate(" ")
.parentInteractionId("parant id")
.additionalInfo(Collections.singletonMap("suggestion", "new suggestion"))
.response("sample response")
.traceNum(1)
.build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
interaction.toXContent(builder, EMPTY_PARAMS);
String interactionContent = TestHelper.xContentBuilderToString(builder);
assertEquals(
"{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}",
"{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"response\":\"sample response\",\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}",
rithin-pullela-aws marked this conversation as resolved.
Show resolved Hide resolved
interactionContent
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,28 @@ public static CreateConversationRequest fromRestRequest(RestRequest restRequest)
}
try (XContentParser parser = restRequest.contentParser()) {
Map<String, Object> body = parser.map();
String name = null;
String applicationType = null;
Map<String, String> additionalInfo = null;

for (String key : body.keySet()) {
rithin-pullela-aws marked this conversation as resolved.
Show resolved Hide resolved
switch (key) {
case ActionConstants.REQUEST_CONVERSATION_NAME_FIELD:
name = (String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD);
break;
case APPLICATION_TYPE_FIELD:
applicationType = (String) body.get(APPLICATION_TYPE_FIELD);
break;
case META_ADDITIONAL_INFO_FIELD:
additionalInfo = (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD);
break;
default:
parser.skipChildren();
break;
}
}
if (body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) != null) {
return new CreateConversationRequest(
(String) body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD),
body.get(APPLICATION_TYPE_FIELD) == null ? null : (String) body.get(APPLICATION_TYPE_FIELD),
body.get(META_ADDITIONAL_INFO_FIELD) == null ? null : (Map<String, String>) body.get(META_ADDITIONAL_INFO_FIELD)
);
return new CreateConversationRequest(name, applicationType, additionalInfo);
} else {
return new CreateConversationRequest();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ public static CreateInteractionRequest fromRestRequest(RestRequest request) thro
}
}

boolean allFieldsEmpty = (input == null || input.trim().isEmpty())
&& (prompt == null || prompt.trim().isEmpty())
&& (response == null || response.trim().isEmpty())
&& (origin == null || origin.trim().isEmpty())
&& (addinf == null || addinf.isEmpty());
if (allFieldsEmpty) {
throw new IllegalArgumentException(
"At least one of the following parameters must be non-empty: " + "input, prompt_template, response, origin, additional_info"
rithin-pullela-aws marked this conversation as resolved.
Show resolved Hide resolved
);
}
return new CreateInteractionRequest(cid, input, prompt, response, origin, addinf, parintid, tracenum);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -139,24 +140,24 @@ public void createConversation(
) {
initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> {
if (indexExists) {
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
Instant now = Instant.now();
IndexRequest request = Requests
.indexRequest(META_INDEX_NAME)
.source(
ConversationalIndexConstants.META_CREATED_TIME_FIELD,
now,
ConversationalIndexConstants.META_UPDATED_TIME_FIELD,
now,
ConversationalIndexConstants.META_NAME_FIELD,
name,
ConversationalIndexConstants.USER_FIELD,
userstr == null ? null : User.parse(userstr).getName(),
ConversationalIndexConstants.APPLICATION_TYPE_FIELD,
applicationType,
ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD,
additionalInfos == null ? Map.of() : additionalInfos
);
Map<String, Object> sourceMap = new HashMap<>();
sourceMap.put(ConversationalIndexConstants.META_CREATED_TIME_FIELD, now);
sourceMap.put(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, now);
if (name != null && !name.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.META_NAME_FIELD, name);
}
if (userStr != null && !userStr.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName());
}
if (applicationType != null && !applicationType.trim().isEmpty()) {
sourceMap.put(ConversationalIndexConstants.APPLICATION_TYPE_FIELD, applicationType);
}
if (additionalInfos != null && !additionalInfos.isEmpty()) {
sourceMap.put(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD, additionalInfos);
}
IndexRequest request = Requests.indexRequest(META_INDEX_NAME).source(sourceMap);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
Expand Down Expand Up @@ -210,12 +211,12 @@ public void getConversations(int from, int maxResults, ActionListener<List<Conve
return;
}
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
QueryBuilder queryBuilder;
if (userstr == null)
if (userStr == null)
queryBuilder = new MatchAllQueryBuilder();
else
queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName());
queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userStr).getName());
request.source().query(queryBuilder);
request.source().from(from).size(maxResults);
request.source().sort(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, SortOrder.DESC);
Expand Down Expand Up @@ -264,8 +265,8 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
return;
}
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
String userstr = getUserStrFromThreadContext();
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String userStr = getUserStrFromThreadContext();
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -308,7 +309,7 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
listener.onResponse(true);
return;
}
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId);
Expand All @@ -318,12 +319,12 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
throw new ResourceNotFoundException("Memory [" + conversationId + "] not found");
}
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
if (userStr == null || User.parse(userStr) == null) {
internalListener.onResponse(true);
return;
}
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
String user = User.parse(userstr).getName();
String user = User.parse(userStr).getName();
// If you're not the owner of this conversation, you do not have permission
if (!user.equals(conversation.getUser())) {
internalListener.onResponse(false);
Expand Down Expand Up @@ -353,9 +354,9 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
QueryBuilder originalQuery = request.source().query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
newQuery.must(originalQuery);
String userstr = getUserStrFromThreadContext();
if (userstr != null) {
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String userStr = getUserStrFromThreadContext();
if (userStr != null) {
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user));
}
request.source().query(newQuery);
Expand Down Expand Up @@ -388,11 +389,11 @@ public void updateConversation(String conversationId, UpdateRequest updateReques
if (access) {
innerUpdateConversation(updateRequest, listener);
} else {
String userstr = client
String userStr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
throw new OpenSearchStatusException(
"User [" + user + "] does not have access to memory " + conversationId,
RestStatus.UNAUTHORIZED
Expand Down Expand Up @@ -421,7 +422,7 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
listener.onFailure(new IndexNotFoundException("cannot get memory since the memory index does not exist", META_INDEX_NAME));
return;
}
String userstr = getUserStrFromThreadContext();
String userStr = getUserStrFromThreadContext();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<ConversationMeta> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId);
Expand All @@ -432,12 +433,12 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
}
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
// If no security, return conversation
if (userstr == null || User.parse(userstr) == null) {
if (userStr == null || User.parse(userStr) == null) {
internalListener.onResponse(conversation);
return;
}
// If security and correct user, return conversation
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
String user = User.parse(userStr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userStr).getName();
if (user.equals(conversation.getUser())) {
internalListener.onResponse(conversation);
log.info("Successfully get the memory for {}", conversationId);
Expand Down
Loading
Loading