Skip to content

Commit

Permalink
add backend thread for port sequence detector
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyYe committed Mar 1, 2024
1 parent 7493721 commit 2939da4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 37 deletions.
1 change: 1 addition & 0 deletions src/sequence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub use port_sequence::PortSequenceDetector;
mod port_sequence;

pub trait SequenceDetector {
fn start(&mut self);
fn add_sequence(&mut self, client_ip: String, sequence: i32);
fn match_sequence(&mut self, client_ip: &str) -> bool;
}
117 changes: 80 additions & 37 deletions src/sequence/port_sequence.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{SystemTime, UNIX_EPOCH};

use crate::config::Config;
Expand All @@ -10,8 +12,8 @@ pub struct PortSequenceDetector {
timeout: u64,
sequence_set: HashSet<i32>,
sequence_rules: Vec<Vec<i32>>,
client_sequences: HashMap<String, Vec<i32>>,
client_timeout: HashMap<String, u64>,
client_sequences: Arc<Mutex<HashMap<String, Vec<i32>>>>,
client_timeout: Arc<Mutex<HashMap<String, u64>>>,
}

impl PortSequenceDetector {
Expand All @@ -33,8 +35,8 @@ impl PortSequenceDetector {
timeout: config.timeout,
sequence_set,
sequence_rules,
client_sequences: HashMap::new(),
client_timeout: HashMap::new(),
client_sequences: Arc::new(Mutex::new(HashMap::new())),
client_timeout: Arc::new(Mutex::new(HashMap::new())),
}
}
}
Expand All @@ -51,26 +53,30 @@ impl SequenceDetector for PortSequenceDetector {
client_ip, sequence
);

let client_sequence = self
.client_sequences
.entry(client_ip.clone())
.or_insert(Vec::new());
client_sequence.push(sequence);

// get the current time stamp
self.client_timeout.entry(client_ip.clone()).or_insert(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
);
{
let mut client_sequence = self.client_sequences.lock().unwrap();
let client_sequence = client_sequence
.entry(client_ip.clone())
.or_insert(Vec::new());
client_sequence.push(sequence);

// get the current time stamp
let mut client_timeout = self.client_timeout.lock().unwrap();
client_timeout.entry(client_ip.clone()).or_insert(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
);
}

self.match_sequence(&client_ip);
}

fn match_sequence(&mut self, client_ip: &str) -> bool {
// Check if the current sequence matches any of the rules
let client_sequence = self.client_sequences.get_mut(client_ip);
let mut client_sequence = self.client_sequences.lock().unwrap();
let client_sequence = client_sequence.get_mut(client_ip);
if let Some(sequence) = client_sequence {
for rule in &self.sequence_rules {
if sequence.ends_with(rule) {
Expand All @@ -80,35 +86,58 @@ impl SequenceDetector for PortSequenceDetector {
return true;
}
}
}

// check if the sequence has expired
let timeout_entry = self.client_timeout.get(client_ip);
if let Some(timeout) = timeout_entry {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
false
}

if current_time - timeout > self.timeout {
println!("Sequence timeout for: {}", client_ip);
sequence.clear();
self.client_timeout.remove(client_ip);
}
fn start(&mut self) {
let client_sequences = Arc::clone(&self.client_sequences);
let client_timeout = Arc::clone(&self.client_timeout);
let timeout = self.timeout;

thread::spawn(move || loop {
thread::sleep(std::time::Duration::from_millis(200));

let mut client_sequences = client_sequences.lock().unwrap();
let mut client_timeout = client_timeout.lock().unwrap();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();

let clients_to_remove: Vec<_> = client_timeout
.iter()
.filter_map(|(client_ip, _)| {
let last_time = client_timeout.get(client_ip).unwrap();
if now - last_time > timeout {
return Some(client_ip.clone());
}
None
})
.collect();

for client_ip in clients_to_remove {
println!("Removing client: {} due to timeout...", client_ip);
client_sequences.remove(&client_ip);
client_timeout.remove(&client_ip);
}
}
});

false
println!("Port sequence detector thread started...");
}
}

#[cfg(test)]
mod tests {
use std::{thread, time::Duration};

use super::*;

fn create_config() -> Config {
Config {
interface: "enp3s0".to_string(),
timeout: 5,
timeout: 2,
rules: vec![
crate::config::config::Rule {
name: "enable ssh".to_string(),
Expand All @@ -130,23 +159,36 @@ mod tests {
let detector = PortSequenceDetector::new(config);
assert_eq!(detector.sequence_set.len(), 5);
assert_eq!(detector.sequence_rules.len(), 2);
assert_eq!(detector.timeout, 5);
assert_eq!(detector.timeout, 2);
}

#[test]
fn test_add_sequence() {
let config = create_config();
let mut detector = PortSequenceDetector::new(config);
detector.add_sequence("127.0.0.1".to_owned(), 3);
assert_eq!(detector.client_sequences.get("127.0.0.1"), Some(&vec![3]));
let client_sequences = detector.client_sequences.lock().unwrap();
assert_eq!(client_sequences.get("127.0.0.1"), Some(&vec![3]));
}

#[test]
fn test_add_sequence_with_timeout() {
let config = create_config();
let mut detector = PortSequenceDetector::new(config);
detector.start();
detector.add_sequence("127.0.0.1".to_owned(), 3);
thread::sleep(Duration::from_secs(4));
let client_sequences = detector.client_sequences.lock().unwrap();
assert_eq!(client_sequences.get("127.0.0.1"), None);
}

#[test]
fn test_add_none_existing_sequence() {
let config = create_config();
let mut detector = PortSequenceDetector::new(config);
detector.add_sequence("127.0.0.1".to_owned(), 9);
assert_eq!(detector.client_sequences.get("127.0.0.1"), None);
let client_sequences = detector.client_sequences.lock().unwrap();
assert_eq!(client_sequences.get("127.0.0.1"), None);
}

#[test]
Expand All @@ -158,6 +200,7 @@ mod tests {
detector.add_sequence("127.0.0.1".to_owned(), 5);
detector.add_sequence("127.0.0.1".to_owned(), 6);
assert_eq!(detector.match_sequence("127.0.0.1"), false);
assert_eq!(detector.client_sequences.get("127.0.0.1").unwrap().len(), 0);
let client_sequences = detector.client_sequences.lock().unwrap();
assert_eq!(client_sequences.get("127.0.0.1").unwrap().len(), 0);
}
}

0 comments on commit 2939da4

Please sign in to comment.