Skip to content

Commit

Permalink
feat: implement wide mode scanning in FastVM
Browse files Browse the repository at this point in the history
  • Loading branch information
plusvic committed Sep 4, 2023
1 parent f38f25e commit fc9c125
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 81 deletions.
212 changes: 171 additions & 41 deletions yara-x/src/re/fast/fastvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl<'r> FastVM<'r> {
&mut self,
start: C,
input: &[u8],
wide: bool,
mut f: impl FnMut(usize) -> Action,
) where
C: CodeLoc,
Expand All @@ -76,6 +77,7 @@ impl<'r> FastVM<'r> {
&input[..cmp::min(input.len(), self.scan_limit)]
};

let step = if wide { 2 } else { 1 };
let mut next_positions = IndexSet::default();

self.positions.insert(0);
Expand Down Expand Up @@ -107,15 +109,18 @@ impl<'r> FastVM<'r> {
self.try_match_literal_bck(
&input[..input.len() - position],
literal,
wide,
)
} else {
self.try_match_literal_fwd(
&input[*position..],
literal,
wide,
)
};
if is_match {
next_positions.insert(position + literal.len());
next_positions
.insert(position + step * literal.len());
}
}
}
Expand All @@ -129,16 +134,19 @@ impl<'r> FastVM<'r> {
&input[..input.len() - position],
literal,
mask,
wide,
)
} else {
self.try_match_masked_literal_fwd(
&input[*position..],
literal,
mask,
wide,
)
};
if is_match {
next_positions.insert(position + literal.len());
next_positions
.insert(position + step * literal.len());
}
}
}
Expand All @@ -154,16 +162,19 @@ impl<'r> FastVM<'r> {
self.try_match_literal_bck(
&input[..input.len() - position],
literal,
wide,
)
} else {
self.try_match_literal_fwd(
&input[*position..],
literal,
wide,
)
};
if is_match {
next_positions
.insert(position + literal.len());
next_positions.insert(
position + step * literal.len(),
);
}
}
Instr::MaskedLiteral(literal, mask) => {
Expand All @@ -172,17 +183,20 @@ impl<'r> FastVM<'r> {
&input[..input.len() - position],
literal,
mask,
wide,
)
} else {
self.try_match_masked_literal_fwd(
&input[*position..],
literal,
mask,
wide,
)
};
if is_match {
next_positions
.insert(position + literal.len());
next_positions.insert(
position + step * literal.len(),
);
}
}
// The only valid instructions in alternatives
Expand All @@ -199,7 +213,7 @@ impl<'r> FastVM<'r> {
}
Instr::Jump(jump) => {
for position in &self.positions {
next_positions.insert(position + jump as usize);
next_positions.insert(position + step * jump as usize);
}
}
Instr::JumpRange(range) => {
Expand All @@ -212,6 +226,7 @@ impl<'r> FastVM<'r> {
self.jump_bck(
&input[..input.len() - position],
literal,
wide,
&range,
*position,
&mut next_positions,
Expand All @@ -226,6 +241,7 @@ impl<'r> FastVM<'r> {
self.jump_fwd(
&input[*position..],
literal,
wide,
&range,
*position,
&mut next_positions,
Expand All @@ -242,6 +258,7 @@ impl<'r> FastVM<'r> {
self.jump_bck(
&input[..input.len() - position],
literal,
wide,
&range,
*position,
&mut next_positions,
Expand All @@ -258,6 +275,7 @@ impl<'r> FastVM<'r> {
self.jump_fwd(
&input[*position..],
literal,
wide,
&range,
*position,
&mut next_positions,
Expand All @@ -268,7 +286,7 @@ impl<'r> FastVM<'r> {
for position in mem::take(&mut self.positions) {
for i in range.clone() {
next_positions
.insert(position + i as usize);
.insert(position + step * i as usize);
}
}
}
Expand All @@ -284,19 +302,48 @@ impl<'r> FastVM<'r> {

impl FastVM<'_> {
#[inline]
fn try_match_literal_fwd(&self, input: &[u8], literal: &[u8]) -> bool {
if input.len() < literal.len() {
return false;
fn try_match_literal_fwd(
&self,
input: &[u8],
literal: &[u8],
wide: bool,
) -> bool {
if wide {
if input.len() < literal.len() * 2 {
return false;
}
input.iter().step_by(2).eq(literal.iter())
} else {
if input.len() < literal.len() {
return false;
}
memeq(&input[..literal.len()], literal)
}
memeq(&input[..literal.len()], literal)
}

#[inline]
fn try_match_literal_bck(&self, input: &[u8], literal: &[u8]) -> bool {
if input.len() < literal.len() {
return false;
fn try_match_literal_bck(
&self,
input: &[u8],
literal: &[u8],
wide: bool,
) -> bool {
if wide {
if input.len() < literal.len() * 2 {
return false;
}
input
.iter() // iterate input
.rev() // in reverse order
.skip(1) // skipping the last byte that should be 0
.step_by(2) // two bytes at a time
.eq(literal.iter().rev())
} else {
if input.len() < literal.len() {
return false;
}
memeq(&input[input.len() - literal.len()..], literal)
}
memeq(&input[input.len() - literal.len()..], literal)
}

#[inline]
Expand All @@ -305,17 +352,30 @@ impl FastVM<'_> {
input: &[u8],
literal: &[u8],
mask: &[u8],
wide: bool,
) -> bool {
debug_assert_eq!(literal.len(), mask.len());

if input.len() < literal.len() {
return false;
}

for (input, byte, mask) in izip!(input, literal, mask) {
if *input & *mask != *byte & *mask {
if wide {
if input.len() < literal.len() * 2 {
return false;
}
for (input, byte, mask) in
izip!(input.iter().step_by(2), literal, mask)
{
if *input & *mask != *byte & *mask {
return false;
}
}
} else {
if input.len() < literal.len() {
return false;
}
for (input, byte, mask) in izip!(input, literal, mask) {
if *input & *mask != *byte & *mask {
return false;
}
}
}

true
Expand All @@ -327,19 +387,36 @@ impl FastVM<'_> {
input: &[u8],
literal: &[u8],
mask: &[u8],
wide: bool,
) -> bool {
debug_assert_eq!(literal.len(), mask.len());

if input.len() < literal.len() {
return false;
}

for (input, byte, mask) in
izip!(input.iter().rev(), literal.iter().rev(), mask.iter().rev())
{
if *input & *mask != *byte & *mask {
if wide {
if input.len() < literal.len() * 2 {
return false;
}
for (input, byte, mask) in izip!(
input.iter().rev().step_by(2),
literal.iter().rev(),
mask.iter().rev()
) {
if *input & *mask != *byte & *mask {
return false;
}
}
} else {
if input.len() < literal.len() {
return false;
}
for (input, byte, mask) in izip!(
input.iter().rev(),
literal.iter().rev(),
mask.iter().rev()
) {
if *input & *mask != *byte & *mask {
return false;
}
}
}

true
Expand All @@ -350,22 +427,36 @@ impl FastVM<'_> {
&self,
input: &[u8],
literal: &[u8],
wide: bool,
range: &RangeInclusive<u16>,
position: usize,
next_positions: &mut IndexSet,
) {
let jmp_min = *range.start() as usize;
let jmp_max = cmp::min(input.len(), *range.end() as usize + 1);
let jmp_range = jmp_min..jmp_max;
let step = if wide { 2 } else { 1 };

if jmp_range.start >= jmp_range.end {
let n = *range.start() as usize * step;
let m = *range.end() as usize * step;

let range_min = n;
let range_max = cmp::min(input.len(), m + step);

if range_min >= range_max {
return;
}

if let Some(jmp_range) = input.get(jmp_range) {
if let Some(jmp_range) = input.get(range_min..range_max) {
let lit = *literal.first().unwrap();
for offset in memchr::memchr_iter(lit, jmp_range) {
next_positions.insert(position + jmp_min + offset);
if wide {
// In wide mode we are only interested in bytes found
// at even offsets. At odd offsets the input should
// have only zeroes and they are not potential matches.
if offset % 2 == 0 {
next_positions.insert(position + n + offset);
}
} else {
next_positions.insert(position + n + offset);
}
}
}
}
Expand All @@ -375,21 +466,60 @@ impl FastVM<'_> {
&self,
input: &[u8],
literal: &[u8],
wide: bool,
range: &RangeInclusive<u16>,
position: usize,
next_positions: &mut IndexSet,
) {
let jmp_range = input.len().saturating_sub(*range.end() as usize + 1)
..input.len().saturating_sub(*range.start() as usize);
let step = if wide { 2 } else { 1 };

let n = *range.start() as usize * step;
let m = *range.end() as usize * step;

// Let's explain the what this function does using the following
// pattern as an example:
//
// { 01 02 03 [n-m] 04 05 06 07 }
//
// The scheme below resumes what's happening. The atom found by
// Aho-Corasick is `04 05 06 07`, and this function is about to
// process the jump [n-m]. The input received is the data that ends
// where the atom starts. What we want to do is scanning the range
// `range_min..range_max` from right to left looking for all instances
// of `03`, which are positions where `01 02 03` could match while
// scanning backwards.
//
// |--------------- input ---------------|
// | |--------- M ----------|
// | | |---- N ----|
// | ... 01 02 03 | .................... | 04 05 06 07
// ^ ^
// range_min range_max
//
let range_min = input.len().saturating_sub(m + step);
let range_max = input.len().saturating_sub(n);

if jmp_range.start >= jmp_range.end {
if range_min >= range_max {
return;
}

if let Some(jmp_range) = input.get(jmp_range) {
if let Some(jmp_range) = input.get(range_min..range_max) {
let lit = *literal.last().unwrap();
for offset in memchr::memrchr_iter(lit, jmp_range) {
next_positions.insert(position + jmp_range.len() - offset - 1);
if wide {
// In wide mode we are only interested in bytes found
// at even offsets. At odd offsets the input should
// have only zeroes and they are not potential matches.
if offset % 2 == 0 {
next_positions.insert(
position + n + jmp_range.len() - offset - step,
);
}
} else {
next_positions.insert(
position + n + jmp_range.len() - offset - step,
);
}
}
}
}
Expand Down
Loading

0 comments on commit fc9c125

Please sign in to comment.