Skip to content

Commit

Permalink
Merge pull request #25 from aspiwack/refactor-cleanup
Browse files Browse the repository at this point in the history
refactor: pull the various types in their own modules
  • Loading branch information
mergify[bot] authored Aug 2, 2019
2 parents 7c2ba10 + 66b1392 commit e37cf34
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 133 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
_build
.merlin
*.install
177 changes: 44 additions & 133 deletions src/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,100 +16,19 @@ Steps:
- Implement example sampler
- Parse programs
*)
type variable = string

type location = variable
type item = variable
type indexed_item = item * int
(* Future plan for item: there can be more than one of an item in the
pool. In which case it will be translated to n variables (where n
is the number of occurrences of that item in the pool), the
corresponding range_constraints will be translated to formulas
which ensure that the n variables are ordered (to avoid generating
many semantically equal shuffles) *)
type range_constraint = {
scrutinee: item;
range: location list }

type 'i atom =
| Reach of location
| Have of 'i * int
| Assign of location * 'i
(* Assign doesn't come from parsing as it is equivalent to a
singleton range_constraint. It appears in translations.*)
let map_atom f = function
| Reach l -> Reach l
| Have (i, n) -> Have (f i, n)
| Assign (l, i) -> Assign (l, f i)
type clause = {
goal: item atom;
requires: item atom list
}

module StringMap = Map.Make(CCString)

type program = {
locations: location list;
pool: item list;
range_constraints: range_constraint list;
range_definitions: location list StringMap.t;
logic: clause list;
goal: item atom
}

let hash_location = CCHash.string
let hash_item = CCHash.string
let hash_indexed_item (i,n) = CCHash.(combine2 (string i) (int n))
let hash_atom hash_item = function
| Reach l -> CCHash.(combine2 (int 0) (hash_location l))
| Have (i, n) -> CCHash.(combine3 (int 1) (hash_item i) (int n))
| Assign (l,i) -> CCHash.(combine3 (int 2) (hash_location l) (hash_item i))

let print_item i = i
let print_indexed_item (i, n) = i ^ "_" ^ (string_of_int n)
let print_atom print_item = function
| Reach l -> "reach: " ^ l
| Have (i,1) -> "have: " ^ print_item i
| Have (i,n) -> "have: " ^ print_item i ^ " *" ^ (string_of_int n)
| Assign(l,i) -> print_item i ^ "" ^ l
let print_timed_atom print_item = let open Provable in function
| Selection a -> print_atom print_item a
| At(a,i) -> print_atom print_item a ^ " @ " ^ string_of_int i
| Action (n, i) -> n ^ " @ " ^ string_of_int i

let pp_clause fmt {goal;requires} =
let pp_atom fmt a = CCString.pp fmt (print_atom print_item a) in
Format.fprintf fmt "%a :- @[<hov>%a@]" pp_atom goal (CCList.pp ~sep:"," pp_atom) requires
let pp_range fmt {scrutinee;range} =
Format.fprintf fmt "%s ∈ {%a}" scrutinee (CCList.pp CCString.pp) range
let pp_program fmt prog =
(* XXX: I'm not printing range_definitions *)
let pp_locations = CCList.pp (fun fmt l -> Format.fprintf fmt "@[<h>%s@]" l) in
let pp_pool = CCList.pp (fun fmt i -> Format.fprintf fmt "@[<h>%s@]" i) in
let pp_ranges = CCList.pp pp_range in
let pp_logic = CCList.pp pp_clause in
let pp_goal fmt g = Format.fprintf fmt "%s." (print_atom print_item g) in
Format.fprintf fmt "@[<v>@[<v 2>[Locations]@,%a@]@,@[<v 2>[Pool]@,%a@]@,@[<v 2>[Ranges]@,%a@]@,@[<v 2>[Logic]@,%a@]@,@[<v 2>[Goal]@,%a@]@]@." pp_locations prog.locations pp_pool prog.pool pp_ranges prog.range_constraints pp_logic prog.logic pp_goal prog.goal


open Types
let (&&&) = MLBDD.dand
let (|||) = MLBDD.dor
let anot = MLBDD.dnot
let (-->) = MLBDD.imply

(* XXX: do I still need the AtomSet? *)
module AtomSet = Set.Make (struct type t = item atom let compare = compare end)
module TimedAtomSet = Set.Make (struct type t = item atom Provable.timed let compare = compare end)
module TimedAtomMap = Map.Make (struct type t = item atom Provable.timed let compare = compare end)
(* XXX: remove?*)
module AtomMap = Map.Make (struct type t = item atom let compare = compare end)
type formula = Item.t Atom.t Provable.timed Formula.t

type formula = item atom Provable.timed Formula.t

module TimedLiteral = Sat.Literal(struct type t = item atom Provable.timed let equal = (=) let hash = Provable.hash (hash_atom hash_item) let pp fmt a = Format.fprintf fmt "%s" (print_timed_atom print_item a) end)
module TimedIndexedLiteral = Sat.Literal(struct type t = indexed_item atom Provable.timed let equal = (=) let hash = Provable.hash (hash_atom hash_indexed_item) let pp fmt a = Format.fprintf fmt "%s" (print_timed_atom print_indexed_item a) end)
module TimedLiteral = Sat.Literal(struct type t = Item.t Atom.t Provable.timed let equal = (=) let hash = Provable.hash (Atom.hash Item.hash) let pp = pp_timed_atom Item.pp end)
module TimedIndexedLiteral = Sat.Literal(struct type t = IndexedItem.t Atom.t Provable.timed let equal = (=) let hash = Provable.hash (Atom.hash IndexedItem.hash) let pp = pp_timed_atom IndexedItem.pp end)
(* XXX: Literal, should provide its comparison function *)
module TimedLiteralMap = Map.Make(struct type t = TimedLiteral.t let compare = compare end)

module LiteralsWithMults = struct
include TimedIndexedLiteral
type u = TimedLiteral.t
Expand All @@ -118,7 +37,7 @@ module LiteralsWithMults = struct
let norm_u = TimedLiteral.norm

let decomp (n,l) =
([n,Provable.map_timed (map_atom (fun i -> (i,0))) l], 1)
([n,Provable.map_timed (Atom.map (fun i -> (i,0))) l], 1)
end
module Mult = Multiplicity.Make(LiteralsWithMults)

Expand Down Expand Up @@ -146,11 +65,11 @@ let compile_formula man var_index (f : Mult.L.t Formula.t) : MLBDD.t =
compile f

let collect_program_atoms (p : program) : AtomSet.t =
let atoms : item atom Seq.t =
let atoms : Item.t Atom.t Seq.t =
List.to_seq p.pool |>
Seq.flat_map (fun i ->
List.to_seq p.locations |>
Seq.flat_map (fun l -> List.to_seq [Assign(l,i)]))
Seq.flat_map (fun l -> List.to_seq [Atom.Assign(l,i)]))
in
AtomSet.of_seq atoms

Expand All @@ -168,20 +87,19 @@ module Solver = Sat.Solver(Mult.L)(Mult.L.Map)
do! If I use the function twice, it will destroy the previous
bdd. *)
let compile_to_bdd (p : program) : (MLBDD.t * Mult.L.t array) =
let bdd_vars =
collect_program_atoms p |> AtomSet.to_seq |> Array.of_seq
in
let assign (i : item) (l : location) : formula =
Formula.var (Provable.Selection (Assign(l,i)))
let assign (i : Item.t) (l : Location.t) : formula =
Formula.var (Provable.Selection (Atom.Assign(l,i)))
in
let clause (c : clause) : item atom Provable.clause =
let open Provable in {
let clause (c : Clause.t) : Item.t Atom.t Provable.clause =
let open Provable in
let open Clause in
{
hyps=c.requires;
concl=c.goal;
name=gen_rule_name ()
}
in
let range (p : program) (r : range_constraint) : formula =
let range (p : program) (r : RangeConstraint.t) : formula =
let at_least = Formula.disj_map (fun l -> assign r.scrutinee l) r.range in
(* XXX: consider generating at_most/only constraint independently
from the range, on _all_ locations. *)
Expand All @@ -205,7 +123,7 @@ let compile_to_bdd (p : program) : (MLBDD.t * Mult.L.t array) =
Formula.(at_least && at_most && only)
in
let ranges_formula = Seq.map (range p) (List.to_seq p.range_constraints) in
let capacity (p : program) (l : location) : formula =
let capacity (p : program) (l : Location.t) : formula =
let distinct_pairs =
List.to_seq p.pool |>
Seq.flat_map (fun i -> Seq.map (fun i' -> (i, i')) (List.to_seq p.pool)) |>
Expand All @@ -224,8 +142,8 @@ let compile_to_bdd (p : program) : (MLBDD.t * Mult.L.t array) =
List.to_seq p.locations |>
Seq.map begin fun l ->
let open Provable in {
hyps = [ Reach l; Assign(l,i) ];
concl = Have (i, 1);
hyps = [ Atom.Reach l; Atom.Assign(l,i) ];
concl = Atom.Have (i, 1);
name = gen_rule_name ()
}
end
Expand Down Expand Up @@ -254,7 +172,7 @@ let compile_to_bdd (p : program) : (MLBDD.t * Mult.L.t array) =
in
let observable =
CCArray.filter (fun a -> match a with Provable.Selection _ -> true | _ -> false) atoms
|> Array.map (Provable.map_timed (map_atom (fun i -> (i, 0))))
|> Array.map (Provable.map_timed (Atom.map (fun i -> (i, 0))))
|> Array.map TimedLiteral.of_atom
|> Array.map (fun a -> Mult.Individual a)
in
Expand All @@ -269,6 +187,7 @@ let compile_to_bdd (p : program) : (MLBDD.t * Mult.L.t array) =
, observable

let femto_example =
let open Clause in
(* A bit of early Alltp logic *)
(* items *)
let sword = "Sword"
Expand All @@ -278,9 +197,9 @@ let femto_example =
in
let pool = [sword] in
let locations = [well; eastern_boss; ] in
let goal = Reach eastern_boss in
let goal = Atom.Reach eastern_boss in
let range_constraints = [
{scrutinee=sword; range=locations};
{RangeConstraint.scrutinee=sword; range=locations};
] in
let logic = [
{goal=Reach eastern_boss; requires=[Have (sword, 1)]};
Expand All @@ -289,6 +208,7 @@ let femto_example =
{ locations; pool; range_constraints; range_definitions=StringMap.empty; logic; goal }

let micro_example =
let open Clause in
(* A bit of early Alltp logic *)
(* items *)
let sword = "Sword"
Expand All @@ -301,10 +221,10 @@ let micro_example =
in
let pool = [sword; bow] in
let locations = [well; hideout; eastern_chest; eastern_boss; ] in
let goal = Reach eastern_boss in
let goal = Atom.Reach eastern_boss in
let range_constraints = [
{scrutinee=sword; range=locations};
{scrutinee=bow; range=locations};
{RangeConstraint.scrutinee=sword; range=locations};
{RangeConstraint.scrutinee=bow; range=locations};
] in
let logic = [
{goal=Reach eastern_boss; requires=[Have (bow, 1); Have (sword, 1)]};
Expand All @@ -315,6 +235,7 @@ let micro_example =
{ locations; pool; range_constraints; range_definitions=StringMap.empty; logic; goal }

let mini_example =
let open Clause in
(* A bit of early Alltp logic *)
(* items *)
let sword = "Sword"
Expand All @@ -330,11 +251,11 @@ let mini_example =
let pool = [sword; bow; eastern_big] in
let locations = [well; hideout; eastern_chest; eastern_big_chest; eastern_boss; ] in
let eastern_palace = [eastern_chest; eastern_big_chest; eastern_boss] in
let goal = Reach eastern_boss in
let goal = Atom.Reach eastern_boss in
let range_constraints = [
{scrutinee=sword; range=locations};
{scrutinee=bow; range=locations};
{scrutinee=eastern_big; range=eastern_palace};
{RangeConstraint.scrutinee=sword; range=locations};
{RangeConstraint.scrutinee=bow; range=locations};
{RangeConstraint.scrutinee=eastern_big; range=eastern_palace};
] in
let logic = [
{goal=Reach eastern_boss; requires=[Have (bow, 1); Have (sword, 1); Have (eastern_big, 1)]};
Expand All @@ -346,6 +267,7 @@ let mini_example =
{ locations; pool; range_constraints; range_definitions=StringMap.empty; logic; goal }

let example =
let open Clause in
(* A bit of early Alltp logic *)
(* items *)
let sword = "Sword"
Expand All @@ -368,14 +290,14 @@ let example =
let locations = [well; hideout; eastern_chest; eastern_big_chest; eastern_boss; desert_torch; desert_big_chest; desert_boss] in
let eastern_palace = [eastern_chest; eastern_big_chest; eastern_boss] in
let desert_palace = [desert_torch; desert_big_chest; desert_boss] in
let goal = Reach desert_boss in
let goal = Atom.Reach desert_boss in
let range_constraints = [
{scrutinee=sword; range=locations};
{scrutinee=bow; range=locations};
{scrutinee=boots; range=locations};
{scrutinee=glove; range=locations};
{scrutinee=eastern_big; range=eastern_palace};
{scrutinee=desert_big; range=desert_palace};
{RangeConstraint.scrutinee=sword; range=locations};
{RangeConstraint.scrutinee=bow; range=locations};
{RangeConstraint.scrutinee=boots; range=locations};
{RangeConstraint.scrutinee=glove; range=locations};
{RangeConstraint.scrutinee=eastern_big; range=eastern_palace};
{RangeConstraint.scrutinee=desert_big; range=desert_palace};
] in
let logic = [
{goal=Reach eastern_boss; requires=[Have (bow, 1); Have (sword, 1); Have (eastern_big, 1)]};
Expand All @@ -396,11 +318,6 @@ let print_to_dot legend b ~file =
let c = open_out file in
let fmt = formatter_of_out_channel c in
fprintf fmt "digraph bdd {@\n";
(* let ranks = Hashtbl.create 17 in (\* var -> set of nodes *\)
* let add_rank v b =
* try Hashtbl.replace ranks v (S.add b (Hashtbl.find ranks v))
* with Not_found -> Hashtbl.add ranks v (S.singleton b)
* in *)
let visited = H1.create 1024 in
let rec visit b =
if not (H1.mem visited b) then begin
Expand All @@ -412,19 +329,12 @@ let print_to_dot legend b ~file =
fprintf fmt "%d [shape=box label=\"1\"];" (MLBDD.id b)
| BIf (l, v, h) ->
(* add_rank v b; *)
fprintf fmt "%d [label=\"x%s\"];" (MLBDD.id b) (print_timed_atom print_item (legend.(v-1)));
fprintf fmt "%d [label=\"x%a\"];" (MLBDD.id b) (pp_timed_atom Item.pp) (legend.(v-1));
fprintf fmt "%d -> %d;@\n" (MLBDD.id b) (MLBDD.id h);
fprintf fmt "%d -> %d [style=\"dashed\"];@\n" (MLBDD.id b) (MLBDD.id l);
visit h; visit l
end
in
(* Hashtbl.iter
* (fun v s ->
* fprintf fmt "{rank=same; ";
* S.iter (fun x -> fprintf fmt "%d " x.tag) s;
* fprintf fmt ";}@\n"
* )
* ranks; *)
visit b;
fprintf fmt "}@.";
close_out c
Expand All @@ -445,7 +355,7 @@ let parse_file (filename : String.t) =
let accumulate_range prog = function
| Grammar.RangeDecl (item, range_expr) ->
let locs = interp_range_expression prog range_expr in
let range = {scrutinee=item; range=locs} in
let range = {RangeConstraint.scrutinee=item; range=locs} in
{ prog with
pool = item :: prog.pool;
locations = locs @ prog.locations;
Expand All @@ -458,11 +368,12 @@ let parse_file (filename : String.t) =
}
in
let convert_atom = function
| Grammar.Have item -> Have(item, 1)
| Grammar.Reach loc -> Reach loc
| Grammar.Have item -> Atom.Have(item, 1)
| Grammar.Reach loc -> Atom.Reach loc
in
let accumulate_rule prog (goal, requires) =
let clause =
let open Clause in
{ goal = convert_atom goal;
requires = List.map convert_atom requires;
}
Expand Down
Loading

0 comments on commit e37cf34

Please sign in to comment.