From 75009b3414c5b163d88b15e8f03646b48e2b40b6 Mon Sep 17 00:00:00 2001 From: Davide Fissore Date: Tue, 3 Dec 2024 19:46:00 +0100 Subject: [PATCH] [union-find] version with persistent map instead of "persistent" Hashtbl - simple union find with Map and primitives union, find, create, merge - track id removed by merge of typeAssignment - remove ids from functional_preds accordingly - update small test on union find accordingly - update parallel test compilation --- src/compiler/compiler.ml | 84 +++---- src/compiler/determinacy_checker.ml | 68 +++--- src/compiler/determinacy_checker.mli | 2 +- src/compiler/test_union_find.ml | 40 ++-- src/compiler/union_find.ml | 328 ++++++++++++++------------- src/compiler/union_find.mli | 13 +- tests/sources/dune | 3 +- tests/sources/sepcomp_tyid.ml | 22 +- tests/sources/sepcomp_tyid2.ml | 25 ++ 9 files changed, 299 insertions(+), 286 deletions(-) create mode 100644 tests/sources/sepcomp_tyid2.ml diff --git a/src/compiler/compiler.ml b/src/compiler/compiler.ml index 4315f85a..2792e412 100644 --- a/src/compiler/compiler.ml +++ b/src/compiler/compiler.ml @@ -211,7 +211,7 @@ module C = Constants open Compiler_data -module Union_find = Union_find.UF2(IdPos.Map) +module Union_find = Union_find.Make(IdPos.Map) type macro_declaration = (ScopedTerm.t * Loc.t) F.Map.t [@@ deriving show, ord] @@ -1123,7 +1123,7 @@ module Flatten : sig Union_find.t -> TypeAssignment.overloaded_skema_with_id F.Map.t -> TypeAssignment.overloaded_skema_with_id F.Map.t -> - Union_find.t * TypeAssignment.overloaded_skema_with_id F.Map.t + IdPos.t list * Union_find.t * TypeAssignment.overloaded_skema_with_id F.Map.t val merge_type_abbrevs : (F.t * ScopedTypeExpression.t) list -> (F.t * ScopedTypeExpression.t) list -> @@ -1136,7 +1136,7 @@ module Flatten : sig Union_find.t -> ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t -> ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t -> - Union_find.t * ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t + IdPos.t list * Union_find.t * ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t val merge_toplevel_macros : (ScopedTerm.t * Loc.t) F.Map.t -> @@ -1254,13 +1254,33 @@ module Flatten : sig List.map (fun (k, v) -> subst_global s k, ScopedTypeExpression.smart_map (subst_global s) v) l let merge_type_assignments uf t1 t2 = - let res = ref uf in + let uf = ref uf in + let to_remove = ref [] in (* We give precedence to recent type declarations over old ones *) let t = F.Map.union (fun f l1 l2 -> let to_union, ta = TypeAssignment.merge_skema l2 l1 in - List.iter (fun (k,v) -> res := Union_find.union uf k v |> snd) to_union; + List.iter (fun (id1,id2) -> + let rem, uf1 = Union_find.union !uf id1 id2 in + uf := uf1; + Option.iter (fun x -> to_remove := x :: !to_remove) rem; + ) to_union; Some ta) t1 t2 in - !res, t + !to_remove, !uf, t + + let merge_checked_type_abbrevs uf m1 m2 = + let uf = ref uf in + let to_remove = ref [] in + let m = F.Map.union (fun k ((id1,sk),otherloc as x) ((id2,ty),loc) -> + if TypeAssignment.compare_skema sk ty <> 0 then + error ~loc + ("Duplicate type abbreviation for " ^ F.show k ^ + ". Previous declaration: " ^ Loc.show otherloc) + else + let rem, uf1 = Union_find.union !uf id1 id2 in + uf := uf1; + Option.iter (fun x -> to_remove := x :: !to_remove) rem; + Some x) m1 m2 in + !to_remove, !uf, m let merge_types t1 t2 = F.Map.union (fun _ l1 l2 -> Some (TypeList.merge l1 l2)) t1 t2 @@ -1280,18 +1300,6 @@ module Flatten : sig let merge_type_abbrevs m1 m2 = m1 @ m2 - let merge_checked_type_abbrevs uf m1 m2 = - let uf = ref uf in - let m = F.Map.union (fun k ((id1,sk),otherloc as x) ((id2,ty),loc) -> - if TypeAssignment.compare_skema sk ty <> 0 then - error ~loc - ("Duplicate type abbreviation for " ^ F.show k ^ - ". Previous declaration: " ^ Loc.show otherloc) - else - uf := Union_find.union !uf id1 id2 |> snd; - Some x) m1 m2 in - !uf, m - let merge_toplevel_macros otlm toplevel_macros = F.Map.union (fun k (m1,l1) (m2,l2) -> if ScopedTerm.equal ~types:false m1 m2 then Some (m1,l1) else @@ -1365,9 +1373,6 @@ end = struct let { Flat.modes; kinds; types; type_abbrevs; toplevel_macros; type_uf } = signature in - let type_uf = ref type_uf in - let otuf = ref otuf in - let all_kinds = Flatten.merge_kinds ok kinds in let func_setter_object = new Determinacy_checker.merger ofp in let check_k_begin = Unix.gettimeofday () in @@ -1384,8 +1389,6 @@ end = struct ". Previous declaration: " ^ Loc.show otherloc) end else func_setter_object#add_ty_abbr id scoped_ty; - otuf := Union_find.add !otuf id; - type_uf := Union_find.add !type_uf id; F.Map.add name ((id, ty),loc) all_type_abbrevs, F.Map.add name ((id,ty),loc) type_abbrevs) (ota,F.Map.empty) type_abbrevs in let check_k_end = Unix.gettimeofday () in @@ -1400,8 +1403,6 @@ end = struct let types = F.Map.mapi (fun name e -> let tys = Type_checker.check_types ~type_abbrevs:all_type_abbrevs ~kinds:all_kinds e in let ids = get_ids tys in - List.iter (fun e -> otuf := Union_find.add !otuf e) ids; - List.iter (fun e -> type_uf := Union_find.add !type_uf e) ids; func_setter_object#add_func_ty_list e ids; tys) types in @@ -1415,13 +1416,15 @@ end = struct let check_t_end = Unix.gettimeofday () in - let otuf, all_types = Flatten.merge_type_assignments !otuf ot types in + let all_type_uf = Union_find.merge otuf type_uf in + let to_remove, all_type_uf, all_types = Flatten.merge_type_assignments all_type_uf ot types in let all_toplevel_macros = Flatten.merge_toplevel_macros otlm toplevel_macros in let all_modes = Flatten.merge_modes om modes in let all_functional_preds = func_setter_object#merge in + let all_functional_preds = List.fold_left Determinacy_checker.remove all_functional_preds to_remove in - { Assembled.modes; functional_preds = func_setter_object#get_local_func; kinds; types; type_abbrevs; toplevel_macros; type_uf= !type_uf }, - { Assembled.modes = all_modes; functional_preds = all_functional_preds; kinds = all_kinds; types = all_types; type_abbrevs = all_type_abbrevs; toplevel_macros = all_toplevel_macros; type_uf = otuf }, + { Assembled.modes; functional_preds = func_setter_object#get_local_func; kinds; types; type_abbrevs; toplevel_macros; type_uf }, + { Assembled.modes = all_modes; functional_preds = all_functional_preds; kinds = all_kinds; types = all_types; type_abbrevs = all_type_abbrevs; toplevel_macros = all_toplevel_macros; type_uf = all_type_uf }, check_t_end -. check_t_begin +. check_k_end -. check_k_begin, types_indexing @@ -1434,13 +1437,11 @@ end = struct let check_begin = Unix.gettimeofday () in - Format.eprintf "Type uf is %a@." Union_find.pp type_uf; - let unknown, clauses = List.fold_left (fun (unknown,clauses) ({ Ast.Clause.body; loc; needs_spilling; attributes = { Ast.Structured.typecheck } } as clause) -> let unknown = if typecheck then Type_checker.check ~is_rule:true ~unknown ~type_abbrevs ~kinds ~types body ~exp:(Val Prop) else unknown in - if String.starts_with ~prefix:"File \"<" (Loc.show loc) then Format.eprintf "The clause is %a@." ScopedTerm.pp body; + (* if String.starts_with ~prefix:"File \"<" (Loc.show loc) then Format.eprintf "The clause is %a@." ScopedTerm.pp body; *) let spilled = {clause with body = if needs_spilling then Spilling.main body else body; needs_spilling = false} in if typecheck then Determinacy_checker.check_clause ~uf:type_uf ~loc ~env:functional_preds spilled.body ~modes; unknown, spilled :: clauses) (F.Map.empty,[]) clauses in @@ -1460,13 +1461,15 @@ end = struct ) builtins; let more_types = Type_checker.check_undeclared ~unknown in - let type_uf, u_types = Flatten.merge_type_assignments type_uf signature.Assembled.types more_types in - let type_uf, types = Flatten.merge_type_assignments type_uf types more_types in + let _, _, u_types = Flatten.merge_type_assignments type_uf signature.types more_types in + let _, _, types = Flatten.merge_type_assignments type_uf types more_types in + + (* TODO: forall i in toremove @@ toremove1, remove i from functional preds *) let check_end = Unix.gettimeofday () in - let signature = { signature with Assembled.types = u_types; type_uf } in - let precomputed_signature = { precomputed_signature with Assembled.types; type_uf } in + let signature = { signature with types = u_types; type_uf } in + let precomputed_signature = { precomputed_signature with types; type_uf } in let checked_code = { CheckedFlat.signature; clauses; chr; builtins; types_indexing } in @@ -1777,11 +1780,16 @@ let extend1_signature base_signature (signature : checked_compilation_unit_signa let { Assembled.toplevel_macros; kinds; types; type_abbrevs; modes; functional_preds; type_uf } = signature in let kinds = Flatten.merge_kinds ok kinds in let type_uf = Union_find.merge otyuf type_uf in - let type_uf, type_abbrevs = Flatten.merge_checked_type_abbrevs type_uf ota type_abbrevs in - let type_uf, types = Flatten.merge_type_assignments type_uf ot types in + let to_remove, type_uf, type_abbrevs = Flatten.merge_checked_type_abbrevs type_uf ota type_abbrevs in + let to_remove1, type_uf, types = Flatten.merge_type_assignments type_uf ot types in let modes = Flatten.merge_modes om modes in let toplevel_macros = Flatten.merge_toplevel_macros otlm toplevel_macros in - Format.eprintf "Merged type uf is %a@." Union_find.pp type_uf; + let functional_preds = + let fp = Determinacy_checker.merge ofp functional_preds in + let fp = List.fold_left Determinacy_checker.remove fp to_remove in + let fp = List.fold_left Determinacy_checker.remove fp to_remove1 in + fp + in { Assembled.kinds; types; type_abbrevs; functional_preds; modes; toplevel_macros; type_uf } diff --git a/src/compiler/determinacy_checker.ml b/src/compiler/determinacy_checker.ml index 7cdfea15..99b5b26a 100644 --- a/src/compiler/determinacy_checker.ml +++ b/src/compiler/determinacy_checker.ml @@ -5,7 +5,7 @@ open Elpi_util.Util open Elpi_parser.Ast open Compiler_data module C = Constants -module Union_find = Union_find.UF2(IdPos.Map) +module Union_find = Union_find.Make (IdPos.Map) let to_print f = if false then f () @@ -17,7 +17,6 @@ type functionality = | NoProp of functionality list (** -> for kinds like list, int, string *) | BoundVar of F.t (** -> in predicates like: std.exists or in parametric type abbreviations. *) | AssumedFunctional (** -> variadic predicates: never backtrack *) - (* pred p i:int *) (* Arrow (NoProp[]) (Relation) : NoProp -> Relation *) @@ -198,21 +197,22 @@ module Compilation = struct end let merge = Compilation.merge -let remove t k = {t with cmap = IdPos.Map.remove k t.cmap} +let remove t k = { t with cmap = IdPos.Map.remove k t.cmap } -let get_functionality ~uf ?tyag ~loc ~env k' = - if k' = Scope.dummy_type_decl_id then Any +let get_functionality ~uf ?tyag ~loc ~env id = + if id = Scope.dummy_type_decl_id then Any else - let k = Union_find.find uf k' in - if k' <> k then Format.eprintf "Found a father in uf.\n child%a\n father:%a@." IdPos.pp k IdPos.pp k'; - match IdPos.Map.find_opt k env.cmap with + let id' = Union_find.find uf id in + if id <> id' then assert (not (IdPos.Map.mem id env.cmap)); + (* Sanity check *) + match IdPos.Map.find_opt id' env.cmap with | None -> ( (* TODO: this is temporary: waiting for unknown type to be added in the type db After that change, the match becomes useless and ~tyag can be removed from the parameters of get_functionality *) match tyag with - | None -> error ~loc (Format.asprintf "Cannot find functionality of id %a\n%!" IdPos.pp k) + | None -> error ~loc (Format.asprintf "Cannot find functionality of id %a\n%!" IdPos.pp id') | Some (ty, ag) -> Compilation.TypeAssignment.type_ass_2func ~ag ~loc env ty) | Some (name, func) -> if F.equal F.pif name || F.equal F.sigmaf name then functionality_pi_sigma else func @@ -227,21 +227,22 @@ let rec all_relational = function let ( <<= ) ~loc a b = let rec aux ~loc a b = - match (a, b) with - | BoundVar _, _ | _, BoundVar _ -> true (* TODO: this is not correct... -> use ref with uvar to make unification *) - | NoProp _, x -> aux Any x ~loc - | x, NoProp _ -> aux x Any ~loc - | _, Any -> true - | Any, _ -> all_relational b - | _, Relational -> true - | Relational, _ -> false - | Functional, Functional -> true - | AssumedFunctional, _ -> true - | _, AssumedFunctional -> error ~loc (Format.asprintf "Cannot compare AssumedFunctional with different functionality") - | Arrow (l1, r1), Arrow (l2, r2) -> aux l2 l1 ~loc && aux r1 r2 ~loc - | Arrow _, _ | _, Arrow _ -> - error ~loc (Format.asprintf "Type error1 in comparing %a and %a" pp_functionality a pp_functionality b) -(* | NoProp _, _ | _, NoProp _ -> error ~loc "Type error2" *) + match (a, b) with + | BoundVar _, _ | _, BoundVar _ -> true (* TODO: this is not correct... -> use ref with uvar to make unification *) + | NoProp _, x -> aux Any x ~loc + | x, NoProp _ -> aux x Any ~loc + | _, Any -> true + | Any, _ -> all_relational b + | _, Relational -> true + | Relational, _ -> false + | Functional, Functional -> true + | AssumedFunctional, _ -> true + | _, AssumedFunctional -> + error ~loc (Format.asprintf "Cannot compare AssumedFunctional with different functionality") + | Arrow (l1, r1), Arrow (l2, r2) -> aux l2 l1 ~loc && aux r1 r2 ~loc + | Arrow _, _ | _, Arrow _ -> + error ~loc (Format.asprintf "Type error1 in comparing %a and %a" pp_functionality a pp_functionality b) + (* | NoProp _, _ | _, NoProp _ -> error ~loc "Type error2" *) in let res = aux a b ~loc in to_print (fun () -> Format.eprintf "%a <= %a = %b@." pp_functionality a pp_functionality b res); @@ -277,7 +278,6 @@ let cmp ~loc f1 f2 = | false, false -> error ~loc (Format.asprintf "Functionality %a and %a are not comparable" pp_functionality f1 pp_functionality f2) - (* R e adesso so che è F -> F *) (* p X, q X *) @@ -381,7 +381,9 @@ module Checker_clause = struct Some (Compilation.TypeAssignment.type_ass_2func ~loc global ty)) | Const (Global { decl_id }, _) -> Some - (match get_functionality ~uf ~loc ~tyag:(ty, []) ~env:global decl_id with Relational -> Functional | e -> e) + (match get_functionality ~uf ~loc ~tyag:(ty, []) ~env:global decl_id with + | Relational -> Functional + | e -> e) | App (Global { decl_id }, _, x, xs) -> Some (get_functionality ~uf ~loc ~tyag:(ty, x :: xs) ~env:global decl_id) | App (Bound scope, f, _, _) | Const (Bound scope, f) -> Ctx.get ctx (f, scope) | CData _ -> Some (NoProp []) @@ -582,7 +584,7 @@ module Checker_clause = struct let v = get_funct_of_term ctx t |> Option.get in add_env ~loc:t.loc n ~v; r - | _ -> if ((infer ctx t) <<= l) ~loc:t.loc then r else Any) + | _ -> if (infer ctx t <<= l) ~loc:t.loc then r else Any) and infer_outputs_fail ctx = fold_on_modes (fun _ _ r -> r) @@ -786,8 +788,7 @@ module Checker_clause = struct to_print (fun () -> Format.eprintf "The contex_head is %a@." pp_env ()); - Format.eprintf "Getting the functionality of %a and func hd is %a@." ScopedTerm.pp hd pp_functionality (get_func_hd ctx hd); - + (* Format.eprintf "Getting the functionality of %a and func hd is %a@." ScopedTerm.pp hd pp_functionality (get_func_hd ctx hd); *) Option.iter (fun body -> check_body ctx body (get_func_hd ctx hd)) body; if body <> None then check_head_output ctx hd @@ -799,8 +800,8 @@ end let to_check_clause ScopedTerm.{ it; loc } = let n = get_namef it in - (not (F.equal n F.mainf)) - (* && Re.Str.string_match (Re.Str.regexp ".*test.*") (Loc.show loc) 0 *) + not (F.equal n F.mainf) +(* && Re.Str.string_match (Re.Str.regexp ".*test.*") (Loc.show loc) 0 *) (* && Re.Str.string_match (Re.Str.regexp ".*test.*functionality.*") (Loc.show loc) 0 *) let check_clause ?uf ~loc ~env ~modes t = @@ -827,10 +828,7 @@ class merger (all_func : env) = method get_all_func = all_func method get_local_func = local_func method add_ty_abbr = self#add_func true - - method add_func_ty_list ty (id_list : IdPos.t list) = - List.iter2 (self#add_func false) id_list ty - + method add_func_ty_list ty id_list = List.iter2 (self#add_func false) id_list ty method merge : env = merge all_func local_func end diff --git a/src/compiler/determinacy_checker.mli b/src/compiler/determinacy_checker.mli index 68537177..2488e068 100644 --- a/src/compiler/determinacy_checker.mli +++ b/src/compiler/determinacy_checker.mli @@ -3,7 +3,7 @@ (* ------------------------------------------------------------------------- *) open Compiler_data open Elpi_util.Util -module Union_find : Union_find.UF2S with type key = IdPos.t and type t = IdPos.t IdPos.Map.t +module Union_find : Union_find.S with type key = IdPos.t and type t = IdPos.t IdPos.Map.t type t [@@deriving show, ord] diff --git a/src/compiler/test_union_find.ml b/src/compiler/test_union_find.ml index c6abf3dc..cba9dd07 100644 --- a/src/compiler/test_union_find.ml +++ b/src/compiler/test_union_find.ml @@ -1,36 +1,28 @@ open Elpi_compiler.Union_find -module M = Make (struct - include Int - - let hash x = x - let pp _ _ = () -end) +module M = Make (Elpi_util.Util.IntMap) open M let _ = - (* Partition avec 9 classes (qui sont des singletons) *) - let uf = create () in - - for i = 1 to 8 do - add uf i - done; + (* From https://fr.wikipedia.org/wiki/Union-find#/media/Fichier:Dsu_disjoint_sets_final.svg *) + let uf = ref empty in + let update uf (_,act) = uf := act in + let union uf a b = update uf (union !uf a b) in (* Partition avec 4 classes disjointes obtenue après Union(1, 2), Union(3, 4), Union(2, 5), Union(1, 6) et Union(2, 8). *) - union uf (1, 2); - union uf (3, 4); - union uf (2, 5); + union uf 1 2; + union uf 3 4; + union uf 2 5; - let uf1 = do_open (do_close uf) in + let uf1 = ref !uf in - union uf (1, 6); - union uf (2, 8); - assert (roots uf |> List.length = 3); + union uf 1 6; + union uf 3 1; + assert (roots !uf |> List.length = 1); (* The cloned table is not impacted by the modification in uf *) - (* uf should be: https://fr.wikipedia.org/wiki/Union-find#/media/Fichier:Dsu_disjoint_sets_final.svg *) - assert (roots uf1 |> List.length = 5); - union uf1 (1, 6); - assert (roots uf1 |> List.length = 4); - assert (roots uf |> List.length = 3) + assert (roots !uf1 |> List.length = 2); + union uf1 7 8; + assert (roots !uf1 |> List.length = 3); + assert (roots !uf |> List.length = 1) diff --git a/src/compiler/union_find.ml b/src/compiler/union_find.ml index 25f0c279..1d0af4cb 100644 --- a/src/compiler/union_find.ml +++ b/src/compiler/union_find.ml @@ -1,191 +1,199 @@ open Elpi_util -module type M = sig - include Hashtbl.HashedType - - val pp : Format.formatter -> t -> unit - val compare : t -> t -> int -end - -module Make (M : M) = struct - module HT = struct - module Hashtbl = Hashtbl.Make (M) - - type uf = { mutable rank : int; id : M.t; mutable parent : uf } - type opened = uf Hashtbl.t - type closed = opened - - let is_root t = t == t.parent - let roots tbl = Hashtbl.fold (fun _ e acc -> if is_root e then e.id :: acc else acc) tbl [] - let rec to_list t = if is_root t then [ t.id ] else t.id :: to_list t.parent - let create () = Hashtbl.create 2024 - - let add tbl id = - if not (Hashtbl.mem tbl id) then - let rec cell = { rank = 0; id; parent = cell } in - Hashtbl.add tbl id cell - - let find tbl key = - let t = Hashtbl.find tbl key in - let rec aux t = - if is_root t then t - else ( - t.parent <- t.parent.parent; - aux t.parent) - in - aux t - - let union tbl (v1, v2) = - let x = find tbl v1 in - let y = find tbl v2 in - if x == y then () - else - let x, y = if x.rank < y.rank then (y, x) else (x, y) in - y.parent <- x; - if x.rank = y.rank then x.rank <- x.rank + 1 - - let find tbl key = (find tbl key).id - - let rec clone uf = - let cell = { rank = uf.rank; parent = uf.parent; id = uf.id } in - let parent = if is_root uf then cell else clone uf.parent in - cell.parent <- parent; - cell - - let do_open tbl = - Hashtbl.fold - (fun k v acc -> - Hashtbl.replace acc k (clone v); - acc) - tbl (create ()) - - let merge tbl1 tbl2 = - let tbl1 = do_open tbl1 in - Hashtbl.iter (fun k v -> Hashtbl.replace tbl1 k (clone v)) tbl2; - tbl1 - end - - module Map = struct - module Hashtbl = Map.Make (M) - - type uf = { mutable rank : int; id : M.t; mutable parent : uf } - type opened = uf Hashtbl.t - type closed = opened - - let is_root t = t == t.parent - let roots tbl = Hashtbl.fold (fun _ e acc -> if is_root e then e.id :: acc else acc) tbl [] - let rec to_list t = if is_root t then [ t.id ] else t.id :: to_list t.parent - let create () = Hashtbl.empty - - let add tbl id = - if not (Hashtbl.mem id tbl) then - let rec cell = { rank = 0; id; parent = cell } in - Hashtbl.add id cell tbl - else tbl - - let find tbl key = - let t = Hashtbl.find tbl key in - let rec aux t = - if is_root t then t - else ( - t.parent <- t.parent.parent; - aux t.parent) - in - aux t - - let union tbl (v1, v2) = - let x = find tbl v1 in - let y = find tbl v2 in - if x == y then () - else - let x, y = if x.rank < y.rank then (y, x) else (x, y) in - y.parent <- x; - if x.rank = y.rank then x.rank <- x.rank + 1 - - let find tbl key = (find tbl key).id - - let rec clone uf = - let cell = { rank = uf.rank; parent = uf.parent; id = uf.id } in - let parent = if is_root uf then cell else clone uf.parent in - cell.parent <- parent; - cell - - let do_open tbl = Hashtbl.fold (fun k v acc -> Hashtbl.add k (clone v) acc) tbl (create ()) - - let merge tbl1 tbl2 = - let tbl1 = do_open tbl1 in - Hashtbl.fold (fun k v acc -> Hashtbl.add k (clone v) acc) tbl2 tbl1 - end - - include HT - - let do_close a = a - - (* printers *) - let pp_uf fmt t = Format.fprintf fmt "%a" (Util.pplist M.pp ",") (to_list t) - - let pp fmt tbl = - let keys = Hashtbl.to_seq tbl in - let l = List.of_seq keys in - let sorted_keys = List.sort (fun a b -> compare (fst a) (fst b)) l in - let pp_elt fmt (k, v) = - if is_root v then Format.fprintf fmt "@[@[%a@] -> root;" M.pp k - else Format.fprintf fmt "@[@[%a@] -> @[%a@]@];" M.pp k pp_uf v.parent - in - Format.fprintf fmt "{{ %a }}" (Util.pplist pp_elt ",") sorted_keys - - let show t = Format.asprintf "%a" pp t - let pp_closed = pp - let show_closed = show - let pp_opened = pp - let show_opened = show - let create_close = create - let union_c = union -end - -module type UF2S = sig +(* module type M = sig + include Hashtbl.HashedType + + val pp : Format.formatter -> t -> unit + val compare : t -> t -> int + end + + module Make (M : M) = struct + module HT = struct + module Hashtbl = Hashtbl.Make (M) + + type uf = { mutable rank : int; id : M.t; mutable parent : uf } + type opened = uf Hashtbl.t + type closed = opened + + let is_root t = t == t.parent + let roots tbl = Hashtbl.fold (fun _ e acc -> if is_root e then e.id :: acc else acc) tbl [] + let rec to_list t = if is_root t then [ t.id ] else t.id :: to_list t.parent + let create () = Hashtbl.create 2024 + + let add tbl id = + if not (Hashtbl.mem tbl id) then + let rec cell = { rank = 0; id; parent = cell } in + Hashtbl.add tbl id cell + + let find tbl key = + let t = Hashtbl.find tbl key in + let rec aux t = + if is_root t then t + else ( + t.parent <- t.parent.parent; + aux t.parent) + in + aux t + + let union tbl (v1, v2) = + let x = find tbl v1 in + let y = find tbl v2 in + if x == y then () + else + let x, y = if x.rank < y.rank then (y, x) else (x, y) in + y.parent <- x; + if x.rank = y.rank then x.rank <- x.rank + 1 + + let find tbl key = (find tbl key).id + + let rec clone uf = + let cell = { rank = uf.rank; parent = uf.parent; id = uf.id } in + let parent = if is_root uf then cell else clone uf.parent in + cell.parent <- parent; + cell + + let do_open tbl = + Hashtbl.fold + (fun k v acc -> + Hashtbl.replace acc k (clone v); + acc) + tbl (create ()) + + let merge tbl1 tbl2 = + let tbl1 = do_open tbl1 in + Hashtbl.iter (fun k v -> Hashtbl.replace tbl1 k (clone v)) tbl2; + tbl1 + end + + module Map = struct + module Hashtbl = Map.Make (M) + + type uf = { mutable rank : int; id : M.t; mutable parent : uf } + type opened = uf Hashtbl.t + type closed = opened + + let is_root t = t == t.parent + let roots tbl = Hashtbl.fold (fun _ e acc -> if is_root e then e.id :: acc else acc) tbl [] + let rec to_list t = if is_root t then [ t.id ] else t.id :: to_list t.parent + let create () = Hashtbl.empty + + let add tbl id = + if not (Hashtbl.mem id tbl) then + let rec cell = { rank = 0; id; parent = cell } in + Hashtbl.add id cell tbl + else tbl + + let find tbl key = + let t = Hashtbl.find tbl key in + let rec aux t = + if is_root t then t + else ( + t.parent <- t.parent.parent; + aux t.parent) + in + aux t + + let union tbl (v1, v2) = + let x = find tbl v1 in + let y = find tbl v2 in + if x == y then () + else + let x, y = if x.rank < y.rank then (y, x) else (x, y) in + y.parent <- x; + if x.rank = y.rank then x.rank <- x.rank + 1 + + let find tbl key = (find tbl key).id + + let rec clone uf = + let cell = { rank = uf.rank; parent = uf.parent; id = uf.id } in + let parent = if is_root uf then cell else clone uf.parent in + cell.parent <- parent; + cell + + let do_open tbl = Hashtbl.fold (fun k v acc -> Hashtbl.add k (clone v) acc) tbl (create ()) + + let merge tbl1 tbl2 = + let tbl1 = do_open tbl1 in + Hashtbl.fold (fun k v acc -> Hashtbl.add k (clone v) acc) tbl2 tbl1 + end + + include HT + + let do_close a = a + + (* printers *) + let pp_uf fmt t = Format.fprintf fmt "%a" (Util.pplist M.pp ",") (to_list t) + + let pp fmt tbl = + let keys = Hashtbl.to_seq tbl in + let l = List.of_seq keys in + let sorted_keys = List.sort (fun a b -> compare (fst a) (fst b)) l in + let pp_elt fmt (k, v) = + if is_root v then Format.fprintf fmt "@[@[%a@] -> root;" M.pp k + else Format.fprintf fmt "@[@[%a@] -> @[%a@]@];" M.pp k pp_uf v.parent + in + Format.fprintf fmt "{{ %a }}" (Util.pplist pp_elt ",") sorted_keys + + let show t = Format.asprintf "%a" pp t + let pp_closed = pp + let show_closed = show + let pp_opened = pp + let show_opened = show + let create_close = create + let union_c = union + end *) + +module type S = sig include Util.Show include Util.ShowKey val empty : t - val add : t -> key -> t + val is_empty : t -> bool val find : t -> key -> key val union : t -> key -> key -> key option * t val merge : t -> t -> t + val roots : t -> key list end - -module UF2 (M : Elpi_util.Util.Map.S) = struct +module Make (M : Elpi_util.Util.Map.S) : S with type t = M.key M.t and type key = M.key = struct type key = M.key [@@deriving show] type t = key M.t [@@deriving show] let empty = M.empty - let add m v1 = M.add v1 v1 m - let find m v = M.find_opt v m |> Option.value ~default:v - - let union m j i = + let is_empty = ( = ) M.empty + let rec find m v = + match M.find_opt v m with + | None -> v + | Some e -> find m e + + let union m i j = + assert ( i <> j ); let ri = find m i in let rj = find m j in (* r1 is put in the same disjoint set of rj and can be removed from other data structures *) - if ri <> rj then (Some ri, M.add ri rj m) else None, m + if ri <> rj then (Some ri, M.add ri rj m) else (None, m) let merge u1 u2 = (* all disjoint-set in u1 and u2 should be pairwise disjoint *) M.union (fun _ a _ -> Some a) u1 u2 - (* M.fold (fun k father acc -> - let acc = if M.mem father acc then assert false else add acc father in - union acc k father - ) u1 u2 *) + (* M.fold (fun k father acc -> + let acc = if M.mem father acc then assert false else add acc father in + union acc k father + ) u1 u2 *) let is_root acc k = find acc k = k + let roots d = + let roots = ref [] in + let add e = if not (List.mem e !roots) then roots := e :: !roots in + M.iter (fun k v -> add (find d k)) d; + !roots + let pp fmt v = Format.fprintf fmt "{{\n"; - M.iter (fun k v -> - if k <> v then - Format.fprintf fmt "@[%a -> %a@]\n" M.pp_key k M.pp_key v - ) v; + M.iter (fun k v -> if k <> v then Format.fprintf fmt "@[%a -> %a@]\n" M.pp_key k M.pp_key v) v; Format.fprintf fmt "}}@." + let pp_key = M.pp_key end diff --git a/src/compiler/union_find.mli b/src/compiler/union_find.mli index 9e80797d..15441f2a 100644 --- a/src/compiler/union_find.mli +++ b/src/compiler/union_find.mli @@ -1,13 +1,13 @@ open Elpi_util.Util -module type M = sig +(* module type M = sig include Hashtbl.HashedType val pp : Format.formatter -> t -> unit val compare : t -> t -> int end -module Make : functor (M : M) -> sig +module Make1 : functor (M : M) -> sig type opened [@@deriving show] type closed [@@deriving show] @@ -32,17 +32,18 @@ module Make : functor (M : M) -> sig val union_c : closed -> M.t * M.t -> unit val pp : Format.formatter -> closed -> unit -end +end *) -module type UF2S = sig +module type S = sig include Show include ShowKey val empty : t - val add : t -> key -> t + val is_empty : t -> bool val find : t -> key -> key val union : t -> key -> key -> key option * t val merge : t -> t -> t + val roots : t -> key list end -module UF2 : functor (M : Elpi_util.Util.Map.S) -> UF2S with type key = M.key and type t = M.key M.t +module Make (M : Elpi_util.Util.Map.S) : S with type key = M.key and type t = M.key M.t diff --git a/tests/sources/dune b/tests/sources/dune index 18d25315..0aba2917 100644 --- a/tests/sources/dune +++ b/tests/sources/dune @@ -13,4 +13,5 @@ (executable (name sepcomp_perf3) (modules sepcomp_perf3) (libraries sepcomp)) (executable (name sepcomp_perf4) (modules sepcomp_perf4) (libraries sepcomp)) (executable (name sepcomp_perf5) (modules sepcomp_perf5) (libraries sepcomp)) -(executable (name sepcomp_tyid) (modules sepcomp_tyid) (libraries sepcomp)) \ No newline at end of file +(executable (name sepcomp_tyid) (modules sepcomp_tyid) (libraries sepcomp)) +(executable (name sepcomp_tyid2) (modules sepcomp_tyid2) (libraries sepcomp)) \ No newline at end of file diff --git a/tests/sources/sepcomp_tyid.ml b/tests/sources/sepcomp_tyid.ml index a165bc09..01c4d788 100644 --- a/tests/sources/sepcomp_tyid.ml +++ b/tests/sources/sepcomp_tyid.ml @@ -17,31 +17,11 @@ let () = let elpi = init () in let flags = Compile.default_flags in let pmain,_ = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 1 maine in - let _,(unit1:Compile.compilation_unit) = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 2 p1 in + let _,unit1 = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 2 p1 in let _,unit2 = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 3 p2 in let cp = Compile.extend ~base:pmain unit1 in let cp = Compile.extend ~base:cp unit2 in - (* Format.eprintf "The program is @[%a@]@." Compile.pp_program cp; *) - let q = Compile.query cp (Parse.goal_from ~elpi ~loc:(Ast.Loc.initial "g") (Lexing.from_string "main")) in - (* TODO: check that in the determinacy checker map the pred p appears once *) - exec q - -let () = - let open Sepcomp.Sepcomp_template in - let elpi = init () in - let flags = Compile.default_flags in - let pmain,_ = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 1 maine in - let pmain,_ = cc ~elpi ~flags ~base:pmain 2 p1 in - let pmain,_ = cc ~elpi ~flags ~base:pmain 3 p2 in - - (* let cp = Compile.extend ~base:pmain unit1 in - let cp = Compile.extend ~base:cp unit2 in *) - - (* Format.eprintf "The program is @[%a@]@." Compile.pp_program cp; *) - - let q = Compile.query pmain (Parse.goal_from ~elpi ~loc:(Ast.Loc.initial "g") (Lexing.from_string "main")) in - (* TODO: check that in the determinacy checker map the pred p appears once *) exec q diff --git a/tests/sources/sepcomp_tyid2.ml b/tests/sources/sepcomp_tyid2.ml new file mode 100644 index 00000000..e4a5afa6 --- /dev/null +++ b/tests/sources/sepcomp_tyid2.ml @@ -0,0 +1,25 @@ +let p1 = {| + pred p o:int. + p 1. +|} + +let p2 = {| + pred p o:int. + p 2. +|} + +let maine = "pred p o:int. main :- std.findall (p _) L, print L." + +open Elpi.API + +let () = + let open Sepcomp.Sepcomp_template in + let elpi = init () in + let flags = Compile.default_flags in + let pmain,_ = cc ~elpi ~flags ~base:(Compile.empty_base ~elpi) 1 maine in + let pmain,_ = cc ~elpi ~flags ~base:pmain 2 p1 in + let pmain,_ = cc ~elpi ~flags ~base:pmain 3 p2 in + + let q = Compile.query pmain (Parse.goal_from ~elpi ~loc:(Ast.Loc.initial "g") (Lexing.from_string "main")) in + + exec q