Skip to content

Commit

Permalink
if/begin fix plus optimizations (#189)
Browse files Browse the repository at this point in the history
* `if`/`begin` fix plus optimizations

* chore: Add eval tests

* One less column

---------

Co-authored-by: wwared <[email protected]>
  • Loading branch information
gabriel-barrett and wwared authored Aug 6, 2024
1 parent 268b432 commit a772954
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 56 deletions.
158 changes: 102 additions & 56 deletions src/lurk/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub fn build_lurk_toplevel() -> (Toplevel<BabyBear, LurkChip>, ZStore<BabyBear,
eval_unop(&builtins),
eval_binop_num(&builtins),
eval_binop_misc(&builtins),
eval_begin(&builtins),
equal(&builtins),
equal_inner(),
car_cdr(),
Expand Down Expand Up @@ -550,7 +551,7 @@ pub fn eval<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
match head_tag {
Tag::Builtin => {
match head [|sym| builtins.index(sym).to_field()] {
"let", "letrec", "lambda", "+", "-", "*", "/", "%", "<", ">", "<=", ">=", "cons", "strcons" => {
"let", "letrec", "lambda", "cons", "strcons" => {
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
return (err_tag, invalid_form)
Expand Down Expand Up @@ -588,16 +589,30 @@ pub fn eval<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
let res = store(fst_tag, fst, snd_tag, snd, env);
return (res_tag, res)
}
"+", "-", "*", "/", "%", "<", ">", "<=", ">=" => {
let (res_tag, res) = call(eval_binop_num, head, fst_tag, fst, snd_tag, snd, env);
return (res_tag, res)
}
"cons", "strcons" => {
let (res_tag, res) = call(eval_binop_misc, head, fst_tag, fst, snd_tag, snd, env);
return (res_tag, res)
}
}
}
"+", "-", "*", "/", "%", "<", ">", "<=", ">=" => {
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
return (err_tag, invalid_form)
}
let (fst_tag, fst, rest_tag, rest) = load(rest);
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
return (err_tag, invalid_form)
}
let (snd_tag, snd, rest_tag, _rest) = load(rest);
let rest_not_nil = sub(rest_tag, nil_tag);
if rest_not_nil {
return (err_tag, invalid_form)
}
let (res_tag, res) = call(eval_binop_num, head, fst_tag, fst, snd_tag, snd, env);
return (res_tag, res)
}
"eval" => {
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
Expand Down Expand Up @@ -659,48 +674,27 @@ pub fn eval<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
return (expr_tag, expr)
}
"begin" => {
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
return (err_tag, invalid_form)
}
let (expr_tag, expr, rest_tag, rest) = load(rest);
let (val_tag, val) = call(eval, expr_tag, expr, env);
match val_tag {
Tag::Err => {
return (val_tag, val)
}
};
match rest_tag {
Tag::Nil => {
return (val_tag, val)
}
Tag::Cons => {
let smaller_expr = store(head_tag, head, rest_tag, rest);
let (val_tag, val) = call(eval, cons_tag, smaller_expr, env);
return (val_tag, val)
}
};
return (err_tag, invalid_form)
let (expr_tag, expr) = call(eval_begin, rest_tag, rest, env);
return (expr_tag, expr)
}
"empty-env" => {
"current-env", "empty-env" => {
let rest_not_nil = sub(rest_tag, nil_tag);
if rest_not_nil {
return (err_tag, invalid_form)
}
let env_tag = Tag::Env;
let env = 0;
return (env_tag, env)
}
"current-env" => {
let rest_not_nil = sub(rest_tag, nil_tag);
if rest_not_nil {
return (err_tag, invalid_form)
match head [|sym| builtins.index(sym).to_field()] {
"current-env" => {
return (env_tag, env)
}
"empty-env" => {
let env = 0;
return (env_tag, env)
}
}
let env_tag = Tag::Env;
return (env_tag, env)
}
"if" => {
// An if expression is a list of 4 elements
// An if expression is a list of 3 or 4 elements
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
return (err_tag, invalid_form)
Expand All @@ -711,28 +705,38 @@ pub fn eval<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
return (err_tag, invalid_form)
}
let (t_branch_tag, t_branch, rest_tag, rest) = load(rest);
let rest_not_cons = sub(rest_tag, cons_tag);
if rest_not_cons {
return (err_tag, invalid_form)
}
let (f_branch_tag, f_branch, rest_tag, _rest) = load(rest);
let rest_not_nil = sub(rest_tag, nil_tag);
if rest_not_nil {
return (err_tag, invalid_form)
}

let (val_tag, val) = call(eval, expr_tag, expr, env);
match val_tag {
match rest_tag {
Tag::Nil => {
let (res_tag, res) = call(eval, f_branch_tag, f_branch, env);
let (val_tag, val) = call(eval, expr_tag, expr, env);
match val_tag {
Tag::Nil, Tag::Err => {
return (val_tag, val)
}
};
let (res_tag, res) = call(eval, t_branch_tag, t_branch, env);
return (res_tag, res)
}
Tag::Err => {
return (val_tag, val)
Tag::Cons => {
let (f_branch_tag, f_branch, rest_tag, _rest) = load(rest);
let rest_not_nil = sub(rest_tag, nil_tag);
if rest_not_nil {
return (err_tag, invalid_form)
}
let (val_tag, val) = call(eval, expr_tag, expr, env);
match val_tag {
Tag::Nil => {
let (res_tag, res) = call(eval, f_branch_tag, f_branch, env);
return (res_tag, res)
}
Tag::Err => {
return (val_tag, val)
}
};
let (res_tag, res) = call(eval, t_branch_tag, t_branch, env);
return (res_tag, res)
}
};
let (res_tag, res) = call(eval, t_branch_tag, t_branch, env);
return (res_tag, res)
return (err_tag, invalid_form)
}
"eq", "=" => {
let res: [2] = call(equal, rest_tag, rest, env);
Expand Down Expand Up @@ -940,6 +944,46 @@ pub fn equal_inner<F: AbstractField + Ord>() -> FuncE<F> {
)
}

pub fn eval_begin<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
func!(
fn eval_begin(rest_tag, rest, env): [2] {
let err_tag = Tag::Err;
let cons_tag = Tag::Cons;
let invalid_form = EvalErr::InvalidForm;
match rest_tag {
Tag::Cons => {
let (expr_tag, expr, rest_tag, rest) = load(rest);
let (val_tag, val) = call(eval, expr_tag, expr, env);
match val_tag {
Tag::Err => {
return (val_tag, val)
}
};
match rest_tag {
Tag::Nil => {
return (val_tag, val)
}
Tag::Cons => {
let builtin_tag = Tag::Builtin;
let begin = builtins.index("begin");
let smaller_expr = store(builtin_tag, begin, rest_tag, rest);
let (val_tag, val) = call(eval, cons_tag, smaller_expr, env);
return (val_tag, val)
}
};
return (err_tag, invalid_form)
}
Tag::Nil => {
let nil_tag = Tag::Nil;
let nil = 0;
return (nil_tag, nil)
}
};
return (err_tag, invalid_form)
}
)
}

pub fn eval_binop_num<F: AbstractField + Ord>(builtins: &BuiltinMemo<'_, F>) -> FuncE<F> {
func!(
fn eval_binop_num(head, exp1_tag, exp1, exp2_tag, exp2, env): [2] {
Expand Down Expand Up @@ -1529,6 +1573,7 @@ mod test {
let eval_unop = FuncChip::from_name("eval_unop", toplevel);
let eval_binop_num = FuncChip::from_name("eval_binop_num", toplevel);
let eval_binop_misc = FuncChip::from_name("eval_binop_misc", toplevel);
let eval_begin = FuncChip::from_name("eval_begin", toplevel);
let eval_let = FuncChip::from_name("eval_let", toplevel);
let eval_letrec = FuncChip::from_name("eval_letrec", toplevel);
let equal = FuncChip::from_name("equal", toplevel);
Expand All @@ -1554,12 +1599,13 @@ mod test {
expected.assert_eq(&computed.to_string());
};
expect_eq(lurk_main.width(), expect!["52"]);
expect_eq(eval.width(), expect!["101"]);
expect_eq(eval.width(), expect!["94"]);
expect_eq(eval_comm_unop.width(), expect!["71"]);
expect_eq(eval_hide.width(), expect!["76"]);
expect_eq(eval_unop.width(), expect!["33"]);
expect_eq(eval_binop_num.width(), expect!["54"]);
expect_eq(eval_binop_misc.width(), expect!["32"]);
expect_eq(eval_begin.width(), expect!["34"]);
expect_eq(eval_let.width(), expect!["54"]);
expect_eq(eval_letrec.width(), expect!["58"]);
expect_eq(equal.width(), expect!["44"]);
Expand Down
3 changes: 3 additions & 0 deletions src/lurk/eval_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ test!(test_app3, "((lambda (x) (lambda (y) x)) 1 2)", |_| {
// builtins
test!(test_if, "(if 1 1 0)", |_| uint(1));
test!(test_if2, "(if nil 1 0)", |_| uint(0));
test!(test_if3, "(if 1 1)", |_| uint(1));
test!(test_if4, "(if nil 1)", |z| z.intern_nil());
test!(test_let, "(let ((x 0) (y 1)) x)", |_| uint(0));
test!(test_let2, "(let ((x 0) (y 1)) y)", |_| uint(1));
test!(test_add, "(+ 1 2)", |_| uint(3));
Expand All @@ -215,6 +217,7 @@ test!(test_div, "(/ 6 3)", |_| uint(2));
test!(test_arith, "(+ (* 2 2) (* 2 3))", |_| uint(10));
test!(test_num_eq, "(= 0 1)", |z| z.intern_nil());
test!(test_num_eq2, "(= 1 1)", |z| z.intern_symbol(&lurk_sym("t")));
test!(test_begin_empty, "(begin)", |z| z.intern_nil());
test!(test_begin, "(begin 1 2 3)", |_| uint(3));
test!(test_quote, "'(x 1 :foo)", |z| {
let x = z.intern_symbol(&user_sym("x"));
Expand Down

0 comments on commit a772954

Please sign in to comment.