Skip to content

Commit

Permalink
Get tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
Negabinary committed May 2, 2024
1 parent 4fe9d9c commit d62f220
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 435 deletions.
127 changes: 0 additions & 127 deletions src/haz3lcore/dynamics/DHExp.re
Original file line number Diff line number Diff line change
Expand Up @@ -85,133 +85,6 @@ let rec strip_casts =
_,
);

let rec fast_equal =
({term: d1, _} as d1exp: t, {term: d2, _} as d2exp: t): bool => {
switch (d1, d2) {
/* Primitive forms: regular structural equality */
| (Var(_), _)
/* TODO: Not sure if this is right... */
| (Bool(_), _)
| (Int(_), _)
| (Float(_), _)
| (Deferral(_), _)
| (Constructor(_), _) => d1 == d2
| (String(s1), String(s2)) => String.equal(s1, s2)
| (String(_), _) => false

| (Parens(x), _) => fast_equal(x, d2exp)
| (_, Parens(x)) => fast_equal(d1exp, x)

/* Non-hole forms: recurse */
| (Test(d1), Test(d2)) => fast_equal(d1, d2)
| (Seq(d11, d21), Seq(d12, d22)) =>
fast_equal(d11, d12) && fast_equal(d21, d22)
| (Filter(f1, d1), Filter(f2, d2)) =>
filter_fast_equal(f1, f2) && fast_equal(d1, d2)
| (Let(dp1, d11, d21), Let(dp2, d12, d22)) =>
dp1 == dp2 && fast_equal(d11, d12) && fast_equal(d21, d22)
| (FixF(f1, d1, sigma1), FixF(f2, d2, sigma2)) =>
f1 == f2
&& fast_equal(d1, d2)
&& Option.equal(ClosureEnvironment.id_equal, sigma1, sigma2)
| (Fun(dp1, d1, None, s1), Fun(dp2, d2, None, s2)) =>
dp1 == dp2 && fast_equal(d1, d2) && s1 == s2
| (Fun(dp1, d1, Some(env1), s1), Fun(dp2, d2, Some(env2), s2)) =>
dp1 == dp2
&& fast_equal(d1, d2)
&& ClosureEnvironment.id_equal(env1, env2)
&& s1 == s2
| (TypFun(_tpat1, d1, s1), TypFun(_tpat2, d2, s2)) =>
_tpat1 == _tpat2 && fast_equal(d1, d2) && s1 == s2
| (TypAp(d1, ty1), TypAp(d2, ty2)) => fast_equal(d1, d2) && ty1 == ty2
| (Ap(dir1, d11, d21), Ap(dir2, d12, d22)) =>
dir1 == dir2 && fast_equal(d11, d12) && fast_equal(d21, d22)
| (DeferredAp(d1, ds1), DeferredAp(d2, ds2)) =>
fast_equal(d1, d2)
&& List.length(ds1) == List.length(ds2)
&& List.for_all2(fast_equal, ds1, ds2)
| (Cons(d11, d21), Cons(d12, d22)) =>
fast_equal(d11, d12) && fast_equal(d21, d22)
| (ListConcat(d11, d21), ListConcat(d12, d22)) =>
fast_equal(d11, d12) && fast_equal(d21, d22)
| (Tuple(ds1), Tuple(ds2)) =>
List.length(ds1) == List.length(ds2)
&& List.for_all2(fast_equal, ds1, ds2)
| (BuiltinFun(f1), BuiltinFun(f2)) => f1 == f2
| (ListLit(ds1), ListLit(ds2)) =>
List.length(ds1) == List.length(ds2)
&& List.for_all2(fast_equal, ds1, ds2)
| (UnOp(op1, d1), UnOp(op2, d2)) => op1 == op2 && fast_equal(d1, d2)
| (BinOp(op1, d11, d21), BinOp(op2, d12, d22)) =>
op1 == op2 && fast_equal(d11, d12) && fast_equal(d21, d22)
| (TyAlias(tp1, ut1, d1), TyAlias(tp2, ut2, d2)) =>
tp1 == tp2 && ut1 == ut2 && fast_equal(d1, d2)
| (Cast(d1, ty11, ty21), Cast(d2, ty12, ty22))
| (FailedCast(d1, ty11, ty21), FailedCast(d2, ty12, ty22)) =>
fast_equal(d1, d2) && ty11 == ty12 && ty21 == ty22
| (DynamicErrorHole(d1, reason1), DynamicErrorHole(d2, reason2)) =>
fast_equal(d1, d2) && reason1 == reason2
| (Match(s1, rs1), Match(s2, rs2)) =>
fast_equal(s1, s2)
&& List.length(rs2) == List.length(rs2)
&& List.for_all2(
((k1, v1), (k2, v2)) => k1 == k2 && fast_equal(v1, v2),
rs1,
rs2,
)
| (If(d11, d12, d13), If(d21, d22, d23)) =>
fast_equal(d11, d21) && fast_equal(d12, d22) && fast_equal(d13, d23)
/* We can group these all into a `_ => false` clause; separating
these so that we get exhaustiveness checking. */
| (Seq(_), _)
| (Filter(_), _)
| (Let(_), _)
| (FixF(_), _)
| (Fun(_), _)
| (Test(_), _)
| (Ap(_), _)
| (BuiltinFun(_), _)
| (Cons(_), _)
| (ListConcat(_), _)
| (ListLit(_), _)
| (Tuple(_), _)
| (UnOp(_), _)
| (BinOp(_), _)
| (Cast(_), _)
| (FailedCast(_), _)
| (TyAlias(_), _)
| (TypFun(_), _)
| (TypAp(_), _)
| (DynamicErrorHole(_), _)
| (DeferredAp(_), _)
| (If(_), _)
| (Match(_), _) => false

/* Hole forms: when checking environments, only check that
environment ID's are equal, don't check structural equality.
(This resolves a performance issue with many nested holes.) */
| (EmptyHole, EmptyHole) => true
| (MultiHole(_), MultiHole(_)) => rep_id(d1exp) == rep_id(d2exp)
| (Invalid(text1), Invalid(text2)) => text1 == text2
| (Closure(sigma1, d1), Closure(sigma2, d2)) =>
ClosureEnvironment.id_equal(sigma1, sigma2) && fast_equal(d1, d2)
| (EmptyHole, _)
| (MultiHole(_), _)
| (Invalid(_), _)
| (Closure(_), _) => false
};
}
and filter_fast_equal = (f1, f2) => {
switch (f1, f2) {
| (Filter(f1), Filter(f2)) =>
fast_equal(f1.pat, f2.pat) && f1.act == f2.act
| (Residue(idx1, act1), Residue(idx2, act2)) =>
idx1 == idx2 && act1 == act2
| _ => false
};
};

let assign_name_if_none = (t, name) => {
let (term, rewrap) = unwrap(t);
switch (term) {
Expand Down
6 changes: 2 additions & 4 deletions src/haz3lcore/dynamics/Elaborator.re
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,9 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
};
}
);
// TODO: is elaborated_type the right type to use here??
if (!Statics.is_recursive(ctx, p, def, elaborated_type)) {
let (p, ty1) = elaborate_pattern(m, p);
if (!Statics.is_recursive(ctx, p, def, ty1)) {
let def = add_name(Pat.get_var(p), def);
let (p, ty1) = elaborate_pattern(m, p);
let (def, ty2) = elaborate(m, def);
let (body, ty) = elaborate(m, body);
Exp.Let(p, fresh_cast(def, ty2, ty1), body)
Expand All @@ -278,7 +277,6 @@ let rec elaborate = (m: Statics.Map.t, uexp: UExp.t): (DHExp.t, Typ.t) => {
// TODO: Add names to mutually recursive functions
// TODO: Don't add fixpoint if there already is one
let def = add_name(Option.map(s => s ++ "+", Pat.get_var(p)), def);
let (p, ty1) = elaborate_pattern(m, p);
let (def, ty2) = elaborate(m, def);
let (body, ty) = elaborate(m, body);
let fixf = FixF(p, fresh_cast(def, ty2, ty1), None) |> DHExp.fresh;
Expand Down
2 changes: 1 addition & 1 deletion src/haz3lcore/dynamics/FilterMatcher.re
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ let rec matches_exp =
| (Deferral(_), _) => false

| (Filter(df, dd), Filter(ff, fd)) =>
DHExp.filter_fast_equal(df, ff) && matches_exp(env, dd, fd)
TermBase.StepperFilterKind.fast_equal(df, ff) && matches_exp(env, dd, fd)
| (Filter(_), _) => false

| (Bool(dv), Bool(fv)) => dv == fv
Expand Down
6 changes: 0 additions & 6 deletions src/haz3lcore/lang/term/TPat.re
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,3 @@ let show_cls: cls => string =
| MultiHole => "Broken type alias"
| EmptyHole => "Empty type alias hole"
| Var => "Type alias";

let tyvar_of_utpat = ({term, _}: t) =>
switch (term) {
| Var(x) => Some(x)
| _ => None
};
72 changes: 1 addition & 71 deletions src/haz3lcore/lang/term/Typ.re
Original file line number Diff line number Diff line change
Expand Up @@ -244,37 +244,6 @@ let fresh_var = (var_name: string) => {
var_name ++ "_α" ++ string_of_int(x);
};

let rec subst = (s: t, x: TPat.t, ty: t) => {
switch (TPat.tyvar_of_utpat(x)) {
| Some(str) =>
let (term, rewrap) = unwrap(ty);
switch (term) {
| Int => Int |> rewrap
| Float => Float |> rewrap
| Bool => Bool |> rewrap
| String => String |> rewrap
| Unknown(prov) => Unknown(prov) |> rewrap
| Arrow(ty1, ty2) =>
Arrow(subst(s, x, ty1), subst(s, x, ty2)) |> rewrap
| Prod(tys) => Prod(List.map(subst(s, x), tys)) |> rewrap
| Sum(sm) =>
Sum(ConstructorMap.map(Option.map(subst(s, x)), sm)) |> rewrap
| Forall(tp2, ty)
when TPat.tyvar_of_utpat(x) == TPat.tyvar_of_utpat(tp2) =>
Forall(tp2, ty) |> rewrap
| Forall(tp2, ty) => Forall(tp2, subst(s, x, ty)) |> rewrap
| Rec(tp2, ty) when TPat.tyvar_of_utpat(x) == TPat.tyvar_of_utpat(tp2) =>
Rec(tp2, ty) |> rewrap
| Rec(tp2, ty) => Rec(tp2, subst(s, x, ty)) |> rewrap
| List(ty) => List(subst(s, x, ty)) |> rewrap
| Var(y) => str == y ? s : Var(y) |> rewrap
| Parens(ty) => Parens(subst(s, x, ty)) |> rewrap
| Ap(t1, t2) => Ap(subst(s, x, t1), subst(s, x, t2)) |> rewrap
};
| None => ty
};
};

let unroll = (ty: t): t =>
switch (term_of(ty)) {
| Rec(tp, ty_body) => subst(ty, tp, ty_body)
Expand All @@ -283,46 +252,7 @@ let unroll = (ty: t): t =>

/* Type Equality: This coincides with alpha equivalence for normalized types.
Other types may be equivalent but this will not detect so if they are not normalized. */
let rec eq_internal = (n: int, t1: t, t2: t) => {
switch (term_of(t1), term_of(t2)) {
| (Parens(t1), _) => eq_internal(n, t1, t2)
| (_, Parens(t2)) => eq_internal(n, t1, t2)
| (Rec(x1, t1), Rec(x2, t2))
| (Forall(x1, t1), Forall(x2, t2)) =>
let alpha_subst = subst(Var("=" ++ string_of_int(n)) |> mk_fast);
eq_internal(n + 1, alpha_subst(x1, t1), alpha_subst(x2, t2));
| (Rec(_), _) => false
| (Forall(_), _) => false
| (Int, Int) => true
| (Int, _) => false
| (Float, Float) => true
| (Float, _) => false
| (Bool, Bool) => true
| (Bool, _) => false
| (String, String) => true
| (String, _) => false
| (Ap(t1, t2), Ap(t1', t2')) =>
eq_internal(n, t1, t1') && eq_internal(n, t2, t2')
| (Ap(_), _) => false
| (Unknown(_), Unknown(_)) => true
| (Unknown(_), _) => false
| (Arrow(t1, t2), Arrow(t1', t2')) =>
eq_internal(n, t1, t1') && eq_internal(n, t2, t2')
| (Arrow(_), _) => false
| (Prod(tys1), Prod(tys2)) => List.equal(eq_internal(n), tys1, tys2)
| (Prod(_), _) => false
| (List(t1), List(t2)) => eq_internal(n, t1, t2)
| (List(_), _) => false
| (Sum(sm1), Sum(sm2)) =>
/* Does not normalize the types. */
ConstructorMap.equal(eq_internal(n), sm1, sm2)
| (Sum(_), _) => false
| (Var(n1), Var(n2)) => n1 == n2
| (Var(_), _) => false
};
};

let eq = (t1: t, t2: t): bool => eq_internal(0, t1, t2);
let eq = (t1: t, t2: t): bool => fast_equal(t1, t2);

/* Lattice join on types. This is a LUB join in the hazel2
sense in that any type dominates Unknown. The optional
Expand Down
Loading

0 comments on commit d62f220

Please sign in to comment.