diff --git a/yara-x/src/compiler/context.rs b/yara-x/src/compiler/context.rs index e75d0a8c1..76b855355 100644 --- a/yara-x/src/compiler/context.rs +++ b/yara-x/src/compiler/context.rs @@ -46,9 +46,6 @@ pub(in crate::compiler) struct Context<'a, 'src, 'sym> { /// Information about the rules compiled so far. pub rules: &'a Vec, - /// Rule that is being compiled. - pub current_rule: &'a RuleInfo, - /// A vector that contains the IR for the patterns declared in the current /// rule, accompanied by their corresponding [`PatternId`]. pub current_rule_patterns: @@ -94,6 +91,13 @@ impl<'a, 'src, 'sym> Context<'a, 'src, 'sym> { self.rules.get(rule_id.0 as usize).unwrap() } + /// Returns the [`RuleInfo`] structure corresponding to the rule currently + /// being compiled. + #[inline] + pub fn get_current_rule(&self) -> &RuleInfo { + self.rules.last().unwrap() + } + /// Given a pattern identifier (e.g. `$a`, `#a`, `@a`) search for it in /// the current rule and return its [`PatternID`]. /// diff --git a/yara-x/src/compiler/ir/ast2ir.rs b/yara-x/src/compiler/ir/ast2ir.rs index e3733958b..30810300d 100644 --- a/yara-x/src/compiler/ir/ast2ir.rs +++ b/yara-x/src/compiler/ir/ast2ir.rs @@ -355,17 +355,18 @@ pub(in crate::compiler) fn expr_from_ast( // A global rule can depend on another global rule. And non-global // rules can depend both on global rules and non-global ones. if let SymbolKind::Rule(rule_id) = symbol.kind() { + let current_rule = ctx.get_current_rule(); let used_rule = ctx.get_rule(*rule_id); - if ctx.current_rule.is_global && !used_rule.is_global { + if current_rule.is_global && !used_rule.is_global { return Err(CompileError::from( CompileErrorInfo::wrong_rule_dependency( ctx.report_builder, ctx.ident_pool - .get(ctx.current_rule.ident_id) + .get(current_rule.ident_id) .unwrap() .to_string(), ident.name.to_string(), - ctx.current_rule.ident_span, + current_rule.ident_span, used_rule.ident_span, ident.span, ), diff --git a/yara-x/src/compiler/mod.rs b/yara-x/src/compiler/mod.rs index 0a0eae16a..a0e287f62 100644 --- a/yara-x/src/compiler/mod.rs +++ b/yara-x/src/compiler/mod.rs @@ -332,13 +332,13 @@ impl<'a> Compiler<'a> { // actually exist, and raise warnings in case of duplicated // imports within the same source file. For each module add a // symbol to the current namespace. - self.process_imports(&ast.imports)?; + self.c_imports(&ast.imports)?; // Iterate over the list of declared rules and verify that their // conditions are semantically valid. For each rule add a symbol // to the current namespace. for rule in &ast.rules { - self.process_rule(rule)?; + self.c_rule(rule)?; } // Transfer the warnings generated by the parser to the compiler @@ -518,6 +518,35 @@ impl<'a> Compiler<'a> { } impl<'a> Compiler<'a> { + fn add_sub_pattern( + &mut self, + sub_pattern: SubPattern, + atoms: I, + f: F, + ) -> SubPatternId + where + I: Iterator, + F: Fn(SubPatternId, A) -> SubPatternAtom, + { + let sub_pattern_id = SubPatternId(self.sub_patterns.len() as u32); + + // Sub-patterns that are anchored at some fixed offset are not added to + // the Aho-Corasick automata. Instead their IDs are added to the + // sub_patterns_anchored_at_0 list, together with the offset they are + // anchored to. + if let SubPattern::Literal { anchored_at: Some(_), .. } = sub_pattern { + self.anchored_sub_patterns.push(sub_pattern_id); + } else { + for atom in atoms { + self.atoms.push(f(sub_pattern_id, atom)) + } + } + + self.sub_patterns.push((self.current_pattern_id, sub_pattern)); + + sub_pattern_id + } + /// Check if another rule, module or variable has the given identifier and /// return an error in that case. fn check_for_existing_identifier( @@ -561,11 +590,51 @@ impl<'a> Compiler<'a> { self.lit_pool.get_or_intern(literal_bytes) } - fn process_rule(&mut self, rule: &ast::Rule) -> Result<(), CompileError> { + /// Takes a snapshot of the compiler's state at this moment. + /// + /// The returned [`Snapshot`] can be passed to [`Compiler::restore_snapshot`] + /// for restoring the compiler to the state it was when the snapshot was + /// taken. + /// + /// This is useful when the compilation of a rule fails, for restoring the + /// compiler to the state it had before starting compiling the failed rule, + /// which avoids leaving junk in the compiler's internal structures. + fn take_snapshot(&self) -> Snapshot { + Snapshot { + next_pattern_id: self.next_pattern_id, + rules_len: self.rules.len(), + atoms_len: self.atoms.len(), + re_code_len: self.re_code.len(), + sub_patterns_len: self.sub_patterns.len(), + } + } + + /// Restores the compiler's to a previous state. + /// + /// Use [`Compiler::take_snapshot`] for taking a snapshot of the compiler's + /// state. + fn restore_snapshot(&mut self, snapshot: Snapshot) { + self.next_pattern_id = snapshot.next_pattern_id; + self.rules.truncate(snapshot.rules_len); + self.sub_patterns.truncate(snapshot.sub_patterns_len); + self.re_code.truncate(snapshot.re_code_len); + self.atoms.truncate(snapshot.atoms_len); + } +} + +impl<'a> Compiler<'a> { + fn c_rule(&mut self, rule: &ast::Rule) -> Result<(), CompileError> { // Check if another rule, module or variable has the same identifier // and return an error in that case. self.check_for_existing_identifier(&rule.identifier)?; + // Take snapshot of the current compiler state. In case of error + // compiling the current rule this snapshot allows restoring the + // compiler to the state it had before starting compiling the rule. + // This way we don't leave too much junk, like atoms, or sub-patterns + // corresponding to failed rules. + let snapshot = self.take_snapshot(); + // Convert the patterns from AST to IR. let patterns_in_rule = patterns_from_ast(&self.report_builder, rule.patterns.as_ref())?; @@ -576,9 +645,7 @@ impl<'a> Compiler<'a> { // Create vector with pairs (PatternId, Pattern). let mut patterns_with_ids = Vec::with_capacity(patterns_in_rule.len()); - let mut pending_patterns = HashSet::new(); - let mut next_pattern_ids = self.next_pattern_id.successors(); for pattern in patterns_in_rule { // Check if this pattern has been declared before, in this rule or @@ -591,7 +658,8 @@ impl<'a> Compiler<'a> { Entry::Occupied(entry) => *entry.get(), // The pattern didn't exist. Entry::Vacant(entry) => { - let pattern_id = next_pattern_ids.next().unwrap(); + let pattern_id = self.next_pattern_id; + self.next_pattern_id.incr(1); pending_patterns.insert(pattern_id); entry.insert(pattern_id); pattern_id @@ -646,7 +714,6 @@ impl<'a> Compiler<'a> { regexp_pool: &mut self.regexp_pool, report_builder: &self.report_builder, rules: &self.rules, - current_rule: self.rules.last().unwrap(), current_rule_patterns: &mut patterns_with_ids, wasm_symbols: &self.wasm_symbols, wasm_exports: &self.wasm_exports, @@ -657,7 +724,14 @@ impl<'a> Compiler<'a> { vars: VarStack::new(), }; - let mut condition = expr_from_ast(&mut ctx, &rule.condition)?; + let mut condition = match expr_from_ast(&mut ctx, &rule.condition) { + Ok(expr) => expr, + Err(err) => { + drop(ctx); + self.restore_snapshot(snapshot); + return Err(err); + } + }; warn_if_not_bool(&mut ctx, condition.ty(), rule.condition.span()); @@ -687,19 +761,19 @@ impl<'a> Compiler<'a> { let anchored_at = pattern.anchored_at(); match pattern.into_pattern() { Pattern::Literal(pattern) => { - self.process_literal_pattern(pattern, anchored_at); + self.c_literal_pattern(pattern, anchored_at); } Pattern::Regexp(pattern) => { - self.process_regexp_pattern( - pattern, - anchored_at, - span, - )?; + if let Err(err) = + self.c_regexp_pattern(pattern, anchored_at, span) + { + self.restore_snapshot(snapshot); + return Err(err); + } } }; if pending { pending_patterns.remove(&pattern_id); - self.next_pattern_id.incr(1); } } } @@ -707,7 +781,7 @@ impl<'a> Compiler<'a> { Ok(()) } - fn process_literal_pattern( + fn c_literal_pattern( &mut self, pattern: LiteralPattern, anchored_at: Option, @@ -879,7 +953,7 @@ impl<'a> Compiler<'a> { } } - fn process_regexp_pattern( + fn c_regexp_pattern( &mut self, pattern: RegexpPattern, anchored_at: Option, @@ -895,7 +969,7 @@ impl<'a> Compiler<'a> { if !tail.is_empty() { // The pattern was split into multiple chained regexps. - return self.process_chain(&head, &tail, pattern.flags, span); + return self.c_chain(&head, &tail, pattern.flags, span); } if head.is_alternation_literal() { @@ -905,7 +979,7 @@ impl<'a> Compiler<'a> { // /foo|bar|baz/ // { 01 02 03 } // { (01 02 03 | 04 05 06 ) } - self.process_alternation_literal(head, anchored_at, pattern.flags); + self.c_alternation_literal(head, anchored_at, pattern.flags); return Ok(()); } @@ -927,7 +1001,7 @@ impl<'a> Compiler<'a> { flags.set(SubPatternFlags::GreedyRegexp); } - let (atoms, is_fast_regexp) = self.compile_regexp(&head, span)?; + let (atoms, is_fast_regexp) = self.c_regexp(&head, span)?; if is_fast_regexp { flags.set(SubPatternFlags::FastRegexp); @@ -952,7 +1026,7 @@ impl<'a> Compiler<'a> { Ok(()) } - fn process_alternation_literal( + fn c_alternation_literal( &mut self, hir: re::hir::Hir, anchored_at: Option, @@ -1041,7 +1115,7 @@ impl<'a> Compiler<'a> { } } - fn process_chain( + fn c_chain( &mut self, leading: &re::hir::Hir, trailing: &[ChainedPattern], @@ -1075,11 +1149,11 @@ impl<'a> Compiler<'a> { if ascii { prev_sub_pattern_ascii = - self.process_literal_chain_head(literal, flags); + self.c_literal_chain_head(literal, flags); } if wide { - prev_sub_pattern_wide = self.process_literal_chain_head( + prev_sub_pattern_wide = self.c_literal_chain_head( literal, flags | SubPatternFlags::Wide, ); @@ -1087,8 +1161,7 @@ impl<'a> Compiler<'a> { } else { let mut flags = common_flags; - let (atoms, is_fast_regexp) = - self.compile_regexp(leading, span)?; + let (atoms, is_fast_regexp) = self.c_regexp(leading, span)?; if is_fast_regexp { flags.set(SubPatternFlags::FastRegexp); @@ -1133,7 +1206,7 @@ impl<'a> Compiler<'a> { if let hir::HirKind::Literal(literal) = p.hir.kind() { if wide { - prev_sub_pattern_wide = self.process_literal_chain_tail( + prev_sub_pattern_wide = self.c_literal_chain_tail( literal, prev_sub_pattern_wide, p.gap.clone(), @@ -1141,7 +1214,7 @@ impl<'a> Compiler<'a> { ); }; if ascii { - prev_sub_pattern_ascii = self.process_literal_chain_tail( + prev_sub_pattern_ascii = self.c_literal_chain_tail( literal, prev_sub_pattern_ascii, p.gap.clone(), @@ -1153,8 +1226,7 @@ impl<'a> Compiler<'a> { flags.set(SubPatternFlags::GreedyRegexp); } - let (atoms, is_fast_regexp) = - self.compile_regexp(&p.hir, span)?; + let (atoms, is_fast_regexp) = self.c_regexp(&p.hir, span)?; if is_fast_regexp { flags.set(SubPatternFlags::FastRegexp); @@ -1189,7 +1261,7 @@ impl<'a> Compiler<'a> { Ok(()) } - fn compile_regexp( + fn c_regexp( &mut self, hir: &re::hir::Hir, span: Span, @@ -1242,7 +1314,7 @@ impl<'a> Compiler<'a> { Ok((atoms, is_fast_regexp)) } - fn process_literal_chain_head( + fn c_literal_chain_head( &mut self, literal: &hir::Literal, flags: SubPatternFlagSet, @@ -1261,7 +1333,7 @@ impl<'a> Compiler<'a> { ) } - fn process_literal_chain_tail( + fn c_literal_chain_tail( &mut self, literal: &hir::Literal, chained_to: SubPatternId, @@ -1287,7 +1359,7 @@ impl<'a> Compiler<'a> { ) } - fn process_imports( + fn c_imports( &mut self, imports: &[ast::Import], ) -> Result<(), CompileError> { @@ -1407,35 +1479,6 @@ impl<'a> Compiler<'a> { Ok(()) } - - fn add_sub_pattern( - &mut self, - sub_pattern: SubPattern, - atoms: I, - f: F, - ) -> SubPatternId - where - I: Iterator, - F: Fn(SubPatternId, A) -> SubPatternAtom, - { - let sub_pattern_id = SubPatternId(self.sub_patterns.len() as u32); - - // Sub-patterns that are anchored at some fixed offset are not added to - // the Aho-Corasick automata. Instead their IDs are added to the - // sub_patterns_anchored_at_0 list, together with the offset they are - // anchored to. - if let SubPattern::Literal { anchored_at: Some(_), .. } = sub_pattern { - self.anchored_sub_patterns.push(sub_pattern_id); - } else { - for atom in atoms { - self.atoms.push(f(sub_pattern_id, atom)) - } - } - - self.sub_patterns.push((self.current_pattern_id, sub_pattern)); - - sub_pattern_id - } } impl fmt::Debug for Compiler<'_> { @@ -1597,10 +1640,6 @@ impl PatternId { fn incr(&mut self, amount: usize) { self.0 += amount as i32; } - - fn successors(&self) -> impl Iterator { - iter::successors(Some(self.0), |n| Some(n + 1)).map(PatternId) - } } impl From for PatternId { @@ -1766,3 +1805,13 @@ impl SubPattern { } } } + +/// A snapshot that represents the state of the compiler at a particular moment. +#[derive(Debug, PartialEq, Eq)] +struct Snapshot { + next_pattern_id: PatternId, + rules_len: usize, + atoms_len: usize, + re_code_len: usize, + sub_patterns_len: usize, +} diff --git a/yara-x/src/compiler/tests/mod.rs b/yara-x/src/compiler/tests/mod.rs index 47eafdd10..d79571bfa 100644 --- a/yara-x/src/compiler/tests/mod.rs +++ b/yara-x/src/compiler/tests/mod.rs @@ -100,6 +100,22 @@ fn var_stack() { assert_eq!(stack.used, 0); } +#[test] +fn snapshots() { + let mut compiler = Compiler::new(); + + compiler + .add_source(r#"rule test { strings: $a = "foo" condition: $a }"#) + .unwrap(); + let snapshot = compiler.take_snapshot(); + + compiler + .add_source(r#"rule test { strings: $a = /{}/ condition: $a }"#) + .expect_err("compilation should fail"); + + assert_eq!(compiler.take_snapshot(), snapshot); +} + #[test] fn globals() { let mut compiler = Compiler::new();