Skip to content

Commit

Permalink
implement the sequence detector
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyYe committed Feb 29, 2024
1 parent 097361b commit 51eb3b4
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate serde;

use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Rule {
pub name: String,
pub sequence: Vec<i32>,
Expand Down
14 changes: 8 additions & 6 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod config;

use config::Config;
use std::fs::File;
use std::io::Read;

pub fn load_config() -> Result<Config, Box<dyn std::error::Error>> {
let mut file = File::open("config.yaml")?;
pub use config::Config;
pub use config::Rule;

mod config;

pub fn load_config(path: &str) -> Result<Config, Box<dyn std::error::Error>> {
let mut file = File::open(path)?;
let mut content = String::new();

file.read_to_string(&mut content)?;
Expand All @@ -21,7 +23,7 @@ mod tests {

#[test]
fn test_load_config() {
let config = load_config().unwrap();
let config = load_config("config.yaml").unwrap();
assert_eq!(config.interface, "enp3s0");
assert_eq!(config.rules.len(), 2);
}
Expand Down
10 changes: 4 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
mod server;
mod config;

use server::Server;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = config::load_config()?;
println!("{:?}", config);
mod config;
mod sequence;
mod server;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let server = Server::new("enp3s0".to_string());
server.start();

Expand Down
6 changes: 6 additions & 0 deletions src/sequence/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
mod port_sequence;

pub trait SequenceDetector {
fn add_sequence(&mut self, client_ip: String, sequence: i32);
fn match_sequence(&self, client_ip: &str) -> bool;
}
120 changes: 120 additions & 0 deletions src/sequence/port_sequence.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::collections::{HashMap, HashSet};

use crate::config::Config;
use crate::sequence::SequenceDetector;

#[derive(Debug)]
pub struct PortSequenceDetector {
sequence_set: HashSet<i32>,
sequence_rules: Vec<Vec<i32>>,
client_sequences: HashMap<String, Vec<i32>>,
}

impl PortSequenceDetector {
pub fn new(config: Config) -> PortSequenceDetector {
let mut sequence_rules = Vec::new();
for rule in config.rules.clone() {
sequence_rules.push(rule.sequence);
}

let mut sequence_set = HashSet::new();
for rule in config.rules {
for sequence in rule.sequence {
sequence_set.insert(sequence);
}
}

PortSequenceDetector {
sequence_set,
sequence_rules,
client_sequences: HashMap::new(),
}
}
}

impl SequenceDetector for PortSequenceDetector {
fn add_sequence(&mut self, client_ip: String, sequence: i32) {
// check if the sequence is in the set
if !self.sequence_set.contains(&sequence) {
return;
}

let client_sequence = self.client_sequences.entry(client_ip).or_insert(Vec::new());
client_sequence.push(sequence);
}

fn match_sequence(&self, client_ip: &str) -> bool {
// Check if the current sequence matches any of the rules
let client_sequence = self.client_sequences.get(client_ip);
if let Some(sequence) = client_sequence {
for rule in &self.sequence_rules {
if sequence.ends_with(rule) {
println!("Matched sequence: {:?}", rule);
return true;
}
}
}

false
}
}

#[cfg(test)]
mod tests {
use super::*;

fn create_config() -> Config {
Config {
interface: "enp3s0".to_string(),
rules: vec![
crate::config::Rule {
name: "enable ssh".to_string(),
sequence: vec![1, 2, 3],
timeout: 5,
command: "ls -lh".to_string(),
},
crate::config::Rule {
name: "disable ssh".to_string(),
sequence: vec![3, 5, 6],
timeout: 5,
command: "du -sh *".to_string(),
},
],
}
}

#[test]
fn test_new() {
let config = create_config();
let detector = PortSequenceDetector::new(config);
assert_eq!(detector.sequence_set.len(), 5);
assert_eq!(detector.sequence_rules.len(), 2);
}

#[test]
fn test_add_sequence() {
let config = create_config();
let mut detector = PortSequenceDetector::new(config);
detector.add_sequence("127.0.0.1".to_owned(), 3);
assert_eq!(detector.client_sequences.get("127.0.0.1"), Some(&vec![3]));
}

#[test]
fn test_add_none_existing_sequence() {
let config = create_config();
let mut detector = PortSequenceDetector::new(config);
detector.add_sequence("127.0.0.1".to_owned(), 9);
assert_eq!(detector.client_sequences.get("127.0.0.1"), None);
}

#[test]
fn test_match_sequence() {
let config = create_config();
let mut detector = PortSequenceDetector::new(config);
detector.add_sequence("127.0.0.1".to_owned(), 1);
detector.add_sequence("127.0.0.1".to_owned(), 3);
detector.add_sequence("127.0.0.1".to_owned(), 5);
detector.add_sequence("127.0.0.1".to_owned(), 6);
assert_eq!(detector.match_sequence("127.0.0.1"), true);
}
}

0 comments on commit 51eb3b4

Please sign in to comment.