From b68da48fecccc4a6b911b2d19405ba1dc65d6612 Mon Sep 17 00:00:00 2001 From: fukusuket <41001169+fukusuket@users.noreply.github.com> Date: Mon, 24 Jun 2024 00:44:26 +0900 Subject: [PATCH] feat: ignore referenced rules in correlation --- src/detections/rule/correlation_parser.rs | 135 ++++++++++++++-------- src/detections/rule/mod.rs | 3 +- src/main.rs | 7 +- 3 files changed, 94 insertions(+), 51 deletions(-) diff --git a/src/detections/rule/correlation_parser.rs b/src/detections/rule/correlation_parser.rs index e45678774..37d010678 100644 --- a/src/detections/rule/correlation_parser.rs +++ b/src/detections/rule/correlation_parser.rs @@ -1,5 +1,8 @@ use std::error::Error; +use std::sync::Arc; +use hashbrown::HashMap; +use yaml_rust::yaml::Hash; use yaml_rust::Yaml; use crate::detections::configs::StoredStatic; @@ -8,10 +11,12 @@ use crate::detections::rule::aggregation_parser::{ AggregationConditionToken, AggregationParseInfo, }; use crate::detections::rule::count::TimeFrameInfo; -use crate::detections::rule::selectionnodes::OrSelectionNode; +use crate::detections::rule::selectionnodes::{OrSelectionNode, SelectionNode}; use crate::detections::rule::{DetectionNode, RuleNode}; -fn is_related_rule(rule_node: &RuleNode, id_or_title: &str) -> bool { +type Name2Selection = HashMap>>; + +fn is_referenced_rule(rule_node: &RuleNode, id_or_title: &str) -> bool { if let Some(hash) = rule_node.yaml.as_hash() { if let Some(id) = hash.get(&Yaml::String("id".to_string())) { if id.as_str() == Some(id_or_title) { @@ -163,26 +168,29 @@ fn parse_tframe(value: String) -> Result> { } fn create_related_rule_nodes( - related_rules_ids: Vec, + related_rules_ids: &Vec, other_rules: &[RuleNode], stored_static: &StoredStatic, -) -> Vec { +) -> (Vec, Name2Selection) { let mut related_rule_nodes: Vec = Vec::new(); + let mut name_to_selection: Name2Selection = HashMap::new(); for id in related_rules_ids { for other_rule in other_rules { - if is_related_rule(other_rule, &id) { + if is_referenced_rule(other_rule, id) { let mut node = RuleNode::new(other_rule.rulepath.clone(), other_rule.yaml.clone()); let _ = node.init(stored_static); + name_to_selection.extend(node.detection.name_to_selection.clone()); related_rule_nodes.push(node); } } } - related_rule_nodes + (related_rule_nodes, name_to_selection) } fn create_detection( rule_node: &RuleNode, related_rule_nodes: Vec, + name_to_selection: HashMap>>, ) -> Result> { let condition = parse_condition(&rule_node.yaml["correlation"])?; let group_by = get_group_by_from_yaml(&rule_node.yaml)?; @@ -199,6 +207,7 @@ fn create_detection( _cmp_num: condition.1, }; Ok(DetectionNode::new_with_data( + name_to_selection, Some(Box::new(nodes)), Some(agg_info), Some(time_frame), @@ -229,56 +238,90 @@ fn error_log( *parseerror_count += 1; } +fn merge_referenced_rule( + rule: RuleNode, + other_rules: &mut Vec, + stored_static: &StoredStatic, + parse_error_count: &mut u128, +) -> RuleNode { + let rule_type = rule.yaml["correlation"]["type"].as_str(); + if rule_type != Some("event_count") && rule_type != Some("value_count") { + let m = "The type of correlation rule only supports event_count/value_count."; + error_log(&rule.rulepath, m, stored_static, parse_error_count); + return rule; + } + let referenced_ids = match get_related_rules_id(&rule.yaml) { + Ok(related_rules_ids) => related_rules_ids, + Err(_) => { + let m = "Referenced rule not found."; + error_log(&rule.rulepath, m, stored_static, parse_error_count); + return rule; + } + }; + if referenced_ids.is_empty() { + let m = "Referenced rule not found."; + error_log(&rule.rulepath, m, stored_static, parse_error_count); + return rule; + } + let (referenced_rules, name_to_selection) = + create_related_rule_nodes(&referenced_ids, other_rules, stored_static); + let is_not_referenced_rule = |rule_node: &RuleNode| { + let id = rule_node.yaml["id"].as_str().unwrap_or_default(); + let title = rule_node.yaml["title"].as_str().unwrap_or_default(); + let name = rule_node.yaml["name"].as_str().unwrap_or_default(); + !referenced_ids.contains(&id.to_string()) + && !referenced_ids.contains(&title.to_string()) + && !referenced_ids.contains(&name.to_string()) + }; + if !rule.yaml["correlation"]["generate"] + .as_bool() + .unwrap_or_default() + { + other_rules.retain(is_not_referenced_rule); + } + let referenced_hashes: Vec = referenced_rules + .iter() + .filter_map(|rule_node| rule_node.yaml.as_hash().cloned()) + .collect(); + let detection = match create_detection(&rule, referenced_rules, name_to_selection) { + Ok(detection) => detection, + Err(e) => { + error_log( + &rule.rulepath, + e.to_string().as_str(), + stored_static, + parse_error_count, + ); + return rule; + } + }; + let referenced_yaml: Yaml = + Yaml::Array(referenced_hashes.into_iter().map(Yaml::Hash).collect()); + let mut merged_yaml = rule.yaml.as_hash().unwrap().clone(); + merged_yaml.insert(Yaml::String("detection".to_string()), referenced_yaml); + RuleNode::new_with_detection(rule.rulepath, Yaml::Hash(merged_yaml), detection) +} + pub fn parse_correlation_rules( rule_nodes: Vec, stored_static: &StoredStatic, - parseerror_count: &mut u128, + parse_error_count: &mut u128, ) -> Vec { - let (correlation_rules, other_rules): (Vec, Vec) = rule_nodes + let (correlation_rules, mut not_correlation_rules): (Vec, Vec) = rule_nodes .into_iter() .partition(|rule_node| !rule_node.yaml["correlation"].is_badvalue()); let mut parsed_rules: Vec = correlation_rules .into_iter() - .map(|rule_node| { - let rule_type = rule_node.yaml["correlation"]["type"].as_str(); - if rule_type != Some("event_count") && rule_type != Some("value_count") { - let m = "The type of correlations rule only supports event_count/value_count."; - error_log(&rule_node.rulepath, m, stored_static, parseerror_count); - return rule_node; - } - let related_rules_ids = get_related_rules_id(&rule_node.yaml); - let related_rules_ids = match related_rules_ids { - Ok(related_rules_ids) => related_rules_ids, - Err(_) => { - let m = "Related rule not found."; - error_log(&rule_node.rulepath, m, stored_static, parseerror_count); - return rule_node; - } - }; - if related_rules_ids.is_empty() { - let m = "Related rule not found."; - error_log(&rule_node.rulepath, m, stored_static, parseerror_count); - return rule_node; - } - let related_rules = - create_related_rule_nodes(related_rules_ids, &other_rules, stored_static); - let detection = create_detection(&rule_node, related_rules); - let detection = match detection { - Ok(detection) => detection, - Err(e) => { - error_log( - &rule_node.rulepath, - e.to_string().as_str(), - stored_static, - parseerror_count, - ); - return rule_node; - } - }; - RuleNode::new_with_detection(rule_node.rulepath, rule_node.yaml, detection) + .map(|correlation_rule_node| { + merge_referenced_rule( + correlation_rule_node, + &mut not_correlation_rules, + stored_static, + parse_error_count, + ) }) .collect(); - parsed_rules.extend(other_rules); + parsed_rules.extend(not_correlation_rules); parsed_rules } diff --git a/src/detections/rule/mod.rs b/src/detections/rule/mod.rs index d392c7a3a..a154488dc 100644 --- a/src/detections/rule/mod.rs +++ b/src/detections/rule/mod.rs @@ -176,12 +176,13 @@ impl DetectionNode { } pub fn new_with_data( + name_to_selection: HashMap>>, condition: Option>, aggregation_condition: Option, timeframe: Option, ) -> DetectionNode { DetectionNode { - name_to_selection: HashMap::new(), + name_to_selection, condition, aggregation_condition, timeframe, diff --git a/src/main.rs b/src/main.rs index 69bc3a2de..1bed3e2d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ use libmimalloc_sys::mi_stats_print_out; use mimalloc::MiMalloc; use nested::Nested; use num_format::{Locale, ToFormattedString}; +use rust_embed::Embed; use serde_json::{Map, Value}; use termcolor::{BufferWriter, Color, ColorChoice}; use tokio::runtime::Runtime; @@ -65,7 +66,8 @@ use hayabusa::timeline::computer_metrics::countup_event_by_computer; use hayabusa::{detections::configs, timeline::timelines::Timeline}; use hayabusa::{detections::utils::write_color_buffer, filter}; use hayabusa::{options, yaml}; -use rust_embed::Embed; +#[cfg(target_os = "windows")] +use is_elevated::is_elevated; #[derive(Embed)] #[folder = "art/"] @@ -76,9 +78,6 @@ struct Arts; #[include = "contributors.txt"] struct Contributors; -#[cfg(target_os = "windows")] -use is_elevated::is_elevated; - #[global_allocator] static GLOBAL: MiMalloc = MiMalloc;