diff --git a/src/compiler/abepi.rs b/src/compiler/abepi.rs index ceac8789..03b488ff 100644 --- a/src/compiler/abepi.rs +++ b/src/compiler/abepi.rs @@ -65,6 +65,9 @@ impl + TryInto + Clone + Debug, V: Clone + Debug> CompilationU Statement::Transition(dsym, id, stmt) => { self.compiler_statement_transition(dsym, id, *stmt) } + Statement::HyperTransition(dsym, ids, call, state) => { + self.compiler_statement_hyper_transition(dsym, ids, call, state) + } _ => vec![], } } @@ -420,6 +423,16 @@ impl + TryInto + Clone + Debug, V: Clone + Debug> CompilationU result } + + fn compiler_statement_hyper_transition( + &self, + _dsym: DebugSymRef, + _ids: Vec, + _call: Expression, + _state: V, + ) -> Vec> { + todo!("Compile expressions? Needs specs") + } } fn flatten_bin_op( diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 9a3c5584..e717a993 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -558,57 +558,22 @@ mod test { use crate::{ compiler::{compile, compile_file, compile_legacy}, - parser::ast::debug_sym_factory::DebugSymRefFactory, + parser::{ast::debug_sym_factory::DebugSymRefFactory, lang::TLDeclsParser}, wit_gen::TraceGenerator, }; use super::Config; - // TODO rewrite the test after machines are able to call other machines + // TODO improve the test for HyperTransition #[test] fn test_compiler_fibo_multiple_machines() { // Source code containing two machines let circuit = " - machine fibo1 (signal n) (signal b: field) { - // n and be are created automatically as shared - // signals - signal a: field, i; - - // there is always a state called initial - // input signals get bound to the signal - // in the initial state (first instance) - state initial { - signal c; - - i, a, b, c <== 1, 1, 1, 2; - - -> middle { - i', a', b', n' <== i + 1, b, c, n; - } - } - - state middle { - signal c; - - c <== a + b; - - if i + 1 == n { - -> final { - i', b', n' <== i + 1, c, n; - } - } else { - -> middle { - i', a', b', n' <== i + 1, b, c, n; - } - } - } - - // There is always a state called final. - // Output signals get automatically bound to the signals - // with the same name in the final step (last instance). - // This state can be implicit if there are no constraints in it. + machine caller (signal n) (signal b: field) { + signal b_1: field; + b_1' <== fibo(n) -> final; } - machine fibo2 (signal n) (signal b: field) { + machine fibo (signal n) (signal b: field) { // n and be are created automatically as shared // signals signal a: field, i; @@ -839,4 +804,43 @@ mod test { r#"[SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:24:39 }, SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:28:46 }]"# ) } + + #[test] + fn test_parse_hyper_transition() { + let circuit = " + machine caller (signal n) (signal b: field) { + a', b, c' <== fibo(d, e, f + g) -> final; + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit); + + assert!(result.is_ok()); + + let circuit = " + machine caller (signal n) (signal b: field) { + -> final { + a', b, c' <== fibo(d, e, f + g); + } + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit); + + assert!(result.is_ok()); + + // TODO should no-arg calls be allowed? Needs more specs for function/machine calls + let circuit = " + machine caller (signal n) (signal b: field) { + smth <== a() -> final; + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = TLDeclsParser::new().parse(&debug_sym_ref_factory, circuit); + + assert!(result.is_ok()); + } } diff --git a/src/compiler/semantic/analyser.rs b/src/compiler/semantic/analyser.rs index 3664608d..e5291ea4 100644 --- a/src/compiler/semantic/analyser.rs +++ b/src/compiler/semantic/analyser.rs @@ -272,6 +272,11 @@ impl Analyser { Statement::SignalDecl(_, _) => {} Statement::WGVarDecl(_, _) => {} + Statement::HyperTransition(_, ids, call, state) => { + self.analyse_expression(call); + self.collect_id_usages(&[state]); + self.collect_id_usages(&ids); + } } } @@ -308,6 +313,12 @@ impl Analyser { } => { self.extract_usages_expression(&sub); } + Expression::Call(_, fun, exprs) => { + self.collect_id_usages(&[fun]); + exprs + .into_iter() + .for_each(|expr| self.extract_usages_expression(&expr)); + } _ => {} } } diff --git a/src/compiler/semantic/rules.rs b/src/compiler/semantic/rules.rs index 5bf846e8..818162c1 100644 --- a/src/compiler/semantic/rules.rs +++ b/src/compiler/semantic/rules.rs @@ -43,6 +43,9 @@ fn undeclared_rule(analyser: &mut Analyser, expr: &Expression {} + Expression::Call(_, _, args) => { + args.iter().for_each(|arg| undeclared_rule(analyser, arg)); + } } } diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 9b2b3e24..30130eaa 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -250,6 +250,7 @@ impl SetupInterpreter { SignalAssignment(_, _, _) | WGAssignment(_, _, _) => vec![], SignalDecl(_, _) | WGVarDecl(_, _) => vec![], + HyperTransition(_, _, _, _) => todo!("Implement compilation for hyper transitions"), }; self.add_poly_constraints(result.into_iter().map(|cr| cr.anti_booly).collect()); diff --git a/src/interpreter/expr.rs b/src/interpreter/expr.rs index afc85830..d0eb3f9b 100644 --- a/src/interpreter/expr.rs +++ b/src/interpreter/expr.rs @@ -82,6 +82,9 @@ pub(crate) fn eval_expr( Const(_, v) => Ok(Value::Field(F::from_big_int(v))), True(_) => Ok(Value::Bool(true)), False(_) => Ok(Value::Bool(false)), + Call(_, _, _) => { + todo!("Needs specs. Evaluate the argument expressions, evaluate the function output?") + } } .map_err(|msg| Message::RuntimeErr { msg, diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index 7cd05010..fc71787c 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -206,6 +206,7 @@ impl<'a, F: Field + Hash> Interpreter<'a, F> { Block(_, stmts) => self.exec_step_block(stmts), Assert(_, _) => Ok(None), StateDecl(_, _, _) => Ok(None), + HyperTransition(_, _, _, _) => todo!("Needs specs"), } } diff --git a/src/parser/ast/expression.rs b/src/parser/ast/expression.rs index 61556d11..941641ad 100644 --- a/src/parser/ast/expression.rs +++ b/src/parser/ast/expression.rs @@ -193,6 +193,12 @@ pub enum Expression { Const(DebugSymRef, F), True(DebugSymRef), False(DebugSymRef), + /// Function or machine call. + /// Tuple values: + /// - debug symbol reference; + /// - function/machine ID; + /// - call argument expressions vector. + Call(DebugSymRef, V, Vec>), } // Shorthand for BigInt expression @@ -217,6 +223,9 @@ impl Expression { Const(_, _) => true, True(_) => false, False(_) => false, + Call(_, _, _) => { + todo!("Needs specs. For a function call, depends on the function return type?") + } } } @@ -234,6 +243,9 @@ impl Expression { when_true.is_logic() } + Expression::Call { .. } => { + todo!("Needs specs. For a function call, depends on the function return type?") + } _ => false, } } @@ -247,6 +259,7 @@ impl Expression { Expression::Const(dsym, _) => dsym, Expression::True(dsym) => dsym, Expression::False(dsym) => dsym, + Expression::Call(dsym, _, _) => dsym, } } @@ -260,6 +273,7 @@ impl Expression { Expression::Query(_, _) => false, Expression::True(_) => false, Expression::False(_) => false, + Expression::Call(_, _, _) => false, } } } @@ -315,6 +329,7 @@ impl Debug for Expression { Expression::True(_) => write!(f, "true"), Expression::False(_) => write!(f, "false"), + Expression::Call(_, fun, exprs) => write!(f, "{:?}({:?})", fun, exprs), } } } diff --git a/src/parser/ast/statement.rs b/src/parser/ast/statement.rs index 1d7e9b7f..896cd50e 100644 --- a/src/parser/ast/statement.rs +++ b/src/parser/ast/statement.rs @@ -13,28 +13,42 @@ pub struct TypedIdDecl { #[derive(Clone)] pub enum Statement { - Assert(DebugSymRef, Expression), // assert x; - - SignalAssignment(DebugSymRef, Vec, Vec>), // x <-- y; - SignalAssignmentAssert(DebugSymRef, Vec, Vec>), // x <== y; - WGAssignment(DebugSymRef, Vec, Vec>), // x = y; - - IfThen(DebugSymRef, Box>, Box>), // if x { y } + /// assert x; + Assert(DebugSymRef, Expression), + /// x <-- y; + SignalAssignment(DebugSymRef, Vec, Vec>), + /// x <== y; + SignalAssignmentAssert(DebugSymRef, Vec, Vec>), + /// x = y; + WGAssignment(DebugSymRef, Vec, Vec>), + /// if x { y } + IfThen(DebugSymRef, Box>, Box>), + /// if x { y } else { z } IfThenElse( DebugSymRef, Box>, Box>, Box>, - ), // if x { y } else { z } - - SignalDecl(DebugSymRef, Vec>), // signal x; - WGVarDecl(DebugSymRef, Vec>), // var x; - - StateDecl(DebugSymRef, V, Box>), // state x { y } - - Transition(DebugSymRef, V, Box>), // -> x { y } - - Block(DebugSymRef, Vec>), // { x } + ), + /// signal x; + SignalDecl(DebugSymRef, Vec>), + /// var x; + WGVarDecl(DebugSymRef, Vec>), + /// state x { y } + StateDecl(DebugSymRef, V, Box>), + /// Transition to another state. + /// -> x { y } + Transition(DebugSymRef, V, Box>), + /// { x } + Block(DebugSymRef, Vec>), + /// Call into another machine with assertion and subsequent transition to another + /// state. + /// Tuple values: + /// - debug symbol reference; + /// - assigned signal IDs; + /// - call expression; + /// - next state ID; + HyperTransition(DebugSymRef, Vec, Expression, V), } impl Debug for Statement { @@ -84,6 +98,18 @@ impl Debug for Statement { .join(" ") ) } + Statement::HyperTransition(_, ids, call, state) => { + write!( + f, + "{:?} <== {:?} -> {:?};", + ids.iter() + .map(|id| id.name()) + .collect::>() + .join(", "), + call, + state + ) + } } } } @@ -102,6 +128,7 @@ impl Statement { Statement::StateDecl(dsym, _, _) => dsym.clone(), Statement::Transition(dsym, _, _) => dsym.clone(), Statement::Block(dsym, _) => dsym.clone(), + Statement::HyperTransition(dsym, _, _, _) => dsym.clone(), } } } diff --git a/src/parser/build.rs b/src/parser/build.rs index d832991d..e179f875 100644 --- a/src/parser/build.rs +++ b/src/parser/build.rs @@ -1,5 +1,3 @@ -use num_bigint::BigInt; - use super::ast::{expression::Expression, statement::Statement, DebugSymRef, Identifier}; pub fn build_bin_op, F, V>( @@ -64,25 +62,14 @@ pub fn build_transition( Statement::Transition(dsym, id, Box::new(block)) } -pub fn add_dsym( +pub fn build_hyper_transition( dsym: DebugSymRef, - stmt: Statement, -) -> Statement { - match stmt { - Statement::Assert(_, expr) => Statement::Assert(dsym, expr), - Statement::SignalAssignment(_, ids, exprs) => Statement::SignalAssignment(dsym, ids, exprs), - Statement::SignalAssignmentAssert(_, ids, exprs) => { - Statement::SignalAssignmentAssert(dsym, ids, exprs) - } - Statement::WGAssignment(_, ids, exprs) => Statement::WGAssignment(dsym, ids, exprs), - Statement::StateDecl(_, id, block) => Statement::StateDecl(dsym, id, block), - Statement::IfThen(_, cond, then_block) => Statement::IfThen(dsym, cond, then_block), - Statement::IfThenElse(_, cond, then_block, else_block) => { - Statement::IfThenElse(dsym, cond, then_block, else_block) - } - Statement::Block(_, stmts) => Statement::Block(dsym, stmts), - Statement::SignalDecl(_, ids) => Statement::SignalDecl(dsym, ids), - Statement::WGVarDecl(_, ids) => Statement::WGVarDecl(dsym, ids), - Statement::Transition(_, id, stmt) => Statement::Transition(dsym, id, stmt), + ids: Vec, + call: Expression, + state: Identifier, +) -> Statement { + match call { + Expression::Call(_, _, _) => Statement::HyperTransition(dsym, ids, call, state), + _ => unreachable!("Hyper transition must include a call statement"), } } diff --git a/src/parser/chiquito.lalrpop b/src/parser/chiquito.lalrpop index b8fb582c..36b50a3c 100644 --- a/src/parser/chiquito.lalrpop +++ b/src/parser/chiquito.lalrpop @@ -37,9 +37,9 @@ pub Statements: Vec> = { } ParseStatement: Statement = { - ";" => add_dsym(dsym_factory.create(l,r), s), - => add_dsym(dsym_factory.create(l,r), s), - => add_dsym(dsym_factory.create(l,r), s), + ";" => s, + => s, + => s, } StatementType: Statement = { @@ -53,6 +53,7 @@ StatementType: Statement = { ParseWGVarDecl, ParseTransitionSimple, ParseTransition, + HyperTransition, } AssertEq: Statement = { @@ -111,6 +112,10 @@ ParseTransition: Statement = { "->" => build_transition(dsym_factory.create(l,r), id, block), } +HyperTransition: Statement = { + "<==" "->" => build_hyper_transition(dsym_factory.create(l,r), ids, call, st), +} + ParseSignalDecl: Statement = { "signal" => Statement::SignalDecl(dsym_factory.create(l,r), ids), } @@ -196,12 +201,17 @@ ParsePrefix: Expr = { ExpressionTerm } +Call: Expr = { + "(" ")" => Expression::Call(dsym_factory.create(l,r), fun, es), +} + ExpressionTerm: Expr = { => build_query(id, dsym_factory.create(l,r)), => lit, "true" => Expression::True(dsym_factory.create(l,r)), "false" => Expression::False(dsym_factory.create(l,r)), "(" ")" => e, + Call } ParseBinOp: Expr = {