diff --git a/src/config/config.rs b/src/config/config.rs index 9169ab3..a5455f5 100644 --- a/src/config/config.rs +++ b/src/config/config.rs @@ -2,7 +2,7 @@ extern crate serde; use serde::{Deserialize, Serialize}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Rule { pub name: String, pub sequence: Vec, diff --git a/src/config/mod.rs b/src/config/mod.rs index eaa5444..be3c703 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,11 +1,13 @@ -mod config; - -use config::Config; use std::fs::File; use std::io::Read; -pub fn load_config() -> Result> { - let mut file = File::open("config.yaml")?; +pub use config::Config; +pub use config::Rule; + +mod config; + +pub fn load_config(path: &str) -> Result> { + let mut file = File::open(path)?; let mut content = String::new(); file.read_to_string(&mut content)?; @@ -21,7 +23,7 @@ mod tests { #[test] fn test_load_config() { - let config = load_config().unwrap(); + let config = load_config("config.yaml").unwrap(); assert_eq!(config.interface, "enp3s0"); assert_eq!(config.rules.len(), 2); } diff --git a/src/main.rs b/src/main.rs index bd2bf36..16cc4e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,10 @@ -mod server; -mod config; - use server::Server; -fn main() -> Result<(), Box> { - let config = config::load_config()?; - println!("{:?}", config); +mod config; +mod sequence; +mod server; +fn main() -> Result<(), Box> { let server = Server::new("enp3s0".to_string()); server.start(); diff --git a/src/sequence/mod.rs b/src/sequence/mod.rs new file mode 100644 index 0000000..8b882e5 --- /dev/null +++ b/src/sequence/mod.rs @@ -0,0 +1,6 @@ +mod port_sequence; + +pub trait SequenceDetector { + fn add_sequence(&mut self, client_ip: String, sequence: i32); + fn match_sequence(&self, client_ip: &str) -> bool; +} diff --git a/src/sequence/port_sequence.rs b/src/sequence/port_sequence.rs new file mode 100644 index 0000000..2903514 --- /dev/null +++ b/src/sequence/port_sequence.rs @@ -0,0 +1,120 @@ +use std::collections::{HashMap, HashSet}; + +use crate::config::Config; +use crate::sequence::SequenceDetector; + +#[derive(Debug)] +pub struct PortSequenceDetector { + sequence_set: HashSet, + sequence_rules: Vec>, + client_sequences: HashMap>, +} + +impl PortSequenceDetector { + pub fn new(config: Config) -> PortSequenceDetector { + let mut sequence_rules = Vec::new(); + for rule in config.rules.clone() { + sequence_rules.push(rule.sequence); + } + + let mut sequence_set = HashSet::new(); + for rule in config.rules { + for sequence in rule.sequence { + sequence_set.insert(sequence); + } + } + + PortSequenceDetector { + sequence_set, + sequence_rules, + client_sequences: HashMap::new(), + } + } +} + +impl SequenceDetector for PortSequenceDetector { + fn add_sequence(&mut self, client_ip: String, sequence: i32) { + // check if the sequence is in the set + if !self.sequence_set.contains(&sequence) { + return; + } + + let client_sequence = self.client_sequences.entry(client_ip).or_insert(Vec::new()); + client_sequence.push(sequence); + } + + fn match_sequence(&self, client_ip: &str) -> bool { + // Check if the current sequence matches any of the rules + let client_sequence = self.client_sequences.get(client_ip); + if let Some(sequence) = client_sequence { + for rule in &self.sequence_rules { + if sequence.ends_with(rule) { + println!("Matched sequence: {:?}", rule); + return true; + } + } + } + + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_config() -> Config { + Config { + interface: "enp3s0".to_string(), + rules: vec![ + crate::config::Rule { + name: "enable ssh".to_string(), + sequence: vec![1, 2, 3], + timeout: 5, + command: "ls -lh".to_string(), + }, + crate::config::Rule { + name: "disable ssh".to_string(), + sequence: vec![3, 5, 6], + timeout: 5, + command: "du -sh *".to_string(), + }, + ], + } + } + + #[test] + fn test_new() { + let config = create_config(); + let detector = PortSequenceDetector::new(config); + assert_eq!(detector.sequence_set.len(), 5); + assert_eq!(detector.sequence_rules.len(), 2); + } + + #[test] + fn test_add_sequence() { + let config = create_config(); + let mut detector = PortSequenceDetector::new(config); + detector.add_sequence("127.0.0.1".to_owned(), 3); + assert_eq!(detector.client_sequences.get("127.0.0.1"), Some(&vec![3])); + } + + #[test] + fn test_add_none_existing_sequence() { + let config = create_config(); + let mut detector = PortSequenceDetector::new(config); + detector.add_sequence("127.0.0.1".to_owned(), 9); + assert_eq!(detector.client_sequences.get("127.0.0.1"), None); + } + + #[test] + fn test_match_sequence() { + let config = create_config(); + let mut detector = PortSequenceDetector::new(config); + detector.add_sequence("127.0.0.1".to_owned(), 1); + detector.add_sequence("127.0.0.1".to_owned(), 3); + detector.add_sequence("127.0.0.1".to_owned(), 5); + detector.add_sequence("127.0.0.1".to_owned(), 6); + assert_eq!(detector.match_sequence("127.0.0.1"), true); + } +}