Skip to content

Commit

Permalink
getCorrelationAlerts API changes
Browse files Browse the repository at this point in the history
Signed-off-by: Riya Saxena <[email protected]>
  • Loading branch information
riysaxen-amzn committed Jun 6, 2024
1 parent e23e969 commit 42f2b46
Show file tree
Hide file tree
Showing 7 changed files with 397 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, Map
public static final String CORRELATION_RULES_BASE_URI = PLUGINS_BASE_URI + "/correlation/rules";

public static final String CUSTOM_LOG_TYPE_URI = PLUGINS_BASE_URI + "/logtype";

public static final String CORRELATIONS_ALERTS_BASE_URI = PLUGINS_BASE_URI + "/correlationAlerts";
public static final String JOB_INDEX_NAME = ".opensearch-sap--job";
public static final Map<String, Object> TIF_JOB_INDEX_SETTING = Map.of(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1, IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "0-all", IndexMetadata.SETTING_INDEX_HIDDEN, true);

Expand Down Expand Up @@ -215,7 +217,8 @@ public List<RestHandler> getRestHandlers(Settings settings,
new RestSearchCorrelationRuleAction(),
new RestIndexCustomLogTypeAction(),
new RestSearchCustomLogTypeAction(),
new RestDeleteCustomLogTypeAction()
new RestDeleteCustomLogTypeAction(),
new RestGetCorrelationsAlertsAction()
);
}

Expand Down Expand Up @@ -336,7 +339,8 @@ public List<Setting<?>> getSettings() {
new ActionHandler<>(IndexCustomLogTypeAction.INSTANCE, TransportIndexCustomLogTypeAction.class),
new ActionHandler<>(SearchCustomLogTypeAction.INSTANCE, TransportSearchCustomLogTypeAction.class),
new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class),
new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class)
new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class),
new ActionPlugin.ActionHandler<>(GetCorrelationAlertsAction.INSTANCE, TransportGetCorrelationAlertsAction.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.securityanalytics.action;

import org.opensearch.action.ActionType;

public class GetCorrelationAlertsAction extends ActionType<GetCorrelationAlertsResponse> {

public static final GetCorrelationAlertsAction INSTANCE = new GetCorrelationAlertsAction();
public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlationAlerts/get";

public GetCorrelationAlertsAction() {
super(NAME, GetCorrelationAlertsResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package org.opensearch.securityanalytics.action;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.IOException;
import java.time.Instant;
import java.util.Locale;

import static org.opensearch.action.ValidateActions.addValidationError;

public class GetCorrelationAlertsRequest extends ActionRequest {
private String correlationRuleId;
private String correlationRuleName;
private Table table;
private String severityLevel;
private String alertState;

private Instant startTime;

private Instant endTime;

public static final String CORRELATION_RULE_ID = "correlation_rule_id";

public GetCorrelationAlertsRequest(
String correlationRuleId,
String correlationRuleName,
Table table,
String severityLevel,
String alertState,
Instant startTime,
Instant endTime
) {
super();
this.correlationRuleId = correlationRuleId;
this.correlationRuleName = correlationRuleName;
this.table = table;
this.severityLevel = severityLevel;
this.alertState = alertState;
this.startTime = startTime;
this.endTime = endTime;
}
public GetCorrelationAlertsRequest(StreamInput sin) throws IOException {
this(
sin.readOptionalString(),
sin.readOptionalString(),
Table.readFrom(sin),
sin.readString(),
sin.readString(),
sin.readOptionalInstant(),
sin.readOptionalInstant()
);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if ((correlationRuleId == null || correlationRuleId.length() == 0)) {
validationException = addValidationError(String.format(Locale.getDefault(),
"At least one of correlation rule id needs to be passed", CORRELATION_RULE_ID),
validationException);
}
return validationException;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(correlationRuleId);
out.writeOptionalString(correlationRuleName);
table.writeTo(out);
out.writeString(severityLevel);
out.writeString(alertState);
out.writeOptionalInstant(startTime);
out.writeOptionalInstant(endTime);
}

public String getCorrelationRuleId() {
return correlationRuleId;
}

public Table getTable() {
return table;
}

public String getSeverityLevel() {
return severityLevel;
}

public String getAlertState() {
return alertState;
}

public String getCorrelationRuleName() {
return correlationRuleName;
}

public Instant getStartTime() {
return startTime;
}

public Instant getEndTime() {
return endTime;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.opensearch.securityanalytics.action;

import org.opensearch.commons.alerting.model.CorrelationAlert;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Collections;
import java.util.List;

public class GetCorrelationAlertsResponse extends ActionResponse implements ToXContentObject {

private static final String CORRELATION_ALERTS_FIELD = "correlationAlerts";
private static final String TOTAL_ALERTS_FIELD = "total_alerts";

private List<CorrelationAlert> alerts;
private Integer totalAlerts;

public GetCorrelationAlertsResponse(List<CorrelationAlert> alerts, Integer totalAlerts) {
super();
this.alerts = alerts;
this.totalAlerts = totalAlerts;
}

public GetCorrelationAlertsResponse(StreamInput sin) throws IOException {
this(
Collections.unmodifiableList(sin.readList(CorrelationAlert::new)),
sin.readInt()
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(this.alerts);
out.writeInt(this.totalAlerts);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject()
.field(CORRELATION_ALERTS_FIELD, alerts)
.field(TOTAL_ALERTS_FIELD, totalAlerts);
return builder.endObject();
}

public List<CorrelationAlert> getAlerts() {
return this.alerts;
}

public Integer getTotalAlerts() {
return this.totalAlerts;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.common.lucene.uid.Versions;
import org.opensearch.commons.alerting.model.ActionExecutionResult;
import org.opensearch.commons.alerting.model.Alert;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.commons.authuser.User;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
Expand All @@ -29,6 +30,11 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.commons.alerting.model.CorrelationAlert;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.securityanalytics.action.GetCorrelationAlertsResponse;
import org.opensearch.securityanalytics.util.CorrelationIndices;
import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -252,6 +258,47 @@ public static CorrelationAlert parse(XContentParser xcp, String id, long version
actionExecutionResults
);
}
public void getAlertsByRuleId(String ruleId, Table tableProp, ActionListener<GetCorrelationAlertsResponse> listener) {
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("correlation_rule_id", ruleId));

FieldSortBuilder sortBuilder = SortBuilders
.fieldSort(tableProp.getSortString())
.order(SortOrder.fromString(tableProp.getSortOrder()));
if (!tableProp.getMissing().isEmpty()) {
sortBuilder.missing(tableProp.getMissing());
}

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.version(true)
.seqNoAndPrimaryTerm(true)
.query(queryBuilder)
.sort(sortBuilder)
.size(tableProp.getSize())
.from(tableProp.getStartIndex());

SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX)
.source(searchSourceBuilder);

client.search(searchRequest, ActionListener.wrap(
searchResponse -> {
if (searchResponse.getHits().getTotalHits().equals(0)) {
listener.onResponse(new GetCorrelationAlertsResponse(Collections.emptyList(), 0));
} else {
listener.onResponse(new GetCorrelationAlertsResponse(
parseCorrelationAlerts(searchResponse),
searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ?
(int) searchResponse.getHits().getTotalHits().value : 0)
);
}
},
e -> {
log.error("Search request to fetch correlation alerts failed", e);
listener.onFailure(e);
}
));
}

}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package org.opensearch.securityanalytics.resthandler;

import org.opensearch.client.node.NodeClient;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
import org.opensearch.securityanalytics.SecurityAnalyticsPlugin;
import org.opensearch.securityanalytics.action.GetCorrelationAlertsAction;
import org.opensearch.securityanalytics.action.GetCorrelationAlertsRequest;

import java.io.IOException;
import java.time.DateTimeException;
import java.time.Instant;
import java.util.List;

import static java.util.Collections.singletonList;
import static org.opensearch.rest.RestRequest.Method.GET;

public class RestGetCorrelationsAlertsAction extends BaseRestHandler {

@Override
public String getName() {
return "get_correlation_alerts_action_sa";
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

String correlationRuleId = request.param("correlation_rule_id", null);
String correlationRuleName = request.param("correlation_rule_name", null);
String severityLevel = request.param("severityLevel", "ALL");
String alertState = request.param("alertState", "ALL");
// Table params
String sortString = request.param("sortString", "start_time");
String sortOrder = request.param("sortOrder", "asc");
String missing = request.param("missing");
int size = request.paramAsInt("size", 20);
int startIndex = request.paramAsInt("startIndex", 0);
String searchString = request.param("searchString", "");

Instant startTime = null;
String startTimeParam = request.param("startTime");
if (startTimeParam != null && !startTimeParam.isEmpty()) {
try {
startTime = Instant.ofEpochMilli(Long.parseLong(startTimeParam));
} catch (NumberFormatException | NullPointerException | DateTimeException e) {
startTime = Instant.now();
}
}

Instant endTime = null;
String endTimeParam = request.param("endTime");
if (endTimeParam != null && !endTimeParam.isEmpty()) {
try {
endTime = Instant.ofEpochMilli(Long.parseLong(endTimeParam));
} catch (NumberFormatException | NullPointerException | DateTimeException e) {
endTime = Instant.now();
}
}

Table table = new Table(
sortOrder,
sortString,
missing,
size,
startIndex,
searchString
);

GetCorrelationAlertsRequest req = new GetCorrelationAlertsRequest(
correlationRuleId,
correlationRuleName,
table,
severityLevel,
alertState,
startTime,
endTime
);

return channel -> client.execute(
GetCorrelationAlertsAction.INSTANCE,
req,
new RestToXContentListener<>(channel)
);
}

@Override
public List<Route> routes() {
return singletonList(new Route(GET, SecurityAnalyticsPlugin.CORRELATIONS_ALERTS_BASE_URI));
}
}
Loading

0 comments on commit 42f2b46

Please sign in to comment.