diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 0e6e768781970..263c4d3a6f78d 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -16,7 +16,6 @@ import org.opensearch.action.search.SearchRequestOperationsListener; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugin.insights.core.service.QueryInsightsService; import org.opensearch.plugin.insights.rules.model.Attribute; @@ -153,18 +152,16 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo // Get internal computed and user provided labels Map labels = new HashMap<>(); // Retrieve user provided label if exists - ThreadContext threadContext = threadPool.getThreadContext(); - String userProvidedLabel = threadContext.getRequestHeadersOnly().get(Task.X_OPAQUE_ID); + String userProvidedLabel = RequestLabelingService.getUserProvidedTag(threadPool); if (userProvidedLabel != null) { labels.put(Task.X_OPAQUE_ID, userProvidedLabel); } // Retrieve computed labels if exists - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = RequestLabelingService.getRuleBasedLabels(threadPool); if (computedLabels != null) { labels.putAll(computedLabels); } attributes.put(Attribute.LABELS, labels); - // construct SearchQueryRecord from attributes and measurements SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); queryInsightsService.addRecord(record); diff --git a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java index a4d4ca5736af0..d944ed46778f6 100644 --- a/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java +++ b/plugins/query-insights/src/test/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListenerTests.java @@ -19,7 +19,9 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.plugin.insights.core.service.QueryInsightsService; import org.opensearch.plugin.insights.core.service.TopQueriesService; +import org.opensearch.plugin.insights.rules.model.Attribute; import org.opensearch.plugin.insights.rules.model.MetricType; +import org.opensearch.plugin.insights.rules.model.SearchQueryRecord; import org.opensearch.plugin.insights.settings.QueryInsightsSettings; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.support.ValueType; @@ -35,10 +37,13 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Phaser; +import org.mockito.ArgumentCaptor; + import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -70,11 +75,12 @@ public void setup() { when(queryInsightsService.getTopQueriesService(MetricType.LATENCY)).thenReturn(topQueriesService); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "test"), new HashMap<>())); - threadContext.putTransient(RequestLabelingService.COMPUTED_LABELS, Map.of("a", "b")); + threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, "userLabel"), new HashMap<>())); + threadContext.putTransient(RequestLabelingService.RULE_BASED_LABELS, Map.of("labelKey", "labelValue")); when(threadPool.getThreadContext()).thenReturn(threadContext); } + @SuppressWarnings("unchecked") public void testOnRequestEnd() throws InterruptedException { Long timestamp = System.currentTimeMillis() - 100L; SearchType searchType = SearchType.QUERY_THEN_FETCH; @@ -101,10 +107,19 @@ public void testOnRequestEnd() throws InterruptedException { when(searchRequestContext.phaseTookMap()).thenReturn(phaseLatencyMap); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchQueryRecord.class); queryInsightsListener.onRequestEnd(searchPhaseContext, searchRequestContext); - verify(queryInsightsService, times(1)).addRecord(any()); + verify(queryInsightsService, times(1)).addRecord(captor.capture()); + SearchQueryRecord generatedRecord = captor.getValue(); + assertEquals(timestamp.longValue(), generatedRecord.getTimestamp()); + assertEquals(numberOfShards, generatedRecord.getAttributes().get(Attribute.TOTAL_SHARDS)); + assertEquals(searchType.toString().toLowerCase(Locale.ROOT), generatedRecord.getAttributes().get(Attribute.SEARCH_TYPE)); + assertEquals(searchSourceBuilder.toString(), generatedRecord.getAttributes().get(Attribute.SOURCE)); + Map labels = (Map) generatedRecord.getAttributes().get(Attribute.LABELS); + assertEquals("labelValue", labels.get("labelKey")); + assertEquals("userLabel", labels.get(Task.X_OPAQUE_ID)); } public void testConcurrentOnRequestEnd() throws InterruptedException { diff --git a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java index 6e0f9dfc14355..8a37a322428d5 100644 --- a/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java +++ b/server/src/main/java/org/opensearch/search/labels/RequestLabelingService.java @@ -14,6 +14,7 @@ import org.opensearch.threadpool.ThreadPool; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; @@ -25,7 +26,7 @@ public class RequestLabelingService { /** * Field name for computed labels */ - public static final String COMPUTED_LABELS = "computed_labels"; + public static final String RULE_BASED_LABELS = "rule_based_labels"; private final ThreadPool threadPool; private final List rules; @@ -35,21 +36,22 @@ public RequestLabelingService(final ThreadPool threadPool, final List rule } /** - * Get all the existing rules - * - * @return list of existing rules - */ - public List getRules() { - return rules; - } - - /** - * Add a labeling rule to the service + * Evaluate all labeling rules and store the computed rules into thread context * - * @param rule {@link Rule} + * @param searchRequest {@link SearchRequest} */ - public void addRule(final Rule rule) { - this.rules.add(rule); + public void applyAllRules(final SearchRequest searchRequest) { + Map labels = rules.stream() + .map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest)) + .flatMap(m -> m.entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement)); + String userProvidedTag = getUserProvidedTag(threadPool); + if (labels.containsKey(Task.X_OPAQUE_ID) && userProvidedTag.equals(labels.get(Task.X_OPAQUE_ID))) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unexpected label %s found: %s", Task.X_OPAQUE_ID, userProvidedTag) + ); + } + threadPool.getThreadContext().putTransient(RULE_BASED_LABELS, labels); } /** @@ -57,20 +59,11 @@ public void addRule(final Rule rule) { * * @return user provided tag */ - public String getUserProvidedTag() { + public static String getUserProvidedTag(ThreadPool threadPool) { return threadPool.getThreadContext().getRequestHeadersOnly().getOrDefault(Task.X_OPAQUE_ID, null); } - /** - * Evaluate all labeling rules and store the computed rules into thread context - * - * @param searchRequest {@link SearchRequest} - */ - public void applyAllRules(final SearchRequest searchRequest) { - Map labels = rules.stream() - .map(rule -> rule.evaluate(threadPool.getThreadContext(), searchRequest)) - .flatMap(m -> m.entrySet().stream()) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existing, replacement) -> replacement)); - threadPool.getThreadContext().putTransient(COMPUTED_LABELS, labels); + public static Map getRuleBasedLabels(ThreadPool threadPool) { + return threadPool.getThreadContext().getTransient(RequestLabelingService.RULE_BASED_LABELS); } } diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java index d672bb199404f..2c191aa491b32 100644 --- a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java +++ b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java @@ -8,7 +8,6 @@ package org.opensearch.search.labels; -import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.action.search.SearchRequestContext; import org.opensearch.action.search.SearchRequestOperationsListener; @@ -29,7 +28,4 @@ public void onRequestStart(SearchRequestContext searchRequestContext) { // add tags to search request requestLabelingService.applyAllRules(searchRequestContext.getRequest()); } - - @Override - public void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} } diff --git a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java index fe7f899d9c45e..2225002a3e6db 100644 --- a/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java +++ b/server/src/test/java/org/opensearch/search/labels/RequestLabelingServiceTests.java @@ -42,18 +42,10 @@ public void setUpVariables() { when(threadPool.getThreadContext()).thenReturn(threadContext); } - public void testAddRule() { - Rule mockRule = mock(Rule.class); - requestLabelingService.addRule(mockRule); - List rules = requestLabelingService.getRules(); - assertEquals(1, rules.size()); - assertEquals(mockRule, rules.get(0)); - } - public void testGetUserProvidedTag() { String expectedTag = "test-tag"; threadContext.setHeaders(new Tuple<>(Collections.singletonMap(Task.X_OPAQUE_ID, expectedTag), new HashMap<>())); - String actualTag = requestLabelingService.getUserProvidedTag(); + String actualTag = RequestLabelingService.getUserProvidedTag(threadPool); assertEquals(expectedTag, actualTag); } @@ -63,7 +55,7 @@ public void testBasicApplyAllRules() { when(mockRule1.evaluate(threadContext, mockSearchRequest)).thenReturn(mockLabelMap); rules.add(mockRule1); requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); assertEquals(1, computedLabels.size()); assertEquals("value1", computedLabels.get("label1")); } @@ -77,7 +69,7 @@ public void testApplyAllRulesWithConflict() { rules.add(mockRule1); rules.add(mockRule2); requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); assertEquals(1, computedLabels.size()); assertEquals("value2", computedLabels.get("conflictingLabel")); } @@ -91,7 +83,7 @@ public void testApplyAllRulesWithoutConflict() { rules.add(mockRule1); rules.add(mockRule2); requestLabelingService.applyAllRules(mockSearchRequest); - Map computedLabels = threadContext.getTransient(RequestLabelingService.COMPUTED_LABELS); + Map computedLabels = threadContext.getTransient(RequestLabelingService.RULE_BASED_LABELS); assertEquals(2, computedLabels.size()); assertEquals("value1", computedLabels.get("label1")); assertEquals("value2", computedLabels.get("label2"));