Skip to content

Commit

Permalink
Add RegionsHierarchy.ml
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Nov 13, 2023
1 parent 746239e commit 6c88d30
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 147 deletions.
5 changes: 3 additions & 2 deletions compiler/AssociatedTypes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,11 @@ let ctx_subst_norm_signature (ctx : C.eval_ctx)
(ty_subst : T.TypeVarId.id -> T.ty)
(cg_subst : T.ConstGenericVarId.id -> T.const_generic)
(tr_subst : T.TraitClauseId.id -> T.trait_instance_id)
(tr_self : T.trait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
(tr_self : T.trait_instance_id) (sg : A.fun_sig)
(regions_hierarchy : T.region_groups) : A.inst_fun_sig =
let sg =
Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self
sg
sg regions_hierarchy
in
let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in
let inputs = List.map (ctx_normalize_ty ctx) inputs in
Expand Down
18 changes: 6 additions & 12 deletions compiler/Assumed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ module Sig = struct
let mk_slice_ty (ty : T.ty) : T.ty =
TAdt (TAssumed TSlice, mk_generic_args [] [ ty ] [])

let mk_sig generics regions_hierarchy inputs output : A.fun_sig =
let mk_sig generics inputs output : A.fun_sig =
let preds : T.predicates =
{ regions_outlive = []; types_outlive = []; trait_type_constraints = [] }
in
Expand All @@ -88,26 +88,23 @@ module Sig = struct
generics;
preds;
parent_params_info = None;
regions_hierarchy;
inputs;
output;
}

(** [fn<T>(T) -> Box<T>] *)
let box_new_sig : A.fun_sig =
let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
let regions_hierarchy = [] in
let inputs = [ tvar_0 (* T *) ] in
let output = mk_box_ty tvar_0 (* Box<T> *) in
mk_sig generics regions_hierarchy inputs output
mk_sig generics inputs output

(** [fn<T>(Box<T>) -> ()] *)
let box_free_sig : A.fun_sig =
let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
let regions_hierarchy = [] in
let inputs = [ mk_box_ty tvar_0 (* Box<T> *) ] in
let output = mk_unit_ty (* () *) in
mk_sig generics regions_hierarchy inputs output
mk_sig generics inputs output

(** Array/slice functions *)

Expand All @@ -129,7 +126,6 @@ module Sig = struct
let generics =
mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *)
in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
let inputs =
[
mk_ref_ty rvar_0
Expand All @@ -145,7 +141,7 @@ module Sig = struct
(output_ty type_param_0.index)
is_mut (* &'a (mut) output_ty<T> *)
in
mk_sig generics regions_hierarchy inputs output
mk_sig generics inputs output

let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig =
(* Array<T, N> *)
Expand Down Expand Up @@ -176,13 +172,12 @@ module Sig = struct
(* <T, N> *)
mk_generic_params [] [ type_param_0 ] [ cg_param_0 ]
in
let regions_hierarchy = [] (* <> *) in
let inputs = [ tvar_0 (* T *) ] in
let output =
(* [T; N] *)
mk_array_ty tvar_0 cgvar_0
in
mk_sig generics regions_hierarchy inputs output
mk_sig generics inputs output

(** Helper:
[fn<T>(&'a [T]) -> usize]
Expand All @@ -191,12 +186,11 @@ module Sig = struct
let generics =
mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
let inputs =
[ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ]
in
let output = mk_usize_ty (* usize *) in
mk_sig generics regions_hierarchy inputs output
mk_sig generics inputs output
end

type raw_assumed_fun_info =
Expand Down
2 changes: 2 additions & 0 deletions compiler/Contexts.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ open Types
open Expressions
open Values
open LlbcAst
open LlbcAstUtils
module V = Values
open ValuesUtils
open Identifiers
Expand Down Expand Up @@ -190,6 +191,7 @@ type type_context = {
type fun_context = {
fun_decls : fun_decl FunDeclId.Map.t;
fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t;
regions_hierarchies : T.region_groups FunIdMap.t;
}
[@@deriving show]

Expand Down
8 changes: 6 additions & 2 deletions compiler/ExtractBase.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,18 @@ let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
let def_id = def.def_id in
let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in
let sg = llbc_def.signature in
let num_rgs = List.length sg.regions_hierarchy in
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FRegular def_id)
ctx.trans_ctx.fun_ctx.regions_hierarchies
in
let num_rgs = List.length regions_hierarchy in
let { keep_fwd; fwd = _; backs } = trans_group in
let num_backs = List.length backs in
let rg_info =
match def.back_id with
| None -> None
| Some rg_id ->
let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in
let rg = T.RegionGroupId.nth regions_hierarchy rg_id in
let region_names =
List.map
(fun rid -> (T.RegionId.nth sg.generics.regions rid).name)
Expand Down
28 changes: 20 additions & 8 deletions compiler/Interpreter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ let compute_contexts (m : A.crate) : C.decls_ctx =
let fun_infos =
FunsAnalysis.analyze_module m fun_decls global_decls !Config.use_state
in
let fun_ctx = { C.fun_decls; fun_infos } in
let regions_hierarchies =
RegionsHierarchy.compute_regions_hierarchies type_decls fun_decls
in
let fun_ctx = { C.fun_decls; fun_infos; regions_hierarchies } in
let global_ctx = { C.global_decls } in
let trait_decls_ctx = { C.trait_decls } in
let trait_impls_ctx = { C.trait_impls } in
Expand Down Expand Up @@ -124,8 +127,8 @@ let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (sg : A.fun_sig)
List.fold_left_map
(fun tr_map (c : T.trait_clause) ->
let subst = mk_subst tr_map in
let { T.trait_id = trait_decl_id; generics; _ } = c in
let generics = Subst.generic_args_substitute subst generics in
let { T.trait_id = trait_decl_id; clause_generics; _ } = c in
let generics = Subst.generic_args_substitute subst clause_generics in
let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in
(* Note that because we directly refer to the clause, we give it
empty generics *)
Expand Down Expand Up @@ -183,8 +186,11 @@ let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl)
* *)
let sg = fdef.signature in
(* Create the context *)
let regions_hierarchy =
FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies
in
let region_groups =
List.map (fun (g : T.region_group) -> g.id) sg.regions_hierarchy
List.map (fun (g : T.region_group) -> g.id) regions_hierarchy
in
let ctx =
initialize_eval_context ctx region_groups sg.generics.types
Expand Down Expand Up @@ -269,7 +275,6 @@ let evaluate_function_symbolic_synthesize_backward_from_return
* the return type. Note that it is important to re-generate
* an instantiation of the signature, so that we use fresh
* region ids for the return abstractions. *)
let sg = fdef.signature in
let _, ret_inst_sg =
symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind
in
Expand All @@ -282,7 +287,10 @@ let evaluate_function_symbolic_synthesize_backward_from_return
* will end - this will allow us to, first, mark the other return
* regions as non-endable, and, second, end those parent regions in
* proper order. *)
let parent_rgs = list_ancestor_region_groups sg back_id in
let regions_hierarchy =
FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies
in
let parent_rgs = list_ancestor_region_groups regions_hierarchy back_id in
let parent_input_abs_ids =
T.RegionGroupId.mapi
(fun rg_id rg ->
Expand Down Expand Up @@ -455,6 +463,10 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx)
(* Create the evaluation context *)
let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in

let regions_hierarchy =
FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies
in

(* Create the continuation to finish the evaluation *)
let config = C.mk_config C.SymbolicMode in
let cf_finish res ctx =
Expand Down Expand Up @@ -511,7 +523,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx)
let back_el =
T.RegionGroupId.mapi
(fun gid _ -> (gid, finish_back_eval gid))
fdef.signature.regions_hierarchy
regions_hierarchy
in
let back_el = T.RegionGroupId.Map.of_list back_el in
(* Put everything together *)
Expand Down Expand Up @@ -555,7 +567,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx)
let back_el =
T.RegionGroupId.mapi
(fun gid _ -> (gid, finish_back_eval gid))
fdef.signature.regions_hierarchy
regions_hierarchy
in
let back_el = T.RegionGroupId.Map.of_list back_el in
(* Put everything together *)
Expand Down
2 changes: 1 addition & 1 deletion compiler/InterpreterStatements.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ and eval_assumed_function_call_symbolic (config : C.config)
let inst_sig =
match fid with
| BoxFree ->
(* should have been treated above *)
(* Should have been treated above *)
raise (Failure "Unreachable")
| _ ->
(* There shouldn't be any reference to Self *)
Expand Down
9 changes: 7 additions & 2 deletions compiler/InterpreterUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,19 @@ let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.generic_args)
(* Erase the regions in the generics we use for the instantiation *)
let generics = Subst.generic_args_erase_regions generics in
let tr_self = Subst.trait_instance_id_erase_regions tr_self in
(* Compute the regions hierarchy *)
let regions_hierarchy =
RegionsHierarchy.compute_regions_hierarchy_for_sig
ctx.type_context.type_decls sg
in
(* Generate fresh abstraction ids and create a substitution from region
* group ids to abstraction ids *)
let rg_abs_ids_bindings =
List.map
(fun rg ->
let abs_id = C.fresh_abstraction_id () in
(rg.T.id, abs_id))
sg.regions_hierarchy
regions_hierarchy
in
let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t =
List.fold_left
Expand Down Expand Up @@ -512,7 +517,7 @@ let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.generic_args)
(* Substitute the signature *)
let inst_sig =
AssociatedTypes.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst
tr_subst tr_self sg
tr_subst tr_self sg regions_hierarchy
in
(* Return *)
inst_sig
13 changes: 13 additions & 0 deletions compiler/LlbcAstUtils.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
open LlbcAst
include Charon.LlbcAstUtils
open Collections

module FunIdOrderedType : OrderedType with type t = fun_id = struct
type t = fun_id

let compare = compare_fun_id
let to_string = show_fun_id
let pp_t = pp_fun_id
let show_t = show_fun_id
end

module FunIdMap = Collections.MakeMap (FunIdOrderedType)
module FunIdSet = Collections.MakeSet (FunIdOrderedType)

let lookup_fun_sig (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) :
fun_sig =
Expand Down
9 changes: 5 additions & 4 deletions compiler/PureMicroPasses.ml
Original file line number Diff line number Diff line change
Expand Up @@ -786,18 +786,19 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
if rg_id0 = rg_id1 then true
else
(* We need to use the regions hierarchy *)
(* First, lookup the signature of the LLBC function *)
let sg =
let regions_hierarchy =
let id0 =
match id0 with
| FunId fun_id -> fun_id
| TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id
in
LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls
LlbcAstUtils.FunIdMap.find id0
ctx.fun_ctx.regions_hierarchies
in
(* Compute the set of ancestors of the function in call1 *)
let call1_ancestors =
LlbcAstUtils.list_ancestor_region_groups sg rg_id1
LlbcAstUtils.list_ancestor_region_groups regions_hierarchy
rg_id1
in
(* Check if the function used in call0 is inside *)
T.RegionGroupId.Set.mem rg_id0 call1_ancestors
Expand Down
Loading

0 comments on commit 6c88d30

Please sign in to comment.