Skip to content

Commit

Permalink
improve the overall logic and fix several bugs based on comments
Browse files Browse the repository at this point in the history
Signed-off-by: Chenyang Ji <[email protected]>
  • Loading branch information
ansjcy committed Jun 4, 2024
1 parent deab414 commit ff0eb39
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void onRequestEnd(final SearchPhaseContext context, final SearchRequestCo
attributes.put(Attribute.TOTAL_SHARDS, context.getNumShards());
attributes.put(Attribute.INDICES, request.indices());
attributes.put(Attribute.PHASE_LATENCY_MAP, searchRequestContext.phaseTookMap());
attributes.put(Attribute.USER_NAME, request.source().labels().get(DefaultUserInfoLabelingRule.USER_NAME));
attributes.put(Attribute.LABELS, request.source().labels());
SearchQueryRecord record = new SearchQueryRecord(request.getOrCreateAbsoluteStartMillis(), measurements, attributes);
queryInsightsService.addRecord(record);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,9 @@ public enum Attribute {
*/
NODE_ID,
/**
* User associated with this request
* Custom labels
*/
USER_NAME,
/**
* Custom tenant tags
*/
CUSTOMIZED_TAG;
LABELS;

/**
* Read an Attribute from a StreamInput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ protected SearchRequestOperationsListener(final boolean enabled) {
this.enabled = enabled;
}

protected abstract void onPhaseStart(SearchPhaseContext context);
protected void onPhaseStart(SearchPhaseContext context) {};

protected abstract void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext);
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {};

protected abstract void onPhaseFailure(SearchPhaseContext context, Throwable cause);
protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {};

protected void onRequestStart(SearchRequestContext searchRequestContext) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1141,14 +1141,6 @@ public Map<String, Object> labels() {
return labels;
}

/**
* Define labels within this search request.
*/
public SearchSourceBuilder labels(Map<String, Object> labels) {
this.labels = labels;
return this;
}

/**
* Add labels within this search request.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ public void applyAllRules(final ThreadContext threadContext, final SearchRequest
.map(rule -> rule.evaluate(threadContext, searchRequest))
.flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// Handling potential spoofing by checking if any conflicts exist between user-supplied labels and the computed labels
for (String key : searchRequest.source().labels().keySet()) {
if (labels.containsKey(key)) {
throw new IllegalArgumentException("Unexpected label found: " + key);
}
}
searchRequest.source().addLabels(labels);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@ public SearchRequestLabelingListener(final ThreadPool threadPool, final RuleBase
this.ruleBasedLabelingService = ruleBasedLabelingService;
}

@Override
protected void onPhaseStart(SearchPhaseContext context) {}

@Override
protected void onPhaseEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {}

@Override
protected void onPhaseFailure(SearchPhaseContext context, Throwable cause) {}

@Override
public void onRequestStart(SearchRequestContext searchRequestContext) {
// add tags to search request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.util.Map;

/**
* Rules to get user info labels, specifically, the info is injected by the security plugin.
* Rules to get user info labels, specifically, the info that is injected by the security plugin.
*/
public class DefaultUserInfoLabelingRule implements Rule {
/**
Expand Down Expand Up @@ -56,16 +56,16 @@ private Map<String, Object> getUserInfoFromThreadContext(ThreadContext threadCon
if (threadContext == null) {
return userInfoMap;
}
Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO);
if (userInfoObj == null) {
return userInfoMap;
}
String userInfoStr = userInfoObj.toString();
Object remoteAddressObj = threadContext.getTransient(REQUEST_HEADER_REMOTE_ADDRESS);
if (remoteAddressObj != null) {
userInfoMap.put(REMOTE_ADDRESS, remoteAddressObj.toString());
}

Object userInfoObj = threadContext.getTransient(REQUEST_HEADER_USER_INFO);
if (userInfoObj == null) {
return userInfoMap;
}
String userInfoStr = userInfoObj.toString();
String[] userInfo = userInfoStr.split("\\|");
if ((userInfo.length == 0) || (Strings.isNullOrEmpty(userInfo[0]))) {
return userInfoMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ public void testGetUserInfoFromThreadContext() {
assertEquals(expectedUserInfoMap, actualUserInfoMap);
}

public void testGetPartialInfoFromThreadContext() {
threadContext.putTransient("_opendistro_security_remote_address", "127.0.0.1");
Map<String, Object> expectedUserInfoMap = new HashMap<>();
expectedUserInfoMap.put(DefaultUserInfoLabelingRule.REMOTE_ADDRESS, "127.0.0.1");
Map<String, Object> actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest);
assertEquals(expectedUserInfoMap, actualUserInfoMap);
}

public void testGetUserInfoFromThreadContext_EmptyUserInfo() {
Map<String, Object> actualUserInfoMap = defaultUserInfoLabelingRule.evaluate(threadContext, searchRequest);
assertTrue(actualUserInfoMap.isEmpty());
Expand Down

0 comments on commit ff0eb39

Please sign in to comment.