From 812ec2dc553afbba921ddc0e2aa9742bd6232156 Mon Sep 17 00:00:00 2001 From: "Victor M. Alvarez" Date: Thu, 31 Aug 2023 17:41:21 +0200 Subject: [PATCH] perf: reimplement FastVM with a faster algorithm --- yara-x/src/compiler/atoms/quality.rs | 2 +- yara-x/src/re/fast/fastvm.rs | 236 +++++++++++++-------------- 2 files changed, 111 insertions(+), 127 deletions(-) diff --git a/yara-x/src/compiler/atoms/quality.rs b/yara-x/src/compiler/atoms/quality.rs index 4f19adc6a..73406d0c3 100644 --- a/yara-x/src/compiler/atoms/quality.rs +++ b/yara-x/src/compiler/atoms/quality.rs @@ -401,7 +401,7 @@ mod test { assert!(q_01xx03 < q_010203); assert!(q_010x0x > q_01); assert!(q_010x0x < q_010203); - assert!(q_01020000 < q_0102xx04); + assert_eq!(q_01020000, q_0102xx04); assert!(q_01020102 > q_01020000); assert!(q_01020102 > q_01010101); assert!(q_01020304 > q_01020102); diff --git a/yara-x/src/re/fast/fastvm.rs b/yara-x/src/re/fast/fastvm.rs index 96256b7fb..85882703a 100644 --- a/yara-x/src/re/fast/fastvm.rs +++ b/yara-x/src/re/fast/fastvm.rs @@ -1,5 +1,6 @@ -use std::cmp; +use std::collections::BTreeSet; use std::ops::RangeInclusive; +use std::{cmp, mem}; use itertools::izip; use memx::memeq; @@ -16,17 +17,21 @@ use crate::re::{Action, CodeLoc, DEFAULT_SCAN_LIMIT}; pub(crate) struct FastVM<'r> { /// The code for the VM. Produced by [`crate::re::fast::Compiler`]. code: &'r [u8], - /// The list of currently active positions. - threads: Vec<(usize, usize)>, /// Maximum number of bytes to scan. The VM will abort after ingesting /// this number of bytes from the input. scan_limit: usize, + /// A set with all the positions currently tracked. + positions: BTreeSet, } impl<'r> FastVM<'r> { /// Creates a new [`FastVM`]. pub fn new(code: &'r [u8]) -> Self { - Self { code, threads: Vec::new(), scan_limit: DEFAULT_SCAN_LIMIT } + Self { + code, + positions: BTreeSet::new(), + scan_limit: DEFAULT_SCAN_LIMIT, + } } /// Specifies the maximum number of bytes that will be scanned by the @@ -55,26 +60,35 @@ impl<'r> FastVM<'r> { C: CodeLoc, { let backwards = start.backwards(); - self.threads.push((start.location(), 0)); + let mut ip = start.location(); + + let input = if backwards { + &input[input.len().saturating_sub(self.scan_limit)..] + } else { + &input[..cmp::min(input.len(), self.scan_limit)] + }; - let input = &input[..cmp::min(input.len(), self.scan_limit)]; + self.positions.insert(0); - while let Some((mut ip, mut position)) = self.threads.pop() { - while position <= input.len() { - let (instr, instr_size) = - InstrParser::decode_instr(&self.code[ip..]); + while !self.positions.is_empty() { + let (instr, instr_size) = + InstrParser::decode_instr(&self.code[ip..]); - ip += instr_size; + ip += instr_size; - match instr { - Instr::Match => match f(position) { - Action::Stop => { - self.threads.clear(); - return; + match instr { + Instr::Match => { + for position in mem::take(&mut self.positions) { + match f(position) { + Action::Stop => { + return; + } + Action::Continue => {} } - Action::Continue => break, - }, - Instr::Literal(literal) => { + } + } + Instr::Literal(literal) => { + for position in mem::take(&mut self.positions) { let is_match = if backwards { self.try_match_literal_bck( &input[..input.len() - position], @@ -86,12 +100,13 @@ impl<'r> FastVM<'r> { literal, ) }; - if !is_match { - break; + if is_match { + self.positions.insert(position + literal.len()); } - position += literal.len(); } - Instr::MaskedLiteral(literal, mask) => { + } + Instr::MaskedLiteral(literal, mask) => { + for position in mem::take(&mut self.positions) { let is_match = if backwards { self.try_match_masked_literal_bck( &input[..input.len() - position], @@ -105,80 +120,68 @@ impl<'r> FastVM<'r> { mask, ) }; - if !is_match { - break; + if is_match { + self.positions.insert(position + literal.len()); } - position += literal.len(); } - Instr::Jump(jump) => { - position += jump as usize; + } + Instr::Jump(jump) => { + for position in mem::take(&mut self.positions) { + self.positions.insert(position + jump as usize); } - Instr::JumpRange(range) => { - match InstrParser::decode_instr(&self.code[ip..]) { - (Instr::Literal(literal), _) if backwards => { - if let Some(new_position) = self.jump_bck( + } + Instr::JumpRange(range) => { + match InstrParser::decode_instr(&self.code[ip..]) { + (Instr::Literal(literal), _) if backwards => { + for position in mem::take(&mut self.positions) { + self.jump_bck( &input[..input.len() - position], - *literal.last().unwrap(), - range, - ip, + literal, + &range, position, - ) { - position = new_position; - } else { - break; - } + ); } - (Instr::Literal(literal), _) if !backwards => { - if let Some(new_position) = self.jump_fwd( + } + (Instr::Literal(literal), _) if !backwards => { + for position in mem::take(&mut self.positions) { + self.jump_fwd( &input[position..], - *literal.first().unwrap(), - range, - ip, + literal, + &range, position, - ) { - position = new_position; - } else { - break; - } + ) } - (Instr::MaskedLiteral(literal, mask), _) - if backwards && mask.last() == Some(&0xff) => - { - if let Some(new_position) = self.jump_bck( + } + (Instr::MaskedLiteral(literal, mask), _) + if backwards && mask.last() == Some(&0xff) => + { + for position in mem::take(&mut self.positions) { + self.jump_bck( &input[..input.len() - position], - *literal.last().unwrap(), - range, - ip, + literal, + &range, position, - ) { - position = new_position; - } else { - break; - } + ); } - (Instr::MaskedLiteral(literal, mask), _) - if !backwards - && mask.first() == Some(&0xff) => - { - if let Some(new_position) = self.jump_fwd( + } + (Instr::MaskedLiteral(literal, mask), _) + if !backwards && mask.first() == Some(&0xff) => + { + for position in mem::take(&mut self.positions) { + self.jump_fwd( &input[position..], - *literal.first().unwrap(), - range, - ip, + literal, + &range, position, - ) { - position = new_position; - } else { - break; - } + ); } - _ => { - let min = *range.start() as usize; - for i in range.skip(1).rev() { - self.threads - .push((ip, position + i as usize)) + } + _ => { + for position in mem::take(&mut self.positions) { + for i in range.clone() { + self.positions + .insert(position + i as usize); } - position += min; } } } @@ -190,7 +193,7 @@ impl<'r> FastVM<'r> { impl FastVM<'_> { #[inline] - fn try_match_literal_fwd(&mut self, input: &[u8], literal: &[u8]) -> bool { + fn try_match_literal_fwd(&self, input: &[u8], literal: &[u8]) -> bool { if input.len() < literal.len() { return false; } @@ -207,7 +210,7 @@ impl FastVM<'_> { #[inline] fn try_match_masked_literal_fwd( - &mut self, + &self, input: &[u8], literal: &[u8], mask: &[u8], @@ -229,7 +232,7 @@ impl FastVM<'_> { #[inline] fn try_match_masked_literal_bck( - &mut self, + &self, input: &[u8], literal: &[u8], mask: &[u8], @@ -255,65 +258,46 @@ impl FastVM<'_> { fn jump_fwd( &mut self, input: &[u8], - lit: u8, - range: RangeInclusive, - ip: usize, - mut position: usize, - ) -> Option { + literal: &[u8], + range: &RangeInclusive, + position: usize, + ) { let jmp_min = *range.start() as usize; - let jmp_max = std::cmp::min(input.len(), *range.end() as usize + 1); + let jmp_max = cmp::min(input.len(), *range.end() as usize + 1); let jmp_range = jmp_min..jmp_max; if jmp_range.start >= jmp_range.end { - return None; + return; } - let mut offsets = memchr::memrchr_iter(lit, input.get(jmp_range)?); - - let last = offsets.next_back(); - for offset in offsets { - self.threads.push((ip, position + jmp_min + offset)) - } - - if let Some(offset) = last { - position += jmp_min + offset; - } else { - return None; + if let Some(jmp_range) = input.get(jmp_range) { + let lit = *literal.first().unwrap(); + for offset in memchr::memchr_iter(lit, jmp_range) { + self.positions.insert(position + jmp_min + offset); + } } - - Some(position) } #[inline] fn jump_bck( &mut self, input: &[u8], - lit: u8, - range: RangeInclusive, - ip: usize, - mut position: usize, - ) -> Option { - let jump_range = input.len().saturating_sub(*range.end() as usize + 1) + literal: &[u8], + range: &RangeInclusive, + position: usize, + ) { + let jmp_range = input.len().saturating_sub(*range.end() as usize + 1) ..input.len().saturating_sub(*range.start() as usize); - if jump_range.start >= jump_range.end { - return None; - } - - let jump_range = input.get(jump_range)?; - let mut offsets = memchr::memchr_iter(lit, jump_range); - - let last = offsets.next_back(); - for offset in offsets { - self.threads.push((ip, position + jump_range.len() - offset - 1)) + if jmp_range.start >= jmp_range.end { + return; } - if let Some(offset) = last { - position += jump_range.len() - offset - 1; - } else { - return None; + if let Some(jmp_range) = input.get(jmp_range) { + let lit = *literal.last().unwrap(); + for offset in memchr::memrchr_iter(lit, jmp_range) { + self.positions.insert(position + jmp_range.len() - offset - 1); + } } - - Some(position) } }