Skip to content

Commit

Permalink
Merge Records into ADT
Browse files Browse the repository at this point in the history
  • Loading branch information
Halbaroth committed Aug 6, 2024
1 parent 6424e77 commit 1d30760
Show file tree
Hide file tree
Showing 22 changed files with 194 additions and 669 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals_intf Intervals_core Intervals
Ite_rel Matching Matching_types Polynome Records Records_rel
Ite_rel Matching Matching_types Polynome
Satml_frontend_hybrid Satml_frontend Satml Sat_solver Sat_solver_sig
Sig Sig_rel Theory Uf Use Domains Domains_intf Rel_utils Bitlist
; structures
Expand Down
16 changes: 11 additions & 5 deletions src/lib/frontend/cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,18 @@ let rec make_term quant_basename t =
E.mk_term (Sy.Op Sy.Concat)
[mk_term t1; mk_term t2] ty

| TTdot (t, s) ->
E.mk_term (Sy.Op (Sy.Access (Uid.of_hstring s))) [mk_term t] ty

| TTrecord lbs ->
| TTrecord (ty, lbs) ->
let lbs = List.map (fun (_, t) -> mk_term t) lbs in
E.mk_record lbs ty
let cstr =
match ty with
| Tadt (name, params, true) ->
begin match Ty.type_body name params with
| [{ constr; _ }] -> Uid.show constr
| _ -> assert false
end
| _ -> assert false
in
E.mk_constr (Uid.of_string cstr) lbs ty

| TTlet (binders, t2) ->
let binders =
Expand Down
160 changes: 11 additions & 149 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -515,87 +515,17 @@ let rec dty_to_ty ?(update = false) ?(is_var = false) dty =
| _ -> unsupported "Type %a" DE.Ty.print dty

and handle_ty_app ?(update = false) ty_c l =
(* Applies the substitutions in [tysubsts] to each encountered type
variable. *)
let rec apply_ty_substs tysubsts ty =
match ty with
| Ty.Tvar { v; _ } ->
Ty.M.find v tysubsts

| Text (tyl, hs) ->
Ty.Text (List.map (apply_ty_substs tysubsts) tyl, hs)

| Tfarray (ti, tv) ->
Tfarray (
apply_ty_substs tysubsts ti,
apply_ty_substs tysubsts tv
)

| Tadt (hs, tyl) ->
Tadt (hs, List.map (apply_ty_substs tysubsts) tyl)

| Trecord ({ args; lbs; _ } as rcrd) ->
Trecord {
rcrd with
args = List.map (apply_ty_substs tysubsts) args;
lbs = List.map (
fun (hs, t) ->
hs, apply_ty_substs tysubsts t
) lbs;
}

| _ -> ty
in
let tyl = List.map (dty_to_ty ~update) l in
(* Recover the initial versions of the types and apply them on the provided
type arguments stored in [tyl]. *)
match Cache.find_ty ty_c with
| Tadt (hs, _) -> Tadt (hs, tyl)

| Trecord { args; _ } as ty ->
let tysubsts =
List.fold_left2 (
fun acc tv ty ->
match tv with
| Ty.Tvar { v; _ } -> Ty.M.add v ty acc
| _ -> assert false
) Ty.M.empty args tyl
in
apply_ty_substs tysubsts ty

| Tadt (hs, _, record) -> Tadt (hs, tyl, record)
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

(** Handles a simple type declaration. *)
let mk_ty_decl (ty_c: DE.ty_cst) =
match DT.definition ty_c with
| Some (
(Adt
{ cases = [| { cstr = { id_ty; _ } as cstr; dstrs; _ } |]; _ } as adt)
) ->
(* Records and adts that only have one case are treated in the same way,
and considered as records. *)
Nest.attach_orders [adt];
let uid = Uid.of_ty_cst ty_c in
let tyvl = Cache.store_ty_vars_ret id_ty in
let lbs =
Array.fold_right (
fun c acc ->
match c with
| Some (DE.{ id_ty; _ } as id) ->
let pty = dty_to_ty id_ty in
(Uid.of_term_cst id, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c

) dstrs []
in
let record_constr = Uid.of_term_cst cstr in
let ty = Ty.trecord ~record_constr tyvl uid lbs in
Cache.store_ty ty_c ty

| Some (Adt { cases; _ } as adt) ->
Nest.attach_orders [adt];
let uid = Uid.of_ty_cst ty_c in
Expand Down Expand Up @@ -653,29 +583,7 @@ let mk_term_decl ({ id_ty; path; tags; _ } as tcst: DE.term_cst) =
let mk_mr_ty_decls (tdl: DE.ty_cst list) =
let handle_ty_decl (ty: Ty.t) (tdef: DE.Ty.def option) =
match ty, tdef with
| Trecord { args; name; record_constr; _ },
Some (
Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ }
) ->
let lbs =
Array.fold_right (
fun c acc ->
match c with
| Some (DE.{ id_ty; _ } as id) ->
let pty = dty_to_ty id_ty in
(Uid.of_term_cst id, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c
) dstrs []
in
let ty =
Ty.trecord ~record_constr args name lbs
in
Cache.store_ty ty_c ty

| Tadt (hs, tyl), Some (Adt { cases; ty = ty_c; _ }) ->
| Tadt (hs, tyl, _), Some (Adt { cases; ty = ty_c; _ }) ->
let cs =
Array.fold_right (
fun DE.{ cstr; dstrs; _ } accl ->
Expand All @@ -696,38 +604,16 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =

| _ -> assert false
in
(* If there are adts in the list of type declarations then records are
converted to adts, because that's how it's done in the legacy typechecker.
But it might be more efficient not to do that. *)
let rev_tdefs, contains_adts =
List.fold_left (
fun (acc, ca) ty_c ->
match DT.definition ty_c with
| Some (Adt { record; cases; _ } as df)
when not record && Array.length cases > 1 ->
df :: acc, true
| Some (Adt _ as df) ->
df :: acc, ca
| Some Abstract | None ->
assert false
) ([], false) tdl
in
let rev_tdefs = List.rev_map (fun td -> Option.get @@ DT.definition td) tdl in
Nest.attach_orders rev_tdefs;
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| DE.Adt { cases; record; ty = ty_c; } as adt ->
| DE.Adt { cases; ty = ty_c; _ } as adt ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let uid = Uid.of_ty_cst ty_c in
let record_constr = Uid.of_term_cst cases.(0).cstr in
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr tyvl uid []
else
Ty.t_adt uid tyvl
in
let ty = Ty.t_adt uid tyvl in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc

Expand Down Expand Up @@ -954,14 +840,8 @@ let rec mk_expr
E.mk_term sy [] ty

| B.Constructor _ ->
begin match dty_to_ty term_ty with
| Trecord _ as ty ->
E.mk_record [] ty
| Tadt _ as ty ->
E.mk_constr (Uid.of_term_cst tcst) [] ty
| ty ->
Fmt.failwith "unexpected type %a@." Ty.print ty
end
let ty = dty_to_ty term_ty in
E.mk_constr (Uid.of_term_cst tcst) [] ty

| _ -> unsupported "Constant term %a" DE.Term.print term
end
Expand Down Expand Up @@ -1002,10 +882,7 @@ let rec mk_expr
let e = aux_mk_expr x in
let sy =
match Cache.find_ty adt with
| Trecord _ ->
Sy.Op (Sy.Access (Uid.of_term_cst destr))
| Tadt _ ->
Sy.destruct (Uid.of_term_cst destr)
| Tadt _ -> Sy.destruct (Uid.of_term_cst destr)
| _ -> assert false
in
E.mk_term sy [e] ty
Expand Down Expand Up @@ -1037,11 +914,6 @@ let rec mk_expr
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]

| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
replace the tester of a record by the true literal. *)
E.vrai
| _ -> assert false
end

Expand Down Expand Up @@ -1308,19 +1180,9 @@ let rec mk_expr

| B.Constructor _, _ ->
let ty = dty_to_ty term_ty in
begin match ty with
| Ty.Tadt _ ->
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_constr (Uid.of_term_cst tcst) l ty
| Ty.Trecord _ ->
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_record l ty
| _ ->
Fmt.failwith
"Constructor error: %a does not belong to a record nor an\
algebraic data type"
DE.Term.print app_term
end
let sy = Sy.constr @@ Uid.of_term_cst tcst in
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term sy l ty

| B.Coercion, [ x ] ->
begin match DT.view (DE.Term.ty x), DT.view term_ty with
Expand Down
22 changes: 1 addition & 21 deletions src/lib/frontend/models.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ module Pp_smtlib_term = struct
asprintf "%a" Ty.pp_smtlib t

let rec print fmt t =
let {Expr.f;xs;ty; _} = Expr.term_view t in
let {Expr.f;xs; _} = Expr.term_view t in
match f, xs with

| Sy.Lit lit, xs ->
Expand Down Expand Up @@ -150,26 +150,6 @@ module Pp_smtlib_term = struct
| Sy.Op Sy.Extract (i, j), [e] ->
fprintf fmt "%a^{%d,%d}" print e i j

| Sy.Op (Sy.Access field), [e] ->
if Options.get_output_smtlib () then
fprintf fmt "(%a %a)" Uid.pp field print e
else
fprintf fmt "%a.%a" print e Uid.pp field

| Sy.Op (Sy.Record), _ ->
begin match ty with
| Ty.Trecord { Ty.lbs = lbs; _ } ->
assert (List.length xs = List.length lbs);
fprintf fmt "{";
ignore (List.fold_left2 (fun first (field,_) e ->
fprintf fmt "%s%a = %a" (if first then "" else "; ")
Uid.pp field print e;
false
) true lbs xs);
fprintf fmt "}";
| _ -> assert false
end

(* TODO: introduce PrefixOp in the future to simplify this ? *)
| Sy.Op op, [e1; e2] when op == Sy.Pow || op == Sy.Integer_round ||
op == Sy.Max_real || op == Sy.Max_int ||
Expand Down
Loading

0 comments on commit 1d30760

Please sign in to comment.