Skip to content

Commit

Permalink
perf: reimplement FastVM with a faster algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
plusvic committed Aug 31, 2023
1 parent 98e164c commit 812ec2d
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 127 deletions.
2 changes: 1 addition & 1 deletion yara-x/src/compiler/atoms/quality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ mod test {
assert!(q_01xx03 < q_010203);
assert!(q_010x0x > q_01);
assert!(q_010x0x < q_010203);
assert!(q_01020000 < q_0102xx04);
assert_eq!(q_01020000, q_0102xx04);
assert!(q_01020102 > q_01020000);
assert!(q_01020102 > q_01010101);
assert!(q_01020304 > q_01020102);
Expand Down
236 changes: 110 additions & 126 deletions yara-x/src/re/fast/fastvm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cmp;
use std::collections::BTreeSet;
use std::ops::RangeInclusive;
use std::{cmp, mem};

use itertools::izip;
use memx::memeq;
Expand All @@ -16,17 +17,21 @@ use crate::re::{Action, CodeLoc, DEFAULT_SCAN_LIMIT};
pub(crate) struct FastVM<'r> {
/// The code for the VM. Produced by [`crate::re::fast::Compiler`].
code: &'r [u8],
/// The list of currently active positions.
threads: Vec<(usize, usize)>,
/// Maximum number of bytes to scan. The VM will abort after ingesting
/// this number of bytes from the input.
scan_limit: usize,
/// A set with all the positions currently tracked.
positions: BTreeSet<usize>,
}

impl<'r> FastVM<'r> {
/// Creates a new [`FastVM`].
pub fn new(code: &'r [u8]) -> Self {
Self { code, threads: Vec::new(), scan_limit: DEFAULT_SCAN_LIMIT }
Self {
code,
positions: BTreeSet::new(),
scan_limit: DEFAULT_SCAN_LIMIT,
}
}

/// Specifies the maximum number of bytes that will be scanned by the
Expand Down Expand Up @@ -55,26 +60,35 @@ impl<'r> FastVM<'r> {
C: CodeLoc,
{
let backwards = start.backwards();
self.threads.push((start.location(), 0));
let mut ip = start.location();

let input = if backwards {
&input[input.len().saturating_sub(self.scan_limit)..]
} else {
&input[..cmp::min(input.len(), self.scan_limit)]
};

let input = &input[..cmp::min(input.len(), self.scan_limit)];
self.positions.insert(0);

while let Some((mut ip, mut position)) = self.threads.pop() {
while position <= input.len() {
let (instr, instr_size) =
InstrParser::decode_instr(&self.code[ip..]);
while !self.positions.is_empty() {
let (instr, instr_size) =
InstrParser::decode_instr(&self.code[ip..]);

ip += instr_size;
ip += instr_size;

match instr {
Instr::Match => match f(position) {
Action::Stop => {
self.threads.clear();
return;
match instr {
Instr::Match => {
for position in mem::take(&mut self.positions) {
match f(position) {
Action::Stop => {
return;
}
Action::Continue => {}
}
Action::Continue => break,
},
Instr::Literal(literal) => {
}
}
Instr::Literal(literal) => {
for position in mem::take(&mut self.positions) {
let is_match = if backwards {
self.try_match_literal_bck(
&input[..input.len() - position],
Expand All @@ -86,12 +100,13 @@ impl<'r> FastVM<'r> {
literal,
)
};
if !is_match {
break;
if is_match {
self.positions.insert(position + literal.len());
}
position += literal.len();
}
Instr::MaskedLiteral(literal, mask) => {
}
Instr::MaskedLiteral(literal, mask) => {
for position in mem::take(&mut self.positions) {
let is_match = if backwards {
self.try_match_masked_literal_bck(
&input[..input.len() - position],
Expand All @@ -105,80 +120,68 @@ impl<'r> FastVM<'r> {
mask,
)
};
if !is_match {
break;
if is_match {
self.positions.insert(position + literal.len());
}
position += literal.len();
}
Instr::Jump(jump) => {
position += jump as usize;
}
Instr::Jump(jump) => {
for position in mem::take(&mut self.positions) {
self.positions.insert(position + jump as usize);
}
Instr::JumpRange(range) => {
match InstrParser::decode_instr(&self.code[ip..]) {
(Instr::Literal(literal), _) if backwards => {
if let Some(new_position) = self.jump_bck(
}
Instr::JumpRange(range) => {
match InstrParser::decode_instr(&self.code[ip..]) {
(Instr::Literal(literal), _) if backwards => {
for position in mem::take(&mut self.positions) {
self.jump_bck(
&input[..input.len() - position],
*literal.last().unwrap(),
range,
ip,
literal,
&range,
position,
) {
position = new_position;
} else {
break;
}
);
}
(Instr::Literal(literal), _) if !backwards => {
if let Some(new_position) = self.jump_fwd(
}
(Instr::Literal(literal), _) if !backwards => {
for position in mem::take(&mut self.positions) {
self.jump_fwd(
&input[position..],
*literal.first().unwrap(),
range,
ip,
literal,
&range,
position,
) {
position = new_position;
} else {
break;
}
)
}
(Instr::MaskedLiteral(literal, mask), _)
if backwards && mask.last() == Some(&0xff) =>
{
if let Some(new_position) = self.jump_bck(
}
(Instr::MaskedLiteral(literal, mask), _)
if backwards && mask.last() == Some(&0xff) =>
{
for position in mem::take(&mut self.positions) {
self.jump_bck(
&input[..input.len() - position],
*literal.last().unwrap(),
range,
ip,
literal,
&range,
position,
) {
position = new_position;
} else {
break;
}
);
}
(Instr::MaskedLiteral(literal, mask), _)
if !backwards
&& mask.first() == Some(&0xff) =>
{
if let Some(new_position) = self.jump_fwd(
}
(Instr::MaskedLiteral(literal, mask), _)
if !backwards && mask.first() == Some(&0xff) =>
{
for position in mem::take(&mut self.positions) {
self.jump_fwd(
&input[position..],
*literal.first().unwrap(),
range,
ip,
literal,
&range,
position,
) {
position = new_position;
} else {
break;
}
);
}
_ => {
let min = *range.start() as usize;
for i in range.skip(1).rev() {
self.threads
.push((ip, position + i as usize))
}
_ => {
for position in mem::take(&mut self.positions) {
for i in range.clone() {
self.positions
.insert(position + i as usize);
}
position += min;
}
}
}
Expand All @@ -190,7 +193,7 @@ impl<'r> FastVM<'r> {

impl FastVM<'_> {
#[inline]
fn try_match_literal_fwd(&mut self, input: &[u8], literal: &[u8]) -> bool {
fn try_match_literal_fwd(&self, input: &[u8], literal: &[u8]) -> bool {
if input.len() < literal.len() {
return false;
}
Expand All @@ -207,7 +210,7 @@ impl FastVM<'_> {

#[inline]
fn try_match_masked_literal_fwd(
&mut self,
&self,
input: &[u8],
literal: &[u8],
mask: &[u8],
Expand All @@ -229,7 +232,7 @@ impl FastVM<'_> {

#[inline]
fn try_match_masked_literal_bck(
&mut self,
&self,
input: &[u8],
literal: &[u8],
mask: &[u8],
Expand All @@ -255,65 +258,46 @@ impl FastVM<'_> {
fn jump_fwd(
&mut self,
input: &[u8],
lit: u8,
range: RangeInclusive<u16>,
ip: usize,
mut position: usize,
) -> Option<usize> {
literal: &[u8],
range: &RangeInclusive<u16>,
position: usize,
) {
let jmp_min = *range.start() as usize;
let jmp_max = std::cmp::min(input.len(), *range.end() as usize + 1);
let jmp_max = cmp::min(input.len(), *range.end() as usize + 1);
let jmp_range = jmp_min..jmp_max;

if jmp_range.start >= jmp_range.end {
return None;
return;
}

let mut offsets = memchr::memrchr_iter(lit, input.get(jmp_range)?);

let last = offsets.next_back();
for offset in offsets {
self.threads.push((ip, position + jmp_min + offset))
}

if let Some(offset) = last {
position += jmp_min + offset;
} else {
return None;
if let Some(jmp_range) = input.get(jmp_range) {
let lit = *literal.first().unwrap();
for offset in memchr::memchr_iter(lit, jmp_range) {
self.positions.insert(position + jmp_min + offset);
}
}

Some(position)
}

#[inline]
fn jump_bck(
&mut self,
input: &[u8],
lit: u8,
range: RangeInclusive<u16>,
ip: usize,
mut position: usize,
) -> Option<usize> {
let jump_range = input.len().saturating_sub(*range.end() as usize + 1)
literal: &[u8],
range: &RangeInclusive<u16>,
position: usize,
) {
let jmp_range = input.len().saturating_sub(*range.end() as usize + 1)
..input.len().saturating_sub(*range.start() as usize);

if jump_range.start >= jump_range.end {
return None;
}

let jump_range = input.get(jump_range)?;
let mut offsets = memchr::memchr_iter(lit, jump_range);

let last = offsets.next_back();
for offset in offsets {
self.threads.push((ip, position + jump_range.len() - offset - 1))
if jmp_range.start >= jmp_range.end {
return;
}

if let Some(offset) = last {
position += jump_range.len() - offset - 1;
} else {
return None;
if let Some(jmp_range) = input.get(jmp_range) {
let lit = *literal.last().unwrap();
for offset in memchr::memrchr_iter(lit, jmp_range) {
self.positions.insert(position + jmp_range.len() - offset - 1);
}
}

Some(position)
}
}

0 comments on commit 812ec2d

Please sign in to comment.