Skip to content

Commit

Permalink
comments updated
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Dec 15, 2023
1 parent 219aeac commit d4a4232
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class ActionConstants {
public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations";
/** name of list on interactions in all responses */
public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions";
/** name of list on interactions in all responses */
/** name of list on traces in all responses */
public final static String RESPONSE_TRACES_LIST_FIELD = "traces";
/** name of interaction Id field in all responses */
public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id";
Expand Down Expand Up @@ -65,12 +65,11 @@ public class ActionConstants {
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for update conversations */
public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_update";
/** path for put interaction */
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
/** path for get interactions */
/** path for get traces */
public final static String GET_TRACES_REST_PATH = "/_plugins/_ml/memory/trace" + "/{interaction_id}/_list";
/** path for delete conversation */
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class StringUtils {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
Expand All @@ -33,15 +32,13 @@ public class GetTracesTransportAction extends HandledTransportAction<GetTracesRe
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetTracesTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
Client client
) {
super(GetTracesAction.NAME, transportService, actionFilters, GetTracesRequest::new);
this.client = client;
Expand All @@ -53,6 +50,7 @@ public void doExecute(Task task, GetTracesRequest request, ActionListener<GetTra
int maxResults = request.getMaxResults();
int from = request.getFrom();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
// TODO: check this newStoredContext() method and remove it if it's redundant
ActionListener<GetTracesResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
ActionListener<List<Interaction>> al = ActionListener.wrap(tracesList -> {
internalListener.onResponse(new GetTracesResponse(tracesList, from + maxResults, tracesList.size() == maxResults));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class UpdateConversationRequest extends ActionRequest {
private String conversationId;
private Map<String, Object> updateContent;

private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD, APPLICATION_TYPE_FIELD));
private static final Set<String> allowedList = new HashSet<>(Arrays.asList(META_NAME_FIELD));

@Builder
public UpdateConversationRequest(String conversationId, Map<String, Object> updateContent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Set;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -26,8 +25,6 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
Expand All @@ -48,9 +45,6 @@ public class GetTracesTransportActionTests extends OpenSearchTestCase {
@Mock
Client client;

@Mock
ClusterService clusterService;

@Mock
TransportService transportService;

Expand Down Expand Up @@ -82,11 +76,8 @@ public void setup() throws IOException {
this.threadContext = new ThreadContext(settings);
when(this.client.threadPool()).thenReturn(this.threadPool);
when(this.threadPool.getThreadContext()).thenReturn(this.threadContext);
when(this.clusterService.getSettings()).thenReturn(settings);
when(this.clusterService.getClusterSettings())
.thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED)));

this.action = spy(new GetTracesTransportAction(transportService, actionFilters, cmHandler, client, clusterService));
this.action = spy(new GetTracesTransportAction(transportService, actionFilters, cmHandler, client));
}

public void testGetTraces_noMorePages() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ public void testParse_Success() throws IOException {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.parse(parser, "conversationId");
assertEquals(updateConversationRequest.getConversationId(), "conversationId");
assertEquals("new name", updateConversationRequest.getUpdateContent().get("name"));
assertEquals("new type", updateConversationRequest.getUpdateContent().get("application_type"));
}

@Test
Expand Down

0 comments on commit d4a4232

Please sign in to comment.