Skip to content

Commit

Permalink
Finish unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Aug 21, 2023
1 parent 54fbb2b commit 4c96312
Show file tree
Hide file tree
Showing 60 changed files with 1,683 additions and 1,092 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ public String toString() {
+ ",cid=" + conversationId
+ ",timestamp=" + timestamp
+ ",agent=" + agent
+ ",prompt=" + prompt
+ ",input=" + input
+ ",response=" + response
+ ",metadata=" + metadata
+ "}";
}

Expand Down
10 changes: 10 additions & 0 deletions conversational-memory/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
id 'com.diffplug.spotless' version '6.18.0'
}

dependencies {
Expand Down Expand Up @@ -75,4 +76,13 @@ jacocoTestCoverageVerification {
}
}
dependsOn jacocoTestReport
}

spotless {
java {
removeUnusedImports()
importOrder 'java', 'javax', 'org', 'com'

eclipse().configFile rootProject.file('.eclipseformat.xml')
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public interface ConversationalMemoryHandler {
* @param listener gets the ID of the new interaction
*/
public void createInteraction(
String conversationId,
String conversationId,
String input,
String prompt,
String response,
Expand All @@ -87,7 +87,7 @@ public void createInteraction(
* @return ActionFuture for the interactionId of the new interaction
*/
public ActionFuture<String> createInteraction(
String conversationId,
String conversationId,
String input,
String prompt,
String response,
Expand Down Expand Up @@ -171,4 +171,4 @@ public ActionFuture<String> createInteraction(
*/
public ActionFuture<Boolean> deleteConversation(String conversationId);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ public class CreateConversationAction extends ActionType<CreateConversationRespo
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/conversational/conversation/create";

private CreateConversationAction() { super(NAME, CreateConversationResponse::new); }
}
private CreateConversationAction() {
super(NAME, CreateConversationResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public CreateConversationRequest(String name) {
super();
this.name = name;
}

/**
* Constructor
* name will be null
Expand All @@ -79,11 +80,11 @@ public ActionRequestValidationException validate() {
* @throws IOException if something breaks
*/
public static CreateConversationRequest fromRestRequest(RestRequest restRequest) throws IOException {
if(restRequest.hasParam(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
if (restRequest.hasParam(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) {
return new CreateConversationRequest(restRequest.param(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD));
} else {
return new CreateConversationRequest();
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
return builder;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ public class CreateConversationRestAction extends BaseRestHandler {

@Override
public List<Route> routes() {
return List.of(
new Route(RestRequest.Method.POST, ActionConstants.CREATE_CONVERSATION_PATH)
);
return List.of(new Route(RestRequest.Method.POST, ActionConstants.CREATE_CONVERSATION_PATH));
}

@Override
Expand All @@ -50,4 +48,4 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
return channel -> client.execute(CreateConversationAction.INSTANCE, ccRequest, new RestToXContentListener<>(channel));
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
*/
package org.opensearch.ml.conversational.action.memory.conversation;

import org.opensearch.core.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.conversational.ConversationalMemoryHandler;
import org.opensearch.ml.conversational.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -50,7 +50,7 @@ public class CreateConversationTransportAction extends HandledTransportAction<Cr
public CreateConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
) {
super(CreateConversationAction.NAME, transportService, actionFilters, CreateConversationRequest::new);
Expand All @@ -63,23 +63,20 @@ protected void doExecute(Task task, CreateConversationRequest request, ActionLis
String name = request.getName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<CreateConversationResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<String> al = ActionListener.wrap(r -> {
internalListener.onResponse(new CreateConversationResponse(r));
}, e -> {
ActionListener<String> al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> {
log.error(e.toString());
internalListener.onFailure(e);
});

if(name == null) {
if (name == null) {
cmHandler.createConversation(al);
} else {
cmHandler.createConversation(name, al);
}
} catch(Exception e) {
} catch (Exception e) {
log.error("Failed to create new conversation with name " + request.getName(), e);
actionListener.onFailure(e);
}
}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ public class DeleteConversationAction extends ActionType<DeleteConversationRespo
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/conversational/conversation/delete";

private DeleteConversationAction() {super(NAME, DeleteConversationResponse::new);}
}
private DeleteConversationAction() {
super(NAME, DeleteConversationResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.conversational.action.memory.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
Expand All @@ -28,8 +30,6 @@

import lombok.AllArgsConstructor;

import static org.opensearch.action.ValidateActions.addValidationError;

/**
* Action Request for Delete Conversation
*/
Expand Down Expand Up @@ -81,4 +81,4 @@ public static DeleteConversationRequest fromRestRequest(RestRequest request) thr
return new DeleteConversationRequest(cid);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,4 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
return builder;
}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ public class DeleteConversationRestAction extends BaseRestHandler {

@Override
public List<Route> routes() {
return List.of(
new Route(RestRequest.Method.DELETE, ActionConstants.DELETE_CONVERSATION_PATH)
);
return List.of(new Route(RestRequest.Method.DELETE, ActionConstants.DELETE_CONVERSATION_PATH));
}

@Override
Expand All @@ -49,4 +47,4 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
DeleteConversationRequest dcRequest = DeleteConversationRequest.fromRestRequest(request);
return channel -> client.execute(DeleteConversationAction.INSTANCE, dcRequest, new RestToXContentListener<>(channel));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
*/
package org.opensearch.ml.conversational.action.memory.conversation;

import org.opensearch.core.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.conversational.ConversationalMemoryHandler;
import org.opensearch.ml.conversational.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -50,7 +50,7 @@ public class DeleteConversationTransportAction extends HandledTransportAction<De
public DeleteConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
OpenSearchConversationalMemoryHandler cmHandler,
Client client
) {
super(DeleteConversationAction.NAME, transportService, actionFilters, DeleteConversationRequest::new);
Expand All @@ -61,18 +61,16 @@ public DeleteConversationTransportAction(
@Override
public void doExecute(Task task, DeleteConversationRequest request, ActionListener<DeleteConversationResponse> listener) {
String conversationId = request.getId();
try(ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<DeleteConversationResponse> internalListener = ActionListener.runBefore(listener, () -> context.restore());
ActionListener<Boolean> al = ActionListener.wrap(success -> {
DeleteConversationResponse response = new DeleteConversationResponse(success);
internalListener.onResponse(response);
}, e -> {
internalListener.onFailure(e);
});
}, e -> { internalListener.onFailure(e); });
cmHandler.deleteConversation(conversationId, al);
} catch (Exception e) {
log.error(e.toString());
listener.onFailure(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ public class GetConversationsAction extends ActionType<GetConversationsResponse>
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/conversational/conversation/list";

private GetConversationsAction() { super(NAME, GetConversationsResponse::new); }
private GetConversationsAction() {
super(NAME, GetConversationsResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.conversational.action.memory.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
Expand All @@ -28,8 +30,6 @@

import lombok.Getter;

import static org.opensearch.action.ValidateActions.addValidationError;

/**
* ActionRequest for list conversations action
*/
Expand Down Expand Up @@ -87,7 +87,7 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if(this.maxResults <= 0) {
if (this.maxResults <= 0) {
exception = addValidationError("Can't list 0 or negative conversations", exception);
}
return exception;
Expand All @@ -100,20 +100,24 @@ public ActionRequestValidationException validate() {
* @throws IOException if something breaks
*/
public static GetConversationsRequest fromRestRequest(RestRequest request) throws IOException {
if(request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) {
if(request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
return new GetConversationsRequest(Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)),
Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)));
if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) {
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
return new GetConversationsRequest(
Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)),
Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD))
);
} else {
return new GetConversationsRequest(ActionConstants.DEFAULT_MAX_RESULTS,
Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)));
return new GetConversationsRequest(
ActionConstants.DEFAULT_MAX_RESULTS,
Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD))
);
}
} else {
if(request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) {
return new GetConversationsRequest(Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)));
} else {
return new GetConversationsRequest();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversational.ActionConstants;
import org.opensearch.ml.common.conversational.ConversationMeta;

import lombok.AllArgsConstructor;
import lombok.Getter;

import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

/**
* Action Response for CreateConversation
*/
Expand Down Expand Up @@ -75,15 +74,15 @@ public boolean hasMorePages() {
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.startArray(ActionConstants.RESPONSE_CONVERSATION_LIST_FIELD);
for(ConversationMeta conversation : conversations) {
for (ConversationMeta conversation : conversations) {
conversation.toXContent(builder, params);
}
builder.endArray();
if(hasMoreTokens) {
if (hasMoreTokens) {
builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken);
}
builder.endObject();
return builder;
}

}
}
Loading

0 comments on commit 4c96312

Please sign in to comment.