diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index cfff7da26..9284889e9 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -5,13 +5,15 @@ package org.opensearch.securityanalytics.correlation; import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.Triple; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.OpenSearchStatusException; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.routing.Preference; import org.opensearch.commons.alerting.model.DocLevelQuery; +import org.opensearch.commons.notifications.model.ChannelMessage; +import org.opensearch.commons.notifications.model.EventSource; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; @@ -36,10 +38,13 @@ import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction; import org.opensearch.securityanalytics.util.AutoCorrelationsRepo; - +import org.opensearch.securityanalytics.util.NotificationApiUtils; +import org.opensearch.securityanalytics.util.NotificationApiHelper; +import org.opensearch.commons.alerting.model.action.Action; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -331,7 +336,6 @@ public void onResponse(MultiSearchResponse items) { for (FilteredCorrelationRule rule: filteredCorrelationRules) { List queries = rule.correlationRule.getCorrelationQueries(); Long timeWindow = rule.correlationRule.getCorrTimeWindow(); - for (CorrelationQuery query: queries) { List correlationQueries; if (categoryToQueriesMap.containsKey(query.getCategory())) { @@ -367,8 +371,9 @@ public void onResponse(MultiSearchResponse items) { categoryToQueriesMap.put(query.getCategory(), correlationQueries); } } + searchFindingsByTimestamp(detectorType, categoryToQueriesMap, categoryToTimeWindowMap, - filteredCorrelationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), + filteredCorrelationRules, autoCorrelations ); } @@ -391,7 +396,7 @@ public void onFailure(Exception e) { * this method searches for parent findings given the log category & correlation time window & collects all related docs * for them. */ - private void searchFindingsByTimestamp(String detectorType, Map> categoryToQueriesMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void searchFindingsByTimestamp(String detectorType, Map> categoryToQueriesMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List>> categoryToQueriesPairs = new ArrayList<>(); @@ -457,7 +462,8 @@ public void onFailure(Exception e) { if (!autoCorrelations.isEmpty()) { correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); + List correlationRuleIds = correlationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRuleIds); } } } @@ -465,7 +471,7 @@ public void onFailure(Exception e) { /** * Given the related docs from parent findings, this method filters only those related docs which match parent join criteria. */ - private void searchDocsWithFilterKeys(String detectorType, Map relatedDocsMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void searchDocsWithFilterKeys(String detectorType, Map relatedDocsMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -526,7 +532,8 @@ public void onFailure(Exception e) { if (!autoCorrelations.isEmpty()) { correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); + List correlationRuleIds = correlationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRuleIds); } } } @@ -535,7 +542,7 @@ public void onFailure(Exception e) { * Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for * the finding to be correlated. */ - private void getCorrelatedFindings(String detectorType, Map> filteredRelatedDocIds, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void getCorrelatedFindings(String detectorType, Map> filteredRelatedDocIds, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -591,6 +598,31 @@ public void onResponse(MultiSearchResponse items) { } ++idx; } + for (FilteredCorrelationRule corrRule: correlationRules) { + List triggers = corrRule.correlationRule.getCorrelationTriggers(); + List list=new ArrayList<>(); + + log.info("triggers are: {}", triggers.toString()); + for(CorrelationRuleTrigger trigger: triggers) { + String severity = trigger.getSeverity(); + List actions = trigger.getActions(); + log.info("trigger Actions are: {}", actions.toString()); + + for(Action action: actions) { + if (action.getDestinationId() != null && !action.getDestinationId().isEmpty()) { + log.info("Destination id is: {}", action.getDestinationId()); + list.add(action.getDestinationId()); + try { + log.info("Reaching here and calling: {}", action.getDestinationId()); + NotificationApiUtils.sendNotification((NodeClient) client, action.getDestinationId(), "", list); + } catch (IOException e) { + log.info("Exception is: {}", e); + } + } + } + } + + } for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { if (correlatedFindings.containsKey(autoCorrelation.getKey())) { @@ -601,7 +633,8 @@ public void onResponse(MultiSearchResponse items) { correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); } } - correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + List correlationRuleIds = correlationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()); + correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRuleIds); } @Override @@ -613,7 +646,8 @@ public void onFailure(Exception e) { if (!autoCorrelations.isEmpty()) { correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of()); } else { - correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules); + List correlationRuleIds = correlationRules.stream().map(it -> it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()); + correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRuleIds); } } } diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java index b7f5a4f70..9f50594c9 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java +++ b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java @@ -29,6 +29,7 @@ public class CorrelationRule implements Writeable, ToXContentObject { public static final Long NO_VERSION = 1L; private static final String CORRELATION_QUERIES = "correlate"; private static final String CORRELATION_TIME_WINDOW = "time_window"; + private static final String TRIGGERS_FIELD = "triggers"; private String id; @@ -40,16 +41,19 @@ public class CorrelationRule implements Writeable, ToXContentObject { private Long corrTimeWindow; - public CorrelationRule(String id, Long version, String name, List correlationQueries, Long corrTimeWindow) { + private List triggers; + + public CorrelationRule(String id, Long version, String name, List correlationQueries, Long corrTimeWindow, List triggers) { this.id = id != null ? id : NO_ID; this.version = version != null ? version : NO_VERSION; this.name = name; this.correlationQueries = correlationQueries; this.corrTimeWindow = corrTimeWindow != null? corrTimeWindow: 300000L; + this.triggers = triggers; } public CorrelationRule(StreamInput sin) throws IOException { - this(sin.readString(), sin.readLong(), sin.readString(), sin.readList(CorrelationQuery::readFrom), sin.readLong()); + this(sin.readString(), sin.readLong(), sin.readString(), sin.readList(CorrelationQuery::readFrom), sin.readLong(), sin.readList(CorrelationRuleTrigger::readFrom)); } @Override @@ -60,8 +64,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws CorrelationQuery[] correlationQueries = new CorrelationQuery[] {}; correlationQueries = this.correlationQueries.toArray(correlationQueries); + CorrelationRuleTrigger[] correlationRuleTriggers = new CorrelationRuleTrigger[] {}; + correlationRuleTriggers = this.triggers.toArray(correlationRuleTriggers); builder.field(CORRELATION_QUERIES, correlationQueries); builder.field(CORRELATION_TIME_WINDOW, corrTimeWindow); + builder.field(TRIGGERS_FIELD, correlationRuleTriggers); return builder.endObject(); } @@ -74,6 +81,10 @@ public void writeTo(StreamOutput out) throws IOException { for (CorrelationQuery query : correlationQueries) { query.writeTo(out); } + + for (CorrelationRuleTrigger trigger : triggers) { + trigger.writeTo(out); + } out.writeLong(corrTimeWindow); } @@ -88,6 +99,7 @@ public static CorrelationRule parse(XContentParser xcp, String id, Long version) String name = null; List correlationQueries = new ArrayList<>(); Long corrTimeWindow = null; + List triggers = new ArrayList<>(); XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -108,11 +120,18 @@ public static CorrelationRule parse(XContentParser xcp, String id, Long version) case CORRELATION_TIME_WINDOW: corrTimeWindow = xcp.longValue(); break; + case TRIGGERS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + CorrelationRuleTrigger trigger = CorrelationRuleTrigger.parse(xcp); + triggers.add(trigger); + } + break; default: xcp.skipChildren(); } } - return new CorrelationRule(id, version, name, correlationQueries, corrTimeWindow); + return new CorrelationRule(id, version, name, correlationQueries, corrTimeWindow, triggers); } public static CorrelationRule readFrom(StreamInput sin) throws IOException { @@ -151,6 +170,10 @@ public Long getCorrTimeWindow() { return corrTimeWindow; } + public List getCorrelationTriggers() { + return triggers; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -159,7 +182,8 @@ public boolean equals(Object o) { return id.equals(that.id) && version.equals(that.version) && name.equals(that.name) - && correlationQueries.equals(that.correlationQueries); + && correlationQueries.equals(that.correlationQueries) + && triggers.equals(that.triggers); } @Override diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java new file mode 100644 index 000000000..dd3811546 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.model; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.UUIDs; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.commons.alerting.model.action.Action; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class CorrelationRuleTrigger implements Writeable, ToXContentObject { + + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + private String id; + + private String name; + + private String severity; + + private List actions; + + private static final String ID_FIELD = "id"; + + private static final String SEVERITY_FIELD = "severity"; + private static final String ACTIONS_FIELD = "actions"; + + private static final String NAME_FIELD = "name"; + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + CorrelationRuleTrigger.class, + new ParseField(ID_FIELD), + CorrelationRuleTrigger::parse + ); + + public CorrelationRuleTrigger(String id, + String name, + String severity, + List actions) { + this.id = id == null ? UUIDs.base64UUID() : id; + this.name = name; + this.severity = severity; + this.actions = actions; + } + + public CorrelationRuleTrigger(StreamInput sin) throws IOException { + this( + sin.readString(), + sin.readString(), + sin.readString(), + sin.readList(Action::readFrom) + ); + } + + public Map asTemplateArg() { + return Map.of( + ACTIONS_FIELD, actions.stream().map(Action::asTemplateArg) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeString(name); + out.writeString(severity); + out.writeCollection(actions); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + + Action[] actionArray = new Action[]{}; + actionArray = actions.toArray(actionArray); + + return builder.startObject() + .field(ID_FIELD, id) + .field(NAME_FIELD, name) + .field(SEVERITY_FIELD, severity) + .field(ACTIONS_FIELD, actionArray) + .endObject(); + } + + public static CorrelationRuleTrigger parse(XContentParser xcp) throws IOException { + String id = null; + String name = null; + String severity = null; + List actions = new ArrayList<>(); + + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + + switch (fieldName) { + case ID_FIELD: + id = xcp.text(); + break; + case NAME_FIELD: + name = xcp.text(); + break; + case SEVERITY_FIELD: + severity = xcp.text(); + break; + case ACTIONS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + Action action = Action.parse(xcp); + actions.add(action); + } + break; + default: + xcp.skipChildren(); + } + } + return new CorrelationRuleTrigger(id, name, severity, actions); + } + + public static CorrelationRuleTrigger readFrom(StreamInput sin) throws IOException { + return new CorrelationRuleTrigger(sin); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CorrelationRuleTrigger that = (CorrelationRuleTrigger) o; + return Objects.equals(id, that.id) && Objects.equals(name, that.name) && Objects.equals(severity, that.severity) && Objects.equals(actions, that.actions); + } + + @Override + public int hashCode() { + return Objects.hash(id, name, severity, actions); + } + + public String getId() { + return id; + } + + public String getName() { + return name; + } + + public String getSeverity() { + return severity; + } + + public List getActions() { + List transformedActions = new ArrayList<>(); + + if (actions != null) { + for (Action action : actions) { + String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode() : ""; + subjectTemplate = subjectTemplate.replace("{{ctx.detector", "{{ctx.monitor"); + + action.getMessageTemplate(); + String messageTemplate = action.getMessageTemplate().getIdOrCode(); + messageTemplate = messageTemplate.replace("{{ctx.detector", "{{ctx.monitor"); + + Action transformedAction = new Action(action.getName(), action.getDestinationId(), + new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, subjectTemplate, Collections.emptyMap()), + new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, messageTemplate, Collections.emptyMap()), + action.getThrottleEnabled(), action.getThrottle(), + action.getId(), action.getActionExecutionPolicy()); + + transformedActions.add(transformedAction); + } + } + return transformedActions; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationRuleAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationRuleAction.java index d027d26de..f0db3c6d6 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationRuleAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportSearchCorrelationRuleAction.java @@ -19,7 +19,6 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.commons.notifications.action.SendNotificationRequest; import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; diff --git a/src/main/java/org/opensearch/securityanalytics/util/NotificationApiHelper.java b/src/main/java/org/opensearch/securityanalytics/util/NotificationApiHelper.java new file mode 100644 index 000000000..d1499a135 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/util/NotificationApiHelper.java @@ -0,0 +1,41 @@ +package org.opensearch.securityanalytics.util; + +import org.opensearch.commons.notifications.model.ChannelMessage; +import org.opensearch.commons.notifications.model.EventSource; +import org.opensearch.commons.notifications.model.SeverityType; + +import java.util.List; +/** + * Helper class for sending test notifications. + */ +public class NotificationApiHelper { + + public static ChannelMessage generateMessage(String configId) { + return new ChannelMessage( + getMessageTextDescription(configId), + getMessageHtmlDescription(configId), + null + ); + } + + public static EventSource generateEventSource(String configId, String severity, List tags) { + return new EventSource( + getMessageTitle(configId), + configId, + SeverityType.INFO, + tags + ); + } + + private static String getMessageTitle(String configId) { + return "Test Message Title-" + configId; // TODO: change as per spec + } + + private static String getMessageTextDescription(String configId) { + return "Test message content body for config id " + configId; // TODO: change as per spec + } + + private static String getMessageHtmlDescription(String configId) { + return "
Test Message

Test Message for config id " + configId + "

"; // TODO: change as per spec + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/util/NotificationApiUtils.java b/src/main/java/org/opensearch/securityanalytics/util/NotificationApiUtils.java new file mode 100644 index 000000000..618de8527 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/util/NotificationApiUtils.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.util; + +import com.google.protobuf.BoolValue; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BackoffPolicy; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.notifications.NotificationsPluginInterface; +import org.opensearch.commons.notifications.action.*; +import org.opensearch.commons.notifications.model.ChannelMessage; +import org.opensearch.commons.notifications.model.EventSource; +import org.opensearch.commons.notifications.model.NotificationConfigInfo; +import org.opensearch.commons.notifications.model.SeverityType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.rest.RestStatus; + +import java.io.IOException; +import java.util.List; +import java.util.Set; + +public class NotificationApiUtils { + + private static final Logger logger = LogManager.getLogger(NotificationApiUtils.class); + /** + * Extension function for publishing a notification to a channel in the Notification plugin. + */ + public static void sendNotification(NodeClient client, String configId, String severity, List channelIds) throws IOException { + ChannelMessage message = NotificationApiHelper.generateMessage(configId); + NotificationsPluginInterface.INSTANCE.sendNotification(client, new EventSource(configId, configId, SeverityType.CRITICAL, channelIds), message, channelIds, new ActionListener() { + @Override + public void onResponse(SendNotificationResponse sendNotificationResponse) { + if(sendNotificationResponse.getStatus() == RestStatus.OK) { + logger.info("Successfully sent a notification, Notification Event: " + sendNotificationResponse.getNotificationEvent()); + } + else { + logger.error("Successfully sent a notification, Notification Event: " + sendNotificationResponse.getNotificationEvent()); + } + + } + @Override + public void onFailure(Exception e) { + logger.error("Failed while sending a notification: " + e.toString()); + new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + } + + +} \ No newline at end of file