From 63553acf944c0f478a8f9559bd6cb88db39bf954 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Fri, 30 Aug 2024 18:04:15 +0800 Subject: [PATCH] Implement hyper-transition analysis --- src/compiler/abepi.rs | 9 +- src/compiler/compiler.rs | 17 +- src/compiler/semantic/analyser.rs | 85 ++- src/compiler/semantic/mod.rs | 23 + src/compiler/semantic/rules.rs | 852 ++++++++++++++++++------------ src/compiler/setup_inter.rs | 6 +- src/interpreter/mod.rs | 2 +- src/parser/ast/statement.rs | 22 +- src/parser/build.rs | 25 +- src/parser/chiquito.lalrpop | 2 +- 10 files changed, 652 insertions(+), 391 deletions(-) diff --git a/src/compiler/abepi.rs b/src/compiler/abepi.rs index 03b488ff..559d3e57 100644 --- a/src/compiler/abepi.rs +++ b/src/compiler/abepi.rs @@ -65,8 +65,8 @@ impl + TryInto + Clone + Debug, V: Clone + Debug> CompilationU Statement::Transition(dsym, id, stmt) => { self.compiler_statement_transition(dsym, id, *stmt) } - Statement::HyperTransition(dsym, ids, call, state) => { - self.compiler_statement_hyper_transition(dsym, ids, call, state) + Statement::HyperTransition(dsym, assign_call, state_transition) => { + self.compiler_statement_hyper_transition(dsym, *assign_call, *state_transition) } _ => vec![], } @@ -427,9 +427,8 @@ impl + TryInto + Clone + Debug, V: Clone + Debug> CompilationU fn compiler_statement_hyper_transition( &self, _dsym: DebugSymRef, - _ids: Vec, - _call: Expression, - _state: V, + _assign_call: Statement, + _state_transition: Statement, ) -> Vec> { todo!("Compile expressions? Needs specs") } diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index e717a993..a950412e 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -570,10 +570,12 @@ mod test { // Source code containing two machines let circuit = " machine caller (signal n) (signal b: field) { - signal b_1: field; - b_1' <== fibo(n) -> final; - } - machine fibo (signal n) (signal b: field) { + state initial { + signal b_1: field; + b_1' <== fibo(n) -> final; + } + } + machine fibo (signal n) (signal b: field) { // n and be are created automatically as shared // signals signal a: field, i; @@ -809,7 +811,8 @@ mod test { fn test_parse_hyper_transition() { let circuit = " machine caller (signal n) (signal b: field) { - a', b, c' <== fibo(d, e, f + g) -> final; + var some_var; + a', b', some_var' <== fibo(d, e, f + g) -> final; } "; @@ -821,7 +824,7 @@ mod test { let circuit = " machine caller (signal n) (signal b: field) { -> final { - a', b, c' <== fibo(d, e, f + g); + a', b', c' <== fibo(d, e, f + g); } } "; @@ -834,7 +837,7 @@ mod test { // TODO should no-arg calls be allowed? Needs more specs for function/machine calls let circuit = " machine caller (signal n) (signal b: field) { - smth <== a() -> final; + smth' <== a() -> final; } "; diff --git a/src/compiler/semantic/analyser.rs b/src/compiler/semantic/analyser.rs index e5291ea4..593bf9b0 100644 --- a/src/compiler/semantic/analyser.rs +++ b/src/compiler/semantic/analyser.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use num_bigint::BigInt; use crate::{ @@ -48,46 +49,86 @@ impl Default for Analyser { impl Analyser { /// Analyse chiquito AST. fn analyse(&mut self, program: &[TLDecl]) { + program + .iter() + .for_each(|decl: &TLDecl| self.collect_tl_decl(decl)); program.iter().for_each(|decl| self.analyse_tldcl(decl)); } - /// Analyse top level declaration. - fn analyse_tldcl(&mut self, decl: &TLDecl) { + /// Collect a top level declaration to later perform necessary checks (e.g., validate the + /// calls of machines from other machines). + fn collect_tl_decl(&mut self, decl: &TLDecl) { match decl.clone() { TLDecl::MachineDecl { dsym: _, id, input_params, output_params, - block, + block: _, } => { - let sym = SymTableEntry::new( + let sym = SymTableEntry::new_machine( id.name(), id.debug_sym_ref(), SymbolCategory::Machine, None, + input_params + .iter() + .map(|param| match param { + Statement::SignalDecl(_, decls) | Statement::WGVarDecl(_, decls) => { + if decls.len() != 1 { + unreachable!("Each input should be a single identifier"); + } + decls[0].id.name() + } + _ => unreachable!("Inputs should be signals or vars"), + }) + .collect_vec(), + output_params + .iter() + .map(|param| match param { + Statement::SignalDecl(_, decls) | Statement::WGVarDecl(_, decls) => { + if decls.len() != 1 { + unreachable!("Each output should be a single identifier"); + } + decls[0].id.name() + } + _ => unreachable!("Outputs should be signals or vars"), + }) + .collect_vec(), ); RULES.apply_new_symbol_tldecl(self, decl, &id, &sym); self.symbols.add_symbol(&self.cur_scope, id.name(), sym); - self.analyse_machine(id, input_params, output_params, block); + self.enter_new_scope(id.name()); + + self.analyse_machine_input_params(input_params); + + self.analyse_machine_output_params(output_params); + + self.exit_scope(); } } } - fn analyse_machine( - &mut self, - id: Identifier, - input_params: Vec>, - output_params: Vec>, - block: Statement, - ) { - self.enter_new_scope(id.name()); + /// Analyse top level declaration. + fn analyse_tldcl(&mut self, decl: &TLDecl) { + match decl.clone() { + TLDecl::MachineDecl { + dsym: _, + id, + input_params: _, + output_params: _, + block, + } => { + self.analyse_machine(id, block); + } + } + } - self.analyse_machine_input_params(input_params); - self.analyse_machine_output_params(output_params); + fn analyse_machine(&mut self, id: Identifier, block: Statement) { + self.enter_new_scope(id.name()); self.add_state_decls(&block); @@ -193,9 +234,6 @@ impl Analyser { fn analyse_statement(&mut self, stmt: Statement) { self.statement_add_symbols(&stmt); - - RULES.apply_statement(self, &stmt); - self.analyse_statement_recursive(stmt); } @@ -233,6 +271,8 @@ impl Analyser { } fn analyse_statement_recursive(&mut self, stmt: Statement) { + RULES.apply_statement(self, &stmt); + match stmt { Statement::Assert(_, expr) => self.analyse_expression(expr), Statement::SignalAssignment(_, ids, exprs) => { @@ -272,10 +312,9 @@ impl Analyser { Statement::SignalDecl(_, _) => {} Statement::WGVarDecl(_, _) => {} - Statement::HyperTransition(_, ids, call, state) => { - self.analyse_expression(call); - self.collect_id_usages(&[state]); - self.collect_id_usages(&ids); + Statement::HyperTransition(_, ref assign_call, ref state_transition) => { + self.analyse_statement_recursive(*assign_call.clone()); + self.analyse_statement_recursive(*state_transition.clone()); } } } @@ -398,7 +437,7 @@ mod test { assert_eq!( format!("{:?}", result), - r#"AnalysisResult { symbols: /: ScopeTable { symbols: "fibo: SymTableEntry { id: \"fibo\", definition_ref: nofile:2:17, usages: [], category: Machine, ty: None }", scope: Global },//fibo: ScopeTable { symbols: "a: SymTableEntry { id: \"a\", definition_ref: nofile:5:20, usages: [nofile:13:17, nofile:16:15, nofile:23:20, nofile:31:20], category: Signal, ty: Some(\"field\") },b: SymTableEntry { id: \"b\", definition_ref: nofile:2:40, usages: [nofile:13:20, nofile:16:30, nofile:16:19, nofile:23:24, nofile:27:20, nofile:31:42, nofile:31:24], category: OutputSignal, ty: Some(\"field\") },i: SymTableEntry { id: \"i\", definition_ref: nofile:5:30, usages: [nofile:13:14, nofile:25:17, nofile:27:31, nofile:27:16, nofile:31:35, nofile:31:16], category: Signal, ty: None },initial: SymTableEntry { id: \"initial\", definition_ref: nofile:10:19, usages: [], category: State, ty: None },middle: SymTableEntry { id: \"middle\", definition_ref: nofile:20:19, usages: [nofile:15:17, nofile:30:18], category: State, ty: None },n: SymTableEntry { id: \"n\", definition_ref: nofile:2:29, usages: [nofile:16:36, nofile:16:23, nofile:25:26, nofile:27:41, nofile:27:24, nofile:31:48, nofile:31:28], category: InputSignal, ty: None }", scope: Machine },//fibo/initial: ScopeTable { symbols: "c: SymTableEntry { id: \"c\", definition_ref: nofile:11:21, usages: [nofile:13:23, nofile:16:33], category: Signal, ty: None }", scope: State },//fibo/middle: ScopeTable { symbols: "c: SymTableEntry { id: \"c\", definition_ref: nofile:21:21, usages: [nofile:23:14, nofile:27:38, nofile:31:45], category: Signal, ty: None }", scope: State }, messages: [] }"# + r#"AnalysisResult { symbols: /: ScopeTable { symbols: "fibo: SymTableEntry { id: \"fibo\", definition_ref: nofile:2:17, usages: [], category: Machine, ty: None, ins: Some([\"n\"]), outs: Some([\"b\"]) }", scope: Global },//fibo: ScopeTable { symbols: "a: SymTableEntry { id: \"a\", definition_ref: nofile:5:20, usages: [nofile:13:17, nofile:16:15, nofile:23:20, nofile:31:20], category: Signal, ty: Some(\"field\"), ins: None, outs: None },b: SymTableEntry { id: \"b\", definition_ref: nofile:2:40, usages: [nofile:13:20, nofile:16:30, nofile:16:19, nofile:23:24, nofile:27:20, nofile:31:42, nofile:31:24], category: OutputSignal, ty: Some(\"field\"), ins: None, outs: None },i: SymTableEntry { id: \"i\", definition_ref: nofile:5:30, usages: [nofile:13:14, nofile:25:17, nofile:27:31, nofile:27:16, nofile:31:35, nofile:31:16], category: Signal, ty: None, ins: None, outs: None },initial: SymTableEntry { id: \"initial\", definition_ref: nofile:10:19, usages: [], category: State, ty: None, ins: None, outs: None },middle: SymTableEntry { id: \"middle\", definition_ref: nofile:20:19, usages: [nofile:15:17, nofile:30:18], category: State, ty: None, ins: None, outs: None },n: SymTableEntry { id: \"n\", definition_ref: nofile:2:29, usages: [nofile:16:36, nofile:16:23, nofile:25:26, nofile:27:41, nofile:27:24, nofile:31:48, nofile:31:28], category: InputSignal, ty: None, ins: None, outs: None }", scope: Machine },//fibo/initial: ScopeTable { symbols: "c: SymTableEntry { id: \"c\", definition_ref: nofile:11:21, usages: [nofile:13:23, nofile:16:33], category: Signal, ty: None, ins: None, outs: None }", scope: State },//fibo/middle: ScopeTable { symbols: "c: SymTableEntry { id: \"c\", definition_ref: nofile:21:21, usages: [nofile:23:14, nofile:27:38, nofile:31:45], category: Signal, ty: None, ins: None, outs: None }", scope: State }, messages: [] }"# ); } } diff --git a/src/compiler/semantic/mod.rs b/src/compiler/semantic/mod.rs index 47d08aaa..403e930a 100644 --- a/src/compiler/semantic/mod.rs +++ b/src/compiler/semantic/mod.rs @@ -55,6 +55,8 @@ pub struct SymTableEntry { pub category: SymbolCategory, /// Type pub ty: Option, + pub ins: Option>, + pub outs: Option>, } impl SymTableEntry { @@ -70,6 +72,27 @@ impl SymTableEntry { usages: Vec::new(), category, ty, + ins: None, + outs: None, + } + } + + pub fn new_machine( + id: String, + definition_ref: DebugSymRef, + category: SymbolCategory, + ty: Option, + input_params: Vec, + output_params: Vec, + ) -> Self { + SymTableEntry { + id, + definition_ref, + usages: Vec::new(), + category, + ty, + ins: Some(input_params), + outs: Some(output_params), } } diff --git a/src/compiler/semantic/rules.rs b/src/compiler/semantic/rules.rs index 818162c1..5af356bf 100644 --- a/src/compiler/semantic/rules.rs +++ b/src/compiler/semantic/rules.rs @@ -107,15 +107,67 @@ fn state_decl(analyser: &mut Analyser, expr: &Statement) { }); } +// Cannot transition to a non-existing state. +fn state_rule(analyser: &mut Analyser, expr: &Statement) { + if let Statement::Transition(_, state, _) = expr { + // State "final" is implicit and may not always be present in the code. + if state.name() == "final" { + return; + } + let found_symbol = &analyser + .symbols + .find_symbol(&analyser.cur_scope, state.name()); + + if found_symbol.is_none() + || found_symbol.as_ref().unwrap().symbol.category != SymbolCategory::State + { + analyser.error( + format!("Cannot transition to non-existing state `{}`", state.name()), + &expr.get_dsym(), + ); + } + } +} + // Should only allow to assign `<--` or assign and assert `<==` signals (and not wg vars). // Left hand side should only have signals. fn assignment_rule(analyser: &mut Analyser, expr: &Statement) { - let ids = match expr { - Statement::SignalAssignment(_, id, _) => id, - Statement::SignalAssignmentAssert(_, id, _) => id, + let (ids, rhs) = match expr { + Statement::SignalAssignment(_, id, rhs) => (id, rhs), + Statement::SignalAssignmentAssert(_, id, rhs) => (id, rhs), _ => return, }; + if let Expression::Call(_, machine, _) = &rhs[0] { + let machine_scope = vec!["/".to_string(), machine.name()]; + let found_machine = &analyser.symbols.find_symbol(&machine_scope, machine.name()); + + if found_machine.is_some() + && found_machine.as_ref().unwrap().symbol.category == SymbolCategory::Machine + { + let outs = &found_machine.as_ref().unwrap().symbol.outs; + if outs.is_some() { + let outs = &outs.clone().unwrap(); + if outs.len() != ids.len() { + analyser.error( + format!( + "Machine `{}` has {} output(s), but left hand side has {} identifier(s)", + machine.name(), + outs.len(), + ids.len() + ), + &machine.debug_sym_ref(), + ) + } + } + } + } else if ids.len() != rhs.len() { + analyser.error( + "Number of identifiers and expressions in assignment should be equal".to_string(), + &expr.get_dsym(), + ) + } + ids.iter().for_each(|id| { if let Some(symbol) = analyser.symbols.find_symbol(&analyser.cur_scope, id.name()) { let is_signal = matches!( @@ -412,11 +464,198 @@ fn wg_assignment_rule(analyser: &mut Analyser, expr: &Statement) { + if let Statement::HyperTransition(_, assign_call, _) = expr { + match *assign_call.to_owned() { + Statement::SignalAssignmentAssert(_, ids, _) => { + ids.iter().for_each(|id| { + if id.1 != 1 { + analyser.error( + format!( + "All assigned identifiers in the hyper-transition must have a forward rotation ('), but `{}` is missing it.", + id.name(), + ), + &id.debug_sym_ref(), + ) + } + }); + } + _ => analyser.error( + "Hyper transition must include an assignment with assertion (<==).".to_string(), + &expr.get_dsym(), + ), + } + } + if let Statement::Block(_, stmts) = expr { + // There can only be a hyper-transition after a hyper-transition in a block + stmts.iter().enumerate().for_each(|(idx, stmt)| { + if idx < stmts.len() - 1 + && let Statement::HyperTransition(_, _, _) = stmt + { + let next_stmt = &stmts[idx + 1]; + analyser.error( + "Hyper-transition should be the last statement in a block".to_string(), + &next_stmt.get_dsym(), + ) + } + }); + } +} + +fn call_rules(analyser: &mut Analyser, expr: &Expression) { + if let Expression::Call(_, machine, exprs) = expr { + // Argument expressions in a call statement should not have nonzero rotation. + exprs + .iter() + .for_each(|expr| detect_nonzero_rotation(expr, analyser)); + + let machine_scope = vec!["/".to_string(), machine.name()]; + let found_machine = &analyser.symbols.find_symbol(&machine_scope, machine.name()); + if found_machine.is_none() + || found_machine.as_ref().unwrap().symbol.category != SymbolCategory::Machine + { + analyser.error( + format!( + "Call statement must call a valid machine, but `{}` is not a machine.", + machine.name() + ), + &machine.debug_sym_ref(), + ) + } else if found_machine.as_ref().unwrap().symbol.category == SymbolCategory::Machine { + let ins = &found_machine.as_ref().unwrap().symbol.ins; + if ins.is_some() { + let ins = &ins.clone().unwrap(); + if ins.len() != exprs.len() { + analyser.error( + format!( + "Expected {} argument(s) for `{}`, but got {}.", + ins.len(), + machine.name(), + exprs.len() + ), + &machine.debug_sym_ref(), + ) + } + for (input, arg) in ins.iter().zip(exprs.iter()) { + let input = analyser + .symbols + .find_symbol(&machine_scope, input.to_string()); + if input.is_none() { + unreachable!("Machine input should be added to the symbol table") + } else { + let input = input.unwrap(); + let arg_is_signal = is_signal_recursive(analyser, arg); + if input.symbol.is_signal() != arg_is_signal { + analyser.error( + format!( + "Cannot assign {} `{:?}` to input {} `{}`", + if arg_is_signal { "signal" } else { "variable" }, + arg, + if input.symbol.is_signal() { + "signal" + } else { + "variable" + }, + input.symbol.id, + ), + expr.get_dsym(), + ) + } + } + } + } + } + + let mut current_machine_scope = analyser.cur_scope.clone(); + current_machine_scope.truncate(2); + if machine_scope == *current_machine_scope { + analyser.error( + "A machine should not call itself.".to_string(), + &machine.debug_sym_ref(), + ) + } + } +} + +/// Check if the expression result is a signal. For each of the Queries in the expression, check if +/// the identifier is a signal. +fn is_signal_recursive(analyser: &Analyser, expr: &Expression) -> bool { + let mut is_signal = true; + match expr { + Expression::Query(_, id) => { + if let Some(symbol) = analyser.symbols.find_symbol(&analyser.cur_scope, id.name()) { + is_signal = is_signal && symbol.symbol.is_signal(); + } + } + Expression::BinOp { lhs, rhs, .. } => { + is_signal = is_signal + && is_signal_recursive(analyser, lhs) + && is_signal_recursive(analyser, rhs); + } + Expression::UnaryOp { sub, .. } => { + is_signal = is_signal && is_signal_recursive(analyser, sub); + } + Expression::Select { + cond, + when_true, + when_false, + .. + } => { + is_signal = is_signal + && is_signal_recursive(analyser, cond) + && is_signal_recursive(analyser, when_true) + && is_signal_recursive(analyser, when_false); + } + _ => (), + } + is_signal +} + +fn detect_nonzero_rotation(expr: &Expression, analyser: &mut Analyser) { + match expr { + Expression::Query(_, id) => { + if id.1 != 0 { + analyser.error( + "Non-zero rotation is not allowed in a call.".to_string(), + &id.debug_sym_ref(), + ) + } + } + Expression::BinOp { + dsym: _, + op: _, + lhs, + rhs, + } => { + detect_nonzero_rotation(lhs, analyser); + detect_nonzero_rotation(rhs, analyser); + } + Expression::UnaryOp { + dsym: _, + op: _, + sub, + } => { + detect_nonzero_rotation(sub, analyser); + } + Expression::Select { + dsym: _, + cond, + when_true, + when_false, + } => { + detect_nonzero_rotation(cond, analyser); + detect_nonzero_rotation(when_true, analyser); + detect_nonzero_rotation(when_false, analyser); + } + _ => (), + } +} + lazy_static! { /// Global semantic analyser rules. pub(super) static ref RULES: RuleSet = RuleSet { - expression: vec![undeclared_rule, true_false_rule], - statement: vec![state_decl, assignment_rule, assert_rule, if_condition_rule, wg_assignment_rule], + expression: vec![undeclared_rule, true_false_rule, call_rules], + statement: vec![state_decl, assignment_rule, assert_rule, if_condition_rule, wg_assignment_rule, state_rule, hyper_transition_rule], new_symbol: vec![rotation_decl, redeclare_rule, types_rule], new_tl_symbol: vec![rotation_decl_tl, machine_decl_tl, types_rule_tl], }; @@ -431,15 +670,11 @@ mod test { #[test] fn test_analyser_undeclared() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal i; // a is undeclared - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -465,38 +700,19 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "use of undeclared variable a", dsym: nofile:23:20 }]"# + ", + r#"[SemErr { msg: "use of undeclared variable a", dsym: nofile:18:20 }]"#, ) } #[test] fn test_analyser_rotation_decl() { - let circuit = " + do_test( + " machine fibo'(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) + state initial { signal c; @@ -522,35 +738,16 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "There cannot be rotation in identifier declaration of fibo", dsym: nofile:2:9 }]"# + ", + r#"[SemErr { msg: "There cannot be rotation in identifier declaration of fibo", dsym: nofile:2:9 }]"#, ); - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) + state initial' { signal c; @@ -576,34 +773,16 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "There cannot be rotation in identifier declaration of initial", dsym: nofile:10:12 }]"# + ", + r#"[SemErr { msg: "There cannot be rotation in identifier declaration of initial", dsym: nofile:5:13 }]"#, ); - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) + state initial { signal c'; @@ -629,37 +808,19 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "There cannot be rotation in identifier declaration of c", dsym: nofile:11:13 }]"# + ", + r#"[SemErr { msg: "There cannot be rotation in identifier declaration of c", dsym: nofile:6:14 }]"#, ) } #[test] fn test_analyser_state_decl() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) + state initial { signal c; @@ -689,35 +850,16 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare state nested here", dsym: nofile:13:17 }]"# + ", + r#"[SemErr { msg: "Cannot declare state nested here", dsym: nofile:8:17 }]"#, ); - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) + state initial { signal c; @@ -747,37 +889,19 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare state nested here", dsym: nofile:18:1 }]"# + ", + r#"[SemErr { msg: "Cannot declare state nested here", dsym: nofile:13:15 }]"#, ); } #[test] fn test_assignment_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) + state initial { signal c; var wrong; @@ -804,38 +928,18 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot assign with <-- or <== to variable wrong with category WGVar, you can only assign to signals. Use = instead.", dsym: nofile:14:14 }]"# + ", + r#"[SemErr { msg: "Cannot assign with <-- or <== to variable wrong with category WGVar, you can only assign to signals. Use = instead.", dsym: nofile:9:14 }]"#, ); } #[test] fn test_assert_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -864,40 +968,21 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot use wgvar wrong in statement assert wrong == 3;", dsym: nofile:24:14 }, SemErr { msg: "Cannot use wgvar wrong in statement [c] <== [(a + b) + wrong];", dsym: nofile:26:14 }]"# + ", + r#"[SemErr { msg: "Cannot use wgvar wrong in statement assert wrong == 3;", dsym: nofile:18:14 }, SemErr { msg: "Cannot use wgvar wrong in statement [c] <== [(a + b) + wrong];", dsym: nofile:20:14 }]"#, ) } #[test] fn test_machine_decl_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; i, a, b, c <== 1, 1, 1, 2; // this cannot be here - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -929,38 +1014,18 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare [i, a, b, c] <== [1, 1, 1, 2]; in the machine, only states, wgvars and signals are allowed", dsym: nofile:2:9 }, SemErr { msg: "Cannot declare if (i + 1) == n { [a] <-- [3]; } else { [b] <== [3]; } in the machine, only states, wgvars and signals are allowed", dsym: nofile:2:9 }]"# + ", + r#"[SemErr { msg: "Cannot declare [i, a, b, c] <== [1, 1, 1, 2]; in the machine, only states, wgvars and signals are allowed", dsym: nofile:2:9 }, SemErr { msg: "Cannot declare if (i + 1) == n { [a] <-- [3]; } else { [b] <== [3]; } in the machine, only states, wgvars and signals are allowed", dsym: nofile:2:9 }]"#, ); } #[test] fn test_redeclare_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -995,38 +1060,18 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot redeclare middle in the same scope [\"/\", \"fibo\"]", dsym: nofile:28:13 }, SemErr { msg: "Cannot redeclare n in the same scope [\"/\", \"fibo\"]", dsym: nofile:20:13 }, SemErr { msg: "Cannot redeclare c in the same scope [\"/\", \"fibo\", \"middle\"]", dsym: nofile:30:14 }]"# + ", + r#"[SemErr { msg: "Cannot redeclare middle in the same scope [\"/\", \"fibo\"]", dsym: nofile:22:13 }, SemErr { msg: "Cannot redeclare n in the same scope [\"/\", \"fibo\"]", dsym: nofile:14:13 }, SemErr { msg: "Cannot redeclare c in the same scope [\"/\", \"fibo\", \"middle\"]", dsym: nofile:24:14 }]"#, ); } #[test] fn test_types_rule() { - let circuit = " + do_test( + " machine fibo(signal n: uint) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -1052,38 +1097,18 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot declare n with type uint, only field and bool are allowed.", dsym: nofile:2:22 }, SemErr { msg: "Cannot declare c with type int, only field and bool are allowed.", dsym: nofile:21:14 }]"# + ", + r#"[SemErr { msg: "Cannot declare n with type uint, only field and bool are allowed.", dsym: nofile:2:22 }, SemErr { msg: "Cannot declare c with type int, only field and bool are allowed.", dsym: nofile:15:14 }]"#, ); } #[test] fn test_true_false_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; var is_true; @@ -1118,38 +1143,18 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot use true in expression 2 + true", dsym: nofile:15:42 }, SemErr { msg: "Cannot use true in expression 1 * true", dsym: nofile:32:24 }, SemErr { msg: "Cannot use false in expression false - 123", dsym: nofile:32:31 }, SemErr { msg: "Cannot use false in expression false * false", dsym: nofile:32:50 }, SemErr { msg: "Cannot use false in expression false * false", dsym: nofile:32:58 }]"# + ", + r#"[SemErr { msg: "Cannot use true in expression 2 + true", dsym: nofile:9:42 }, SemErr { msg: "Cannot use true in expression 1 * true", dsym: nofile:26:24 }, SemErr { msg: "Cannot use false in expression false - 123", dsym: nofile:26:31 }, SemErr { msg: "Cannot use false in expression false * false", dsym: nofile:26:50 }, SemErr { msg: "Cannot use false in expression false * false", dsym: nofile:26:58 }]"#, ); } #[test] fn test_if_expression_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -1195,39 +1200,19 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; - - let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let decls = lang::TLDeclsParser::new() - .parse(&debug_sym_ref_factory, circuit) - .unwrap(); - - let result = analyse(&decls); - - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Condition i + 1 in if statement must be a logic expression", dsym: nofile:36:14 }, SemErr { msg: "Signal c in if statement condition must be bool", dsym: nofile:37:17 }, SemErr { msg: "Condition 4 in if statement must be a logic expression", dsym: nofile:43:17 }]"# + ", + r#"[SemErr { msg: "Condition i + 1 in if statement must be a logic expression", dsym: nofile:30:14 }, SemErr { msg: "Signal c in if statement condition must be bool", dsym: nofile:31:17 }, SemErr { msg: "Condition 4 in if statement must be a logic expression", dsym: nofile:37:17 }]"#, ); } #[test] fn test_wg_assignment_rule() { - let circuit = " + do_test( + " machine fibo(signal n) (signal b: field) { - // n and be are created automatically as shared - // signals signal a: field, i; var wgvar; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) state initial { signal c; @@ -1256,24 +1241,231 @@ mod test { } } } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. } - "; + ", + r#"[SemErr { msg: "Cannot assign with = to Signal i, you can only assign to WGVars. Use <-- or <== instead.", dsym: nofile:9:14 }]"#, + ); + } + + #[test] + fn test_assignment_args_count() { + // The number of identifiers and expressions in the assignment should be equal: + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e: field; + state initial { + c' <-- d, e; + } + } + ", + r#"[SemErr { msg: "Number of identifiers and expressions in assignment should be equal", dsym: nofile:6:17 }]"#, + ); + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e: field; + state initial { + c' <== d, e; + } + } + ", + r#"[SemErr { msg: "Number of identifiers and expressions in assignment should be equal", dsym: nofile:6:17 }]"#, + ); + } + + #[test] + fn test_transition_to_valid_state() { + // The transition should be to a valid state. Trying to transition to a state that does not + // exist: + do_test( + " + machine caller (signal n) (signal b: field) { + state initial { + -> no_state; + } + } + ", + r#"[SemErr { msg: "Cannot transition to non-existing state `no_state`", dsym: nofile:4:17 }]"#, + ); + } + + #[test] + fn test_call_assignments_rotation() { + // Testing the mandatory identifier rotation syntax + do_test( + " + machine caller (signal n) (signal b: field) { + state initial { + b <== other(n) -> final; + } + } + machine other (signal n) (signal b: field) { + state initial { + b' <== n; + } + } + ", + r#"[SemErr { msg: "All assigned identifiers in the hyper-transition must have a forward rotation ('), but `b` is missing it.", dsym: nofile:4:17 }]"#, + ); + } + + #[test] + fn test_call_no_expression_rotation() { + // Testing the absence of rotation in expressions + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e; + state initial { + c' <== other(d + e') -> final; + } + } + machine other (signal n) (signal b: field) { + state initial { + b' <== n; + } + } + ", + r#"[SemErr { msg: "Non-zero rotation is not allowed in a call.", dsym: nofile:6:34 }]"#, + ); + } + + #[test] + fn test_hyper_transition_valid_machine() { + // The callee should be a valid machine. Trying to call a machine that does not exist: + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e; + state initial { + c' <== other(d + e) -> final; + } + } + ", + r#"[SemErr { msg: "Call statement must call a valid machine, but `other` is not a machine.", dsym: nofile:6:24 }]"#, + ); + } + + #[test] + fn test_hyper_transition_call_itself() { + // The callee should be a valid machine. Trying to call itself: + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e; + state initial { + c' <== caller(d + e) -> final; + } + } + ", + r#"[SemErr { msg: "A machine should not call itself.", dsym: nofile:6:24 }]"#, + ); + } + + #[test] + fn test_hyper_transition_valid_state() { + // The transition should be to a valid state. Trying to transition to a state that does not + // exist: + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e; + state initial { + c' <== other(d + e) -> no_state; + } + } + machine other (signal n) (signal b: field) { + state initial { + b' <== n; + } + } + ", + r#"[SemErr { msg: "Cannot transition to non-existing state `no_state`", dsym: nofile:6:37 }]"#, + ); + } + + #[test] + fn test_hyper_transition_is_the_last() { + // The hyper-transition should be the last statement in a block + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e; + state initial { + c' <== other(d + e) -> final; + c' <== 1; + } + } + machine other (signal n) (signal b: field) { + state initial { + b' <== n; + } + } + ", + r#"[SemErr { msg: "Hyper-transition should be the last statement in a block", dsym: nofile:7:17 }]"#, + ); + } + + #[test] + fn test_hyper_transition_call_arg_count() { + // Cannot call a machine with the wrong number of arguments + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + signal d, e: field; + state initial { + c' <== other(d, e) -> final; + } + } + machine other (signal n) (signal b: field) { + state initial { + b' <== n; + } + } + ", + r#"[SemErr { msg: "Expected 1 argument(s) for `other`, but got 2.", dsym: nofile:6:24 }]"#, + ); + } + + #[test] + fn test_hyper_transition_assignment_arg_count() { + // Cannot assign call result with the wrong number of arguments + do_test( + " + machine caller (signal n) (signal b: field) { + signal c; + state initial { + c' <== other(n) -> final; + } + } + machine other (signal n) (signal b: field, signal c: field) { + state initial { + b' <== n; + c' <== n; + } + } + ", + r#"[SemErr { msg: "Machine `other` has 2 output(s), but left hand side has 1 identifier(s)", dsym: nofile:5:24 }]"#, + ); + } + + fn do_test(circuit: &str, expected: &str) { let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); let decls = lang::TLDeclsParser::new() .parse(&debug_sym_ref_factory, circuit) .unwrap(); - let result = analyse(&decls); - assert_eq!( - format!("{:?}", result.messages), - r#"[SemErr { msg: "Cannot assign with = to Signal i, you can only assign to WGVars. Use <-- or <== instead.", dsym: nofile:15:14 }]"# - ); + assert_eq!(format!("{:?}", result.messages), expected); } } diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 30130eaa..f9ec58a6 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -212,6 +212,9 @@ impl SetupInterpreter { Statement::StateDecl(dsym, id, stmt) => self.interpret_state_decl(dsym, id, stmt), Statement::SignalDecl(_, _) => {} Statement::WGVarDecl(_, _) => {} + Statement::HyperTransition(_, _, _) => { + // TODO interpret hyper transition? Needs specs + } _ => unreachable!("semantic analyser should prevent this"), } @@ -250,7 +253,8 @@ impl SetupInterpreter { SignalAssignment(_, _, _) | WGAssignment(_, _, _) => vec![], SignalDecl(_, _) | WGVarDecl(_, _) => vec![], - HyperTransition(_, _, _, _) => todo!("Implement compilation for hyper transitions"), + // TODO Implement compilation for hyper transitions + HyperTransition(_, _, _) => vec![], }; self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect()); diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index fc71787c..cab6f2b6 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -206,7 +206,7 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { Block(_, stmts) => self.exec_step_block(stmts), Assert(_, _) => Ok(None), StateDecl(_, _, _) => Ok(None), - HyperTransition(_, _, _, _) => todo!("Needs specs"), + HyperTransition(_, _, _) => todo!("Needs specs"), } } diff --git a/src/parser/ast/statement.rs b/src/parser/ast/statement.rs index 896cd50e..95ea34a3 100644 --- a/src/parser/ast/statement.rs +++ b/src/parser/ast/statement.rs @@ -45,10 +45,9 @@ pub enum Statement { /// state. /// Tuple values: /// - debug symbol reference; - /// - assigned signal IDs; - /// - call expression; - /// - next state ID; - HyperTransition(DebugSymRef, Vec, Expression, V), + /// - Statement::SignalAssignmentAssert; + /// - Statement::Transition; + HyperTransition(DebugSymRef, Box>, Box>), } impl Debug for Statement { @@ -98,17 +97,8 @@ impl Debug for Statement { .join(" ") ) } - Statement::HyperTransition(_, ids, call, state) => { - write!( - f, - "{:?} <== {:?} -> {:?};", - ids.iter() - .map(|id| id.name()) - .collect::>() - .join(", "), - call, - state - ) + Statement::HyperTransition(_, assign_call, state_transition) => { + write!(f, "{:?} -> {{ {:?} }};", state_transition, assign_call,) } } } @@ -128,7 +118,7 @@ impl Statement { Statement::StateDecl(dsym, _, _) => dsym.clone(), Statement::Transition(dsym, _, _) => dsym.clone(), Statement::Block(dsym, _) => dsym.clone(), - Statement::HyperTransition(dsym, _, _, _) => dsym.clone(), + Statement::HyperTransition(dsym, _, _) => dsym.clone(), } } } diff --git a/src/parser/build.rs b/src/parser/build.rs index e179f875..f49399c4 100644 --- a/src/parser/build.rs +++ b/src/parser/build.rs @@ -62,14 +62,25 @@ pub fn build_transition( Statement::Transition(dsym, id, Box::new(block)) } -pub fn build_hyper_transition( +pub fn build_inline_hyper_transition( dsym: DebugSymRef, - ids: Vec, - call: Expression, - state: Identifier, + assign_call: Statement, + state_transition: Statement, ) -> Statement { - match call { - Expression::Call(_, _, _) => Statement::HyperTransition(dsym, ids, call, state), - _ => unreachable!("Hyper transition must include a call statement"), + match &assign_call { + Statement::SignalAssignmentAssert(_, _, call) => { + if call.len() != 1 { + unreachable!("Inline hyper-transition should have a single call statement") + } + match call[0] { + Expression::Call(_, _, _) => Statement::HyperTransition( + dsym, + Box::new(assign_call), + Box::new(state_transition), + ), + _ => unreachable!("Inline hyper-transition should have a call statement"), + } + } + _ => unreachable!("Hyper transition must include a SignalAssignmentAssert"), } } diff --git a/src/parser/chiquito.lalrpop b/src/parser/chiquito.lalrpop index 36b50a3c..72b15a06 100644 --- a/src/parser/chiquito.lalrpop +++ b/src/parser/chiquito.lalrpop @@ -113,7 +113,7 @@ ParseTransition: Statement = { } HyperTransition: Statement = { - "<==" "->" => build_hyper_transition(dsym_factory.create(l,r), ids, call, st), + => build_inline_hyper_transition(dsym_factory.create(l,r), assign_call, st), } ParseSignalDecl: Statement = {