diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index f18f75639..ca97abd7e 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -108,6 +108,8 @@ public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, Map public static final String CORRELATION_RULES_BASE_URI = PLUGINS_BASE_URI + "/correlation/rules"; public static final String CUSTOM_LOG_TYPE_URI = PLUGINS_BASE_URI + "/logtype"; + + public static final String CORRELATIONS_ALERTS_BASE_URI = PLUGINS_BASE_URI + "/correlationAlerts"; public static final String JOB_INDEX_NAME = ".opensearch-sap--job"; public static final Map TIF_JOB_INDEX_SETTING = Map.of(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1, IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-all", IndexMetadata.SETTING_INDEX_HIDDEN, true); @@ -215,7 +217,8 @@ public List getRestHandlers(Settings settings, new RestSearchCorrelationRuleAction(), new RestIndexCustomLogTypeAction(), new RestSearchCustomLogTypeAction(), - new RestDeleteCustomLogTypeAction() + new RestDeleteCustomLogTypeAction(), + new RestGetCorrelationsAlertsAction() ); } @@ -336,7 +339,8 @@ public List> getSettings() { new ActionHandler<>(IndexCustomLogTypeAction.INSTANCE, TransportIndexCustomLogTypeAction.class), new ActionHandler<>(SearchCustomLogTypeAction.INSTANCE, TransportSearchCustomLogTypeAction.class), new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class), - new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class) + new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class), + new ActionPlugin.ActionHandler<>(GetCorrelationAlertsAction.INSTANCE, TransportGetCorrelationAlertsAction.class) ); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsAction.java new file mode 100644 index 000000000..336dad080 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsAction.java @@ -0,0 +1,13 @@ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionType; + +public class GetCorrelationAlertsAction extends ActionType { + + public static final GetCorrelationAlertsAction INSTANCE = new GetCorrelationAlertsAction(); + public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlationAlerts/get"; + + public GetCorrelationAlertsAction() { + super(NAME, GetCorrelationAlertsResponse::new); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java new file mode 100644 index 000000000..77811f6d1 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java @@ -0,0 +1,108 @@ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.commons.alerting.model.Table; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.time.Instant; +import java.util.Locale; + +import static org.opensearch.action.ValidateActions.addValidationError; + +public class GetCorrelationAlertsRequest extends ActionRequest { + private String correlationRuleId; + private String correlationRuleName; + private Table table; + private String severityLevel; + private String alertState; + + private Instant startTime; + + private Instant endTime; + + public static final String CORRELATION_RULE_ID = "correlation_rule_id"; + + public GetCorrelationAlertsRequest( + String correlationRuleId, + String correlationRuleName, + Table table, + String severityLevel, + String alertState, + Instant startTime, + Instant endTime + ) { + super(); + this.correlationRuleId = correlationRuleId; + this.correlationRuleName = correlationRuleName; + this.table = table; + this.severityLevel = severityLevel; + this.alertState = alertState; + this.startTime = startTime; + this.endTime = endTime; + } + public GetCorrelationAlertsRequest(StreamInput sin) throws IOException { + this( + sin.readOptionalString(), + sin.readOptionalString(), + Table.readFrom(sin), + sin.readString(), + sin.readString(), + sin.readOptionalInstant(), + sin.readOptionalInstant() + ); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if ((correlationRuleId == null || correlationRuleId.length() == 0)) { + validationException = addValidationError(String.format(Locale.getDefault(), + "At least one of correlation rule id needs to be passed", CORRELATION_RULE_ID), + validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(correlationRuleId); + out.writeOptionalString(correlationRuleName); + table.writeTo(out); + out.writeString(severityLevel); + out.writeString(alertState); + out.writeOptionalInstant(startTime); + out.writeOptionalInstant(endTime); + } + + public String getCorrelationRuleId() { + return correlationRuleId; + } + + public Table getTable() { + return table; + } + + public String getSeverityLevel() { + return severityLevel; + } + + public String getAlertState() { + return alertState; + } + + public String getCorrelationRuleName() { + return correlationRuleName; + } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java new file mode 100644 index 000000000..52c4ebc96 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java @@ -0,0 +1,56 @@ +package org.opensearch.securityanalytics.action; + +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class GetCorrelationAlertsResponse extends ActionResponse implements ToXContentObject { + + private static final String CORRELATION_ALERTS_FIELD = "correlationAlerts"; + private static final String TOTAL_ALERTS_FIELD = "total_alerts"; + + private List alerts; + private Integer totalAlerts; + + public GetCorrelationAlertsResponse(List alerts, Integer totalAlerts) { + super(); + this.alerts = alerts; + this.totalAlerts = totalAlerts; + } + + public GetCorrelationAlertsResponse(StreamInput sin) throws IOException { + this( + Collections.unmodifiableList(sin.readList(CorrelationAlert::new)), + sin.readInt() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(this.alerts); + out.writeInt(this.totalAlerts); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject() + .field(CORRELATION_ALERTS_FIELD, alerts) + .field(TOTAL_ALERTS_FIELD, totalAlerts); + return builder.endObject(); + } + + public List getAlerts() { + return this.alerts; + } + + public Integer getTotalAlerts() { + return this.totalAlerts; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java index f7aeb4e4d..bc76176a0 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -14,6 +14,7 @@ import org.opensearch.common.lucene.uid.Versions; import org.opensearch.commons.alerting.model.ActionExecutionResult; import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.alerting.model.Table; import org.opensearch.commons.authuser.User; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -29,6 +30,11 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsResponse; import org.opensearch.securityanalytics.util.CorrelationIndices; import java.io.IOException; import java.time.Instant; @@ -252,6 +258,47 @@ public static CorrelationAlert parse(XContentParser xcp, String id, long version actionExecutionResults ); } + public void getAlertsByRuleId(String ruleId, Table tableProp, ActionListener listener) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("correlation_rule_id", ruleId)); + + FieldSortBuilder sortBuilder = SortBuilders + .fieldSort(tableProp.getSortString()) + .order(SortOrder.fromString(tableProp.getSortOrder())); + if (!tableProp.getMissing().isEmpty()) { + sortBuilder.missing(tableProp.getMissing()); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .version(true) + .seqNoAndPrimaryTerm(true) + .query(queryBuilder) + .sort(sortBuilder) + .size(tableProp.getSize()) + .from(tableProp.getStartIndex()); + + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getTotalHits().equals(0)) { + listener.onResponse(new GetCorrelationAlertsResponse(Collections.emptyList(), 0)); + } else { + listener.onResponse(new GetCorrelationAlertsResponse( + parseCorrelationAlerts(searchResponse), + searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ? + (int) searchResponse.getHits().getTotalHits().value : 0) + ); + } + }, + e -> { + log.error("Search request to fetch correlation alerts failed", e); + listener.onFailure(e); + } + )); + } + } diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetCorrelationsAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetCorrelationsAlertsAction.java new file mode 100644 index 000000000..a371a562a --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetCorrelationsAlertsAction.java @@ -0,0 +1,92 @@ +package org.opensearch.securityanalytics.resthandler; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.model.Table; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsAction; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsRequest; + +import java.io.IOException; +import java.time.DateTimeException; +import java.time.Instant; +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.opensearch.rest.RestRequest.Method.GET; + +public class RestGetCorrelationsAlertsAction extends BaseRestHandler { + + @Override + public String getName() { + return "get_correlation_alerts_action_sa"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + String correlationRuleId = request.param("correlation_rule_id", null); + String correlationRuleName = request.param("correlation_rule_name", null); + String severityLevel = request.param("severityLevel", "ALL"); + String alertState = request.param("alertState", "ALL"); + // Table params + String sortString = request.param("sortString", "start_time"); + String sortOrder = request.param("sortOrder", "asc"); + String missing = request.param("missing"); + int size = request.paramAsInt("size", 20); + int startIndex = request.paramAsInt("startIndex", 0); + String searchString = request.param("searchString", ""); + + Instant startTime = null; + String startTimeParam = request.param("startTime"); + if (startTimeParam != null && !startTimeParam.isEmpty()) { + try { + startTime = Instant.ofEpochMilli(Long.parseLong(startTimeParam)); + } catch (NumberFormatException | NullPointerException | DateTimeException e) { + startTime = Instant.now(); + } + } + + Instant endTime = null; + String endTimeParam = request.param("endTime"); + if (endTimeParam != null && !endTimeParam.isEmpty()) { + try { + endTime = Instant.ofEpochMilli(Long.parseLong(endTimeParam)); + } catch (NumberFormatException | NullPointerException | DateTimeException e) { + endTime = Instant.now(); + } + } + + Table table = new Table( + sortOrder, + sortString, + missing, + size, + startIndex, + searchString + ); + + GetCorrelationAlertsRequest req = new GetCorrelationAlertsRequest( + correlationRuleId, + correlationRuleName, + table, + severityLevel, + alertState, + startTime, + endTime + ); + + return channel -> client.execute( + GetCorrelationAlertsAction.INSTANCE, + req, + new RestToXContentListener<>(channel) + ); + } + + @Override + public List routes() { + return singletonList(new Route(GET, SecurityAnalyticsPlugin.CORRELATIONS_ALERTS_BASE_URI)); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java new file mode 100644 index 000000000..2408dbd84 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java @@ -0,0 +1,75 @@ +package org.opensearch.securityanalytics.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.securityanalytics.action.*; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportGetCorrelationAlertsAction extends HandledTransportAction implements SecureTransportAction { + + private final NamedXContentRegistry xContentRegistry; + + private final ClusterService clusterService; + + private final Settings settings; + + private final ThreadPool threadPool; + + private final CorrelationAlertService correlationAlertService; + + private volatile Boolean filterByEnabled; + + private static final Logger log = LogManager.getLogger(TransportGetCorrelationAlertsAction.class); + + + @Inject + public TransportGetCorrelationAlertsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, GetCorrelationAlertsAction getCorrelationAlertsAction, ThreadPool threadPool, Settings settings, NamedXContentRegistry xContentRegistry, Client client) { + super(getCorrelationAlertsAction.NAME, transportService, actionFilters, GetCorrelationAlertsRequest::new); + this.xContentRegistry = xContentRegistry; + this.correlationAlertService = new CorrelationAlertService(client, xContentRegistry); + this.clusterService = clusterService; + this.threadPool = threadPool; + this.settings = settings; + this.filterByEnabled = SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES.get(this.settings); + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled); + } + + @Override + protected void doExecute(Task task, GetCorrelationAlertsRequest request, ActionListener actionListener) { + + User user = readUserFromThreadContext(this.threadPool); + + String validateBackendRoleMessage = validateUserBackendRoles(user, this.filterByEnabled); + if (!"".equals(validateBackendRoleMessage)) { + actionListener.onFailure(new OpenSearchStatusException("Do not have permissions to resource", RestStatus.FORBIDDEN)); + return; + } + + if (request.getCorrelationRuleId() != null) { + correlationAlertService.getAlertsByRuleId( + request.getCorrelationRuleId(), + request.getTable(), + actionListener + ); + } + } + + private void setFilterByEnabled(boolean filterByEnabled) { + this.filterByEnabled = filterByEnabled; + } +} \ No newline at end of file