Skip to content

Commit

Permalink
refactor code based on comments
Browse files Browse the repository at this point in the history
Signed-off-by: Chenyang Ji <[email protected]>
  • Loading branch information
ansjcy committed Jun 6, 2024
1 parent b34314c commit c135c5c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.action.search.SearchRequestOperationsListener;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
import org.opensearch.plugin.insights.rules.model.Attribute;
Expand Down Expand Up @@ -153,18 +152,16 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo
// Get internal computed and user provided labels
Map<String, Object> labels = new HashMap<>();
// Retrieve user provided label if exists
ThreadContext threadContext = threadPool.getThreadContext();
String userProvidedLabel = threadContext.getRequestHeadersOnly().get(Task.X_OPAQUE_ID);
String userProvidedLabel = RequestLabelingService.getUserProvidedTag(threadPool);
if (userProvidedLabel != null) {
labels.put(Task.X_OPAQUE_ID, userProvidedLabel);
}
// Retrieve computed labels if exists
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
Map<String, Object> computedLabels = RequestLabelingService.getRuleBasedLabels(threadPool);
if (computedLabels != null) {
labels.putAll(computedLabels);
}
attributes.put(Attribute.LABELS, labels);

// construct SearchQueryRecord from attributes and measurements
SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes);
queryInsightsService.addRecord(record);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.plugin.insights.core.service.QueryInsightsService;
import org.opensearch.plugin.insights.core.service.TopQueriesService;
import org.opensearch.plugin.insights.rules.model.Attribute;
import org.opensearch.plugin.insights.rules.model.MetricType;
import org.opensearch.plugin.insights.rules.model.SearchQueryRecord;
import org.opensearch.plugin.insights.settings.QueryInsightsSettings;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
Expand All @@ -35,10 +37,13 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Phaser;

import org.mockito.ArgumentCaptor;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand Down Expand Up @@ -70,11 +75,12 @@ public void setup() {
when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService);

ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "test"), new HashMap<>()));
threadContext.putTransient(RequestLabelingService.COMPUTED_LABELS, Map.of("a", "b"));
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel"), new HashMap<>()));
threadContext.putTransient(RequestLabelingService.RULE_BASED_LABELS, Map.of("labelKey", "labelValue"));
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

@SuppressWarnings("unchecked")
public void testOnRequestEnd() throws InterruptedException {
Long timestamp = System.currentTimeMillis() - 100L;
SearchType searchType = SearchType.QUERY_THEN_FETCH;
Expand All @@ -101,10 +107,19 @@ public void testOnRequestEnd() throws InterruptedException {
when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap);
when(searchPhaseContext.getRequest()).thenReturn(searchRequest);
when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards);
ArgumentCaptor<SearchQueryRecord> captor = ArgumentCaptor.forClass(SearchQueryRecord.class);

queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext);

verify(queryInsightsService, times(1)).addRecord(any());
verify(queryInsightsService, times(1)).addRecord(captor.capture());
SearchQueryRecord generatedRecord = captor.getValue();
assertEquals(timestamp.longValue(), generatedRecord.getTimestamp());
assertEquals(numberOfShards, generatedRecord.getAttributes().get(Attribute.TOTAL_SHARDS));
assertEquals(searchType.toString().toLowerCase(Locale.ROOT), generatedRecord.getAttributes().get(Attribute.SEARCH_TYPE));
assertEquals(searchSourceBuilder.toString(), generatedRecord.getAttributes().get(Attribute.SOURCE));
Map<String, String> labels = (Map<String, String>) generatedRecord.getAttributes().get(Attribute.LABELS);
assertEquals("labelValue", labels.get("labelKey"));
assertEquals("userLabel", labels.get(Task.X_OPAQUE_ID));
}

public void testConcurrentOnRequestEnd() throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.threadpool.ThreadPool;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

Expand All @@ -25,7 +26,7 @@ public class RequestLabelingService {
/**
* Field name for computed labels
*/
public static final String COMPUTED_LABELS = "computed_labels";
public static final String RULE_BASED_LABELS = "rule_based_labels";
private final ThreadPool threadPool;
private final List<Rule> rules;

Expand All @@ -35,42 +36,34 @@ public RequestLabelingService(final ThreadPool threadPool, final List<Rule> rule
}

/**
* Get all the existing rules
*
* @return list of existing rules
*/
public List<Rule> getRules() {
return rules;
}

/**
* Add a labeling rule to the service
* Evaluate all labeling rules and store the computed rules into thread context
*
* @param rule {@link Rule}
* @param searchRequest {@link SearchRequest}
*/
public void addRule(final Rule rule) {
this.rules.add(rule);
public void applyAllRules(final SearchRequest searchRequest) {
Map<String, Object> labels = rules.stream()
.map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest))
.flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement));
String userProvidedTag = getUserProvidedTag(threadPool);
if (labels.containsKey(Task.X_OPAQUE_ID) && userProvidedTag.equals(labels.get(Task.X_OPAQUE_ID))) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Unexpected label %s found: %s", Task.X_OPAQUE_ID, userProvidedTag)
);
}
threadPool.getThreadContext().putTransient(RULE_BASED_LABELS, labels);
}

/**
* Get the user provided tag from the X-Opaque-Id header
*
* @return user provided tag
*/
public String getUserProvidedTag() {
public static String getUserProvidedTag(ThreadPool threadPool) {
return threadPool.getThreadContext().getRequestHeadersOnly().getOrDefault(Task.X_OPAQUE_ID, null);
}

/**
* Evaluate all labeling rules and store the computed rules into thread context
*
* @param searchRequest {@link SearchRequest}
*/
public void applyAllRules(final SearchRequest searchRequest) {
Map<String, Object> labels = rules.stream()
.map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest))
.flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement));
threadPool.getThreadContext().putTransient(COMPUTED_LABELS, labels);
public static Map<String, Object> getRuleBasedLabels(ThreadPool threadPool) {
return threadPool.getThreadContext().getTransient(RequestLabelingService.RULE_BASED_LABELS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.opensearch.search.labels;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchRequestContext;
import org.opensearch.action.search.SearchRequestOperationsListener;

Expand All @@ -29,7 +28,4 @@ public void onRequestStart(SearchRequestContext searchRequestContext) {
// add tags to search request
requestLabelingService.applyAllRules(searchRequestContext.getRequest());
}

@Override
public void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,10 @@ public void setUpVariables() {
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testAddRule() {
Rule mockRule = mock(Rule.class);
requestLabelingService.addRule(mockRule);
List<Rule> rules = requestLabelingService.getRules();
assertEquals(1, rules.size());
assertEquals(mockRule, rules.get(0));
}

public void testGetUserProvidedTag() {
String expectedTag = "test-tag";
threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, expectedTag), new HashMap<>()));
String actualTag = requestLabelingService.getUserProvidedTag();
String actualTag = RequestLabelingService.getUserProvidedTag(threadPool);
assertEquals(expectedTag, actualTag);
}

Expand All @@ -63,7 +55,7 @@ public void testBasicApplyAllRules() {
when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap);
rules.add(mockRule1);
requestLabelingService.applyAllRules(mockSearchRequest);
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS);
assertEquals(1, computedLabels.size());
assertEquals("value1", computedLabels.get("label1"));
}
Expand All @@ -77,7 +69,7 @@ public void testApplyAllRulesWithConflict() {
rules.add(mockRule1);
rules.add(mockRule2);
requestLabelingService.applyAllRules(mockSearchRequest);
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS);
assertEquals(1, computedLabels.size());
assertEquals("value2", computedLabels.get("conflictingLabel"));
}
Expand All @@ -91,7 +83,7 @@ public void testApplyAllRulesWithoutConflict() {
rules.add(mockRule1);
rules.add(mockRule2);
requestLabelingService.applyAllRules(mockSearchRequest);
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS);
Map<String, Object> computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS);
assertEquals(2, computedLabels.size());
assertEquals("value1", computedLabels.get("label1"));
assertEquals("value2", computedLabels.get("label2"));
Expand Down

0 comments on commit c135c5c

Please sign in to comment.