From ff0eb398ec7b68b8d150c5283a92f80b8cef33c4 Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Tue, 4 Jun 2024 14:03:28 -0700 Subject: [PATCH] improve the overall logic and fix several bugs based on comments Signed-off-by: Chenyang Ji --- .../core/listener/QueryInsightsListener.java | 2 +- .../plugin/insights/rules/model/Attribute.java | 8 ++------ .../search/SearchRequestOperationsListener.java | 6 +++--- .../search/builder/SearchSourceBuilder.java | 8 -------- .../search/labels/RuleBasedLabelingService.java | 6 ++++++ .../search/labels/SearchRequestLabelingListener.java | 9 --------- .../labels/rules/DefaultUserInfoLabelingRule.java | 12 ++++++------ .../labels/DefaultUserInfoLabelingRuleTests.java | 8 ++++++++ 8 files changed, 26 insertions(+), 33 deletions(-) diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java index 2f328c534a8f4..e2ecb76591c5b 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/core/listener/QueryInsightsListener.java @@ -139,7 +139,7 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards()); attributes.put(Attribute.INDICES, request.indices()); attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap()); - attributes.put(Attribute.USER_NAME, request.source().labels().get(DefaultUserInfoLabelingRule.USER_NAME)); + attributes.put(Attribute.LABELS, request.source().labels()); SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes); queryInsightsService.addRecord(record); } catch (Exception e) { diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java index dc000a49e5d36..bd8948305f06d 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java @@ -45,13 +45,9 @@ public enum Attribute { */ NODE_ID, /** - * User associated with this request + * Custom labels */ - USER_NAME, - /** - * Custom tenant tags - */ - CUSTOMIZED_TAG; + LABELS; /** * Read an Attribute from a StreamInput diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java index 53efade174502..b944572cef122 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java @@ -41,11 +41,11 @@ protected SearchRequestOperationsListener(final boolean enabled) { this.enabled = enabled; } - protected abstract void onPhaseStart(SearchPhaseContext context); + protected void onPhaseStart(SearchPhaseContext context) {}; - protected abstract void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext); + protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}; - protected abstract void onPhaseFailure(SearchPhaseContext context, Throwable cause); + protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {}; protected void onRequestStart(SearchRequestContext searchRequestContext) {} diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index c75f39b3d8bef..584a1ec7afd1a 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -1141,14 +1141,6 @@ public Map labels() { return labels; } - /** - * Define labels within this search request. - */ - public SearchSourceBuilder labels(Map labels) { - this.labels = labels; - return this; - } - /** * Add labels within this search request. */ diff --git a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java index 942faebea2b12..cee801560a72c 100644 --- a/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java +++ b/server/src/main/java/org/opensearch/search/labels/RuleBasedLabelingService.java @@ -47,6 +47,12 @@ public void applyAllRules(final ThreadContext threadContext, final SearchRequest .map(rule -> rule.evaluate(threadContext, searchRequest)) .flatMap(m -> m.entrySet().stream()) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + // Handling potential spoofing by checking if any conflicts exist between user-supplied labels and the computed labels + for (String key : searchRequest.source().labels().keySet()) { + if (labels.containsKey(key)) { + throw new IllegalArgumentException("Unexpected label found: " + key); + } + } searchRequest.source().addLabels(labels); } } diff --git a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java index cb00fcba6accd..fa6ed0f04880c 100644 --- a/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java +++ b/server/src/main/java/org/opensearch/search/labels/SearchRequestLabelingListener.java @@ -27,15 +27,6 @@ public SearchRequestLabelingListener(final ThreadPool threadPool, final RuleBase this.ruleBasedLabelingService = ruleBasedLabelingService; } - @Override - protected void onPhaseStart(SearchPhaseContext context) {} - - @Override - protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} - - @Override - protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {} - @Override public void onRequestStart(SearchRequestContext searchRequestContext) { // add tags to search request diff --git a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java index 079f377439292..63fd95ea1d855 100644 --- a/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java +++ b/server/src/main/java/org/opensearch/search/labels/rules/DefaultUserInfoLabelingRule.java @@ -17,7 +17,7 @@ import java.util.Map; /** - * Rules to get user info labels, specifically, the info is injected by the security plugin. + * Rules to get user info labels, specifically, the info that is injected by the security plugin. */ public class DefaultUserInfoLabelingRule implements Rule { /** @@ -56,16 +56,16 @@ private Map getUserInfoFromThreadContext(ThreadContext threadCon if (threadContext == null) { return userInfoMap; } - Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO); - if (userInfoObj == null) { - return userInfoMap; - } - String userInfoStr = userInfoObj.toString(); Object remoteAddressObj = threadContext.getTransient(REQUEST_HEADER_REMOTE_ADDRESS); if (remoteAddressObj != null) { userInfoMap.put(REMOTE_ADDRESS, remoteAddressObj.toString()); } + Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO); + if (userInfoObj == null) { + return userInfoMap; + } + String userInfoStr = userInfoObj.toString(); String[] userInfo = userInfoStr.split("\\|"); if ((userInfo.length == 0) || (Strings.isNullOrEmpty(userInfo[0]))) { return userInfoMap; diff --git a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java index cbb91332760e4..dd220eae4f5a7 100644 --- a/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java +++ b/server/src/test/java/org/opensearch/search/labels/DefaultUserInfoLabelingRuleTests.java @@ -44,6 +44,14 @@ public void testGetUserInfoFromThreadContext() { assertEquals(expectedUserInfoMap, actualUserInfoMap); } + public void testGetPartialInfoFromThreadContext() { + threadContext.putTransient("_opendistro_security_remote_address", "127.0.0.1"); + Map expectedUserInfoMap = new HashMap<>(); + expectedUserInfoMap.put(DefaultUserInfoLabelingRule.REMOTE_ADDRESS, "127.0.0.1"); + Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); + assertEquals(expectedUserInfoMap, actualUserInfoMap); + } + public void testGetUserInfoFromThreadContext_EmptyUserInfo() { Map actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest); assertTrue(actualUserInfoMap.isEmpty());