Skip to content

Commit

Permalink
addressing the comments
Browse files Browse the repository at this point in the history
Signed-off-by: Riya Saxena <[email protected]>
  • Loading branch information
riysaxen-amzn committed Jun 6, 2024
1 parent 29ba163 commit 2a4cf26
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.alerting.action.PublishFindingsRequest;
import org.opensearch.commons.alerting.model.Finding;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
Expand Down Expand Up @@ -80,9 +81,11 @@ public class JoinEngine {

private static final Logger log = LogManager.getLogger(JoinEngine.class);

private final User user;

public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry,
long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction,
LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService) {
LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService, User user) {
this.client = client;
this.request = request;
this.xContentRegistry = xContentRegistry;
Expand All @@ -93,6 +96,7 @@ public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRe
this.enableAutoCorrelations = enableAutoCorrelations;
this.correlationAlertService = correlationAlertService;
this.notificationService = notificationService;
this.user = user;
}

public void onSearchDetectorResponse(Detector detector, Finding finding) {
Expand Down Expand Up @@ -555,7 +559,7 @@ private void getCorrelatedFindings(String detectorType, Map<String, List<String>

if (!correlatedFindings.isEmpty()) {
CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService);
correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout);
correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout, user);
correlationRuleScheduler.shutdown();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.commons.alerting.model.ActionExecutionResult;
import org.opensearch.commons.alerting.model.Alert;
import org.opensearch.commons.authuser.User;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -19,6 +23,7 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
Expand All @@ -37,6 +42,24 @@ public class CorrelationAlertService {
private final NamedXContentRegistry xContentRegistry;
private final Client client;

protected static final String CORRELATED_FINDING_IDS = "correlated_finding_ids";
protected static final String CORRELATION_RULE_ID = "correlation_rule_id";
protected static final String CORRELATION_RULE_NAME = "correlation_rule_name";
protected static final String ALERT_ID_FIELD = "id";
protected static final String SCHEMA_VERSION_FIELD = "schema_version";
protected static final String ALERT_VERSION_FIELD = "version";
protected static final String USER_FIELD = "user";
protected static final String TRIGGER_NAME_FIELD = "trigger_name";
protected static final String STATE_FIELD = "state";
protected static final String START_TIME_FIELD = "start_time";
protected static final String END_TIME_FIELD = "end_time";
protected static final String ACKNOWLEDGED_TIME_FIELD = "acknowledged_time";
protected static final String ERROR_MESSAGE_FIELD = "error_message";
protected static final String SEVERITY_FIELD = "severity";
protected static final String ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results";
protected static final String NO_ID = "";
protected static final long NO_VERSION = Versions.NOT_FOUND;

public CorrelationAlertService(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client;
this.xContentRegistry = xContentRegistry;
Expand Down Expand Up @@ -72,7 +95,7 @@ public void getActiveAlerts(String ruleId, long currentTime, ActionListener<Corr
listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0));
} else {
listener.onResponse(new CorrelationAlertsList(
Collections.emptyList(),
parseCorrelationAlerts(searchResponse),
searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ?
(int) searchResponse.getHits().getTotalHits().value : 0)
);
Expand Down Expand Up @@ -125,12 +148,112 @@ public List<CorrelationAlert> parseCorrelationAlerts(final SearchResponse respon
hit.getSourceAsString()
);
xcp.nextToken();
CorrelationAlert correlationAlert = CorrelationAlert.parse(xcp, hit.getId(), hit.getVersion());
CorrelationAlert correlationAlert = parse(xcp, hit.getId(), hit.getVersion());
alerts.add(correlationAlert);
}
return alerts;
}

// logic will be moved to common-utils, once the parsing logic in common-utils is fixed
public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException {
// Parse additional CorrelationAlert-specific fields
List<String> correlatedFindingIds = new ArrayList<>();
String correlationRuleId = null;
String correlationRuleName = null;
User user = null;
int schemaVersion = 0;
String triggerName = null;
Alert.State state = null;
String errorMessage = null;
String severity = null;
List<ActionExecutionResult> actionExecutionResults = new ArrayList<>();
Instant startTime = null;
Instant endTime = null;
Instant acknowledgedTime = null;

while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = xcp.currentName();
xcp.nextToken();
switch (fieldName) {
case CORRELATED_FINDING_IDS:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
correlatedFindingIds.add(xcp.text());
}
break;
case CORRELATION_RULE_ID:
correlationRuleId = xcp.text();
break;
case CORRELATION_RULE_NAME:
correlationRuleName = xcp.text();
break;
case USER_FIELD:
user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp);
break;
case ALERT_ID_FIELD:
id = xcp.text();
break;
case ALERT_VERSION_FIELD:
version = xcp.longValue();
break;
case SCHEMA_VERSION_FIELD:
schemaVersion = xcp.intValue();
break;
case TRIGGER_NAME_FIELD:
triggerName = xcp.text();
break;
case STATE_FIELD:
state = Alert.State.valueOf(xcp.text());
break;
case ERROR_MESSAGE_FIELD:
errorMessage = xcp.textOrNull();
break;
case SEVERITY_FIELD:
severity = xcp.text();
break;
case ACTION_EXECUTION_RESULTS_FIELD:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
actionExecutionResults.add(ActionExecutionResult.parse(xcp));
}
break;
case START_TIME_FIELD:
startTime = Instant.parse(xcp.text());
break;
case END_TIME_FIELD:
endTime = Instant.parse(xcp.text());
break;
case ACKNOWLEDGED_TIME_FIELD:
if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) {
acknowledgedTime = null;
} else {
acknowledgedTime = Instant.parse(xcp.text());
}
break;
}
}

// Create and return CorrelationAlert object
return new CorrelationAlert(
correlatedFindingIds,
correlationRuleId,
correlationRuleName,
id,
version,
schemaVersion,
user,
triggerName,
state,
startTime,
endTime,
acknowledgedTime,
errorMessage,
severity,
actionExecutionResults
);
}
}




Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.commons.alerting.model.Alert;
import org.opensearch.commons.alerting.model.CorrelationAlert;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.CorrelationRuleTrigger;
import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService;
import org.opensearch.securityanalytics.correlation.alert.notifications.CorrelationAlertContext;
import org.opensearch.client.node.NodeClient;
import org.opensearch.commons.alerting.model.action.Action;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;
Expand All @@ -24,7 +24,6 @@
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.opensearch.script.ScriptService;

public class CorrelationRuleScheduler {

Expand All @@ -33,17 +32,15 @@ public class CorrelationRuleScheduler {
private final CorrelationAlertService correlationAlertService;
private final NotificationService notificationService;
private final ExecutorService executorService;
private static ScriptService scriptService;

public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) {
this.client = client;
this.scriptService = scriptService;
this.correlationAlertService = correlationAlertService;
this.notificationService = notificationService;
this.executorService = Executors.newCachedThreadPool();
}

public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding, TimeValue indexTimeout) {
public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding, TimeValue indexTimeout, User user) {
for (CorrelationRule rule : correlationRules) {
CorrelationRuleTrigger trigger = rule.getCorrelationTrigger();
if (trigger != null) {
Expand All @@ -54,7 +51,7 @@ public void schedule(List<CorrelationRule> correlationRules, Map<String, List<St
findingIds.addAll(categoryFindingIds);
}
}
scheduleRule(rule, findingIds, indexTimeout, sourceFinding);
scheduleRule(rule, findingIds, indexTimeout, sourceFinding, user);
}
}
}
Expand All @@ -63,10 +60,10 @@ public void shutdown() {
executorService.shutdown();
}

private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId) {
private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId, User user) {
long startTime = Instant.now().toEpochMilli();
long endTime = startTime + correlationRule.getCorrTimeWindow();
RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId);
RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId, user);
executorService.submit(ruleTask);
}

Expand All @@ -79,8 +76,9 @@ private class RuleTask implements Runnable {
private final NotificationService notificationService;
private final TimeValue indexTimeout;
private final String sourceFindingId;
private final User user;

public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId) {
public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId, User user) {
this.correlationRule = correlationRule;
this.correlatedFindingIds = correlatedFindingIds;
this.startTime = startTime;
Expand All @@ -89,6 +87,7 @@ public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingI
this.notificationService = notificationService;
this.indexTimeout = indexTimeout;
this.sourceFindingId = sourceFindingId;
this.user = user;
}

@Override
Expand All @@ -103,13 +102,14 @@ public void onResponse(CorrelationAlertsList correlationAlertsList) {
addCorrelationAlertIntoIndex();
List<Action> actions = correlationRule.getCorrelationTrigger().getActions();
for (Action action : actions) {
String configId = action.getDestinationId();
CorrelationAlertContext ctx = new CorrelationAlertContext(correlatedFindingIds, correlationRule.getName(), correlationRule.getCorrTimeWindow(), sourceFindingId);
String transfomedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate());
String transformedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate());
String transformedMessage = notificationService.compileTemplate(ctx, action.getMessageTemplate());
try {
notificationService.sendNotification(action.getDestinationId(), correlationRule.getCorrelationTrigger().getSeverity(), transfomedSubject, transformedMessage);
notificationService.sendNotification(configId, correlationRule.getCorrelationTrigger().getSeverity(), transformedSubject, transformedMessage);
} catch (Exception e) {
log.error("Failed while sending a notification: " + e.toString());
log.error("Failed while sending a notification with " + configId + "for correlationRule id " + correlationRule.getId(), e);
new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e);
}

Expand Down Expand Up @@ -142,7 +142,7 @@ private void addCorrelationAlertIntoIndex() {
UUID.randomUUID().toString(),
1L,
1,
null,
user,
correlationRule.getCorrelationTrigger().getName(),
Alert.State.ACTIVE,
Instant.ofEpochMilli(startTime),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.commons.alerting.action.PublishFindingsRequest;
import org.opensearch.commons.alerting.action.SubscribeFindingsResponse;
import org.opensearch.commons.alerting.action.AlertingActions;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -140,7 +141,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
protected void doExecute(Task task, ActionRequest request, ActionListener<SubscribeFindingsResponse> actionListener) {
try {
PublishFindingsRequest transformedRequest = transformRequest(request);
AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener);
AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, readUserFromThreadContext(this.threadPool), actionListener);

if (!this.correlationIndices.correlationIndexExists()) {
try {
Expand Down Expand Up @@ -213,14 +214,12 @@ public class AsyncCorrelateFindingAction {
private final AtomicBoolean counter = new AtomicBoolean();
private final Task task;

AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, ActionListener<SubscribeFindingsResponse> listener) {
AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, User user, ActionListener<SubscribeFindingsResponse> listener) {
this.task = task;
this.request = request;
this.listener = listener;

this.response =new AtomicReference<>();

this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService);
this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService, user);
this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this);
}

Expand Down

0 comments on commit 2a4cf26

Please sign in to comment.