From 599243633281e6827f0f4f095fb12d313e0125fa Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 17 Apr 2024 14:16:37 -0400 Subject: [PATCH] feat: Add comptime Interpreter (#4821) # Description ## Problem\* Resolves #4587 ## Summary\* Implements an interpreter to interpret Noir functions at compile-time. This is intended to be used in the future for evaluation of `comptime` items for metaprogramming, although there is currently no way to trigger this interpreter in user code. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: Maxim Vezenov --- Cargo.lock | 1 + Cargo.toml | 1 + compiler/noirc_evaluator/Cargo.toml | 4 +- compiler/noirc_frontend/Cargo.toml | 1 + compiler/noirc_frontend/src/ast/expression.rs | 5 + .../src/hir/comptime/interpreter.rs | 1282 +++++++++++++++++ .../noirc_frontend/src/hir/comptime/mod.rs | 2 + .../noirc_frontend/src/hir/comptime/tests.rs | 166 +++ .../src/hir/def_collector/dc_mod.rs | 1 + .../src/hir/resolution/resolver.rs | 1 + .../noirc_frontend/src/hir/type_check/mod.rs | 53 +- compiler/noirc_frontend/src/hir_def/expr.rs | 4 +- .../noirc_frontend/src/hir_def/function.rs | 4 +- compiler/noirc_frontend/src/hir_def/stmt.rs | 2 +- .../src/monomorphization/mod.rs | 2 +- compiler/noirc_frontend/src/node_interner.rs | 4 + .../src/parser/parser/function.rs | 1 + 17 files changed, 1503 insertions(+), 31 deletions(-) create mode 100644 compiler/noirc_frontend/src/hir/comptime/interpreter.rs create mode 100644 compiler/noirc_frontend/src/hir/comptime/tests.rs diff --git a/Cargo.lock b/Cargo.lock index e83c10a1932..b01c22ed75b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3163,6 +3163,7 @@ dependencies = [ "base64 0.21.2", "chumsky", "fm", + "im", "iter-extended", "lalrpop", "lalrpop-util", diff --git a/Cargo.toml b/Cargo.toml index 6a939878f9f..cdbb40f630a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ tempfile = "3.6.0" jsonrpc = { version = "0.16.0", features = ["minreq_http"] } flate2 = "1.0.24" +im = { version = "15.1", features = ["serde"] } tracing = "0.1.40" tracing-web = "0.1.3" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/compiler/noirc_evaluator/Cargo.toml b/compiler/noirc_evaluator/Cargo.toml index fad7c3c309e..fb2f003aa56 100644 --- a/compiler/noirc_evaluator/Cargo.toml +++ b/compiler/noirc_evaluator/Cargo.toml @@ -15,7 +15,7 @@ fxhash.workspace = true iter-extended.workspace = true thiserror.workspace = true num-bigint = "0.4" -im = { version = "15.1", features = ["serde"] } +im.workspace = true serde.workspace = true tracing.workspace = true -chrono = "0.4.37" \ No newline at end of file +chrono = "0.4.37" diff --git a/compiler/noirc_frontend/Cargo.toml b/compiler/noirc_frontend/Cargo.toml index e39ab2fe106..7a23585bd23 100644 --- a/compiler/noirc_frontend/Cargo.toml +++ b/compiler/noirc_frontend/Cargo.toml @@ -17,6 +17,7 @@ iter-extended.workspace = true chumsky.workspace = true thiserror.workspace = true smol_str.workspace = true +im.workspace = true serde_json.workspace = true serde.workspace = true rustc-hash = "1.1.0" diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 0e5919bf7db..755739af8fe 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -387,6 +387,9 @@ pub struct FunctionDefinition { /// True if this function was defined with the 'unconstrained' keyword pub is_unconstrained: bool, + /// True if this function was defined with the 'comptime' keyword + pub is_comptime: bool, + /// Indicate if this function was defined with the 'pub' keyword pub visibility: ItemVisibility, @@ -679,10 +682,12 @@ impl FunctionDefinition { span: ident.span().merge(unresolved_type.span.unwrap()), }) .collect(); + FunctionDefinition { name: name.clone(), attributes: Attributes::empty(), is_unconstrained: false, + is_comptime: false, visibility: ItemVisibility::Private, generics: generics.clone(), parameters: p, diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs new file mode 100644 index 00000000000..81050073008 --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -0,0 +1,1282 @@ +use std::{borrow::Cow, collections::hash_map::Entry, rc::Rc}; + +use acvm::FieldElement; +use im::Vector; +use iter_extended::{try_vecmap, vecmap}; +use noirc_errors::Location; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; + +use crate::{ + hir_def::{ + expr::{ + HirArrayLiteral, HirBlockExpression, HirCallExpression, HirCastExpression, + HirConstructorExpression, HirIdent, HirIfExpression, HirIndexExpression, + HirInfixExpression, HirLambda, HirMemberAccess, HirMethodCallExpression, + HirPrefixExpression, + }, + stmt::{ + HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, + HirPattern, + }, + }, + macros_api::{HirExpression, HirLiteral, HirStatement, NodeInterner}, + node_interner::{DefinitionId, DefinitionKind, ExprId, FuncId, StmtId}, + BinaryOpKind, BlockExpression, FunctionKind, IntegerBitSize, Shared, Signedness, Type, + TypeBinding, TypeBindings, TypeVariableKind, +}; + +#[allow(unused)] +pub(crate) struct Interpreter<'interner> { + /// To expand macros the Interpreter may mutate hir nodes within the NodeInterner + interner: &'interner mut NodeInterner, + + /// Each value currently in scope in the interpreter. + /// Each element of the Vec represents a scope with every scope together making + /// up all currently visible definitions. + scopes: Vec>, + + /// True if we've expanded any macros into any functions and will need + /// to redo name resolution & type checking for that function. + changed_functions: HashSet, + + /// True if we've expanded any macros into global scope and will need + /// to redo name resolution & type checking for everything. + changed_globally: bool, + + in_loop: bool, +} + +#[allow(unused)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Value { + Unit, + Bool(bool), + Field(FieldElement), + I8(i8), + I32(i32), + I64(i64), + U8(u8), + U32(u32), + U64(u64), + String(Rc), + Function(FuncId, Type), + Closure(HirLambda, Vec, Type), + Tuple(Vec), + Struct(HashMap, Value>, Type), + Pointer(Shared), + Array(Vector, Type), + Slice(Vector, Type), + Code(Rc), +} + +/// The possible errors that can halt the interpreter. +#[allow(unused)] +#[derive(Debug)] +pub(crate) enum InterpreterError { + ArgumentCountMismatch { expected: usize, actual: usize, call_location: Location }, + TypeMismatch { expected: Type, value: Value, location: Location }, + NoValueForId { id: DefinitionId, location: Location }, + IntegerOutOfRangeForType { value: FieldElement, typ: Type, location: Location }, + ErrorNodeEncountered { location: Location }, + NonFunctionCalled { value: Value, location: Location }, + NonBoolUsedInIf { value: Value, location: Location }, + NonBoolUsedInConstrain { value: Value, location: Location }, + FailingConstraint { message: Option, location: Location }, + NoMethodFound { object: Value, typ: Type, location: Location }, + NonIntegerUsedInLoop { value: Value, location: Location }, + NonPointerDereferenced { value: Value, location: Location }, + NonTupleOrStructInMemberAccess { value: Value, location: Location }, + NonArrayIndexed { value: Value, location: Location }, + NonIntegerUsedAsIndex { value: Value, location: Location }, + NonIntegerIntegerLiteral { typ: Type, location: Location }, + NonIntegerArrayLength { typ: Type, location: Location }, + NonNumericCasted { value: Value, location: Location }, + IndexOutOfBounds { index: usize, length: usize, location: Location }, + ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, + TypeUnsupported { typ: Type, location: Location }, + InvalidValueForUnary { value: Value, operator: &'static str, location: Location }, + InvalidValuesForBinary { lhs: Value, rhs: Value, operator: &'static str, location: Location }, + CastToNonNumericType { typ: Type, location: Location }, + + // Perhaps this should be unreachable! due to type checking also preventing this error? + // Currently it and the Continue variant are the only interpreter errors without a Location field + BreakNotInLoop, + ContinueNotInLoop, + + // These cases are not errors but prevent us from running more code + // until the loop can be resumed properly. + Break, + Continue, +} + +#[allow(unused)] +type IResult = std::result::Result; + +#[allow(unused)] +impl<'a> Interpreter<'a> { + pub(crate) fn new(interner: &'a mut NodeInterner) -> Self { + Self { + interner, + scopes: vec![HashMap::default()], + changed_functions: HashSet::default(), + changed_globally: false, + in_loop: false, + } + } + + pub(crate) fn call_function( + &mut self, + function: FuncId, + arguments: Vec<(Value, Location)>, + call_location: Location, + ) -> IResult { + let previous_state = self.enter_function(); + + let meta = self.interner.function_meta(&function); + if meta.kind != FunctionKind::Normal { + todo!("Evaluation for {:?} is unimplemented", meta.kind); + } + + if meta.parameters.len() != arguments.len() { + return Err(InterpreterError::ArgumentCountMismatch { + expected: meta.parameters.len(), + actual: arguments.len(), + call_location, + }); + } + + let parameters = meta.parameters.0.clone(); + for ((parameter, typ, _), (argument, arg_location)) in parameters.iter().zip(arguments) { + self.define_pattern(parameter, typ, argument, arg_location)?; + } + + let function_body = self.interner.function(&function).as_expr(); + let result = self.evaluate(function_body)?; + + self.exit_function(previous_state); + Ok(result) + } + + fn call_closure( + &mut self, + closure: HirLambda, + // TODO: How to define environment here? + _environment: Vec, + arguments: Vec<(Value, Location)>, + call_location: Location, + ) -> IResult { + let previous_state = self.enter_function(); + + if closure.parameters.len() != arguments.len() { + return Err(InterpreterError::ArgumentCountMismatch { + expected: closure.parameters.len(), + actual: arguments.len(), + call_location, + }); + } + + let parameters = closure.parameters.iter().zip(arguments); + for ((parameter, typ), (argument, arg_location)) in parameters { + self.define_pattern(parameter, typ, argument, arg_location)?; + } + + let result = self.evaluate(closure.body)?; + + self.exit_function(previous_state); + Ok(result) + } + + /// Enters a function, pushing a new scope and resetting any required state. + /// Returns the previous values of the internal state, to be reset when + /// `exit_function` is called. + fn enter_function(&mut self) -> (bool, Vec>) { + // Drain every scope except the global scope + let scope = self.scopes.drain(1..).collect(); + self.push_scope(); + (std::mem::take(&mut self.in_loop), scope) + } + + fn exit_function(&mut self, mut state: (bool, Vec>)) { + self.in_loop = state.0; + + // Keep only the global scope + self.scopes.truncate(1); + self.scopes.append(&mut state.1); + } + + fn push_scope(&mut self) { + self.scopes.push(HashMap::default()); + } + + fn pop_scope(&mut self) { + self.scopes.pop(); + } + + fn current_scope_mut(&mut self) -> &mut HashMap { + // the global scope is always at index zero, so this is always Some + self.scopes.last_mut().unwrap() + } + + fn define_pattern( + &mut self, + pattern: &HirPattern, + typ: &Type, + argument: Value, + location: Location, + ) -> IResult<()> { + match pattern { + HirPattern::Identifier(identifier) => { + self.define(identifier.id, typ, argument, location) + } + HirPattern::Mutable(pattern, _) => { + self.define_pattern(pattern, typ, argument, location) + } + HirPattern::Tuple(pattern_fields, _) => match (argument, typ) { + (Value::Tuple(fields), Type::Tuple(type_fields)) + if fields.len() == pattern_fields.len() => + { + for ((pattern, typ), argument) in + pattern_fields.iter().zip(type_fields).zip(fields) + { + self.define_pattern(pattern, typ, argument, location)?; + } + Ok(()) + } + (value, _) => { + Err(InterpreterError::TypeMismatch { expected: typ.clone(), value, location }) + } + }, + HirPattern::Struct(struct_type, pattern_fields, _) => { + self.type_check(typ, &argument, location)?; + self.type_check(struct_type, &argument, location)?; + + match argument { + Value::Struct(fields, struct_type) if fields.len() == pattern_fields.len() => { + for (field_name, field_pattern) in pattern_fields { + let field = fields.get(&field_name.0.contents).ok_or_else(|| { + InterpreterError::ExpectedStructToHaveField { + value: Value::Struct(fields.clone(), struct_type.clone()), + field_name: field_name.0.contents.clone(), + location, + } + })?; + + let field_type = field.get_type().into_owned(); + self.define_pattern( + field_pattern, + &field_type, + field.clone(), + location, + )?; + } + Ok(()) + } + value => Err(InterpreterError::TypeMismatch { + expected: typ.clone(), + value, + location, + }), + } + } + } + } + + /// Define a new variable in the current scope + fn define( + &mut self, + id: DefinitionId, + typ: &Type, + argument: Value, + location: Location, + ) -> IResult<()> { + self.type_check(typ, &argument, location)?; + self.current_scope_mut().insert(id, argument); + Ok(()) + } + + /// Mutate an existing variable, potentially from a prior scope. + /// Also type checks the value being assigned + fn checked_mutate( + &mut self, + id: DefinitionId, + typ: &Type, + argument: Value, + location: Location, + ) -> IResult<()> { + self.type_check(typ, &argument, location)?; + for scope in self.scopes.iter_mut().rev() { + if let Entry::Occupied(mut entry) = scope.entry(id) { + entry.insert(argument); + return Ok(()); + } + } + Err(InterpreterError::NoValueForId { id, location }) + } + + /// Mutate an existing variable, potentially from a prior scope + fn mutate(&mut self, id: DefinitionId, argument: Value, location: Location) -> IResult<()> { + for scope in self.scopes.iter_mut().rev() { + if let Entry::Occupied(mut entry) = scope.entry(id) { + entry.insert(argument); + return Ok(()); + } + } + Err(InterpreterError::NoValueForId { id, location }) + } + + fn lookup(&self, ident: &HirIdent) -> IResult { + for scope in self.scopes.iter().rev() { + if let Some(value) = scope.get(&ident.id) { + return Ok(value.clone()); + } + } + + Err(InterpreterError::NoValueForId { id: ident.id, location: ident.location }) + } + + fn lookup_id(&self, id: DefinitionId, location: Location) -> IResult { + for scope in self.scopes.iter().rev() { + if let Some(value) = scope.get(&id) { + return Ok(value.clone()); + } + } + + Err(InterpreterError::NoValueForId { id, location }) + } + + fn type_check(&self, typ: &Type, value: &Value, location: Location) -> IResult<()> { + let typ = typ.follow_bindings(); + let value_type = value.get_type(); + + typ.try_unify(&value_type, &mut TypeBindings::new()).map_err(|_| { + InterpreterError::TypeMismatch { expected: typ, value: value.clone(), location } + }) + } + + /// Evaluate an expression and return the result + fn evaluate(&mut self, id: ExprId) -> IResult { + match self.interner.expression(&id) { + HirExpression::Ident(ident) => self.evaluate_ident(ident, id), + HirExpression::Literal(literal) => self.evaluate_literal(literal, id), + HirExpression::Block(block) => self.evaluate_block(block), + HirExpression::Prefix(prefix) => self.evaluate_prefix(prefix, id), + HirExpression::Infix(infix) => self.evaluate_infix(infix, id), + HirExpression::Index(index) => self.evaluate_index(index, id), + HirExpression::Constructor(constructor) => self.evaluate_constructor(constructor, id), + HirExpression::MemberAccess(access) => self.evaluate_access(access, id), + HirExpression::Call(call) => self.evaluate_call(call, id), + HirExpression::MethodCall(call) => self.evaluate_method_call(call, id), + HirExpression::Cast(cast) => self.evaluate_cast(cast, id), + HirExpression::If(if_) => self.evaluate_if(if_, id), + HirExpression::Tuple(tuple) => self.evaluate_tuple(tuple), + HirExpression::Lambda(lambda) => self.evaluate_lambda(lambda, id), + HirExpression::Quote(block) => Ok(Value::Code(Rc::new(block))), + HirExpression::Error => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::ErrorNodeEncountered { location }) + } + } + } + + fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult { + let definition = self.interner.definition(ident.id); + + match &definition.kind { + DefinitionKind::Function(function_id) => { + let typ = self.interner.id_type(id); + Ok(Value::Function(*function_id, typ)) + } + DefinitionKind::Local(_) => dbg!(self.lookup(&ident)), + DefinitionKind::Global(global_id) => { + let let_ = self.interner.get_global_let_statement(*global_id).unwrap(); + self.evaluate_let(let_)?; + self.lookup(&ident) + } + DefinitionKind::GenericType(type_variable) => { + let value = match &*type_variable.borrow() { + TypeBinding::Unbound(_) => None, + TypeBinding::Bound(binding) => binding.evaluate_to_u64(), + }; + + if let Some(value) = value { + let typ = self.interner.id_type(id); + self.evaluate_integer((value as u128).into(), false, id) + } else { + let location = self.interner.expr_location(&id); + let typ = Type::TypeVariable(type_variable.clone(), TypeVariableKind::Normal); + Err(InterpreterError::NonIntegerArrayLength { typ, location }) + } + } + } + } + + fn evaluate_literal(&mut self, literal: HirLiteral, id: ExprId) -> IResult { + match literal { + HirLiteral::Unit => Ok(Value::Unit), + HirLiteral::Bool(value) => Ok(Value::Bool(value)), + HirLiteral::Integer(value, is_negative) => { + self.evaluate_integer(value, is_negative, id) + } + HirLiteral::Str(string) => Ok(Value::String(Rc::new(string))), + HirLiteral::FmtStr(_, _) => todo!("Evaluate format strings"), + HirLiteral::Array(array) => self.evaluate_array(array, id), + HirLiteral::Slice(array) => self.evaluate_slice(array, id), + } + } + + fn evaluate_integer( + &self, + value: FieldElement, + is_negative: bool, + id: ExprId, + ) -> IResult { + let typ = self.interner.id_type(id).follow_bindings(); + let location = self.interner.expr_location(&id); + + if let Type::FieldElement = &typ { + Ok(Value::Field(value)) + } else if let Type::Integer(sign, bit_size) = &typ { + match (sign, bit_size) { + (Signedness::Unsigned, IntegerBitSize::One) => { + return Err(InterpreterError::TypeUnsupported { typ, location }); + } + (Signedness::Unsigned, IntegerBitSize::Eight) => { + let value: u8 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { 0u8.wrapping_sub(value) } else { value }; + Ok(Value::U8(value)) + } + (Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => { + let value: u32 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { 0u32.wrapping_sub(value) } else { value }; + Ok(Value::U32(value)) + } + (Signedness::Unsigned, IntegerBitSize::SixtyFour) => { + let value: u64 = + value.try_to_u64().ok_or(InterpreterError::IntegerOutOfRangeForType { + value, + typ, + location, + })?; + let value = if is_negative { 0u64.wrapping_sub(value) } else { value }; + Ok(Value::U64(value)) + } + (Signedness::Signed, IntegerBitSize::One) => { + return Err(InterpreterError::TypeUnsupported { typ, location }); + } + (Signedness::Signed, IntegerBitSize::Eight) => { + let value: i8 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { -value } else { value }; + Ok(Value::I8(value)) + } + (Signedness::Signed, IntegerBitSize::ThirtyTwo) => { + let value: i32 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { -value } else { value }; + Ok(Value::I32(value)) + } + (Signedness::Signed, IntegerBitSize::SixtyFour) => { + let value: i64 = + value.try_to_u64().and_then(|value| value.try_into().ok()).ok_or( + InterpreterError::IntegerOutOfRangeForType { value, typ, location }, + )?; + let value = if is_negative { -value } else { value }; + Ok(Value::I64(value)) + } + } + } else { + Err(InterpreterError::NonIntegerIntegerLiteral { typ, location }) + } + } + + fn evaluate_block(&mut self, mut block: HirBlockExpression) -> IResult { + let last_statement = block.statements.pop(); + self.push_scope(); + + for statement in block.statements { + self.evaluate_statement(statement)?; + } + + let result = if let Some(statement) = last_statement { + self.evaluate_statement(statement) + } else { + Ok(Value::Unit) + }; + + self.pop_scope(); + result + } + + fn evaluate_array(&mut self, array: HirArrayLiteral, id: ExprId) -> IResult { + let typ = self.interner.id_type(id); + + match array { + HirArrayLiteral::Standard(elements) => { + let elements = elements + .into_iter() + .map(|id| self.evaluate(id)) + .collect::>>()?; + + Ok(Value::Array(elements, typ)) + } + HirArrayLiteral::Repeated { repeated_element, length } => { + let element = self.evaluate(repeated_element)?; + + if let Some(length) = length.evaluate_to_u64() { + let elements = (0..length).map(|_| element.clone()).collect(); + Ok(Value::Array(elements, typ)) + } else { + let location = self.interner.expr_location(&id); + Err(InterpreterError::NonIntegerArrayLength { typ: length, location }) + } + } + } + } + + fn evaluate_slice(&mut self, array: HirArrayLiteral, id: ExprId) -> IResult { + self.evaluate_array(array, id).map(|value| match value { + Value::Array(array, typ) => Value::Slice(array, typ), + other => unreachable!("Non-array value returned from evaluate array: {other:?}"), + }) + } + + fn evaluate_prefix(&mut self, prefix: HirPrefixExpression, id: ExprId) -> IResult { + let rhs = self.evaluate(prefix.rhs)?; + match prefix.operator { + crate::UnaryOp::Minus => match rhs { + Value::Field(value) => Ok(Value::Field(FieldElement::zero() - value)), + Value::I8(value) => Ok(Value::I8(-value)), + Value::I32(value) => Ok(Value::I32(-value)), + Value::I64(value) => Ok(Value::I64(-value)), + Value::U8(value) => Ok(Value::U8(0 - value)), + Value::U32(value) => Ok(Value::U32(0 - value)), + Value::U64(value) => Ok(Value::U64(0 - value)), + value => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::InvalidValueForUnary { + value, + location, + operator: "minus", + }) + } + }, + crate::UnaryOp::Not => match rhs { + Value::Bool(value) => Ok(Value::Bool(!value)), + Value::I8(value) => Ok(Value::I8(!value)), + Value::I32(value) => Ok(Value::I32(!value)), + Value::I64(value) => Ok(Value::I64(!value)), + Value::U8(value) => Ok(Value::U8(!value)), + Value::U32(value) => Ok(Value::U32(!value)), + Value::U64(value) => Ok(Value::U64(!value)), + value => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::InvalidValueForUnary { value, location, operator: "not" }) + } + }, + crate::UnaryOp::MutableReference => Ok(Value::Pointer(Shared::new(rhs))), + crate::UnaryOp::Dereference { implicitly_added: _ } => match rhs { + Value::Pointer(element) => Ok(element.borrow().clone()), + value => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::NonPointerDereferenced { value, location }) + } + }, + } + } + + fn evaluate_infix(&mut self, infix: HirInfixExpression, id: ExprId) -> IResult { + let lhs = self.evaluate(infix.lhs)?; + let rhs = self.evaluate(infix.rhs)?; + + // TODO: Need to account for operator overloading + assert!( + self.interner.get_selected_impl_for_expression(id).is_none(), + "Operator overloading is unimplemented in the interpreter" + ); + + use InterpreterError::InvalidValuesForBinary; + match infix.operator.kind { + BinaryOpKind::Add => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs + rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs + rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs + rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs + rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs + rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs + rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs + rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "+" }) + } + }, + BinaryOpKind::Subtract => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs - rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs - rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs - rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs - rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs - rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs - rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs - rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "-" }) + } + }, + BinaryOpKind::Multiply => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs * rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs * rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs * rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs * rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs * rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs * rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs * rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "*" }) + } + }, + BinaryOpKind::Divide => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Field(lhs / rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs / rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs / rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs / rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs / rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs / rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs / rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "/" }) + } + }, + BinaryOpKind::Equal => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs == rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs == rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "==" }) + } + }, + BinaryOpKind::NotEqual => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs != rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs != rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "!=" }) + } + }, + BinaryOpKind::Less => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs < rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs < rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<" }) + } + }, + BinaryOpKind::LessEqual => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs <= rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<=" }) + } + }, + BinaryOpKind::Greater => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs > rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs > rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">" }) + } + }, + BinaryOpKind::GreaterEqual => match (lhs, rhs) { + (Value::Field(lhs), Value::Field(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::Bool(lhs >= rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">=" }) + } + }, + BinaryOpKind::And => match (lhs, rhs) { + (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs & rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs & rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs & rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs & rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs & rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs & rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs & rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "&" }) + } + }, + BinaryOpKind::Or => match (lhs, rhs) { + (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs | rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs | rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs | rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs | rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs | rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs | rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs | rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "|" }) + } + }, + BinaryOpKind::Xor => match (lhs, rhs) { + (Value::Bool(lhs), Value::Bool(rhs)) => Ok(Value::Bool(lhs ^ rhs)), + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs ^ rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs ^ rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs ^ rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs ^ rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs ^ rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs ^ rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "^" }) + } + }, + BinaryOpKind::ShiftRight => match (lhs, rhs) { + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs >> rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs >> rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs >> rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs >> rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs >> rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs >> rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: ">>" }) + } + }, + BinaryOpKind::ShiftLeft => match (lhs, rhs) { + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs << rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs << rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs << rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs << rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs << rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs << rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "<<" }) + } + }, + BinaryOpKind::Modulo => match (lhs, rhs) { + (Value::I8(lhs), Value::I8(rhs)) => Ok(Value::I8(lhs % rhs)), + (Value::I32(lhs), Value::I32(rhs)) => Ok(Value::I32(lhs % rhs)), + (Value::I64(lhs), Value::I64(rhs)) => Ok(Value::I64(lhs % rhs)), + (Value::U8(lhs), Value::U8(rhs)) => Ok(Value::U8(lhs % rhs)), + (Value::U32(lhs), Value::U32(rhs)) => Ok(Value::U32(lhs % rhs)), + (Value::U64(lhs), Value::U64(rhs)) => Ok(Value::U64(lhs % rhs)), + (lhs, rhs) => { + let location = self.interner.expr_location(&id); + Err(InvalidValuesForBinary { lhs, rhs, location, operator: "%" }) + } + }, + } + } + + fn evaluate_index(&mut self, index: HirIndexExpression, id: ExprId) -> IResult { + let array = self.evaluate(index.collection)?; + let index = self.evaluate(index.index)?; + + let location = self.interner.expr_location(&id); + let (array, index) = self.bounds_check(array, index, location)?; + + Ok(array[index].clone()) + } + + /// Bounds check the given array and index pair. + /// This will also ensure the given arguments are in fact an array and integer. + fn bounds_check( + &self, + array: Value, + index: Value, + location: Location, + ) -> IResult<(Vector, usize)> { + let collection = match array { + Value::Array(array, _) => array, + Value::Slice(array, _) => array, + value => { + return Err(InterpreterError::NonArrayIndexed { value, location }); + } + }; + + let index = match index { + Value::Field(value) => { + value.try_to_u64().expect("index could not fit into u64") as usize + } + Value::I8(value) => value as usize, + Value::I32(value) => value as usize, + Value::I64(value) => value as usize, + Value::U8(value) => value as usize, + Value::U32(value) => value as usize, + Value::U64(value) => value as usize, + value => { + return Err(InterpreterError::NonIntegerUsedAsIndex { value, location }); + } + }; + + if index >= collection.len() { + use InterpreterError::IndexOutOfBounds; + return Err(IndexOutOfBounds { index, location, length: collection.len() }); + } + + Ok((collection, index)) + } + + fn evaluate_constructor( + &mut self, + constructor: HirConstructorExpression, + id: ExprId, + ) -> IResult { + let fields = constructor + .fields + .into_iter() + .map(|(name, expr)| { + let field_value = self.evaluate(expr)?; + Ok((Rc::new(name.0.contents), field_value)) + }) + .collect::>()?; + + let typ = self.interner.id_type(id); + Ok(Value::Struct(fields, typ)) + } + + fn evaluate_access(&mut self, access: HirMemberAccess, id: ExprId) -> IResult { + let (fields, struct_type) = match self.evaluate(access.lhs)? { + Value::Struct(fields, typ) => (fields, typ), + value => { + let location = self.interner.expr_location(&id); + return Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location }); + } + }; + + fields.get(&access.rhs.0.contents).cloned().ok_or_else(|| { + let location = self.interner.expr_location(&id); + let value = Value::Struct(fields, struct_type); + let field_name = access.rhs.0.contents; + InterpreterError::ExpectedStructToHaveField { value, field_name, location } + }) + } + + fn evaluate_call(&mut self, call: HirCallExpression, id: ExprId) -> IResult { + let function = self.evaluate(call.func)?; + let arguments = try_vecmap(call.arguments, |arg| { + Ok((self.evaluate(arg)?, self.interner.expr_location(&arg))) + })?; + let location = self.interner.expr_location(&id); + + match function { + Value::Function(function_id, _) => self.call_function(function_id, arguments, location), + Value::Closure(closure, env, _) => self.call_closure(closure, env, arguments, location), + value => Err(InterpreterError::NonFunctionCalled { value, location }), + } + } + + fn evaluate_method_call( + &mut self, + call: HirMethodCallExpression, + id: ExprId, + ) -> IResult { + let object = self.evaluate(call.object)?; + let arguments = try_vecmap(call.arguments, |arg| { + Ok((self.evaluate(arg)?, self.interner.expr_location(&arg))) + })?; + let location = self.interner.expr_location(&id); + + let typ = object.get_type().follow_bindings(); + let method_name = &call.method.0.contents; + + // TODO: Traits + let method = match &typ { + Type::Struct(struct_def, _) => { + self.interner.lookup_method(&typ, struct_def.borrow().id, method_name, false) + } + _ => self.interner.lookup_primitive_method(&typ, method_name), + }; + + if let Some(method) = method { + self.call_function(method, arguments, location) + } else { + Err(InterpreterError::NoMethodFound { object, typ, location }) + } + } + + fn evaluate_cast(&mut self, cast: HirCastExpression, id: ExprId) -> IResult { + macro_rules! signed_int_to_field { + ($x:expr) => {{ + // Need to convert the signed integer to an i128 before + // we negate it to preserve the MIN value. + let mut value = $x as i128; + let is_negative = value < 0; + if is_negative { + value = -value; + } + ((value as u128).into(), is_negative) + }}; + } + + let (mut lhs, lhs_is_negative) = match self.evaluate(cast.lhs)? { + Value::Field(value) => (value, false), + Value::U8(value) => ((value as u128).into(), false), + Value::U32(value) => ((value as u128).into(), false), + Value::U64(value) => ((value as u128).into(), false), + Value::I8(value) => signed_int_to_field!(value), + Value::I32(value) => signed_int_to_field!(value), + Value::I64(value) => signed_int_to_field!(value), + Value::Bool(value) => { + (if value { FieldElement::one() } else { FieldElement::zero() }, false) + } + value => { + let location = self.interner.expr_location(&id); + return Err(InterpreterError::NonNumericCasted { value, location }); + } + }; + + macro_rules! cast_to_int { + ($x:expr, $method:ident, $typ:ty, $f:ident) => {{ + let mut value = $x.$method() as $typ; + if lhs_is_negative { + value = 0 - value; + } + Ok(Value::$f(value)) + }}; + } + + // Now actually cast the lhs, bit casting and wrapping as necessary + match cast.r#type.follow_bindings() { + Type::FieldElement => { + if lhs_is_negative { + lhs = FieldElement::zero() - lhs; + } + Ok(Value::Field(lhs)) + } + Type::Integer(sign, bit_size) => match (sign, bit_size) { + (Signedness::Unsigned, IntegerBitSize::One) => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::TypeUnsupported { typ: cast.r#type, location }) + } + (Signedness::Unsigned, IntegerBitSize::Eight) => cast_to_int!(lhs, to_u128, u8, U8), + (Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => { + cast_to_int!(lhs, to_u128, u32, U32) + } + (Signedness::Unsigned, IntegerBitSize::SixtyFour) => { + cast_to_int!(lhs, to_u128, u64, U64) + } + (Signedness::Signed, IntegerBitSize::One) => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::TypeUnsupported { typ: cast.r#type, location }) + } + (Signedness::Signed, IntegerBitSize::Eight) => cast_to_int!(lhs, to_i128, i8, I8), + (Signedness::Signed, IntegerBitSize::ThirtyTwo) => { + cast_to_int!(lhs, to_i128, i32, I32) + } + (Signedness::Signed, IntegerBitSize::SixtyFour) => { + cast_to_int!(lhs, to_i128, i64, I64) + } + }, + Type::Bool => Ok(Value::Bool(!lhs.is_zero() || lhs_is_negative)), + typ => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::CastToNonNumericType { typ, location }) + } + } + } + + fn evaluate_if(&mut self, if_: HirIfExpression, id: ExprId) -> IResult { + let condition = match self.evaluate(if_.condition)? { + Value::Bool(value) => value, + value => { + let location = self.interner.expr_location(&id); + return Err(InterpreterError::NonBoolUsedInIf { value, location }); + } + }; + + self.push_scope(); + + let result = if condition { + if if_.alternative.is_some() { + self.evaluate(if_.consequence) + } else { + self.evaluate(if_.consequence)?; + Ok(Value::Unit) + } + } else { + match if_.alternative { + Some(alternative) => self.evaluate(alternative), + None => Ok(Value::Unit), + } + }; + + self.pop_scope(); + result + } + + fn evaluate_tuple(&mut self, tuple: Vec) -> IResult { + let fields = try_vecmap(tuple, |field| self.evaluate(field))?; + Ok(Value::Tuple(fields)) + } + + fn evaluate_lambda(&mut self, lambda: HirLambda, id: ExprId) -> IResult { + let location = self.interner.expr_location(&id); + let environment = + try_vecmap(&lambda.captures, |capture| self.lookup_id(capture.ident.id, location))?; + + let typ = self.interner.id_type(id); + Ok(Value::Closure(lambda, environment, typ)) + } + + fn evaluate_statement(&mut self, statement: StmtId) -> IResult { + match self.interner.statement(&statement) { + HirStatement::Let(let_) => self.evaluate_let(let_), + HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), + HirStatement::Assign(assign) => self.evaluate_assign(assign), + HirStatement::For(for_) => self.evaluate_for(for_), + HirStatement::Break => self.evaluate_break(), + HirStatement::Continue => self.evaluate_continue(), + HirStatement::Expression(expression) => self.evaluate(expression), + HirStatement::Semi(expression) => { + self.evaluate(expression)?; + Ok(Value::Unit) + } + HirStatement::Error => { + let location = self.interner.id_location(statement); + Err(InterpreterError::ErrorNodeEncountered { location }) + } + } + } + + fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult { + let rhs = self.evaluate(let_.expression)?; + let location = self.interner.expr_location(&let_.expression); + self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?; + Ok(Value::Unit) + } + + fn evaluate_constrain(&mut self, constrain: HirConstrainStatement) -> IResult { + match self.evaluate(constrain.0)? { + Value::Bool(true) => Ok(Value::Unit), + Value::Bool(false) => { + let location = self.interner.expr_location(&constrain.0); + let message = constrain.2.and_then(|expr| self.evaluate(expr).ok()); + Err(InterpreterError::FailingConstraint { location, message }) + } + value => { + let location = self.interner.expr_location(&constrain.0); + Err(InterpreterError::NonBoolUsedInConstrain { value, location }) + } + } + } + + fn evaluate_assign(&mut self, assign: HirAssignStatement) -> IResult { + let rhs = self.evaluate(assign.expression)?; + self.store_lvalue(assign.lvalue, rhs)?; + Ok(Value::Unit) + } + + fn store_lvalue(&mut self, lvalue: HirLValue, rhs: Value) -> IResult<()> { + match lvalue { + HirLValue::Ident(ident, typ) => { + self.checked_mutate(ident.id, &typ, rhs, ident.location) + } + HirLValue::Dereference { lvalue, element_type: _, location } => { + match self.evaluate_lvalue(&lvalue)? { + Value::Pointer(value) => { + *value.borrow_mut() = rhs; + Ok(()) + } + value => Err(InterpreterError::NonPointerDereferenced { value, location }), + } + } + HirLValue::MemberAccess { object, field_name, field_index, typ: _, location } => { + let index = field_index.expect("The field index should be set after type checking"); + match self.evaluate_lvalue(&object)? { + Value::Tuple(mut fields) => { + fields[index] = rhs; + self.store_lvalue(*object, Value::Tuple(fields)) + } + Value::Struct(mut fields, typ) => { + fields.insert(Rc::new(field_name.0.contents), rhs); + self.store_lvalue(*object, Value::Struct(fields, typ)) + } + value => { + Err(InterpreterError::NonTupleOrStructInMemberAccess { value, location }) + } + } + } + HirLValue::Index { array, index, typ: _, location } => { + let array_value = self.evaluate_lvalue(&array)?; + let index = self.evaluate(index)?; + + let constructor = match &array_value { + Value::Array(..) => Value::Array, + _ => Value::Slice, + }; + + let typ = array_value.get_type().into_owned(); + let (elements, index) = self.bounds_check(array_value, index, location)?; + + let new_array = constructor(elements.update(index, rhs), typ); + self.store_lvalue(*array, new_array) + } + } + } + + fn evaluate_lvalue(&mut self, lvalue: &HirLValue) -> IResult { + match lvalue { + HirLValue::Ident(ident, _) => self.lookup(ident), + HirLValue::Dereference { lvalue, element_type: _, location } => { + match self.evaluate_lvalue(lvalue)? { + Value::Pointer(value) => Ok(value.borrow().clone()), + value => { + Err(InterpreterError::NonPointerDereferenced { value, location: *location }) + } + } + } + HirLValue::MemberAccess { object, field_name, field_index, typ: _, location } => { + let index = field_index.expect("The field index should be set after type checking"); + + match self.evaluate_lvalue(object)? { + Value::Tuple(mut values) => Ok(values.swap_remove(index)), + Value::Struct(fields, _) => Ok(fields[&field_name.0.contents].clone()), + value => Err(InterpreterError::NonTupleOrStructInMemberAccess { + value, + location: *location, + }), + } + } + HirLValue::Index { array, index, typ: _, location } => { + let array = self.evaluate_lvalue(array)?; + let index = self.evaluate(*index)?; + let (elements, index) = self.bounds_check(array, index, *location)?; + Ok(elements[index].clone()) + } + } + } + + fn evaluate_for(&mut self, for_: HirForStatement) -> IResult { + // i128 can store all values from i8 - u64 + let get_index = |this: &mut Self, expr| -> IResult<(_, fn(_) -> _)> { + match this.evaluate(expr)? { + Value::I8(value) => Ok((value as i128, |i| Value::I8(i as i8))), + Value::I32(value) => Ok((value as i128, |i| Value::I32(i as i32))), + Value::I64(value) => Ok((value as i128, |i| Value::I64(i as i64))), + Value::U8(value) => Ok((value as i128, |i| Value::U8(i as u8))), + Value::U32(value) => Ok((value as i128, |i| Value::U32(i as u32))), + Value::U64(value) => Ok((value as i128, |i| Value::U64(i as u64))), + value => { + let location = this.interner.expr_location(&expr); + Err(InterpreterError::NonIntegerUsedInLoop { value, location }) + } + } + }; + + let (start, make_value) = get_index(self, for_.start_range)?; + let (end, _) = get_index(self, for_.end_range)?; + let was_in_loop = std::mem::replace(&mut self.in_loop, true); + + for i in start..end { + self.push_scope(); + self.current_scope_mut().insert(for_.identifier.id, make_value(i)); + + match self.evaluate(for_.block) { + Ok(_) => (), + Err(InterpreterError::Break) => break, + Err(InterpreterError::Continue) => continue, + Err(other) => return Err(other), + } + self.pop_scope(); + } + + self.in_loop = was_in_loop; + Ok(Value::Unit) + } + + fn evaluate_break(&mut self) -> IResult { + if self.in_loop { + Err(InterpreterError::Break) + } else { + Err(InterpreterError::BreakNotInLoop) + } + } + + fn evaluate_continue(&mut self) -> IResult { + if self.in_loop { + Err(InterpreterError::Continue) + } else { + Err(InterpreterError::ContinueNotInLoop) + } + } +} + +impl Value { + fn get_type(&self) -> Cow { + Cow::Owned(match self { + Value::Unit => Type::Unit, + Value::Bool(_) => Type::Bool, + Value::Field(_) => Type::FieldElement, + Value::I8(_) => Type::Integer(Signedness::Signed, IntegerBitSize::Eight), + Value::I32(_) => Type::Integer(Signedness::Signed, IntegerBitSize::ThirtyTwo), + Value::I64(_) => Type::Integer(Signedness::Signed, IntegerBitSize::SixtyFour), + Value::U8(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight), + Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), + Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), + Value::String(value) => { + let length = Type::Constant(value.len() as u64); + Type::String(Box::new(length)) + } + Value::Function(_, typ) => return Cow::Borrowed(typ), + Value::Closure(_, _, typ) => return Cow::Borrowed(typ), + Value::Tuple(fields) => { + Type::Tuple(vecmap(fields, |field| field.get_type().into_owned())) + } + Value::Struct(_, typ) => return Cow::Borrowed(typ), + Value::Array(_, typ) => return Cow::Borrowed(typ), + Value::Slice(_, typ) => return Cow::Borrowed(typ), + Value::Code(_) => Type::Code, + Value::Pointer(element) => { + let element = element.borrow().get_type().into_owned(); + Type::MutableReference(Box::new(element)) + } + }) + } +} diff --git a/compiler/noirc_frontend/src/hir/comptime/mod.rs b/compiler/noirc_frontend/src/hir/comptime/mod.rs index 91621c857cf..83aaddaa405 100644 --- a/compiler/noirc_frontend/src/hir/comptime/mod.rs +++ b/compiler/noirc_frontend/src/hir/comptime/mod.rs @@ -1 +1,3 @@ mod hir_to_ast; +mod interpreter; +mod tests; diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs new file mode 100644 index 00000000000..016e7079886 --- /dev/null +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -0,0 +1,166 @@ +#![cfg(test)] + +use noirc_errors::Location; + +use super::interpreter::{Interpreter, InterpreterError, Value}; +use crate::hir::type_check::test::type_check_src_code; + +fn interpret_helper(src: &str, func_namespace: Vec) -> Result { + let (mut interner, main_id) = type_check_src_code(src, func_namespace); + let mut interpreter = Interpreter::new(&mut interner); + + let no_location = Location::dummy(); + interpreter.call_function(main_id, Vec::new(), no_location) +} + +fn interpret(src: &str, func_namespace: Vec) -> Value { + interpret_helper(src, func_namespace).unwrap_or_else(|error| { + panic!("Expected interpreter to exit successfully, but found {error:?}") + }) +} + +fn interpret_expect_error(src: &str, func_namespace: Vec) -> InterpreterError { + interpret_helper(src, func_namespace).expect_err("Expected interpreter to error") +} + +#[test] +fn interpreter_works() { + let program = "fn main() -> pub Field { 3 }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::Field(3u128.into())); +} + +#[test] +fn mutation_works() { + let program = "fn main() -> pub i8 { + let mut x = 3; + x = 4; + x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::I8(4)); +} + +#[test] +fn mutating_references() { + let program = "fn main() -> pub i32 { + let x = &mut 3; + *x = 4; + *x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::I32(4)); +} + +#[test] +fn mutating_mutable_references() { + let program = "fn main() -> pub i64 { + let mut x = &mut 3; + *x = 4; + *x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::I64(4)); +} + +#[test] +fn mutating_arrays() { + let program = "fn main() -> pub u8 { + let mut a1 = [1, 2, 3, 4]; + a1[1] = 22; + a1[1] + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::U8(22)); +} + +#[test] +fn for_loop() { + let program = "fn main() -> pub u8 { + let mut x = 0; + for i in 0 .. 6 { + x += i; + } + x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::U8(15)); +} + +#[test] +fn for_loop_with_break() { + let program = "unconstrained fn main() -> pub u32 { + let mut x = 0; + for i in 0 .. 6 { + if i == 4 { + break; + } + x += i; + } + x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::U32(6)); +} + +#[test] +fn for_loop_with_continue() { + let program = "unconstrained fn main() -> pub u64 { + let mut x = 0; + for i in 0 .. 6 { + if i == 4 { + continue; + } + x += i; + } + x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::U64(11)); +} + +#[test] +fn assert() { + let program = "fn main() { + assert(1 == 1); + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::Unit); +} + +#[test] +fn assert_fail() { + let program = "fn main() { + assert(1 == 2); + }"; + let result = interpret_expect_error(program, vec!["main".into()]); + assert!(matches!(result, InterpreterError::FailingConstraint { .. })); +} + +#[test] +fn lambda() { + let program = "fn main() -> pub u8 { + let f = |x: u8| x + 1; + f(1) + }"; + let result = interpret(program, vec!["main".into()]); + assert!(matches!(result, Value::U8(2))); +} + +#[test] +fn non_deterministic_recursion() { + let program = " + fn main() -> pub u64 { + fib(10) + } + + fn fib(x: u64) -> u64 { + if x <= 1 { + x + } else { + fib(x - 1) + fib(x - 2) + } + }"; + let result = interpret(program, vec!["main".into(), "fib".into()]); + assert_eq!(result, Value::U64(55)); +} diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 6fbb3b67546..e3c79e39d31 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -416,6 +416,7 @@ impl<'a> ModCollector<'a> { // TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629 attributes: crate::token::Attributes::empty(), is_unconstrained: false, + is_comptime: false, }; let location = Location::new(name.span(), self.file_id); diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 9180201fe17..479f357126a 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -246,6 +246,7 @@ impl<'a> Resolver<'a> { name: name.clone(), attributes: Attributes::empty(), is_unconstrained: false, + is_comptime: false, visibility: ItemVisibility::Public, // Trait functions are always public generics: generics.clone(), parameters: vecmap(parameters, |(name, typ)| Param { diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index cdfc19b3a33..f5323cd07de 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -51,8 +51,7 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec TypeChecker<'interner> { // XXX: These tests are all manual currently. /// We can either build a test apparatus or pass raw code through the resolver #[cfg(test)] -mod test { +pub mod test { use std::collections::{BTreeMap, HashMap}; use std::vec; use fm::FileId; - use iter_extended::vecmap; + use iter_extended::btree_map; use noirc_errors::{Location, Span}; use crate::graph::CrateId; @@ -601,7 +600,7 @@ mod test { "#; - type_check_src_code(src, vec![String::from("main"), String::from("foo")]); + type_check_src_code(src, vec![String::from("main")]); } #[test] fn basic_closure() { @@ -612,7 +611,7 @@ mod test { } "#; - type_check_src_code(src, vec![String::from("main"), String::from("foo")]); + type_check_src_code(src, vec![String::from("main")]); } #[test] @@ -672,8 +671,8 @@ mod test { } } - fn type_check_src_code(src: &str, func_namespace: Vec) { - type_check_src_code_errors_expected(src, func_namespace, 0); + pub fn type_check_src_code(src: &str, func_namespace: Vec) -> (NodeInterner, FuncId) { + type_check_src_code_errors_expected(src, func_namespace, 0) } // This function assumes that there is only one function and this is the @@ -682,7 +681,7 @@ mod test { src: &str, func_namespace: Vec, expected_num_type_check_errs: usize, - ) { + ) -> (NodeInterner, FuncId) { let (program, errors) = parse_program(src); let mut interner = NodeInterner::default(); interner.populate_dummy_operator_traits(); @@ -695,14 +694,16 @@ mod test { errors ); - let main_id = interner.push_test_function_definition("main".into()); + let func_ids = btree_map(&func_namespace, |name| { + (name.to_string(), interner.push_test_function_definition(name.into())) + }); - let func_ids = - vecmap(&func_namespace, |name| interner.push_test_function_definition(name.into())); + let main_id = + *func_ids.get("main").unwrap_or_else(|| func_ids.first_key_value().unwrap().1); let mut path_resolver = TestPathResolver(HashMap::new()); - for (name, id) in func_namespace.into_iter().zip(func_ids.clone()) { - path_resolver.insert_func(name.to_owned(), id); + for (name, id) in func_ids.iter() { + path_resolver.insert_func(name.to_owned(), *id); } let mut def_maps = BTreeMap::new(); @@ -722,20 +723,24 @@ mod test { }, ); - let func_meta = vecmap(program.into_sorted().functions, |nf| { + for nf in program.into_sorted().functions { let resolver = Resolver::new(&mut interner, &path_resolver, &def_maps, file); - let (hir_func, func_meta, resolver_errors) = resolver.resolve_function(nf, main_id); - assert_eq!(resolver_errors, vec![]); - (hir_func, func_meta) - }); - for ((hir_func, meta), func_id) in func_meta.into_iter().zip(func_ids.clone()) { - interner.update_fn(func_id, hir_func); - interner.push_fn_meta(meta, func_id); + let function_id = *func_ids.get(nf.name()).unwrap(); + let (hir_func, func_meta, resolver_errors) = resolver.resolve_function(nf, function_id); + + interner.push_fn_meta(func_meta, function_id); + interner.update_fn(function_id, hir_func); + assert_eq!(resolver_errors, vec![]); } // Type check section - let errors = super::type_check_func(&mut interner, func_ids.first().cloned().unwrap()); + let mut errors = Vec::new(); + + for function in func_ids.values() { + errors.extend(super::type_check_func(&mut interner, *function)); + } + assert_eq!( errors.len(), expected_num_type_check_errs, @@ -744,5 +749,7 @@ mod test { errors.len(), errors ); + + (interner, main_id) } } diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index c2f6031bf6d..eb4ebf3f913 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -260,7 +260,7 @@ impl HirBlockExpression { } /// A variable captured inside a closure -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct HirCapturedVar { pub ident: HirIdent, @@ -274,7 +274,7 @@ pub struct HirCapturedVar { pub transitive_capture_index: Option, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct HirLambda { pub parameters: Vec<(HirPattern, Type)>, pub return_type: Type, diff --git a/compiler/noirc_frontend/src/hir_def/function.rs b/compiler/noirc_frontend/src/hir_def/function.rs index a3bbc9445a8..97f4b6a1616 100644 --- a/compiler/noirc_frontend/src/hir_def/function.rs +++ b/compiler/noirc_frontend/src/hir_def/function.rs @@ -24,8 +24,8 @@ impl HirFunction { HirFunction(expr_id) } - pub const fn as_expr(&self) -> &ExprId { - &self.0 + pub const fn as_expr(&self) -> ExprId { + self.0 } pub fn block(&self, interner: &NodeInterner) -> HirBlockExpression { diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 4c9a33d3dc0..37e3651a9b2 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -61,7 +61,7 @@ pub struct HirAssignStatement { #[derive(Debug, Clone)] pub struct HirConstrainStatement(pub ExprId, pub FileId, pub Option); -#[derive(Debug, Clone, Hash)] +#[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum HirPattern { Identifier(HirIdent), Mutable(Box, Location), diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 4e779244d30..2cccc18fb09 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -288,7 +288,7 @@ impl<'interner> Monomorphizer<'interner> { let modifiers = self.interner.function_modifiers(&f); let name = self.interner.function_name(&f).to_owned(); - let body_expr_id = *self.interner.function(&f).as_expr(); + let body_expr_id = self.interner.function(&f).as_expr(); let body_return_type = self.interner.id_type(body_expr_id); let return_type = match meta.return_type() { Type::TraitAsType(..) => &body_return_type, diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 153c7e45d4a..5b375be8d56 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -242,6 +242,8 @@ pub struct FunctionModifiers { pub attributes: Attributes, pub is_unconstrained: bool, + + pub is_comptime: bool, } impl FunctionModifiers { @@ -254,6 +256,7 @@ impl FunctionModifiers { visibility: ItemVisibility::Public, attributes: Attributes::empty(), is_unconstrained: false, + is_comptime: false, } } } @@ -759,6 +762,7 @@ impl NodeInterner { visibility: function.visibility, attributes: function.attributes.clone(), is_unconstrained: function.is_unconstrained, + is_comptime: function.is_comptime, }; self.push_function_definition(id, modifiers, module, location) } diff --git a/compiler/noirc_frontend/src/parser/parser/function.rs b/compiler/noirc_frontend/src/parser/parser/function.rs index 06e1a958eb1..18f17065038 100644 --- a/compiler/noirc_frontend/src/parser/parser/function.rs +++ b/compiler/noirc_frontend/src/parser/parser/function.rs @@ -36,6 +36,7 @@ pub(super) fn function_definition(allow_self: bool) -> impl NoirParser