Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support field projectors for recursive structs in Lean backend #194

Merged
merged 9 commits into from
May 24, 2024
71 changes: 55 additions & 16 deletions compiler/ExtractTypes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1666,14 +1666,15 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)

(** Auxiliary function.

Generate field projectors in Coq.
Generate field projectors for Lean and Coq.

Sometimes we extract records as inductives in Coq: when this happens we
have to define the field projectors afterwards.
Recursive structs are defined as inductives in Lean and Coq.
Field projectors allow to retrieve the facilities provided by
Lean structures.
*)
let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
(fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit =
sanity_check __FILE__ __LINE__ (!backend = Coq) decl.span;
sanity_check __FILE__ __LINE__ (!backend = Coq || !backend = Lean) decl.span;
match decl.kind with
| Opaque | Enum _ -> ()
| Struct fields ->
Expand All @@ -1685,29 +1686,60 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
ctx_add_generic_params decl.span decl.llbc_name decl.llbc_generics
decl.generics ctx
in
(* Record_var will be the ADT argument to the projector *)
let ctx, record_var = ctx_add_var decl.span "x" (VarId.of_int 0) ctx in
(* Field_var will be the variable in the constructor that is returned by the projector *)
let ctx, field_var = ctx_add_var decl.span "x" (VarId.of_int 1) ctx in
(* Name of the ADT *)
let def_name = ctx_get_local_type decl.span decl.def_id ctx in
(* Name of the ADT constructor. As we are in the struct case, we only have
one constructor *)
let cons_name = ctx_get_struct decl.span (TAdtId decl.def_id) ctx in

let extract_field_proj (field_id : FieldId.id) (_ : field) : unit =
F.pp_print_space fmt ();
(* Outer box for the projector definition *)
F.pp_open_hvbox fmt 0;
(* Inner box for the projector definition *)
F.pp_open_hvbox fmt ctx.indent_incr;
(* Open a box for the [Definition PROJ ... :=] *)

(* For Lean: add some attributes *)
if !backend = Lean then (
(* Box for the attributes *)
F.pp_open_vbox fmt 0;
(* Annotate the projectors with both simp and reducible.
The first one allows to automatically unfold when calling simp in proofs.
The second ensures that projectors will interact well with the unifier *)
F.pp_print_string fmt "@[simp, reducible]";
F.pp_print_break fmt 0 0;
(* Close box for the attributes *)
F.pp_close_box fmt ());

(* Box for the [def ADT.proj ... :=] *)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt "Definition";
(match !backend with
| Lean -> F.pp_print_string fmt "def"
| Coq -> F.pp_print_string fmt "Definition"
| _ -> internal_error __FILE__ __LINE__ decl.span);
F.pp_print_space fmt ();

(* Print the function name. In Lean, the syntax ADT.proj will
allow us to call x.proj for any x of type ADT. In Coq,
we will have to introduce a notation for the projector. *)
let field_name =
ctx_get_field decl.span (TAdtId decl.def_id) field_id ctx
in
if !backend = Lean then (
F.pp_print_string fmt def_name;
F.pp_print_string fmt ".");
F.pp_print_string fmt field_name;

(* Print the generics *)
let as_implicits = true in
extract_generic_params decl.span ctx fmt TypeDeclId.Set.empty
~as_implicits decl.generics type_params cg_params trait_clauses;
(* Print the record parameter *)

(* Print the record parameter as "(x : ADT)" *)
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
F.pp_print_string fmt record_var;
Expand All @@ -1721,14 +1753,17 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
F.pp_print_string fmt p)
type_params;
F.pp_print_string fmt ")";
(* *)

F.pp_print_space fmt ();
F.pp_print_string fmt ":=";
(* Close the box for the [Definition PROJ ... :=] *)

(* Close the box for the [def ADT.proj ... :=] *)
F.pp_close_box fmt ();
F.pp_print_space fmt ();

(* Open a box for the whole match *)
F.pp_open_hvbox fmt 0;

(* Open a box for the [match ... with] *)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt "match";
Expand Down Expand Up @@ -1758,9 +1793,12 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
F.pp_print_string fmt field_var;
(* Close the box for the branch *)
F.pp_close_box fmt ();

(* Print the [end] *)
F.pp_print_space fmt ();
F.pp_print_string fmt "end";
if !backend = Coq then (
F.pp_print_space fmt ();
F.pp_print_string fmt "end");

(* Close the box for the whole match *)
F.pp_close_box fmt ();
(* Close the inner box projector *)
Expand All @@ -1769,12 +1807,13 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
if !backend = Coq then (
F.pp_print_cut fmt ();
F.pp_print_string fmt ".");
(* Close the outer box projector *)
(* Close the outer box for projector definition *)
F.pp_close_box fmt ();
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0
in

(* Only for Coq: we need to define a notation for the projector *)
let extract_proj_notation (field_id : FieldId.id) (_ : field) : unit =
F.pp_print_space fmt ();
(* Outer box for the projector definition *)
Expand Down Expand Up @@ -1815,7 +1854,7 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let extract_field_proj_and_notation (field_id : FieldId.id)
(field : field) : unit =
extract_field_proj field_id field;
extract_proj_notation field_id field
if !backend = Coq then extract_proj_notation field_id field
in

FieldId.iteri extract_field_proj_and_notation fields
Expand All @@ -1828,14 +1867,14 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (decl : type_decl) : unit =
match !backend with
| FStar | Lean | HOL4 -> ()
| Coq ->
| FStar | HOL4 -> ()
| Lean | Coq ->
if
not
(TypesUtils.type_decl_from_decl_id_is_tuple_struct
ctx.trans_ctx.type_ctx.type_infos decl.def_id)
then (
extract_type_decl_coq_arguments ctx fmt kind decl;
if !backend = Coq then extract_type_decl_coq_arguments ctx fmt kind decl;
extract_type_decl_record_field_projectors ctx fmt kind decl)

(** Extract the state type declaration. *)
Expand Down
5 changes: 0 additions & 5 deletions compiler/SymbolicToPure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2903,14 +2903,9 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
- if the ADT is an enumeration (which must have exactly one branch)
- if we forbid using field projectors.
*)
let is_rec_def =
T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls
in
let use_let_with_cons =
is_enum
|| !Config.dont_use_field_projectors
(* TODO: for now, we don't have field projectors over recursive ADTs in Lean. *)
|| (!Config.backend = Lean && is_rec_def)
(* Also, there is a special case when the ADT is to be extracted as
a tuple, because it is a structure with unnamed fields. Some backends
like Lean have projectors for tuples (like so: `x.3`), but others
Expand Down
81 changes: 34 additions & 47 deletions tests/lean/BetreeMain/Funs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,15 @@ mutual divergent def betree.Internal.lookup_in_children
(self : betree.Internal) (key : U64) (st : State) :
Result (State × ((Option U64) × betree.Internal))
:=
let ⟨ i, i1, n, n1 ⟩ := self
if key < i1
if key < self.pivot
then
do
let (st1, (o, n2)) ← betree.Node.lookup n key st
Result.ok (st1, (o, betree.Internal.mk i i1 n2 n1))
let (st1, (o, n)) ← betree.Node.lookup self.left key st
Result.ok (st1, (o, betree.Internal.mk self.id self.pivot n self.right))
else
do
let (st1, (o, n2)) ← betree.Node.lookup n1 key st
Result.ok (st1, (o, betree.Internal.mk i i1 n n2))
let (st1, (o, n)) ← betree.Node.lookup self.right key st
Result.ok (st1, (o, betree.Internal.mk self.id self.pivot self.left n))

/- [betree_main::betree::{betree_main::betree::Node#5}::lookup]:
Source: 'src/betree.rs', lines 709:4-709:58 -/
Expand All @@ -270,8 +269,7 @@ divergent def betree.Node.lookup
match self with
| betree.Node.Internal node =>
do
let ⟨ i, i1, n, n1 ⟩ := node
let (st1, msgs) ← betree.load_internal_node i st
let (st1, msgs) ← betree.load_internal_node node.id st
let (pending, lookup_first_message_for_key_back) ←
betree.Node.lookup_first_message_for_key key msgs
match pending with
Expand All @@ -281,8 +279,7 @@ divergent def betree.Node.lookup
then
do
let (st2, (o, node1)) ←
betree.Internal.lookup_in_children (betree.Internal.mk i i1 n n1) key
st1
betree.Internal.lookup_in_children node key st1
let _ ←
lookup_first_message_for_key_back (betree.List.Cons (k, msg) l)
Result.ok (st2, (o, betree.Node.Internal node1))
Expand All @@ -293,33 +290,26 @@ divergent def betree.Node.lookup
let _ ←
lookup_first_message_for_key_back (betree.List.Cons (k,
betree.Message.Insert v) l)
Result.ok (st1, (some v, betree.Node.Internal (betree.Internal.mk i
i1 n n1)))
Result.ok (st1, (some v, betree.Node.Internal node))
| betree.Message.Delete =>
do
let _ ←
lookup_first_message_for_key_back (betree.List.Cons (k,
betree.Message.Delete) l)
Result.ok (st1, (none, betree.Node.Internal (betree.Internal.mk i i1
n n1)))
Result.ok (st1, (none, betree.Node.Internal node))
| betree.Message.Upsert ufs =>
do
let (st2, (v, node1)) ←
betree.Internal.lookup_in_children (betree.Internal.mk i i1 n n1)
key st1
betree.Internal.lookup_in_children node key st1
let (v1, pending1) ←
betree.Node.apply_upserts (betree.List.Cons (k,
betree.Message.Upsert ufs) l) v key
let ⟨ i2, i3, n2, n3 ⟩ := node1
let msgs1 ← lookup_first_message_for_key_back pending1
let (st3, _) ← betree.store_internal_node i2 msgs1 st2
Result.ok (st3, (some v1, betree.Node.Internal (betree.Internal.mk i2
i3 n2 n3)))
let (st3, _) ← betree.store_internal_node node1.id msgs1 st2
Result.ok (st3, (some v1, betree.Node.Internal node1))
| betree.List.Nil =>
do
let (st2, (o, node1)) ←
betree.Internal.lookup_in_children (betree.Internal.mk i i1 n n1) key
st1
let (st2, (o, node1)) ← betree.Internal.lookup_in_children node key st1
let _ ← lookup_first_message_for_key_back betree.List.Nil
Result.ok (st2, (o, betree.Node.Internal node1))
| betree.Node.Leaf node =>
Expand Down Expand Up @@ -541,34 +531,36 @@ mutual divergent def betree.Internal.flush
× betree.NodeIdCounter)))
:=
do
let ⟨ i, i1, n, n1 ⟩ := self
let p ← betree.ListPairU64T.partition_at_pivot betree.Message content i1
let p ←
betree.ListPairU64T.partition_at_pivot betree.Message content self.pivot
let (msgs_left, msgs_right) := p
let len_left ← betree.List.len (U64 × betree.Message) msgs_left
if len_left >= params.min_flush_size
then
do
let (st1, p1) ←
betree.Node.apply_messages n params node_id_cnt msgs_left st
let (n2, node_id_cnt1) := p1
betree.Node.apply_messages self.left params node_id_cnt msgs_left st
let (n, node_id_cnt1) := p1
let len_right ← betree.List.len (U64 × betree.Message) msgs_right
if len_right >= params.min_flush_size
then
do
let (st2, p2) ←
betree.Node.apply_messages n1 params node_id_cnt1 msgs_right st1
let (n3, node_id_cnt2) := p2
Result.ok (st2, (betree.List.Nil, (betree.Internal.mk i i1 n2 n3,
node_id_cnt2)))
betree.Node.apply_messages self.right params node_id_cnt1 msgs_right
st1
let (n1, node_id_cnt2) := p2
Result.ok (st2, (betree.List.Nil, (betree.Internal.mk self.id self.pivot
n n1, node_id_cnt2)))
else
Result.ok (st1, (msgs_right, (betree.Internal.mk i i1 n2 n1,
node_id_cnt1)))
Result.ok (st1, (msgs_right, (betree.Internal.mk self.id self.pivot n
self.right, node_id_cnt1)))
else
do
let (st1, p1) ←
betree.Node.apply_messages n1 params node_id_cnt msgs_right st
let (n2, node_id_cnt1) := p1
Result.ok (st1, (msgs_left, (betree.Internal.mk i i1 n n2, node_id_cnt1)))
betree.Node.apply_messages self.right params node_id_cnt msgs_right st
let (n, node_id_cnt1) := p1
Result.ok (st1, (msgs_left, (betree.Internal.mk self.id self.pivot
self.left n, node_id_cnt1)))

/- [betree_main::betree::{betree_main::betree::Node#5}::apply_messages]:
Source: 'src/betree.rs', lines 588:4-593:5 -/
Expand All @@ -581,26 +573,21 @@ divergent def betree.Node.apply_messages
match self with
| betree.Node.Internal node =>
do
let ⟨ i, i1, n, n1 ⟩ := node
let (st1, content) ← betree.load_internal_node i st
let (st1, content) ← betree.load_internal_node node.id st
let content1 ← betree.Node.apply_messages_to_internal content msgs
let num_msgs ← betree.List.len (U64 × betree.Message) content1
if num_msgs >= params.min_flush_size
then
do
let (st2, (content2, p)) ←
betree.Internal.flush (betree.Internal.mk i i1 n n1) params node_id_cnt
content1 st1
betree.Internal.flush node params node_id_cnt content1 st1
let (node1, node_id_cnt1) := p
let ⟨ i2, i3, n2, n3 ⟩ := node1
let (st3, _) ← betree.store_internal_node i2 content2 st2
Result.ok (st3, (betree.Node.Internal (betree.Internal.mk i2 i3 n2 n3),
node_id_cnt1))
let (st3, _) ← betree.store_internal_node node1.id content2 st2
Result.ok (st3, (betree.Node.Internal node1, node_id_cnt1))
else
do
let (st2, _) ← betree.store_internal_node i content1 st1
Result.ok (st2, (betree.Node.Internal (betree.Internal.mk i i1 n n1),
node_id_cnt))
let (st2, _) ← betree.store_internal_node node.id content1 st1
Result.ok (st2, (betree.Node.Internal node, node_id_cnt))
| betree.Node.Leaf node =>
do
let (st1, content) ← betree.load_leaf_node node.id st
Expand Down
16 changes: 16 additions & 0 deletions tests/lean/BetreeMain/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ inductive betree.Node :=

end

@[simp, reducible]
def betree.Internal.id (x : betree.Internal) :=
match x with | betree.Internal.mk x1 _ _ _ => x1

@[simp, reducible]
def betree.Internal.pivot (x : betree.Internal) :=
match x with | betree.Internal.mk _ x1 _ _ => x1

@[simp, reducible]
def betree.Internal.left (x : betree.Internal) :=
match x with | betree.Internal.mk _ _ x1 _ => x1

@[simp, reducible]
def betree.Internal.right (x : betree.Internal) :=
match x with | betree.Internal.mk _ _ _ x1 => x1

/- [betree_main::betree::Params]
Source: 'src/betree.rs', lines 187:0-187:13 -/
structure betree.Params where
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runner/aeneas_test_runner.opam
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ homepage: "https://github.com/AeneasVerif/aeneas"
bug-reports: "https://github.com/AeneasVerif/aeneas/issues"
depends: [
"ocaml"
"dune" {>= "3.12"}
"dune" {>= "3.7"}
"odoc" {with-doc}
]
build: [
Expand Down
Loading