Skip to content

Commit

Permalink
[union-find] version with persistent map instead of "persistent" Hashtbl
Browse files Browse the repository at this point in the history
- simple union find with Map and primitives union, find, create, merge
- track id removed by merge of typeAssignment
- remove ids from functional_preds accordingly
- update small test on union find accordingly
- update parallel test compilation
  • Loading branch information
FissoreD committed Dec 3, 2024
1 parent fdf3ec1 commit 75009b3
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 286 deletions.
84 changes: 46 additions & 38 deletions src/compiler/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ module C = Constants

open Compiler_data

module Union_find = Union_find.UF2(IdPos.Map)
module Union_find = Union_find.Make(IdPos.Map)

type macro_declaration = (ScopedTerm.t * Loc.t) F.Map.t
[@@ deriving show, ord]
Expand Down Expand Up @@ -1123,7 +1123,7 @@ module Flatten : sig
Union_find.t ->
TypeAssignment.overloaded_skema_with_id F.Map.t ->
TypeAssignment.overloaded_skema_with_id F.Map.t ->
Union_find.t * TypeAssignment.overloaded_skema_with_id F.Map.t
IdPos.t list * Union_find.t * TypeAssignment.overloaded_skema_with_id F.Map.t
val merge_type_abbrevs :
(F.t * ScopedTypeExpression.t) list ->
(F.t * ScopedTypeExpression.t) list ->
Expand All @@ -1136,7 +1136,7 @@ module Flatten : sig
Union_find.t ->
((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t ->
((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t ->
Union_find.t * ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t
IdPos.t list * Union_find.t * ((IdPos.t *TypeAssignment.skema) * Loc.t) F.Map.t

val merge_toplevel_macros :
(ScopedTerm.t * Loc.t) F.Map.t ->
Expand Down Expand Up @@ -1254,13 +1254,33 @@ module Flatten : sig
List.map (fun (k, v) -> subst_global s k, ScopedTypeExpression.smart_map (subst_global s) v) l

let merge_type_assignments uf t1 t2 =
let res = ref uf in
let uf = ref uf in
let to_remove = ref [] in
(* We give precedence to recent type declarations over old ones *)
let t = F.Map.union (fun f l1 l2 ->
let to_union, ta = TypeAssignment.merge_skema l2 l1 in
List.iter (fun (k,v) -> res := Union_find.union uf k v |> snd) to_union;
List.iter (fun (id1,id2) ->
let rem, uf1 = Union_find.union !uf id1 id2 in
uf := uf1;
Option.iter (fun x -> to_remove := x :: !to_remove) rem;
) to_union;
Some ta) t1 t2 in
!res, t
!to_remove, !uf, t

let merge_checked_type_abbrevs uf m1 m2 =
let uf = ref uf in
let to_remove = ref [] in
let m = F.Map.union (fun k ((id1,sk),otherloc as x) ((id2,ty),loc) ->
if TypeAssignment.compare_skema sk ty <> 0 then
error ~loc
("Duplicate type abbreviation for " ^ F.show k ^
". Previous declaration: " ^ Loc.show otherloc)
else
let rem, uf1 = Union_find.union !uf id1 id2 in
uf := uf1;
Option.iter (fun x -> to_remove := x :: !to_remove) rem;
Some x) m1 m2 in
!to_remove, !uf, m

let merge_types t1 t2 =
F.Map.union (fun _ l1 l2 -> Some (TypeList.merge l1 l2)) t1 t2
Expand All @@ -1280,18 +1300,6 @@ module Flatten : sig

let merge_type_abbrevs m1 m2 = m1 @ m2

let merge_checked_type_abbrevs uf m1 m2 =
let uf = ref uf in
let m = F.Map.union (fun k ((id1,sk),otherloc as x) ((id2,ty),loc) ->
if TypeAssignment.compare_skema sk ty <> 0 then
error ~loc
("Duplicate type abbreviation for " ^ F.show k ^
". Previous declaration: " ^ Loc.show otherloc)
else
uf := Union_find.union !uf id1 id2 |> snd;
Some x) m1 m2 in
!uf, m

let merge_toplevel_macros otlm toplevel_macros =
F.Map.union (fun k (m1,l1) (m2,l2) ->
if ScopedTerm.equal ~types:false m1 m2 then Some (m1,l1) else
Expand Down Expand Up @@ -1365,9 +1373,6 @@ end = struct

let { Flat.modes; kinds; types; type_abbrevs; toplevel_macros; type_uf } = signature in

let type_uf = ref type_uf in
let otuf = ref otuf in

let all_kinds = Flatten.merge_kinds ok kinds in
let func_setter_object = new Determinacy_checker.merger ofp in
let check_k_begin = Unix.gettimeofday () in
Expand All @@ -1384,8 +1389,6 @@ end = struct
". Previous declaration: " ^ Loc.show otherloc)
end
else func_setter_object#add_ty_abbr id scoped_ty;
otuf := Union_find.add !otuf id;
type_uf := Union_find.add !type_uf id;
F.Map.add name ((id, ty),loc) all_type_abbrevs, F.Map.add name ((id,ty),loc) type_abbrevs)
(ota,F.Map.empty) type_abbrevs in
let check_k_end = Unix.gettimeofday () in
Expand All @@ -1400,8 +1403,6 @@ end = struct
let types = F.Map.mapi (fun name e ->
let tys = Type_checker.check_types ~type_abbrevs:all_type_abbrevs ~kinds:all_kinds e in
let ids = get_ids tys in
List.iter (fun e -> otuf := Union_find.add !otuf e) ids;
List.iter (fun e -> type_uf := Union_find.add !type_uf e) ids;
func_setter_object#add_func_ty_list e ids;
tys) types in

Expand All @@ -1415,13 +1416,15 @@ end = struct

let check_t_end = Unix.gettimeofday () in

let otuf, all_types = Flatten.merge_type_assignments !otuf ot types in
let all_type_uf = Union_find.merge otuf type_uf in
let to_remove, all_type_uf, all_types = Flatten.merge_type_assignments all_type_uf ot types in
let all_toplevel_macros = Flatten.merge_toplevel_macros otlm toplevel_macros in
let all_modes = Flatten.merge_modes om modes in
let all_functional_preds = func_setter_object#merge in
let all_functional_preds = List.fold_left Determinacy_checker.remove all_functional_preds to_remove in

{ Assembled.modes; functional_preds = func_setter_object#get_local_func; kinds; types; type_abbrevs; toplevel_macros; type_uf= !type_uf },
{ Assembled.modes = all_modes; functional_preds = all_functional_preds; kinds = all_kinds; types = all_types; type_abbrevs = all_type_abbrevs; toplevel_macros = all_toplevel_macros; type_uf = otuf },
{ Assembled.modes; functional_preds = func_setter_object#get_local_func; kinds; types; type_abbrevs; toplevel_macros; type_uf },
{ Assembled.modes = all_modes; functional_preds = all_functional_preds; kinds = all_kinds; types = all_types; type_abbrevs = all_type_abbrevs; toplevel_macros = all_toplevel_macros; type_uf = all_type_uf },
check_t_end -. check_t_begin +. check_k_end -. check_k_begin,
types_indexing

Expand All @@ -1434,13 +1437,11 @@ end = struct

let check_begin = Unix.gettimeofday () in

Format.eprintf "Type uf is %a@." Union_find.pp type_uf;

let unknown, clauses = List.fold_left (fun (unknown,clauses) ({ Ast.Clause.body; loc; needs_spilling; attributes = { Ast.Structured.typecheck } } as clause) ->
let unknown =
if typecheck then Type_checker.check ~is_rule:true ~unknown ~type_abbrevs ~kinds ~types body ~exp:(Val Prop)
else unknown in
if String.starts_with ~prefix:"File \"<" (Loc.show loc) then Format.eprintf "The clause is %a@." ScopedTerm.pp body;
(* if String.starts_with ~prefix:"File \"<" (Loc.show loc) then Format.eprintf "The clause is %a@." ScopedTerm.pp body; *)
let spilled = {clause with body = if needs_spilling then Spilling.main body else body; needs_spilling = false} in
if typecheck then Determinacy_checker.check_clause ~uf:type_uf ~loc ~env:functional_preds spilled.body ~modes;
unknown, spilled :: clauses) (F.Map.empty,[]) clauses in
Expand All @@ -1460,13 +1461,15 @@ end = struct
) builtins;

let more_types = Type_checker.check_undeclared ~unknown in
let type_uf, u_types = Flatten.merge_type_assignments type_uf signature.Assembled.types more_types in
let type_uf, types = Flatten.merge_type_assignments type_uf types more_types in
let _, _, u_types = Flatten.merge_type_assignments type_uf signature.types more_types in
let _, _, types = Flatten.merge_type_assignments type_uf types more_types in

(* TODO: forall i in toremove @@ toremove1, remove i from functional preds *)

let check_end = Unix.gettimeofday () in

let signature = { signature with Assembled.types = u_types; type_uf } in
let precomputed_signature = { precomputed_signature with Assembled.types; type_uf } in
let signature = { signature with types = u_types; type_uf } in
let precomputed_signature = { precomputed_signature with types; type_uf } in

let checked_code = { CheckedFlat.signature; clauses; chr; builtins; types_indexing } in

Expand Down Expand Up @@ -1777,11 +1780,16 @@ let extend1_signature base_signature (signature : checked_compilation_unit_signa
let { Assembled.toplevel_macros; kinds; types; type_abbrevs; modes; functional_preds; type_uf } = signature in
let kinds = Flatten.merge_kinds ok kinds in
let type_uf = Union_find.merge otyuf type_uf in
let type_uf, type_abbrevs = Flatten.merge_checked_type_abbrevs type_uf ota type_abbrevs in
let type_uf, types = Flatten.merge_type_assignments type_uf ot types in
let to_remove, type_uf, type_abbrevs = Flatten.merge_checked_type_abbrevs type_uf ota type_abbrevs in
let to_remove1, type_uf, types = Flatten.merge_type_assignments type_uf ot types in
let modes = Flatten.merge_modes om modes in
let toplevel_macros = Flatten.merge_toplevel_macros otlm toplevel_macros in
Format.eprintf "Merged type uf is %a@." Union_find.pp type_uf;
let functional_preds =
let fp = Determinacy_checker.merge ofp functional_preds in
let fp = List.fold_left Determinacy_checker.remove fp to_remove in
let fp = List.fold_left Determinacy_checker.remove fp to_remove1 in
fp
in

{ Assembled.kinds; types; type_abbrevs; functional_preds; modes; toplevel_macros; type_uf }

Expand Down
68 changes: 33 additions & 35 deletions src/compiler/determinacy_checker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ open Elpi_util.Util
open Elpi_parser.Ast
open Compiler_data
module C = Constants
module Union_find = Union_find.UF2(IdPos.Map)
module Union_find = Union_find.Make (IdPos.Map)

let to_print f = if false then f ()

Expand All @@ -17,7 +17,6 @@ type functionality =
| NoProp of functionality list (** -> for kinds like list, int, string *)
| BoundVar of F.t (** -> in predicates like: std.exists or in parametric type abbreviations. *)
| AssumedFunctional (** -> variadic predicates: never backtrack *)

(* pred p i:int *)
(* Arrow (NoProp[]) (Relation) : NoProp -> Relation *)

Expand Down Expand Up @@ -198,21 +197,22 @@ module Compilation = struct
end

let merge = Compilation.merge
let remove t k = {t with cmap = IdPos.Map.remove k t.cmap}
let remove t k = { t with cmap = IdPos.Map.remove k t.cmap }

let get_functionality ~uf ?tyag ~loc ~env k' =
if k' = Scope.dummy_type_decl_id then Any
let get_functionality ~uf ?tyag ~loc ~env id =
if id = Scope.dummy_type_decl_id then Any
else
let k = Union_find.find uf k' in
if k' <> k then Format.eprintf "Found a father in uf.\n child%a\n father:%a@." IdPos.pp k IdPos.pp k';
match IdPos.Map.find_opt k env.cmap with
let id' = Union_find.find uf id in
if id <> id' then assert (not (IdPos.Map.mem id env.cmap));
(* Sanity check *)
match IdPos.Map.find_opt id' env.cmap with
| None -> (
(* TODO: this is temporary: waiting for unknown type to be added in the
type db After that change, the match becomes useless and ~tyag can
be removed from the parameters of get_functionality
*)
match tyag with
| None -> error ~loc (Format.asprintf "Cannot find functionality of id %a\n%!" IdPos.pp k)
| None -> error ~loc (Format.asprintf "Cannot find functionality of id %a\n%!" IdPos.pp id')
| Some (ty, ag) -> Compilation.TypeAssignment.type_ass_2func ~ag ~loc env ty)
| Some (name, func) -> if F.equal F.pif name || F.equal F.sigmaf name then functionality_pi_sigma else func

Expand All @@ -227,21 +227,22 @@ let rec all_relational = function

let ( <<= ) ~loc a b =
let rec aux ~loc a b =
match (a, b) with
| BoundVar _, _ | _, BoundVar _ -> true (* TODO: this is not correct... -> use ref with uvar to make unification *)
| NoProp _, x -> aux Any x ~loc
| x, NoProp _ -> aux x Any ~loc
| _, Any -> true
| Any, _ -> all_relational b
| _, Relational -> true
| Relational, _ -> false
| Functional, Functional -> true
| AssumedFunctional, _ -> true
| _, AssumedFunctional -> error ~loc (Format.asprintf "Cannot compare AssumedFunctional with different functionality")
| Arrow (l1, r1), Arrow (l2, r2) -> aux l2 l1 ~loc && aux r1 r2 ~loc
| Arrow _, _ | _, Arrow _ ->
error ~loc (Format.asprintf "Type error1 in comparing %a and %a" pp_functionality a pp_functionality b)
(* | NoProp _, _ | _, NoProp _ -> error ~loc "Type error2" *)
match (a, b) with
| BoundVar _, _ | _, BoundVar _ -> true (* TODO: this is not correct... -> use ref with uvar to make unification *)
| NoProp _, x -> aux Any x ~loc
| x, NoProp _ -> aux x Any ~loc
| _, Any -> true
| Any, _ -> all_relational b
| _, Relational -> true
| Relational, _ -> false
| Functional, Functional -> true
| AssumedFunctional, _ -> true
| _, AssumedFunctional ->
error ~loc (Format.asprintf "Cannot compare AssumedFunctional with different functionality")
| Arrow (l1, r1), Arrow (l2, r2) -> aux l2 l1 ~loc && aux r1 r2 ~loc
| Arrow _, _ | _, Arrow _ ->
error ~loc (Format.asprintf "Type error1 in comparing %a and %a" pp_functionality a pp_functionality b)
(* | NoProp _, _ | _, NoProp _ -> error ~loc "Type error2" *)
in
let res = aux a b ~loc in
to_print (fun () -> Format.eprintf "%a <= %a = %b@." pp_functionality a pp_functionality b res);
Expand Down Expand Up @@ -277,7 +278,6 @@ let cmp ~loc f1 f2 =
| false, false ->
error ~loc (Format.asprintf "Functionality %a and %a are not comparable" pp_functionality f1 pp_functionality f2)


(* R e adesso so che è F -> F *)
(* p X, q X *)

Expand Down Expand Up @@ -381,7 +381,9 @@ module Checker_clause = struct
Some (Compilation.TypeAssignment.type_ass_2func ~loc global ty))
| Const (Global { decl_id }, _) ->
Some
(match get_functionality ~uf ~loc ~tyag:(ty, []) ~env:global decl_id with Relational -> Functional | e -> e)
(match get_functionality ~uf ~loc ~tyag:(ty, []) ~env:global decl_id with
| Relational -> Functional
| e -> e)
| App (Global { decl_id }, _, x, xs) -> Some (get_functionality ~uf ~loc ~tyag:(ty, x :: xs) ~env:global decl_id)
| App (Bound scope, f, _, _) | Const (Bound scope, f) -> Ctx.get ctx (f, scope)
| CData _ -> Some (NoProp [])
Expand Down Expand Up @@ -582,7 +584,7 @@ module Checker_clause = struct
let v = get_funct_of_term ctx t |> Option.get in
add_env ~loc:t.loc n ~v;
r
| _ -> if ((infer ctx t) <<= l) ~loc:t.loc then r else Any)
| _ -> if (infer ctx t <<= l) ~loc:t.loc then r else Any)
and infer_outputs_fail ctx =
fold_on_modes
(fun _ _ r -> r)
Expand Down Expand Up @@ -786,8 +788,7 @@ module Checker_clause = struct

to_print (fun () -> Format.eprintf "The contex_head is %a@." pp_env ());

Format.eprintf "Getting the functionality of %a and func hd is %a@." ScopedTerm.pp hd pp_functionality (get_func_hd ctx hd);

(* Format.eprintf "Getting the functionality of %a and func hd is %a@." ScopedTerm.pp hd pp_functionality (get_func_hd ctx hd); *)
Option.iter (fun body -> check_body ctx body (get_func_hd ctx hd)) body;

if body <> None then check_head_output ctx hd
Expand All @@ -799,8 +800,8 @@ end

let to_check_clause ScopedTerm.{ it; loc } =
let n = get_namef it in
(not (F.equal n F.mainf))
(* && Re.Str.string_match (Re.Str.regexp ".*test.*") (Loc.show loc) 0 *)
not (F.equal n F.mainf)
(* && Re.Str.string_match (Re.Str.regexp ".*test.*") (Loc.show loc) 0 *)
(* && Re.Str.string_match (Re.Str.regexp ".*test.*functionality.*") (Loc.show loc) 0 *)

let check_clause ?uf ~loc ~env ~modes t =
Expand All @@ -827,10 +828,7 @@ class merger (all_func : env) =
method get_all_func = all_func
method get_local_func = local_func
method add_ty_abbr = self#add_func true

method add_func_ty_list ty (id_list : IdPos.t list) =
List.iter2 (self#add_func false) id_list ty

method add_func_ty_list ty id_list = List.iter2 (self#add_func false) id_list ty
method merge : env = merge all_func local_func
end

Expand Down
2 changes: 1 addition & 1 deletion src/compiler/determinacy_checker.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
(* ------------------------------------------------------------------------- *)
open Compiler_data
open Elpi_util.Util
module Union_find : Union_find.UF2S with type key = IdPos.t and type t = IdPos.t IdPos.Map.t
module Union_find : Union_find.S with type key = IdPos.t and type t = IdPos.t IdPos.Map.t

type t [@@deriving show, ord]

Expand Down
40 changes: 16 additions & 24 deletions src/compiler/test_union_find.ml
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
open Elpi_compiler.Union_find

module M = Make (struct
include Int

let hash x = x
let pp _ _ = ()
end)
module M = Make (Elpi_util.Util.IntMap)

open M

let _ =
(* Partition avec 9 classes (qui sont des singletons) *)
let uf = create () in

for i = 1 to 8 do
add uf i
done;
(* From https://fr.wikipedia.org/wiki/Union-find#/media/Fichier:Dsu_disjoint_sets_final.svg *)
let uf = ref empty in
let update uf (_,act) = uf := act in
let union uf a b = update uf (union !uf a b) in

(* Partition avec 4 classes disjointes obtenue après
Union(1, 2), Union(3, 4), Union(2, 5), Union(1, 6) et Union(2, 8). *)
union uf (1, 2);
union uf (3, 4);
union uf (2, 5);
union uf 1 2;
union uf 3 4;
union uf 2 5;

let uf1 = do_open (do_close uf) in
let uf1 = ref !uf in

union uf (1, 6);
union uf (2, 8);
assert (roots uf |> List.length = 3);
union uf 1 6;
union uf 3 1;
assert (roots !uf |> List.length = 1);
(* The cloned table is not impacted by the modification in uf *)
(* uf should be: https://fr.wikipedia.org/wiki/Union-find#/media/Fichier:Dsu_disjoint_sets_final.svg *)
assert (roots uf1 |> List.length = 5);
union uf1 (1, 6);
assert (roots uf1 |> List.length = 4);
assert (roots uf |> List.length = 3)
assert (roots !uf1 |> List.length = 2);
union uf1 7 8;
assert (roots !uf1 |> List.length = 3);
assert (roots !uf |> List.length = 1)
Loading

0 comments on commit 75009b3

Please sign in to comment.