Skip to content

Commit

Permalink
working code
Browse files Browse the repository at this point in the history
Signed-off-by: Riya Saxena <[email protected]>
  • Loading branch information
riysaxen-amzn committed May 24, 2024
1 parent a890e73 commit 425831d
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.commons.alerting.model.DocLevelQuery;
import org.opensearch.commons.notifications.model.ChannelMessage;
import org.opensearch.commons.notifications.model.EventSource;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
Expand All @@ -35,16 +32,17 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig;
import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService;
import org.opensearch.securityanalytics.correlation.alert.CorrelationRuleScheduler;
import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService;
import org.opensearch.securityanalytics.logtype.LogTypeService;
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.CorrelationRuleTrigger;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction;
import org.opensearch.securityanalytics.util.AutoCorrelationsRepo;
import org.opensearch.securityanalytics.util.NotificationApiUtils;
import org.opensearch.securityanalytics.util.NotificationApiHelper;
import org.opensearch.commons.alerting.model.action.Action;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -353,10 +351,9 @@ private void getValidDocuments(String detectorType, List<String> indices, List<C
}
categoryToQueriesMap.put(query.getCategory(), correlationQueries);
}

}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap,
filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()),
filteredCorrelationRules.stream().map(it -> it.correlationRule).collect(Collectors.toList()),
autoCorrelations
);
}, this::onFailure));
Expand All @@ -369,7 +366,7 @@ private void getValidDocuments(String detectorType, List<String> indices, List<C
* this method searches for parent findings given the log category & correlation time window & collects all related docs
* for them.
*/
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, Map<String, Long> categoryToTimeWindowMap, List<FilteredCorrelationRule> correlationRules, Map<String, List<String>> autoCorrelations) {
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, Map<String, Long> categoryToTimeWindowMap, List<CorrelationRule> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<Pair<String, List<CorrelationQuery>>> categoryToQueriesPairs = new ArrayList<>();
Expand Down Expand Up @@ -425,14 +422,14 @@ private void searchFindingsByTimestamp(String detectorType, Map<String, List<Cor
searchDocsWithFilterKeys(detectorType, relatedDocsMap, categoryToTimeWindowMap, correlationRules, autoCorrelations);
}, this::onFailure));
} else {
getTimestampFeature(detectorType, correlationRules, autoCorrelations);
getTimestampFeature(detectorType, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()) , autoCorrelations);
}
}

/**
* Given the related docs from parent findings, this method filters only those related docs which match parent join criteria.
*/
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, Map<String, Long> categoryToTimeWindowMap, List<FilteredCorrelationRule> correlationRules, Map<String, List<String>> autoCorrelations) {
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, Map<String, Long> categoryToTimeWindowMap, List<CorrelationRule> correlationRules, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

Expand Down Expand Up @@ -483,15 +480,15 @@ private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearch
getCorrelatedFindings(detectorType, filteredRelatedDocIds, categoryToTimeWindowMap, correlationRules, autoCorrelations);
}, this::onFailure));
} else {
getTimestampFeature(detectorType, correlationRules, autoCorrelations);
getTimestampFeature(detectorType, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()), autoCorrelations);
}
}

/**
* Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for
* the finding to be correlated.
*/
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, Map<String, Long> categoryToTimeWindowMap, List<FilteredCorrelationRule> correlationRules, Map<String, List<String>> autoCorrelations) {
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, Map<String, Long> categoryToTimeWindowMap, List<CorrelationRule> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();
Expand Down Expand Up @@ -540,38 +537,20 @@ private void getCorrelatedFindings(String detectorType, Map<String, List<String>
for (SearchHit hit : hits) {
findings.add(hit.getId());
}
for (FilteredCorrelationRule corrRule: correlationRules) {
List<CorrelationRuleTrigger> triggers = corrRule.correlationRule.getCorrelationTriggers();
List<String> list=new ArrayList<>();

log.info("triggers are: {}", triggers.toString());
for(CorrelationRuleTrigger trigger: triggers) {
String severity = trigger.getSeverity();
List<Action> actions = trigger.getActions();
log.info("trigger Actions are: {}", actions.toString());

for(Action action: actions) {
if (action.getDestinationId() != null && !action.getDestinationId().isEmpty()) {
log.info("Destination id is: {}", action.getDestinationId());
list.add(action.getDestinationId());
try {
log.info("Reaching here and calling: {}", action.getDestinationId());
NotificationApiUtils.sendNotification((NodeClient) client, action.getDestinationId(), "", list);
} catch (IOException e) {
log.info("Exception is: {}", e);
}
}
}
}

}

if (!findings.isEmpty()) {
correlatedFindings.put(categories.get(idx), findings);
}
++idx;
}

CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler();
correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId());
log.info("Source correlated findings: {}", request.getFinding().getId());
log.info("Get correlated findings: {}", correlatedFindings);
log.info("Source correlated findings: {}", request.getFinding().getId());
log.info("Index correlated findings: {}", idx);

for (Map.Entry<String, List<String>> autoCorrelation: autoCorrelations.entrySet()) {
if (correlatedFindings.containsKey(autoCorrelation.getKey())) {
Set<String> alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey()));
Expand All @@ -581,10 +560,10 @@ private void getCorrelatedFindings(String detectorType, Map<String, List<String>
correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue());
}
}
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules);
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()));
}, this::onFailure));
} else {
getTimestampFeature(detectorType, correlationRules, autoCorrelations);
getTimestampFeature(detectorType, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()), autoCorrelations);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package org.opensearch.securityanalytics.correlation.alert;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService;
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.CorrelationRuleTrigger;

import java.time.Instant;
import java.util.*;
import java.util.concurrent.TimeUnit;

public class CorrelationRuleScheduler {

private static final Logger log = LogManager.getLogger(CorrelationRuleScheduler.class);

public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding) {
// Create a map of correlation rule to list of finding IDs
Map<CorrelationRule, List<String>> correlationRuleToFindingIds = new HashMap<>();
for (CorrelationRule rule : correlationRules) {
CorrelationRuleTrigger trigger = rule.getCorrelationTrigger();
if (trigger != null) {
List<String> findingIds = new ArrayList<>();
for (CorrelationQuery query : rule.getCorrelationQueries()) {
List<String> categoryFindingIds = correlatedFindings.get(query.getCategory());
if (categoryFindingIds != null) {
findingIds.addAll(categoryFindingIds);
}
}
correlationRuleToFindingIds.put(rule, findingIds);
// Simulate generating matched correlation rule IDs on rolling time window basis
scheduleRule(rule, findingIds);
}
}
}
public void scheduleRule(CorrelationRule correlationRule, List<String> findingIds) {
Timer timer = new Timer();
long startTime = Instant.now().toEpochMilli();
long endTime = startTime + TimeUnit.MINUTES.toMillis(correlationRule.getCorrTimeWindow()); // Assuming time window is based on ruleId
// timer.schedule(new RuleTask(this.correlationAlertService, this.notificationService, correlationRule, findingIds, startTime, endTime), 0, 60000); // Check every minute
}

static class RuleTask extends TimerTask {
private final CorrelationAlertService alertService;
private final NotificationService notificationService;
private final CorrelationRule correlationRule;
private final long startTime;
private final long endTime;
private final List<String> correlatedFindingIds;


public RuleTask(CorrelationAlertService alertService, NotificationService notificationService, CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime) {
this.alertService = alertService;
this.notificationService = notificationService;
this.startTime = startTime;
this.endTime = endTime;
this.correlatedFindingIds = correlatedFindingIds;
this.correlationRule = correlationRule;
}

@Override
public void run() {
long currentTime = Instant.now().toEpochMilli();
// if (currentTime >= startTime && currentTime <= endTime) { // Within time window
// try {
// List<String> activeAlertIds = alertService.getActiveAlertsList(correlationRule.getId(), startTime, endTime);
// if (activeAlertIds.isEmpty()) {
// Map<String, Object> correlationAlert = Map.of(
// "start_time", startTime,
// "end_time", endTime,
// "correlation_rule_id", correlationRule.getId(),
// "severity", correlationRule.getCorrelationTrigger().getSeverity()
// // add more fields;
// );
// alertService.indexAlert(correlationAlert);
// //notificationService.sendNotification(alert);
// } else {
// alertService.updateActiveAlerts(activeAlertIds);
// }
// } catch (IOException e) {
// throw new RuntimeException(e);
// }
// }
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.opensearch.securityanalytics.correlation.alert.notifications;

public class NotificationService {
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -29,7 +30,7 @@ public class CorrelationRule implements Writeable, ToXContentObject {
public static final Long NO_VERSION = 1L;
private static final String CORRELATION_QUERIES = "correlate";
private static final String CORRELATION_TIME_WINDOW = "time_window";
private static final String TRIGGERS_FIELD = "triggers";
private static final String TRIGGER_FIELD = "trigger";

private String id;

Expand All @@ -41,19 +42,19 @@ public class CorrelationRule implements Writeable, ToXContentObject {

private Long corrTimeWindow;

private List<CorrelationRuleTrigger> triggers;
private CorrelationRuleTrigger trigger;

public CorrelationRule(String id, Long version, String name, List<CorrelationQuery> correlationQueries, Long corrTimeWindow, List<CorrelationRuleTrigger> triggers) {
public CorrelationRule(String id, Long version, String name, List<CorrelationQuery> correlationQueries, Long corrTimeWindow, CorrelationRuleTrigger trigger) {
this.id = id != null ? id : NO_ID;
this.version = version != null ? version : NO_VERSION;
this.name = name;
this.correlationQueries = correlationQueries;
this.corrTimeWindow = corrTimeWindow != null? corrTimeWindow: 300000L;
this.triggers = triggers;
this.trigger = trigger;
}

public CorrelationRule(StreamInput sin) throws IOException {
this(sin.readString(), sin.readLong(), sin.readString(), sin.readList(CorrelationQuery::readFrom), sin.readLong(), sin.readList(CorrelationRuleTrigger::readFrom));
this(sin.readString(), sin.readLong(), sin.readString(), sin.readList(CorrelationQuery::readFrom), sin.readLong(), sin.readBoolean() ? new CorrelationRuleTrigger(sin) : null);
}

@Override
Expand All @@ -64,11 +65,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

CorrelationQuery[] correlationQueries = new CorrelationQuery[] {};
correlationQueries = this.correlationQueries.toArray(correlationQueries);
CorrelationRuleTrigger[] correlationRuleTriggers = new CorrelationRuleTrigger[] {};
correlationRuleTriggers = this.triggers.toArray(correlationRuleTriggers);
builder.field(CORRELATION_QUERIES, correlationQueries);
builder.field(CORRELATION_TIME_WINDOW, corrTimeWindow);
builder.field(TRIGGERS_FIELD, correlationRuleTriggers);
builder.field(TRIGGER_FIELD, trigger);
return builder.endObject();
}

Expand All @@ -82,7 +81,8 @@ public void writeTo(StreamOutput out) throws IOException {
query.writeTo(out);
}

for (CorrelationRuleTrigger trigger : triggers) {
out.writeBoolean(trigger != null);
if (trigger != null) {
trigger.writeTo(out);
}
out.writeLong(corrTimeWindow);
Expand All @@ -99,8 +99,7 @@ public static CorrelationRule parse(XContentParser xcp, String id, Long version)
String name = null;
List<CorrelationQuery> correlationQueries = new ArrayList<>();
Long corrTimeWindow = null;
List<CorrelationRuleTrigger> triggers = new ArrayList<>();

CorrelationRuleTrigger trigger = null;
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = xcp.currentName();
Expand All @@ -120,18 +119,18 @@ public static CorrelationRule parse(XContentParser xcp, String id, Long version)
case CORRELATION_TIME_WINDOW:
corrTimeWindow = xcp.longValue();
break;
case TRIGGERS_FIELD:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
CorrelationRuleTrigger trigger = CorrelationRuleTrigger.parse(xcp);
triggers.add(trigger);
case TRIGGER_FIELD:
if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) {
trigger = null;
} else {
trigger = CorrelationRuleTrigger.parse(xcp);
}
break;
default:
xcp.skipChildren();
}
}
return new CorrelationRule(id, version, name, correlationQueries, corrTimeWindow, triggers);
return new CorrelationRule(id, version, name, correlationQueries, corrTimeWindow, trigger);
}

public static CorrelationRule readFrom(StreamInput sin) throws IOException {
Expand Down Expand Up @@ -170,8 +169,8 @@ public Long getCorrTimeWindow() {
return corrTimeWindow;
}

public List<CorrelationRuleTrigger> getCorrelationTriggers() {
return triggers;
public CorrelationRuleTrigger getCorrelationTrigger() {
return trigger;
}

@Override
Expand All @@ -183,7 +182,7 @@ public boolean equals(Object o) {
&& version.equals(that.version)
&& name.equals(that.name)
&& correlationQueries.equals(that.correlationQueries)
&& triggers.equals(that.triggers);
&& trigger.equals(that.trigger);
}

@Override
Expand Down
Loading

0 comments on commit 425831d

Please sign in to comment.