From fdf3ec1270b1ac58a7add1b5e232944448352c8d Mon Sep 17 00:00:00 2001 From: Davide Fissore Date: Tue, 3 Dec 2024 16:22:38 +0100 Subject: [PATCH] [union-find] version with persistent map instead of "persistent" Hashtbl --- src/API.ml | 1 + src/API.mli | 6 +++ src/compiler/compiler.ml | 76 +++++++++++++++------------- src/compiler/determinacy_checker.ml | 21 +++++--- src/compiler/determinacy_checker.mli | 23 ++++----- src/compiler/union_find.ml | 53 +++++++++++++++++-- src/compiler/union_find.mli | 15 ++++++ src/utils/util.ml | 9 +++- src/utils/util.mli | 6 +++ tests/sources/sepcomp_tyid.ml | 19 ++++++- 10 files changed, 168 insertions(+), 61 deletions(-) diff --git a/src/API.ml b/src/API.ml index 526319937..27cccbbfc 100644 --- a/src/API.ml +++ b/src/API.ml @@ -1352,6 +1352,7 @@ module Utils = struct let map_acc = BuiltInData.map_acc module type Show = Util.Show + module type ShowKey = Util.ShowKey module type Show1 = Util.Show1 module Map = Util.Map module Set = Util.Set diff --git a/src/API.mli b/src/API.mli index 1c1bbd5bb..0cad4c0ad 100644 --- a/src/API.mli +++ b/src/API.mli @@ -1330,6 +1330,12 @@ module Utils : sig val show : t -> string end + module type ShowKey = sig + type key + val pp_key : Format.formatter -> key -> unit + val show_key : key -> string + end + module type Show1 = sig type 'a t val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit diff --git a/src/compiler/compiler.ml b/src/compiler/compiler.ml index 584bba025..4315f85a3 100644 --- a/src/compiler/compiler.ml +++ b/src/compiler/compiler.ml @@ -211,7 +211,7 @@ module C = Constants open Compiler_data -module Union_find = Union_find.Make(IdPos) +module Union_find = Union_find.UF2(IdPos.Map) type macro_declaration = (ScopedTerm.t * Loc.t) F.Map.t [@@ deriving show, ord] @@ -251,6 +251,7 @@ type unchecked_signature = { types : TypeList.t F.Map.t; type_abbrevs : (F.t * ScopedTypeExpression.t) list; modes : (mode * Loc.t) F.Map.t; + type_uf : Union_find.t } [@@deriving show] @@ -274,7 +275,7 @@ module Assembled = struct type_abbrevs : (TypeAssignment.skema_w_id * Loc.t) F.Map.t; modes : (mode * Loc.t) F.Map.t; functional_preds: Determinacy_checker.t; - type_uf : Union_find.closed + type_uf : Union_find.t } [@@deriving show] @@ -304,7 +305,7 @@ module Assembled = struct types = F.Map.empty; type_abbrevs = F.Map.empty; modes = F.Map.empty; functional_preds = Determinacy_checker.empty_fmap; toplevel_macros = F.Map.empty; - type_uf = Union_find.create_close () + type_uf = Union_find.empty } let empty () = { clauses = []; @@ -1119,10 +1120,10 @@ module Flatten : sig Arity.t F.Map.t -> Arity.t F.Map.t val merge_type_assignments : - Union_find.closed -> + Union_find.t -> TypeAssignment.overloaded_skema_with_id F.Map.t -> TypeAssignment.overloaded_skema_with_id F.Map.t -> - TypeAssignment.overloaded_skema_with_id F.Map.t + Union_find.t * TypeAssignment.overloaded_skema_with_id F.Map.t val merge_type_abbrevs : (F.t * ScopedTypeExpression.t) list -> (F.t * ScopedTypeExpression.t) list -> @@ -1132,10 +1133,10 @@ module Flatten : sig (F.t * ScopedTypeExpression.t) list -> (F.t * ScopedTypeExpression.t) list val merge_checked_type_abbrevs : - Union_find.closed -> + Union_find.t -> ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t -> ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t -> - ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t + Union_find.t * ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t val merge_toplevel_macros : (ScopedTerm.t * Loc.t) F.Map.t -> @@ -1253,11 +1254,13 @@ module Flatten : sig List.map (fun (k, v) -> subst_global s k, ScopedTypeExpression.smart_map (subst_global s) v) l let merge_type_assignments uf t1 t2 = + let res = ref uf in (* We give precedence to recent type declarations over old ones *) - F.Map.union (fun f l1 l2 -> + let t = F.Map.union (fun f l1 l2 -> let to_union, ta = TypeAssignment.merge_skema l2 l1 in - List.iter (Union_find.union_c uf) to_union; - Some ta) t1 t2 + List.iter (fun (k,v) -> res := Union_find.union uf k v |> snd) to_union; + Some ta) t1 t2 in + !res, t let merge_types t1 t2 = F.Map.union (fun _ l1 l2 -> Some (TypeList.merge l1 l2)) t1 t2 @@ -1278,14 +1281,16 @@ module Flatten : sig let merge_type_abbrevs m1 m2 = m1 @ m2 let merge_checked_type_abbrevs uf m1 m2 = - F.Map.union (fun k ((id1,sk),otherloc as x) ((id2,ty),loc) -> + let uf = ref uf in + let m = F.Map.union (fun k ((id1,sk),otherloc as x) ((id2,ty),loc) -> if TypeAssignment.compare_skema sk ty <> 0 then error ~loc ("Duplicate type abbreviation for " ^ F.show k ^ ". Previous declaration: " ^ Loc.show otherloc) else - Union_find.union_c uf (id1, id2); - Some x) m1 m2 + uf := Union_find.union !uf id1 id2 |> snd; + Some x) m1 m2 in + !uf, m let merge_toplevel_macros otlm toplevel_macros = F.Map.union (fun k (m1,l1) (m2,l2) -> @@ -1341,7 +1346,7 @@ module Flatten : sig let run state { Scoped.pbody; toplevel_macros } = let kinds, types, type_abbrevs, modes, clauses_rev, chr_rev = compile_body pbody in - let signature = { Flat.kinds; types; type_abbrevs; modes; toplevel_macros } in + let signature = { Flat.kinds; types; type_abbrevs; modes; toplevel_macros; type_uf = Union_find.empty } in { Flat.clauses = List.(flatten (rev clauses_rev)); chr = List.rev chr_rev; builtins = []; signature } (* TODO builtins can be in a unit *) @@ -1358,10 +1363,11 @@ end = struct let check_signature builtins symbols (base_signature : Assembled.signature) (signature : Flat.unchecked_signature) : Assembled.signature * Assembled.signature * float * _= let { Assembled.modes = om; functional_preds = ofp; kinds = ok; types = ot; type_abbrevs = ota; toplevel_macros = otlm; type_uf = otuf } = base_signature in - let all_tyuf_opened = Union_find.do_open otuf in - let local_tyuf_opened = Union_find.do_open otuf in - - let { Flat.modes; kinds; types; type_abbrevs; toplevel_macros } = signature in + let { Flat.modes; kinds; types; type_abbrevs; toplevel_macros; type_uf } = signature in + + let type_uf = ref type_uf in + let otuf = ref otuf in + let all_kinds = Flatten.merge_kinds ok kinds in let func_setter_object = new Determinacy_checker.merger ofp in let check_k_begin = Unix.gettimeofday () in @@ -1378,8 +1384,8 @@ end = struct ". Previous declaration: " ^ Loc.show otherloc) end else func_setter_object#add_ty_abbr id scoped_ty; - Union_find.add all_tyuf_opened id; - Union_find.add local_tyuf_opened id; + otuf := Union_find.add !otuf id; + type_uf := Union_find.add !type_uf id; F.Map.add name ((id, ty),loc) all_type_abbrevs, F.Map.add name ((id,ty),loc) type_abbrevs) (ota,F.Map.empty) type_abbrevs in let check_k_end = Unix.gettimeofday () in @@ -1394,8 +1400,8 @@ end = struct let types = F.Map.mapi (fun name e -> let tys = Type_checker.check_types ~type_abbrevs:all_type_abbrevs ~kinds:all_kinds e in let ids = get_ids tys in - List.iter (Union_find.add all_tyuf_opened) ids; - List.iter (Union_find.add local_tyuf_opened) ids; + List.iter (fun e -> otuf := Union_find.add !otuf e) ids; + List.iter (fun e -> type_uf := Union_find.add !type_uf e) ids; func_setter_object#add_func_ty_list e ids; tys) types in @@ -1409,15 +1415,13 @@ end = struct let check_t_end = Unix.gettimeofday () in - let all_type_uf = Union_find.do_close all_tyuf_opened in - let all_types = Flatten.merge_type_assignments all_type_uf ot types in + let otuf, all_types = Flatten.merge_type_assignments !otuf ot types in let all_toplevel_macros = Flatten.merge_toplevel_macros otlm toplevel_macros in let all_modes = Flatten.merge_modes om modes in let all_functional_preds = func_setter_object#merge in - let type_uf = Union_find.do_close local_tyuf_opened in - { Assembled.modes; functional_preds = func_setter_object#get_local_func; kinds; types; type_abbrevs; toplevel_macros; type_uf }, - { Assembled.modes = all_modes; functional_preds = all_functional_preds; kinds = all_kinds; types = all_types; type_abbrevs = all_type_abbrevs; toplevel_macros = all_toplevel_macros; type_uf = all_type_uf }, + { Assembled.modes; functional_preds = func_setter_object#get_local_func; kinds; types; type_abbrevs; toplevel_macros; type_uf= !type_uf }, + { Assembled.modes = all_modes; functional_preds = all_functional_preds; kinds = all_kinds; types = all_types; type_abbrevs = all_type_abbrevs; toplevel_macros = all_toplevel_macros; type_uf = otuf }, check_t_end -. check_t_begin +. check_k_end -. check_k_begin, types_indexing @@ -1430,12 +1434,15 @@ end = struct let check_begin = Unix.gettimeofday () in + Format.eprintf "Type uf is %a@." Union_find.pp type_uf; + let unknown, clauses = List.fold_left (fun (unknown,clauses) ({ Ast.Clause.body; loc; needs_spilling; attributes = { Ast.Structured.typecheck } } as clause) -> let unknown = if typecheck then Type_checker.check ~is_rule:true ~unknown ~type_abbrevs ~kinds ~types body ~exp:(Val Prop) else unknown in + if String.starts_with ~prefix:"File \"<" (Loc.show loc) then Format.eprintf "The clause is %a@." ScopedTerm.pp body; let spilled = {clause with body = if needs_spilling then Spilling.main body else body; needs_spilling = false} in - if typecheck then Determinacy_checker.check_clause ~loc ~env:functional_preds spilled.body ~modes; + if typecheck then Determinacy_checker.check_clause ~uf:type_uf ~loc ~env:functional_preds spilled.body ~modes; unknown, spilled :: clauses) (F.Map.empty,[]) clauses in let clauses = List.rev clauses in @@ -1453,13 +1460,13 @@ end = struct ) builtins; let more_types = Type_checker.check_undeclared ~unknown in - let u_types = Flatten.merge_type_assignments type_uf signature.Assembled.types more_types in - let types = Flatten.merge_type_assignments type_uf types more_types in + let type_uf, u_types = Flatten.merge_type_assignments type_uf signature.Assembled.types more_types in + let type_uf, types = Flatten.merge_type_assignments type_uf types more_types in let check_end = Unix.gettimeofday () in - let signature = { signature with Assembled.types = u_types } in - let precomputed_signature = { precomputed_signature with Assembled.types } in + let signature = { signature with Assembled.types = u_types; type_uf } in + let precomputed_signature = { precomputed_signature with Assembled.types; type_uf } in let checked_code = { CheckedFlat.signature; clauses; chr; builtins; types_indexing } in @@ -1770,10 +1777,11 @@ let extend1_signature base_signature (signature : checked_compilation_unit_signa let { Assembled.toplevel_macros; kinds; types; type_abbrevs; modes; functional_preds; type_uf } = signature in let kinds = Flatten.merge_kinds ok kinds in let type_uf = Union_find.merge otyuf type_uf in - let type_abbrevs = Flatten.merge_checked_type_abbrevs type_uf ota type_abbrevs in - let types = Flatten.merge_type_assignments type_uf ot types in + let type_uf, type_abbrevs = Flatten.merge_checked_type_abbrevs type_uf ota type_abbrevs in + let type_uf, types = Flatten.merge_type_assignments type_uf ot types in let modes = Flatten.merge_modes om modes in let toplevel_macros = Flatten.merge_toplevel_macros otlm toplevel_macros in + Format.eprintf "Merged type uf is %a@." Union_find.pp type_uf; { Assembled.kinds; types; type_abbrevs; functional_preds; modes; toplevel_macros; type_uf } diff --git a/src/compiler/determinacy_checker.ml b/src/compiler/determinacy_checker.ml index 258923628..7cdfea15f 100644 --- a/src/compiler/determinacy_checker.ml +++ b/src/compiler/determinacy_checker.ml @@ -5,6 +5,7 @@ open Elpi_util.Util open Elpi_parser.Ast open Compiler_data module C = Constants +module Union_find = Union_find.UF2(IdPos.Map) let to_print f = if false then f () @@ -197,10 +198,13 @@ module Compilation = struct end let merge = Compilation.merge +let remove t k = {t with cmap = IdPos.Map.remove k t.cmap} -let get_functionality ?tyag ~loc ~env k = - if k = Scope.dummy_type_decl_id then Any +let get_functionality ~uf ?tyag ~loc ~env k' = + if k' = Scope.dummy_type_decl_id then Any else + let k = Union_find.find uf k' in + if k' <> k then Format.eprintf "Found a father in uf.\n child%a\n father:%a@." IdPos.pp k IdPos.pp k'; match IdPos.Map.find_opt k env.cmap with | None -> ( (* TODO: this is temporary: waiting for unknown type to be added in the @@ -323,7 +327,7 @@ let not_functional_call_error ~loc t = error ~loc (Format.asprintf "Non functional premise call %a\n" ScopedTerm.pretty_ t) module Checker_clause = struct - let check ~modes ~(global : env) tm = + let check ?(uf = Union_find.empty) ~modes ~(global : env) tm = let env = ref Env.empty in let pp_env fmt () : unit = Format.fprintf fmt "Env : %a" Env.pp !env in (* let pp_ctx fmt ctx : unit = Format.fprintf fmt "Ctx : %a" Ctx.pp ctx in *) @@ -377,8 +381,8 @@ module Checker_clause = struct Some (Compilation.TypeAssignment.type_ass_2func ~loc global ty)) | Const (Global { decl_id }, _) -> Some - (match get_functionality ~loc ~tyag:(ty, []) ~env:global decl_id with Relational -> Functional | e -> e) - | App (Global { decl_id }, _, x, xs) -> Some (get_functionality ~loc ~tyag:(ty, x :: xs) ~env:global decl_id) + (match get_functionality ~uf ~loc ~tyag:(ty, []) ~env:global decl_id with Relational -> Functional | e -> e) + | App (Global { decl_id }, _, x, xs) -> Some (get_functionality ~uf ~loc ~tyag:(ty, x :: xs) ~env:global decl_id) | App (Bound scope, f, _, _) | Const (Bound scope, f) -> Ctx.get ctx (f, scope) | CData _ -> Some (NoProp []) | Spill _ -> error ~loc "get_funct_of_term of spilling: " @@ -781,6 +785,9 @@ module Checker_clause = struct to_print (fun () -> Format.eprintf "END HEAD CHECK@."); to_print (fun () -> Format.eprintf "The contex_head is %a@." pp_env ()); + + Format.eprintf "Getting the functionality of %a and func hd is %a@." ScopedTerm.pp hd pp_functionality (get_func_hd ctx hd); + Option.iter (fun body -> check_body ctx body (get_func_hd ctx hd)) body; if body <> None then check_head_output ctx hd @@ -796,14 +803,14 @@ let to_check_clause ScopedTerm.{ it; loc } = (* && Re.Str.string_match (Re.Str.regexp ".*test.*") (Loc.show loc) 0 *) (* && Re.Str.string_match (Re.Str.regexp ".*test.*functionality.*") (Loc.show loc) 0 *) -let check_clause ~loc ~env ~modes t = +let check_clause ?uf ~loc ~env ~modes t = if to_check_clause t then ( to_print (fun () -> (* check_clause ~loc ~env F.Map.empty t |> ignore *) Format.eprintf "============== STARTING mode checking %a@." Loc.pp t.loc (* Format.eprintf "Modes are [%a]" (F.Map.pp (fun fmt ((e:mode_aux list),_) -> Format.fprintf fmt "%a" pp_mode e)) (modes); *) (* Format.eprintf "Functional preds are %a@." pp env; *)); - Checker_clause.check ~modes ~global:env t) + Checker_clause.check ?uf ~modes ~global:env t) class merger (all_func : env) = object (self) diff --git a/src/compiler/determinacy_checker.mli b/src/compiler/determinacy_checker.mli index 75cb7eda1..685371770 100644 --- a/src/compiler/determinacy_checker.mli +++ b/src/compiler/determinacy_checker.mli @@ -3,20 +3,19 @@ (* ------------------------------------------------------------------------- *) open Compiler_data open Elpi_util.Util +module Union_find : Union_find.UF2S with type key = IdPos.t and type t = IdPos.t IdPos.Map.t -type t [@@ deriving show, ord] +type t [@@deriving show, ord] val empty_fmap : t - -val check_clause : loc:Loc.t -> env:t -> modes:(mode*Loc.t) F.Map.t -> ScopedTerm.t -> unit - +val check_clause : ?uf:Union_find.t -> loc:Loc.t -> env:t -> modes:(mode * Loc.t) F.Map.t -> ScopedTerm.t -> unit val merge : t -> t -> t +val remove : t -> IdPos.t -> t -class merger : t -> - object - method get_all_func : t - method get_local_func : t - method add_ty_abbr : Scope.type_decl_id -> ScopedTypeExpression.t -> unit - method add_func_ty_list : TypeList.t -> IdPos.t list -> unit - method merge : t - end +class merger : t -> object + method get_all_func : t + method get_local_func : t + method add_ty_abbr : Scope.type_decl_id -> ScopedTypeExpression.t -> unit + method add_func_ty_list : TypeList.t -> IdPos.t list -> unit + method merge : t +end diff --git a/src/compiler/union_find.ml b/src/compiler/union_find.ml index 941a4f09f..25f0c279d 100644 --- a/src/compiler/union_find.ml +++ b/src/compiler/union_find.ml @@ -110,11 +110,7 @@ module Make (M : M) = struct cell.parent <- parent; cell - let do_open tbl = - Hashtbl.fold - (fun k v acc -> - Hashtbl.add k (clone v) acc) - tbl (create ()) + let do_open tbl = Hashtbl.fold (fun k v acc -> Hashtbl.add k (clone v) acc) tbl (create ()) let merge tbl1 tbl2 = let tbl1 = do_open tbl1 in @@ -146,3 +142,50 @@ module Make (M : M) = struct let create_close = create let union_c = union end + +module type UF2S = sig + include Util.Show + include Util.ShowKey + + val empty : t + val add : t -> key -> t + val find : t -> key -> key + val union : t -> key -> key -> key option * t + val merge : t -> t -> t +end + + +module UF2 (M : Elpi_util.Util.Map.S) = struct + type key = M.key [@@deriving show] + type t = key M.t [@@deriving show] + + let empty = M.empty + let add m v1 = M.add v1 v1 m + let find m v = M.find_opt v m |> Option.value ~default:v + + let union m j i = + let ri = find m i in + let rj = find m j in + (* r1 is put in the same disjoint set of rj and can be removed from other + data structures *) + if ri <> rj then (Some ri, M.add ri rj m) else None, m + + let merge u1 u2 = + (* all disjoint-set in u1 and u2 should be pairwise disjoint *) + M.union (fun _ a _ -> Some a) u1 u2 + (* M.fold (fun k father acc -> + let acc = if M.mem father acc then assert false else add acc father in + union acc k father + ) u1 u2 *) + + let is_root acc k = find acc k = k + + let pp fmt v = + Format.fprintf fmt "{{\n"; + M.iter (fun k v -> + if k <> v then + Format.fprintf fmt "@[%a -> %a@]\n" M.pp_key k M.pp_key v + ) v; + Format.fprintf fmt "}}@." + let pp_key = M.pp_key +end diff --git a/src/compiler/union_find.mli b/src/compiler/union_find.mli index 161db388f..9e80797d1 100644 --- a/src/compiler/union_find.mli +++ b/src/compiler/union_find.mli @@ -1,3 +1,5 @@ +open Elpi_util.Util + module type M = sig include Hashtbl.HashedType @@ -31,3 +33,16 @@ module Make : functor (M : M) -> sig val union_c : closed -> M.t * M.t -> unit val pp : Format.formatter -> closed -> unit end + +module type UF2S = sig + include Show + include ShowKey + + val empty : t + val add : t -> key -> t + val find : t -> key -> key + val union : t -> key -> key -> key option * t + val merge : t -> t -> t +end + +module UF2 : functor (M : Elpi_util.Util.Map.S) -> UF2S with type key = M.key and type t = M.key M.t diff --git a/src/utils/util.ml b/src/utils/util.ml index be00a6c90..fccb187e3 100644 --- a/src/utils/util.ml +++ b/src/utils/util.ml @@ -8,6 +8,12 @@ module type Show = sig val show : t -> string end +module type ShowKey = sig + type key + val pp_key : Format.formatter -> key -> unit + val show_key : key -> string +end + module type Show1 = sig type 'a t val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit @@ -25,8 +31,7 @@ module Map = struct module type S = sig include Map.S include Show1 with type 'a t := 'a t - val pp_key : Format.formatter -> key -> unit - val show_key : key -> string + include ShowKey with type key := key end module type OrderedType = sig diff --git a/src/utils/util.mli b/src/utils/util.mli index 2d11687c7..3f9ddbac4 100644 --- a/src/utils/util.mli +++ b/src/utils/util.mli @@ -8,6 +8,12 @@ module type Show = sig val show : t -> string end +module type ShowKey = sig + type key + val pp_key : Format.formatter -> key -> unit + val show_key : key -> string +end + module type Show1 = sig type 'a t val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit diff --git a/tests/sources/sepcomp_tyid.ml b/tests/sources/sepcomp_tyid.ml index 13f474289..a165bc092 100644 --- a/tests/sources/sepcomp_tyid.ml +++ b/tests/sources/sepcomp_tyid.ml @@ -27,4 +27,21 @@ let () = let q = Compile.query cp (Parse.goal_from ~elpi ~loc:(Ast.Loc.initial "g") (Lexing.from_string "main")) in (* TODO: check that in the determinacy checker map the pred p appears once *) - exec q \ No newline at end of file + exec q + +let () = + let open Sepcomp.Sepcomp_template in + let elpi = init () in + let flags = Compile.default_flags in + let pmain,_ = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 1 maine in + let pmain,_ = cc ~elpi ~flags ~base:pmain 2 p1 in + let pmain,_ = cc ~elpi ~flags ~base:pmain 3 p2 in + + (* let cp = Compile.extend ~base:pmain unit1 in + let cp = Compile.extend ~base:cp unit2 in *) + + (* Format.eprintf "The program is @[%a@]@." Compile.pp_program cp; *) + + let q = Compile.query pmain (Parse.goal_from ~elpi ~loc:(Ast.Loc.initial "g") (Lexing.from_string "main")) in + (* TODO: check that in the determinacy checker map the pred p appears once *) + exec q