From fc9c1253a42dd8e95439aa5b7aa0e4e7b75953c4 Mon Sep 17 00:00:00 2001 From: "Victor M. Alvarez" Date: Mon, 4 Sep 2023 17:20:48 +0200 Subject: [PATCH] feat: implement wide mode scanning in FastVM --- yara-x/src/re/fast/fastvm.rs | 212 +++++++++++++++++++++++++------ yara-x/src/re/thompson/pikevm.rs | 56 +++++++- yara-x/src/scanner/context.rs | 49 ++----- yara-x/src/tests/mod.rs | 29 +++++ 4 files changed, 265 insertions(+), 81 deletions(-) diff --git a/yara-x/src/re/fast/fastvm.rs b/yara-x/src/re/fast/fastvm.rs index 52f6f465d..a49c77804 100644 --- a/yara-x/src/re/fast/fastvm.rs +++ b/yara-x/src/re/fast/fastvm.rs @@ -63,6 +63,7 @@ impl<'r> FastVM<'r> { &mut self, start: C, input: &[u8], + wide: bool, mut f: impl FnMut(usize) -> Action, ) where C: CodeLoc, @@ -76,6 +77,7 @@ impl<'r> FastVM<'r> { &input[..cmp::min(input.len(), self.scan_limit)] }; + let step = if wide { 2 } else { 1 }; let mut next_positions = IndexSet::default(); self.positions.insert(0); @@ -107,15 +109,18 @@ impl<'r> FastVM<'r> { self.try_match_literal_bck( &input[..input.len() - position], literal, + wide, ) } else { self.try_match_literal_fwd( &input[*position..], literal, + wide, ) }; if is_match { - next_positions.insert(position + literal.len()); + next_positions + .insert(position + step * literal.len()); } } } @@ -129,16 +134,19 @@ impl<'r> FastVM<'r> { &input[..input.len() - position], literal, mask, + wide, ) } else { self.try_match_masked_literal_fwd( &input[*position..], literal, mask, + wide, ) }; if is_match { - next_positions.insert(position + literal.len()); + next_positions + .insert(position + step * literal.len()); } } } @@ -154,16 +162,19 @@ impl<'r> FastVM<'r> { self.try_match_literal_bck( &input[..input.len() - position], literal, + wide, ) } else { self.try_match_literal_fwd( &input[*position..], literal, + wide, ) }; if is_match { - next_positions - .insert(position + literal.len()); + next_positions.insert( + position + step * literal.len(), + ); } } Instr::MaskedLiteral(literal, mask) => { @@ -172,17 +183,20 @@ impl<'r> FastVM<'r> { &input[..input.len() - position], literal, mask, + wide, ) } else { self.try_match_masked_literal_fwd( &input[*position..], literal, mask, + wide, ) }; if is_match { - next_positions - .insert(position + literal.len()); + next_positions.insert( + position + step * literal.len(), + ); } } // The only valid instructions in alternatives @@ -199,7 +213,7 @@ impl<'r> FastVM<'r> { } Instr::Jump(jump) => { for position in &self.positions { - next_positions.insert(position + jump as usize); + next_positions.insert(position + step * jump as usize); } } Instr::JumpRange(range) => { @@ -212,6 +226,7 @@ impl<'r> FastVM<'r> { self.jump_bck( &input[..input.len() - position], literal, + wide, &range, *position, &mut next_positions, @@ -226,6 +241,7 @@ impl<'r> FastVM<'r> { self.jump_fwd( &input[*position..], literal, + wide, &range, *position, &mut next_positions, @@ -242,6 +258,7 @@ impl<'r> FastVM<'r> { self.jump_bck( &input[..input.len() - position], literal, + wide, &range, *position, &mut next_positions, @@ -258,6 +275,7 @@ impl<'r> FastVM<'r> { self.jump_fwd( &input[*position..], literal, + wide, &range, *position, &mut next_positions, @@ -268,7 +286,7 @@ impl<'r> FastVM<'r> { for position in mem::take(&mut self.positions) { for i in range.clone() { next_positions - .insert(position + i as usize); + .insert(position + step * i as usize); } } } @@ -284,19 +302,48 @@ impl<'r> FastVM<'r> { impl FastVM<'_> { #[inline] - fn try_match_literal_fwd(&self, input: &[u8], literal: &[u8]) -> bool { - if input.len() < literal.len() { - return false; + fn try_match_literal_fwd( + &self, + input: &[u8], + literal: &[u8], + wide: bool, + ) -> bool { + if wide { + if input.len() < literal.len() * 2 { + return false; + } + input.iter().step_by(2).eq(literal.iter()) + } else { + if input.len() < literal.len() { + return false; + } + memeq(&input[..literal.len()], literal) } - memeq(&input[..literal.len()], literal) } #[inline] - fn try_match_literal_bck(&self, input: &[u8], literal: &[u8]) -> bool { - if input.len() < literal.len() { - return false; + fn try_match_literal_bck( + &self, + input: &[u8], + literal: &[u8], + wide: bool, + ) -> bool { + if wide { + if input.len() < literal.len() * 2 { + return false; + } + input + .iter() // iterate input + .rev() // in reverse order + .skip(1) // skipping the last byte that should be 0 + .step_by(2) // two bytes at a time + .eq(literal.iter().rev()) + } else { + if input.len() < literal.len() { + return false; + } + memeq(&input[input.len() - literal.len()..], literal) } - memeq(&input[input.len() - literal.len()..], literal) } #[inline] @@ -305,17 +352,30 @@ impl FastVM<'_> { input: &[u8], literal: &[u8], mask: &[u8], + wide: bool, ) -> bool { debug_assert_eq!(literal.len(), mask.len()); - if input.len() < literal.len() { - return false; - } - - for (input, byte, mask) in izip!(input, literal, mask) { - if *input & *mask != *byte & *mask { + if wide { + if input.len() < literal.len() * 2 { + return false; + } + for (input, byte, mask) in + izip!(input.iter().step_by(2), literal, mask) + { + if *input & *mask != *byte & *mask { + return false; + } + } + } else { + if input.len() < literal.len() { return false; } + for (input, byte, mask) in izip!(input, literal, mask) { + if *input & *mask != *byte & *mask { + return false; + } + } } true @@ -327,19 +387,36 @@ impl FastVM<'_> { input: &[u8], literal: &[u8], mask: &[u8], + wide: bool, ) -> bool { debug_assert_eq!(literal.len(), mask.len()); - if input.len() < literal.len() { - return false; - } - - for (input, byte, mask) in - izip!(input.iter().rev(), literal.iter().rev(), mask.iter().rev()) - { - if *input & *mask != *byte & *mask { + if wide { + if input.len() < literal.len() * 2 { + return false; + } + for (input, byte, mask) in izip!( + input.iter().rev().step_by(2), + literal.iter().rev(), + mask.iter().rev() + ) { + if *input & *mask != *byte & *mask { + return false; + } + } + } else { + if input.len() < literal.len() { return false; } + for (input, byte, mask) in izip!( + input.iter().rev(), + literal.iter().rev(), + mask.iter().rev() + ) { + if *input & *mask != *byte & *mask { + return false; + } + } } true @@ -350,22 +427,36 @@ impl FastVM<'_> { &self, input: &[u8], literal: &[u8], + wide: bool, range: &RangeInclusive, position: usize, next_positions: &mut IndexSet, ) { - let jmp_min = *range.start() as usize; - let jmp_max = cmp::min(input.len(), *range.end() as usize + 1); - let jmp_range = jmp_min..jmp_max; + let step = if wide { 2 } else { 1 }; - if jmp_range.start >= jmp_range.end { + let n = *range.start() as usize * step; + let m = *range.end() as usize * step; + + let range_min = n; + let range_max = cmp::min(input.len(), m + step); + + if range_min >= range_max { return; } - if let Some(jmp_range) = input.get(jmp_range) { + if let Some(jmp_range) = input.get(range_min..range_max) { let lit = *literal.first().unwrap(); for offset in memchr::memchr_iter(lit, jmp_range) { - next_positions.insert(position + jmp_min + offset); + if wide { + // In wide mode we are only interested in bytes found + // at even offsets. At odd offsets the input should + // have only zeroes and they are not potential matches. + if offset % 2 == 0 { + next_positions.insert(position + n + offset); + } + } else { + next_positions.insert(position + n + offset); + } } } } @@ -375,21 +466,60 @@ impl FastVM<'_> { &self, input: &[u8], literal: &[u8], + wide: bool, range: &RangeInclusive, position: usize, next_positions: &mut IndexSet, ) { - let jmp_range = input.len().saturating_sub(*range.end() as usize + 1) - ..input.len().saturating_sub(*range.start() as usize); + let step = if wide { 2 } else { 1 }; + + let n = *range.start() as usize * step; + let m = *range.end() as usize * step; + + // Let's explain the what this function does using the following + // pattern as an example: + // + // { 01 02 03 [n-m] 04 05 06 07 } + // + // The scheme below resumes what's happening. The atom found by + // Aho-Corasick is `04 05 06 07`, and this function is about to + // process the jump [n-m]. The input received is the data that ends + // where the atom starts. What we want to do is scanning the range + // `range_min..range_max` from right to left looking for all instances + // of `03`, which are positions where `01 02 03` could match while + // scanning backwards. + // + // |--------------- input ---------------| + // | |--------- M ----------| + // | | |---- N ----| + // | ... 01 02 03 | .................... | 04 05 06 07 + // ^ ^ + // range_min range_max + // + let range_min = input.len().saturating_sub(m + step); + let range_max = input.len().saturating_sub(n); - if jmp_range.start >= jmp_range.end { + if range_min >= range_max { return; } - if let Some(jmp_range) = input.get(jmp_range) { + if let Some(jmp_range) = input.get(range_min..range_max) { let lit = *literal.last().unwrap(); for offset in memchr::memrchr_iter(lit, jmp_range) { - next_positions.insert(position + jmp_range.len() - offset - 1); + if wide { + // In wide mode we are only interested in bytes found + // at even offsets. At odd offsets the input should + // have only zeroes and they are not potential matches. + if offset % 2 == 0 { + next_positions.insert( + position + n + jmp_range.len() - offset - step, + ); + } + } else { + next_positions.insert( + position + n + jmp_range.len() - offset - step, + ); + } } } } diff --git a/yara-x/src/re/thompson/pikevm.rs b/yara-x/src/re/thompson/pikevm.rs index fec670e48..d772a8e7b 100644 --- a/yara-x/src/re/thompson/pikevm.rs +++ b/yara-x/src/re/thompson/pikevm.rs @@ -56,6 +56,60 @@ impl<'r> PikeVM<'r> { self } + /// Executes VM code starting at the `start` location and calls `f` for + /// each match found. The `right` slice contains the bytes at the right + /// of the starting point (i.e: from the starting point until the end of + /// the input), while the `right` slice contains the bytes at the left of + /// the starting point (i.e: from the start of the input until the starting + /// point. + /// + /// ```text + /// <-- left --> | <-- right --> + /// a b c d e f | g h i j k l k + /// | + /// starting point + /// ``` + /// + /// The `f` function must return either [`Action::Continue`] or + /// [`Action::Stop`], the former will cause the VM to keep trying to find + /// longer matches, while the latter will stop the scanning. + #[inline] + pub(crate) fn try_match( + &mut self, + start: C, + right: &[u8], + left: &[u8], + wide: bool, + mut f: impl FnMut(usize) -> Action, + ) where + C: CodeLoc, + { + match (start.backwards(), wide) { + // Going forward, not wide. + (false, false) => { + self.try_match_impl(start, right.iter(), left.iter().rev(), f) + } + // Going forward, wide. + (false, true) => self.try_match_impl( + start, + right.iter().step_by(2), + left.iter().rev().skip(1).step_by(2), + |match_len| f(match_len * 2), + ), + // Going backward, not wide. + (true, false) => { + self.try_match_impl(start, left.iter().rev(), right.iter(), f) + } + // Going backward, wide. + (true, true) => self.try_match_impl( + start, + left.iter().rev().skip(1).step_by(2), + right.iter().step_by(2), + |match_len| f(match_len * 2), + ), + } + } + /// Executes VM code starting at the `start` location and calls `f` for /// each match found. Input bytes are read from the `fwd_input` iterator /// until no more bytes are available or the scan limit is reached. When @@ -83,7 +137,7 @@ impl<'r> PikeVM<'r> { /// that appear right before the start of `fwd_input` for matching some /// look-around assertions that need information about the surrounding /// bytes. - pub(crate) fn try_match<'a, C, F, B>( + fn try_match_impl<'a, C, F, B>( &mut self, start: C, mut fwd_input: F, diff --git a/yara-x/src/scanner/context.rs b/yara-x/src/scanner/context.rs index d4d8c4a1b..1263e4202 100644 --- a/yara-x/src/scanner/context.rs +++ b/yara-x/src/scanner/context.rs @@ -831,24 +831,11 @@ fn verify_regexp_match( // faster and less general FastVM, or for the slower but more general // PikeVM. if let Some(fwd_code) = atom.fwd_code() { - if flags.contains(SubPatternFlags::Wide) { - if flags.contains(SubPatternFlags::FastRegexp) { - todo!() - } else { - vm.pike_vm.try_match( - fwd_code, - scanned_data[atom_pos..].iter().step_by(2), - scanned_data[..atom_pos].iter().rev().skip(1).step_by(2), - |match_len| { - fwd_match_len = Some(match_len * 2); - Action::Stop - }, - ); - } - } else if flags.contains(SubPatternFlags::FastRegexp) { + if flags.contains(SubPatternFlags::FastRegexp) { vm.fast_vm.try_match( fwd_code, &scanned_data[atom_pos..], + flags.contains(SubPatternFlags::Wide), |match_len| { fwd_match_len = Some(match_len); Action::Stop @@ -857,8 +844,9 @@ fn verify_regexp_match( } else { vm.pike_vm.try_match( fwd_code, - scanned_data[atom_pos..].iter(), - scanned_data[..atom_pos].iter().rev(), + &scanned_data[atom_pos..], + &scanned_data[..atom_pos], + flags.contains(SubPatternFlags::Wide), |match_len| { fwd_match_len = Some(match_len); Action::Stop @@ -875,29 +863,11 @@ fn verify_regexp_match( }; if let Some(bck_code) = atom.bck_code() { - if flags.contains(SubPatternFlags::Wide) { - if flags.contains(SubPatternFlags::FastRegexp) { - todo!() - } else { - vm.pike_vm.try_match( - bck_code, - scanned_data[..atom_pos].iter().rev().skip(1).step_by(2), - scanned_data[atom_pos..].iter().step_by(2), - |bck_match_len| { - let range = atom_pos - bck_match_len * 2 - ..atom_pos + fwd_match_len; - if verify_full_word(scanned_data, &range, flags, None) - { - f(Match { range, xor_key: None }); - } - Action::Continue - }, - ); - } - } else if flags.contains(SubPatternFlags::FastRegexp) { + if flags.contains(SubPatternFlags::FastRegexp) { vm.fast_vm.try_match( bck_code, &scanned_data[..atom_pos], + flags.contains(SubPatternFlags::Wide), |bck_match_len| { let range = atom_pos - bck_match_len..atom_pos + fwd_match_len; @@ -910,8 +880,9 @@ fn verify_regexp_match( } else { vm.pike_vm.try_match( bck_code, - scanned_data[..atom_pos].iter().rev(), - scanned_data[atom_pos..].iter(), + &scanned_data[atom_pos..], + &scanned_data[..atom_pos], + flags.contains(SubPatternFlags::Wide), |bck_match_len| { let range = atom_pos - bck_match_len..atom_pos + fwd_match_len; diff --git a/yara-x/src/tests/mod.rs b/yara-x/src/tests/mod.rs index a47536edc..1073426ad 100644 --- a/yara-x/src/tests/mod.rs +++ b/yara-x/src/tests/mod.rs @@ -689,6 +689,11 @@ fn hex_patterns() { pattern_false!(r#"{ 03 04 [-] 01 02 }"#, &[0x01, 0x02, 0x03, 0x04]); pattern_false!(r#"{ 01 02 [2-] 03 04 }"#, &[0x01, 0x02, 0xFF, 0x03, 0x04]); + pattern_false!( + r#"{ 01 03 03 [1-3] 03 04 05 06 07 }"#, + &[0x01, 0x03, 0x03, 0x03, 0x04, 0x05, 0x06, 0x07] + ); + pattern_match!( r#"{ 01 02 ~03 04 05 }"#, &[0x01, 0x02, 0xFF, 0x04, 0x05], @@ -1218,6 +1223,30 @@ fn regexp_wide() { b"f\x00o\x00o\x00b\x00a\x00r\x00", b"f\x00o\x00o\x00b\x00a\x00r\x00" ); + + pattern_match!( + r#"/foo.{1,3}?bar/s wide"#, + b"f\x00o\x00o\x00X\x00b\x00a\x00r\x00", + b"f\x00o\x00o\x00X\x00b\x00a\x00r\x00" + ); + + pattern_match!( + r#"/fo.{0,3}?bar/s wide"#, + b"f\x00o\x00b\x00a\x00r\x00", + b"f\x00o\x00b\x00a\x00r\x00" + ); + + pattern_match!( + r#"/fo.{0,3}?bar/s wide"#, + b"f\x00o\x00o\x00b\x00a\x00r\x00", + b"f\x00o\x00o\x00b\x00a\x00r\x00" + ); + + pattern_match!( + r#"/fo.{0,3}?bar/s wide"#, + b"f\x00o\x00o\x00o\x00b\x00a\x00r\x00", + b"f\x00o\x00o\x00o\x00b\x00a\x00r\x00" + ); } #[test]