From d62f22099e00247626c4470810eaf8e1faef8264 Mon Sep 17 00:00:00 2001 From: Matt Keenan Date: Thu, 2 May 2024 09:12:37 -0400 Subject: [PATCH] Get tests working --- src/haz3lcore/dynamics/DHExp.re | 127 --------- src/haz3lcore/dynamics/Elaborator.re | 6 +- src/haz3lcore/dynamics/FilterMatcher.re | 2 +- src/haz3lcore/lang/term/TPat.re | 6 - src/haz3lcore/lang/term/Typ.re | 72 +---- src/haz3lcore/statics/TermBase.re | 319 +++++++++++++++++++++ test/Test_Elaboration.re | 359 +++++++++--------------- 7 files changed, 456 insertions(+), 435 deletions(-) diff --git a/src/haz3lcore/dynamics/DHExp.re b/src/haz3lcore/dynamics/DHExp.re index f2829ec89a..d39b9c9874 100644 --- a/src/haz3lcore/dynamics/DHExp.re +++ b/src/haz3lcore/dynamics/DHExp.re @@ -85,133 +85,6 @@ let rec strip_casts = _, ); -let rec fast_equal = - ({term: d1, _} as d1exp: t, {term: d2, _} as d2exp: t): bool => { - switch (d1, d2) { - /* Primitive forms: regular structural equality */ - | (Var(_), _) - /* TODO: Not sure if this is right... */ - | (Bool(_), _) - | (Int(_), _) - | (Float(_), _) - | (Deferral(_), _) - | (Constructor(_), _) => d1 == d2 - | (String(s1), String(s2)) => String.equal(s1, s2) - | (String(_), _) => false - - | (Parens(x), _) => fast_equal(x, d2exp) - | (_, Parens(x)) => fast_equal(d1exp, x) - - /* Non-hole forms: recurse */ - | (Test(d1), Test(d2)) => fast_equal(d1, d2) - | (Seq(d11, d21), Seq(d12, d22)) => - fast_equal(d11, d12) && fast_equal(d21, d22) - | (Filter(f1, d1), Filter(f2, d2)) => - filter_fast_equal(f1, f2) && fast_equal(d1, d2) - | (Let(dp1, d11, d21), Let(dp2, d12, d22)) => - dp1 == dp2 && fast_equal(d11, d12) && fast_equal(d21, d22) - | (FixF(f1, d1, sigma1), FixF(f2, d2, sigma2)) => - f1 == f2 - && fast_equal(d1, d2) - && Option.equal(ClosureEnvironment.id_equal, sigma1, sigma2) - | (Fun(dp1, d1, None, s1), Fun(dp2, d2, None, s2)) => - dp1 == dp2 && fast_equal(d1, d2) && s1 == s2 - | (Fun(dp1, d1, Some(env1), s1), Fun(dp2, d2, Some(env2), s2)) => - dp1 == dp2 - && fast_equal(d1, d2) - && ClosureEnvironment.id_equal(env1, env2) - && s1 == s2 - | (TypFun(_tpat1, d1, s1), TypFun(_tpat2, d2, s2)) => - _tpat1 == _tpat2 && fast_equal(d1, d2) && s1 == s2 - | (TypAp(d1, ty1), TypAp(d2, ty2)) => fast_equal(d1, d2) && ty1 == ty2 - | (Ap(dir1, d11, d21), Ap(dir2, d12, d22)) => - dir1 == dir2 && fast_equal(d11, d12) && fast_equal(d21, d22) - | (DeferredAp(d1, ds1), DeferredAp(d2, ds2)) => - fast_equal(d1, d2) - && List.length(ds1) == List.length(ds2) - && List.for_all2(fast_equal, ds1, ds2) - | (Cons(d11, d21), Cons(d12, d22)) => - fast_equal(d11, d12) && fast_equal(d21, d22) - | (ListConcat(d11, d21), ListConcat(d12, d22)) => - fast_equal(d11, d12) && fast_equal(d21, d22) - | (Tuple(ds1), Tuple(ds2)) => - List.length(ds1) == List.length(ds2) - && List.for_all2(fast_equal, ds1, ds2) - | (BuiltinFun(f1), BuiltinFun(f2)) => f1 == f2 - | (ListLit(ds1), ListLit(ds2)) => - List.length(ds1) == List.length(ds2) - && List.for_all2(fast_equal, ds1, ds2) - | (UnOp(op1, d1), UnOp(op2, d2)) => op1 == op2 && fast_equal(d1, d2) - | (BinOp(op1, d11, d21), BinOp(op2, d12, d22)) => - op1 == op2 && fast_equal(d11, d12) && fast_equal(d21, d22) - | (TyAlias(tp1, ut1, d1), TyAlias(tp2, ut2, d2)) => - tp1 == tp2 && ut1 == ut2 && fast_equal(d1, d2) - | (Cast(d1, ty11, ty21), Cast(d2, ty12, ty22)) - | (FailedCast(d1, ty11, ty21), FailedCast(d2, ty12, ty22)) => - fast_equal(d1, d2) && ty11 == ty12 && ty21 == ty22 - | (DynamicErrorHole(d1, reason1), DynamicErrorHole(d2, reason2)) => - fast_equal(d1, d2) && reason1 == reason2 - | (Match(s1, rs1), Match(s2, rs2)) => - fast_equal(s1, s2) - && List.length(rs2) == List.length(rs2) - && List.for_all2( - ((k1, v1), (k2, v2)) => k1 == k2 && fast_equal(v1, v2), - rs1, - rs2, - ) - | (If(d11, d12, d13), If(d21, d22, d23)) => - fast_equal(d11, d21) && fast_equal(d12, d22) && fast_equal(d13, d23) - /* We can group these all into a `_ => false` clause; separating - these so that we get exhaustiveness checking. */ - | (Seq(_), _) - | (Filter(_), _) - | (Let(_), _) - | (FixF(_), _) - | (Fun(_), _) - | (Test(_), _) - | (Ap(_), _) - | (BuiltinFun(_), _) - | (Cons(_), _) - | (ListConcat(_), _) - | (ListLit(_), _) - | (Tuple(_), _) - | (UnOp(_), _) - | (BinOp(_), _) - | (Cast(_), _) - | (FailedCast(_), _) - | (TyAlias(_), _) - | (TypFun(_), _) - | (TypAp(_), _) - | (DynamicErrorHole(_), _) - | (DeferredAp(_), _) - | (If(_), _) - | (Match(_), _) => false - - /* Hole forms: when checking environments, only check that - environment ID's are equal, don't check structural equality. - - (This resolves a performance issue with many nested holes.) */ - | (EmptyHole, EmptyHole) => true - | (MultiHole(_), MultiHole(_)) => rep_id(d1exp) == rep_id(d2exp) - | (Invalid(text1), Invalid(text2)) => text1 == text2 - | (Closure(sigma1, d1), Closure(sigma2, d2)) => - ClosureEnvironment.id_equal(sigma1, sigma2) && fast_equal(d1, d2) - | (EmptyHole, _) - | (MultiHole(_), _) - | (Invalid(_), _) - | (Closure(_), _) => false - }; -} -and filter_fast_equal = (f1, f2) => { - switch (f1, f2) { - | (Filter(f1), Filter(f2)) => - fast_equal(f1.pat, f2.pat) && f1.act == f2.act - | (Residue(idx1, act1), Residue(idx2, act2)) => - idx1 == idx2 && act1 == act2 - | _ => false - }; -}; - let assign_name_if_none = (t, name) => { let (term, rewrap) = unwrap(t); switch (term) { diff --git a/src/haz3lcore/dynamics/Elaborator.re b/src/haz3lcore/dynamics/Elaborator.re index 9827fc4acd..30d4446e0d 100644 --- a/src/haz3lcore/dynamics/Elaborator.re +++ b/src/haz3lcore/dynamics/Elaborator.re @@ -265,10 +265,9 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => { }; } ); - // TODO: is elaborated_type the right type to use here?? - if (!Statics.is_recursive(ctx, p, def, elaborated_type)) { + let (p, ty1) = elaborate_pattern(m, p); + if (!Statics.is_recursive(ctx, p, def, ty1)) { let def = add_name(Pat.get_var(p), def); - let (p, ty1) = elaborate_pattern(m, p); let (def, ty2) = elaborate(m, def); let (body, ty) = elaborate(m, body); Exp.Let(p, fresh_cast(def, ty2, ty1), body) @@ -278,7 +277,6 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => { // TODO: Add names to mutually recursive functions // TODO: Don't add fixpoint if there already is one let def = add_name(Option.map(s => s ++ "+", Pat.get_var(p)), def); - let (p, ty1) = elaborate_pattern(m, p); let (def, ty2) = elaborate(m, def); let (body, ty) = elaborate(m, body); let fixf = FixF(p, fresh_cast(def, ty2, ty1), None) |> DHExp.fresh; diff --git a/src/haz3lcore/dynamics/FilterMatcher.re b/src/haz3lcore/dynamics/FilterMatcher.re index 1787b271a7..22d2141242 100644 --- a/src/haz3lcore/dynamics/FilterMatcher.re +++ b/src/haz3lcore/dynamics/FilterMatcher.re @@ -41,7 +41,7 @@ let rec matches_exp = | (Deferral(_), _) => false | (Filter(df, dd), Filter(ff, fd)) => - DHExp.filter_fast_equal(df, ff) && matches_exp(env, dd, fd) + TermBase.StepperFilterKind.fast_equal(df, ff) && matches_exp(env, dd, fd) | (Filter(_), _) => false | (Bool(dv), Bool(fv)) => dv == fv diff --git a/src/haz3lcore/lang/term/TPat.re b/src/haz3lcore/lang/term/TPat.re index c3e0671b67..3dade36b54 100644 --- a/src/haz3lcore/lang/term/TPat.re +++ b/src/haz3lcore/lang/term/TPat.re @@ -29,9 +29,3 @@ let show_cls: cls => string = | MultiHole => "Broken type alias" | EmptyHole => "Empty type alias hole" | Var => "Type alias"; - -let tyvar_of_utpat = ({term, _}: t) => - switch (term) { - | Var(x) => Some(x) - | _ => None - }; diff --git a/src/haz3lcore/lang/term/Typ.re b/src/haz3lcore/lang/term/Typ.re index 74940f6db1..2f58e9422b 100644 --- a/src/haz3lcore/lang/term/Typ.re +++ b/src/haz3lcore/lang/term/Typ.re @@ -244,37 +244,6 @@ let fresh_var = (var_name: string) => { var_name ++ "_α" ++ string_of_int(x); }; -let rec subst = (s: t, x: TPat.t, ty: t) => { - switch (TPat.tyvar_of_utpat(x)) { - | Some(str) => - let (term, rewrap) = unwrap(ty); - switch (term) { - | Int => Int |> rewrap - | Float => Float |> rewrap - | Bool => Bool |> rewrap - | String => String |> rewrap - | Unknown(prov) => Unknown(prov) |> rewrap - | Arrow(ty1, ty2) => - Arrow(subst(s, x, ty1), subst(s, x, ty2)) |> rewrap - | Prod(tys) => Prod(List.map(subst(s, x), tys)) |> rewrap - | Sum(sm) => - Sum(ConstructorMap.map(Option.map(subst(s, x)), sm)) |> rewrap - | Forall(tp2, ty) - when TPat.tyvar_of_utpat(x) == TPat.tyvar_of_utpat(tp2) => - Forall(tp2, ty) |> rewrap - | Forall(tp2, ty) => Forall(tp2, subst(s, x, ty)) |> rewrap - | Rec(tp2, ty) when TPat.tyvar_of_utpat(x) == TPat.tyvar_of_utpat(tp2) => - Rec(tp2, ty) |> rewrap - | Rec(tp2, ty) => Rec(tp2, subst(s, x, ty)) |> rewrap - | List(ty) => List(subst(s, x, ty)) |> rewrap - | Var(y) => str == y ? s : Var(y) |> rewrap - | Parens(ty) => Parens(subst(s, x, ty)) |> rewrap - | Ap(t1, t2) => Ap(subst(s, x, t1), subst(s, x, t2)) |> rewrap - }; - | None => ty - }; -}; - let unroll = (ty: t): t => switch (term_of(ty)) { | Rec(tp, ty_body) => subst(ty, tp, ty_body) @@ -283,46 +252,7 @@ let unroll = (ty: t): t => /* Type Equality: This coincides with alpha equivalence for normalized types. Other types may be equivalent but this will not detect so if they are not normalized. */ -let rec eq_internal = (n: int, t1: t, t2: t) => { - switch (term_of(t1), term_of(t2)) { - | (Parens(t1), _) => eq_internal(n, t1, t2) - | (_, Parens(t2)) => eq_internal(n, t1, t2) - | (Rec(x1, t1), Rec(x2, t2)) - | (Forall(x1, t1), Forall(x2, t2)) => - let alpha_subst = subst(Var("=" ++ string_of_int(n)) |> mk_fast); - eq_internal(n + 1, alpha_subst(x1, t1), alpha_subst(x2, t2)); - | (Rec(_), _) => false - | (Forall(_), _) => false - | (Int, Int) => true - | (Int, _) => false - | (Float, Float) => true - | (Float, _) => false - | (Bool, Bool) => true - | (Bool, _) => false - | (String, String) => true - | (String, _) => false - | (Ap(t1, t2), Ap(t1', t2')) => - eq_internal(n, t1, t1') && eq_internal(n, t2, t2') - | (Ap(_), _) => false - | (Unknown(_), Unknown(_)) => true - | (Unknown(_), _) => false - | (Arrow(t1, t2), Arrow(t1', t2')) => - eq_internal(n, t1, t1') && eq_internal(n, t2, t2') - | (Arrow(_), _) => false - | (Prod(tys1), Prod(tys2)) => List.equal(eq_internal(n), tys1, tys2) - | (Prod(_), _) => false - | (List(t1), List(t2)) => eq_internal(n, t1, t2) - | (List(_), _) => false - | (Sum(sm1), Sum(sm2)) => - /* Does not normalize the types. */ - ConstructorMap.equal(eq_internal(n), sm1, sm2) - | (Sum(_), _) => false - | (Var(n1), Var(n2)) => n1 == n2 - | (Var(_), _) => false - }; -}; - -let eq = (t1: t, t2: t): bool => eq_internal(0, t1, t2); +let eq = (t1: t, t2: t): bool => fast_equal(t1, t2); /* Lattice join on types. This is a LUB join in the hazel2 sense in that any type dominates Unknown. The optional diff --git a/src/haz3lcore/statics/TermBase.re b/src/haz3lcore/statics/TermBase.re index 1d9c762f75..4fb9476c23 100644 --- a/src/haz3lcore/statics/TermBase.re +++ b/src/haz3lcore/statics/TermBase.re @@ -27,6 +27,8 @@ module rec Any: { t ) => t; + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type t = @@ -67,6 +69,24 @@ module rec Any: { }; x |> f_any(rec_call); }; + + let fast_equal = (x, y) => + switch (x, y) { + | (Exp(x), Exp(y)) => Exp.fast_equal(x, y) + | (Pat(x), Pat(y)) => Pat.fast_equal(x, y) + | (Typ(x), Typ(y)) => Typ.fast_equal(x, y) + | (TPat(x), TPat(y)) => TPat.fast_equal(x, y) + | (Rul(x), Rul(y)) => Rul.fast_equal(x, y) + | (Nul (), Nul ()) => true + | (Any (), Any ()) => true + | (Exp(_), _) + | (Pat(_), _) + | (Typ(_), _) + | (TPat(_), _) + | (Rul(_), _) + | (Nul (), _) + | (Any (), _) => false + }; } and Exp: { [@deriving (show({with_path: false}), sexp, yojson)] @@ -132,6 +152,8 @@ and Exp: { t ) => t; + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type deferral_position = @@ -268,6 +290,118 @@ and Exp: { }; x |> f_exp(rec_call); }; + + let rec fast_equal = (e1, e2) => + switch (e1 |> IdTagged.term_of, e2 |> IdTagged.term_of) { + | (DynamicErrorHole(x, _), _) + | (Parens(x), _) => fast_equal(x, e2) + | (_, DynamicErrorHole(x, _)) + | (_, Parens(x)) => fast_equal(e1, x) + | (EmptyHole, EmptyHole) => true + | (Invalid(s1), Invalid(s2)) => s1 == s2 + | (MultiHole(xs), MultiHole(ys)) when List.length(xs) == List.length(ys) => + List.equal(Any.fast_equal, xs, ys) + | (FailedCast(e1, t1, t2), FailedCast(e2, t3, t4)) => + Exp.fast_equal(e1, e2) + && Typ.fast_equal(t1, t3) + && Typ.fast_equal(t2, t4) + | (Deferral(d1), Deferral(d2)) => d1 == d2 + | (Bool(b1), Bool(b2)) => b1 == b2 + | (Int(i1), Int(i2)) => i1 == i2 + | (Float(f1), Float(f2)) => f1 == f2 + | (String(s1), String(s2)) => s1 == s2 + | (ListLit(xs), ListLit(ys)) => + List.length(xs) == List.length(ys) && List.equal(fast_equal, xs, ys) + | (Constructor(c1), Constructor(c2)) => c1 == c2 + | (Fun(p1, e1, env1, _), Fun(p2, e2, env2, _)) => + Pat.fast_equal(p1, p2) + && fast_equal(e1, e2) + && Option.equal(ClosureEnvironment.id_equal, env1, env2) + | (TypFun(tp1, e1, _), TypFun(tp2, e2, _)) => + TPat.fast_equal(tp1, tp2) && fast_equal(e1, e2) + | (Tuple(xs), Tuple(ys)) => + List.length(xs) == List.length(ys) && List.equal(fast_equal, xs, ys) + | (Var(v1), Var(v2)) => v1 == v2 + | (Let(p1, e1, e2), Let(p2, e3, e4)) => + Pat.fast_equal(p1, p2) && fast_equal(e1, e3) && fast_equal(e2, e4) + | (FixF(p1, e1, c1), FixF(p2, e2, c2)) => + Pat.fast_equal(p1, p2) + && fast_equal(e1, e2) + && Option.equal(ClosureEnvironment.id_equal, c1, c2) + | (TyAlias(tp1, t1, e1), TyAlias(tp2, t2, e2)) => + TPat.fast_equal(tp1, tp2) + && Typ.fast_equal(t1, t2) + && fast_equal(e1, e2) + | (Ap(d1, e1, e2), Ap(d2, e3, e4)) => + d1 == d2 && fast_equal(e1, e3) && fast_equal(e2, e4) + | (TypAp(e1, t1), TypAp(e2, t2)) => + fast_equal(e1, e2) && Typ.fast_equal(t1, t2) + | (DeferredAp(e1, es1), DeferredAp(e2, es2)) => + List.length(es1) == List.length(es2) + && fast_equal(e1, e2) + && List.equal(fast_equal, es1, es2) + | (If(e1, e2, e3), If(e4, e5, e6)) => + fast_equal(e1, e4) && fast_equal(e2, e5) && fast_equal(e3, e6) + | (Seq(e1, e2), Seq(e3, e4)) => + fast_equal(e1, e3) && fast_equal(e2, e4) + | (Test(e1), Test(e2)) => fast_equal(e1, e2) + | (Filter(f1, e1), Filter(f2, e2)) => + StepperFilterKind.fast_equal(f1, f2) && fast_equal(e1, e2) + | (Closure(c1, e1), Closure(c2, e2)) => + ClosureEnvironment.id_equal(c1, c2) && fast_equal(e1, e2) + | (Cons(e1, e2), Cons(e3, e4)) => + fast_equal(e1, e3) && fast_equal(e2, e4) + | (ListConcat(e1, e2), ListConcat(e3, e4)) => + fast_equal(e1, e3) && fast_equal(e2, e4) + | (UnOp(o1, e1), UnOp(o2, e2)) => o1 == o2 && fast_equal(e1, e2) + | (BinOp(o1, e1, e2), BinOp(o2, e3, e4)) => + o1 == o2 && fast_equal(e1, e3) && fast_equal(e2, e4) + | (BuiltinFun(f1), BuiltinFun(f2)) => f1 == f2 + | (Match(e1, rls1), Match(e2, rls2)) => + fast_equal(e1, e2) + && List.length(rls1) == List.length(rls2) + && List.for_all2( + ((p1, e1), (p2, e2)) => + Pat.fast_equal(p1, p2) && fast_equal(e1, e2), + rls1, + rls2, + ) + | (Cast(e1, t1, t2), Cast(e2, t3, t4)) => + fast_equal(e1, e2) && Typ.fast_equal(t1, t3) && Typ.fast_equal(t2, t4) + | (Invalid(_), _) + | (FailedCast(_), _) + | (Deferral(_), _) + | (Bool(_), _) + | (Int(_), _) + | (Float(_), _) + | (String(_), _) + | (ListLit(_), _) + | (Constructor(_), _) + | (Fun(_), _) + | (TypFun(_), _) + | (Tuple(_), _) + | (Var(_), _) + | (Let(_), _) + | (FixF(_), _) + | (TyAlias(_), _) + | (Ap(_), _) + | (TypAp(_), _) + | (DeferredAp(_), _) + | (If(_), _) + | (Seq(_), _) + | (Test(_), _) + | (Filter(_), _) + | (Closure(_), _) + | (Cons(_), _) + | (ListConcat(_), _) + | (UnOp(_), _) + | (BinOp(_), _) + | (BuiltinFun(_), _) + | (Match(_), _) + | (Cast(_), _) + | (MultiHole(_), _) + | (EmptyHole, _) => false + }; } and Pat: { [@deriving (show({with_path: false}), sexp, yojson)] @@ -301,6 +435,8 @@ and Pat: { t ) => t; + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type term = @@ -363,6 +499,48 @@ and Pat: { }; x |> f_pat(rec_call); }; + + let rec fast_equal = (p1, p2) => + switch (p1 |> IdTagged.term_of, p2 |> IdTagged.term_of) { + | (Parens(x), _) => fast_equal(x, p2) + | (_, Parens(x)) => fast_equal(p1, x) + | (EmptyHole, EmptyHole) => true + | (MultiHole(xs), MultiHole(ys)) => + List.length(xs) == List.length(ys) + && List.equal(Any.fast_equal, xs, ys) + | (Invalid(s1), Invalid(s2)) => s1 == s2 + | (Wild, Wild) => true + | (Bool(b1), Bool(b2)) => b1 == b2 + | (Int(i1), Int(i2)) => i1 == i2 + | (Float(f1), Float(f2)) => f1 == f2 + | (String(s1), String(s2)) => s1 == s2 + | (Constructor(c1), Constructor(c2)) => c1 == c2 + | (Var(v1), Var(v2)) => v1 == v2 + | (ListLit(xs), ListLit(ys)) => + List.length(xs) == List.length(ys) && List.equal(fast_equal, xs, ys) + | (Cons(x1, y1), Cons(x2, y2)) => + fast_equal(x1, x2) && fast_equal(y1, y2) + | (Tuple(xs), Tuple(ys)) => + List.length(xs) == List.length(ys) && List.equal(fast_equal, xs, ys) + | (Ap(x1, y1), Ap(x2, y2)) => fast_equal(x1, x2) && fast_equal(y1, y2) + | (Cast(x1, t1, t2), Cast(x2, u1, u2)) => + fast_equal(x1, x2) && Typ.fast_equal(t1, u1) && Typ.fast_equal(t2, u2) + | (EmptyHole, _) + | (MultiHole(_), _) + | (Invalid(_), _) + | (Wild, _) + | (Bool(_), _) + | (Int(_), _) + | (Float(_), _) + | (String(_), _) + | (ListLit(_), _) + | (Constructor(_), _) + | (Cons(_), _) + | (Var(_), _) + | (Tuple(_), _) + | (Ap(_), _) + | (Cast(_), _) => false + }; } and Typ: { [@deriving (show({with_path: false}), sexp, yojson)] @@ -412,6 +590,10 @@ and Typ: { t ) => t; + + let subst: (t, TPat.t, t) => t; + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type type_hole = @@ -502,6 +684,86 @@ and Typ: { }; x |> f_typ(rec_call); }; + + let rec subst = (s: t, x: TPat.t, ty: t) => { + switch (TPat.tyvar_of_utpat(x)) { + | Some(str) => + let (term, rewrap) = IdTagged.unwrap(ty); + switch (term) { + | Int => Int |> rewrap + | Float => Float |> rewrap + | Bool => Bool |> rewrap + | String => String |> rewrap + | Unknown(prov) => Unknown(prov) |> rewrap + | Arrow(ty1, ty2) => + Arrow(subst(s, x, ty1), subst(s, x, ty2)) |> rewrap + | Prod(tys) => Prod(List.map(subst(s, x), tys)) |> rewrap + | Sum(sm) => + Sum(ConstructorMap.map(Option.map(subst(s, x)), sm)) |> rewrap + | Forall(tp2, ty) + when TPat.tyvar_of_utpat(x) == TPat.tyvar_of_utpat(tp2) => + Forall(tp2, ty) |> rewrap + | Forall(tp2, ty) => Forall(tp2, subst(s, x, ty)) |> rewrap + | Rec(tp2, ty) when TPat.tyvar_of_utpat(x) == TPat.tyvar_of_utpat(tp2) => + Rec(tp2, ty) |> rewrap + | Rec(tp2, ty) => Rec(tp2, subst(s, x, ty)) |> rewrap + | List(ty) => List(subst(s, x, ty)) |> rewrap + | Var(y) => str == y ? s : Var(y) |> rewrap + | Parens(ty) => Parens(subst(s, x, ty)) |> rewrap + | Ap(t1, t2) => Ap(subst(s, x, t1), subst(s, x, t2)) |> rewrap + }; + | None => ty + }; + }; + + /* Type Equality: This coincides with alpha equivalence for normalized types. + Other types may be equivalent but this will not detect so if they are not normalized. */ + + let rec eq_internal = (n: int, t1: t, t2: t) => { + switch (IdTagged.term_of(t1), IdTagged.term_of(t2)) { + | (Parens(t1), _) => eq_internal(n, t1, t2) + | (_, Parens(t2)) => eq_internal(n, t1, t2) + | (Rec(x1, t1), Rec(x2, t2)) + | (Forall(x1, t1), Forall(x2, t2)) => + let alpha_subst = + subst({ + term: Var("=" ++ string_of_int(n)), + copied: false, + ids: [Id.invalid], + }); + eq_internal(n + 1, alpha_subst(x1, t1), alpha_subst(x2, t2)); + | (Rec(_), _) => false + | (Forall(_), _) => false + | (Int, Int) => true + | (Int, _) => false + | (Float, Float) => true + | (Float, _) => false + | (Bool, Bool) => true + | (Bool, _) => false + | (String, String) => true + | (String, _) => false + | (Ap(t1, t2), Ap(t1', t2')) => + eq_internal(n, t1, t1') && eq_internal(n, t2, t2') + | (Ap(_), _) => false + | (Unknown(_), Unknown(_)) => true + | (Unknown(_), _) => false + | (Arrow(t1, t2), Arrow(t1', t2')) => + eq_internal(n, t1, t1') && eq_internal(n, t2, t2') + | (Arrow(_), _) => false + | (Prod(tys1), Prod(tys2)) => List.equal(eq_internal(n), tys1, tys2) + | (Prod(_), _) => false + | (List(t1), List(t2)) => eq_internal(n, t1, t2) + | (List(_), _) => false + | (Sum(sm1), Sum(sm2)) => + /* Does not normalize the types. */ + ConstructorMap.equal(eq_internal(n), sm1, sm2) + | (Sum(_), _) => false + | (Var(n1), Var(n2)) => n1 == n2 + | (Var(_), _) => false + }; + }; + + let fast_equal = eq_internal(0); } and TPat: { [@deriving (show({with_path: false}), sexp, yojson)] @@ -523,6 +785,10 @@ and TPat: { t ) => t; + + let tyvar_of_utpat: t => option(string); + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type term = @@ -556,6 +822,26 @@ and TPat: { }; x |> f_tpat(rec_call); }; + + let tyvar_of_utpat = ({term, _}: t) => + switch (term) { + | Var(x) => Some(x) + | _ => None + }; + + let fast_equal = (tp1: t, tp2: t) => + switch (tp1 |> IdTagged.term_of, tp2 |> IdTagged.term_of) { + | (EmptyHole, EmptyHole) => true + | (Invalid(s1), Invalid(s2)) => s1 == s2 + | (MultiHole(xs), MultiHole(ys)) => + List.length(xs) == List.length(ys) + && List.equal(Any.fast_equal, xs, ys) + | (Var(x), Var(y)) => x == y + | (EmptyHole, _) + | (Invalid(_), _) + | (MultiHole(_), _) + | (Var(_), _) => false + }; } and Rul: { [@deriving (show({with_path: false}), sexp, yojson)] @@ -576,6 +862,8 @@ and Rul: { t ) => t; + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type term = @@ -618,6 +906,26 @@ and Rul: { }; x |> f_rul(rec_call); }; + + let fast_equal = (r1: t, r2: t) => + switch (r1 |> IdTagged.term_of, r2 |> IdTagged.term_of) { + | (Invalid(s1), Invalid(s2)) => s1 == s2 + | (Hole(xs), Hole(ys)) => + List.length(xs) == List.length(ys) + && List.equal(Any.fast_equal, xs, ys) + | (Rules(e1, rls1), Rules(e2, rls2)) => + Exp.fast_equal(e1, e2) + && List.length(rls1) == List.length(rls2) + && List.for_all2( + ((p1, e1), (p2, e2)) => + Pat.fast_equal(p1, p2) && Exp.fast_equal(e1, e2), + rls1, + rls2, + ) + | (Invalid(_), _) + | (Hole(_), _) + | (Rules(_), _) => false + }; } and Environment: { @@ -771,6 +1079,8 @@ and StepperFilterKind: { t; let map: (Exp.t => Exp.t, t) => t; + + let fast_equal: (t, t) => bool; } = { [@deriving (show({with_path: false}), sexp, yojson)] type filter = { @@ -805,4 +1115,13 @@ and StepperFilterKind: { | Filter({pat: e, act}) => Filter({pat: exp_map_term(e), act}) | Residue(i, a) => Residue(i, a); }; + + let fast_equal = (f1, f2) => + switch (f1, f2) { + | (Filter({pat: e1, act: a1}), Filter({pat: e2, act: a2})) => + Exp.fast_equal(e1, e2) && a1 == a2 + | (Residue(i1, a1), Residue(i2, a2)) => i1 == i2 && a1 == a2 + | (Filter(_), _) + | (Residue(_), _) => false + }; }; diff --git a/test/Test_Elaboration.re b/test/Test_Elaboration.re index 98d1db0fdd..a2db9c0de4 100644 --- a/test/Test_Elaboration.re +++ b/test/Test_Elaboration.re @@ -1,279 +1,186 @@ open Alcotest; open Haz3lcore; -let dhexp_eq = (d1: option(DHExp.t), d2: option(DHExp.t)): bool => - switch (d1, d2) { - | (Some(d1), Some(d2)) => DHExp.fast_equal(d1, d2) - | _ => false - }; - -let dhexp_print = (d: option(DHExp.t)): string => - switch (d) { - | None => "None" - | Some(d) => DHExp.show(d) - }; - /*Create a testable type for dhexp which requires an equal function (dhexp_eq) and a print function (dhexp_print) */ -let dhexp_typ = testable(Fmt.using(dhexp_print, Fmt.string), dhexp_eq); +let dhexp_typ = testable(Fmt.using(Exp.show, Fmt.string), DHExp.fast_equal); let ids = List.init(12, _ => Id.mk()); let id_at = x => x |> List.nth(ids); let mk_map = CoreSettings.on |> Interface.Statics.mk_map; -let dhexp_of_uexp = u => Elaborator.dhexp_of_uexp(mk_map(u), u, false); +let dhexp_of_uexp = u => Elaborator.elaborate(mk_map(u), u) |> fst; let alco_check = dhexp_typ |> Alcotest.check; -let u1: Term.UExp.t = {ids: [id_at(0)], term: Int(8)}; +let u1: Exp.t = {ids: [id_at(0)], term: Int(8), copied: false}; let single_integer = () => - alco_check("Integer literal 8", Some(IntLit(8)), dhexp_of_uexp(u1)); + alco_check("Integer literal 8", u1, dhexp_of_uexp(u1)); -let u2: Term.UExp.t = {ids: [id_at(0)], term: EmptyHole}; -let empty_hole = () => - alco_check( - "Empty hole", - Some(EmptyHole(id_at(0), 0)), - dhexp_of_uexp(u2), - ); +let u2: Exp.t = {ids: [id_at(0)], term: EmptyHole, copied: false}; +let empty_hole = () => alco_check("Empty hole", u2, dhexp_of_uexp(u2)); -let u3: Term.UExp.t = { +let u3: Exp.t = { ids: [id_at(0)], - term: Parens({ids: [id_at(1)], term: Var("y")}), + term: Parens({ids: [id_at(1)], term: Var("y"), copied: false}), + copied: false, }; -let d3: DHExp.t = - NonEmptyHole(TypeInconsistent, id_at(1), 0, FreeVar(id_at(1), 0, "y")); -let free_var = () => - alco_check( - "Nonempty hole with free variable", - Some(d3), - dhexp_of_uexp(u3), - ); -let u4: Term.UExp.t = { - ids: [id_at(0)], - term: - Let( - { - ids: [id_at(1)], - term: - Tuple([ - {ids: [id_at(2)], term: Var("a")}, - {ids: [id_at(3)], term: Var("b")}, - ]), - }, - { - ids: [id_at(4)], - term: - Tuple([ - {ids: [id_at(5)], term: Int(4)}, - {ids: [id_at(6)], term: Int(6)}, - ]), - }, - { - ids: [id_at(7)], - term: - BinOp( - Int(Minus), - {ids: [id_at(8)], term: Var("a")}, - {ids: [id_at(9)], term: Var("b")}, - ), - }, - ), -}; -let d4: DHExp.t = +let free_var = () => alco_check("free variable", u3, dhexp_of_uexp(u3)); + +let u4: Exp.t = Let( - Tuple([Var("a"), Var("b")]), - Tuple([IntLit(4), IntLit(6)]), - BinIntOp(Minus, BoundVar("a"), BoundVar("b")), - ); + Tuple([Var("a") |> Pat.fresh, Var("b") |> Pat.fresh]) |> Pat.fresh, + Tuple([Int(4) |> Exp.fresh, Int(6) |> Exp.fresh]) |> Exp.fresh, + BinOp(Int(Minus), Var("a") |> Exp.fresh, Var("b") |> Exp.fresh) + |> Exp.fresh, + ) + |> Exp.fresh; + let let_exp = () => - alco_check( - "Let expression for tuple (a, b)", - Some(d4), - dhexp_of_uexp(u4), - ); + alco_check("Let expression for tuple (a, b)", u4, dhexp_of_uexp(u4)); + +let u5 = + BinOp(Int(Plus), Bool(false) |> Exp.fresh, Var("y") |> Exp.fresh) + |> Exp.fresh; + +let d5 = + BinOp( + Int(Plus), + FailedCast(Bool(false) |> Exp.fresh, Bool |> Typ.fresh, Int |> Typ.fresh) + |> Exp.fresh, + Cast( + Var("y") |> Exp.fresh, + Unknown(Internal) |> Typ.fresh, + Int |> Typ.fresh, + ) + |> Exp.fresh, + ) + |> Exp.fresh; -let u5: Term.UExp.t = { - ids: [id_at(0)], - term: - BinOp( - Int(Plus), - {ids: [id_at(1)], term: Bool(false)}, - {ids: [id_at(2)], term: Var("y")}, - ), -}; -let d5: DHExp.t = - BinIntOp( - Plus, - NonEmptyHole(TypeInconsistent, id_at(1), 0, BoolLit(false)), - NonEmptyHole(TypeInconsistent, id_at(2), 0, FreeVar(id_at(2), 0, "y")), - ); let bin_op = () => alco_check( "Inconsistent binary integer operation (plus)", - Some(d5), + d5, dhexp_of_uexp(u5), ); -let u6: Term.UExp.t = { - ids: [id_at(0)], - term: - If( - {ids: [id_at(1)], term: Bool(false)}, - {ids: [id_at(2)], term: Int(8)}, - {ids: [id_at(3)], term: Int(6)}, - ), -}; -let d6: DHExp.t = - IfThenElse(DH.ConsistentIf, BoolLit(false), IntLit(8), IntLit(6)); +let u6: Exp.t = + If(Bool(false) |> Exp.fresh, Int(8) |> Exp.fresh, Int(6) |> Exp.fresh) + |> Exp.fresh; + let consistent_if = () => alco_check( "Consistent case with rules (BoolLit(true), IntLit(8)) and (BoolLit(false), IntLit(6))", - Some(d6), + u6, dhexp_of_uexp(u6), ); -let u7: Term.UExp.t = { - ids: [id_at(0)], - term: - Ap( - { - ids: [id_at(1)], - term: - Fun( - {ids: [id_at(2)], term: Var("x")}, - { - ids: [id_at(3)], - term: - BinOp( - Int(Plus), - {ids: [id_at(4)], term: Int(4)}, - {ids: [id_at(5)], term: Var("x")}, - ), - }, - ), - }, - {ids: [id_at(6)], term: Var("y")}, - ), -}; -let d7: DHExp.t = +let u7: Exp.t = Ap( + Forward, Fun( - Var("x"), - Unknown(Internal), - BinIntOp( - Plus, - IntLit(4), - Cast(BoundVar("x"), Unknown(Internal), Int), - ), + Var("x") |> Pat.fresh, + BinOp(Int(Plus), Int(4) |> Exp.fresh, Int(5) |> Exp.fresh) + |> Exp.fresh, None, - ), - NonEmptyHole(TypeInconsistent, id_at(6), 0, FreeVar(id_at(6), 0, "y")), - ); + None, + ) + |> Exp.fresh, + Var("y") |> Exp.fresh, + ) + |> Exp.fresh; + let ap_fun = () => - alco_check( - "Application of a function of a free variable wrapped inside a nonempty hole constructor", - Some(d7), - dhexp_of_uexp(u7), - ); + alco_check("Application of a function", u7, dhexp_of_uexp(u7)); + +let u8: Exp.t = + Match( + BinOp(Int(Equals), Int(4) |> Exp.fresh, Int(3) |> Exp.fresh) + |> Exp.fresh, + [ + (Bool(true) |> Pat.fresh, Int(24) |> Exp.fresh), + (Bool(false) |> Pat.fresh, Bool(false) |> Exp.fresh), + ], + ) + |> Exp.fresh; + +let d8: Exp.t = + Match( + BinOp(Int(Equals), Int(4) |> Exp.fresh, Int(3) |> Exp.fresh) + |> Exp.fresh, + [ + ( + Bool(true) |> Pat.fresh, + Cast( + Int(24) |> Exp.fresh, + Int |> Typ.fresh, + Unknown(Internal) |> Typ.fresh, + ) + |> Exp.fresh, + ), + ( + Bool(false) |> Pat.fresh, + Cast( + Bool(false) |> Exp.fresh, + Bool |> Typ.fresh, + Unknown(Internal) |> Typ.fresh, + ) + |> Exp.fresh, + ), + ], + ) + |> Exp.fresh; -let u8: Term.UExp.t = { - ids: [id_at(0)], - term: - Match( - { - ids: [id_at(1)], - term: - BinOp( - Int(Equals), - {ids: [id_at(2)], term: Int(4)}, - {ids: [id_at(3)], term: Int(3)}, - ), - }, - [ - ( - {ids: [id_at(6)], term: Bool(true)}, - {ids: [id_at(4)], term: Int(24)}, - ), - ( - {ids: [id_at(7)], term: Bool(false)}, - {ids: [id_at(5)], term: Bool(false)}, - ), - ], - ), -}; -let d8scrut: DHExp.t = BinIntOp(Equals, IntLit(4), IntLit(3)); -let d8rules = - DHExp.[ - Rule(BoolLit(true), IntLit(24)), - Rule(BoolLit(false), BoolLit(false)), - ]; -let d8a: DHExp.t = - InconsistentBranches(id_at(0), 0, Case(d8scrut, d8rules, 0)); -let d8: DHExp.t = NonEmptyHole(TypeInconsistent, id_at(0), 0, d8a); let inconsistent_case = () => alco_check( "Inconsistent branches where the first branch is an integer and second branch is a boolean", - Some(d8), + d8, dhexp_of_uexp(u8), ); -let u9: Term.UExp.t = { - ids: [id_at(0)], - term: - Let( - { - ids: [id_at(1)], - term: - TypeAnn( - {ids: [id_at(2)], term: Var("f")}, - { - ids: [id_at(3)], - term: - Arrow( - {ids: [id_at(4)], term: Int}, - {ids: [id_at(5)], term: Int}, - ), - }, - ), - }, - { - ids: [id_at(6)], - term: - Fun( - {ids: [id_at(7)], term: Var("x")}, - { - ids: [id_at(8)], - term: - BinOp( - Int(Plus), - {ids: [id_at(9)], term: Int(1)}, - {ids: [id_at(10)], term: Var("x")}, - ), - }, - ), - }, - {ids: [id_at(11)], term: Int(55)}, - ), -}; -let d9: DHExp.t = +let u9: Exp.t = Let( - Var("f"), + Cast( + Var("f") |> Pat.fresh, + Arrow(Int |> Typ.fresh, Int |> Typ.fresh) |> Typ.fresh, + Unknown(Internal) |> Typ.fresh, + ) + |> Pat.fresh, + Fun( + Var("x") |> Pat.fresh, + BinOp(Int(Plus), Int(1) |> Exp.fresh, Var("x") |> Exp.fresh) + |> Exp.fresh, + None, + None, + ) + |> Exp.fresh, + Int(55) |> Exp.fresh, + ) + |> Exp.fresh; + +let d9: Exp.t = + Let( + Var("f") |> Pat.fresh, FixF( - "f", - Arrow(Int, Int), + Var("f") |> Pat.fresh, Fun( - Var("x"), - Int, - BinIntOp(Plus, IntLit(1), BoundVar("x")), + Var("x") |> Pat.fresh, + BinOp(Int(Plus), Int(1) |> Exp.fresh, Var("x") |> Exp.fresh) + |> Exp.fresh, + None, Some("f+"), - ), - ), - IntLit(55), - ); + ) + |> Exp.fresh, + None, + ) + |> Exp.fresh, + Int(55) |> Exp.fresh, + ) + |> Exp.fresh; + let let_fun = () => alco_check( "Let expression for function which wraps a fix point constructor around the function", - Some(d9), + d9, dhexp_of_uexp(u9), );