Skip to content

Commit

Permalink
PTG: Simplify code using new hash table class
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaelbdj committed May 23, 2024
1 parent 19179eb commit b797967
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 142 deletions.
130 changes: 36 additions & 94 deletions src/lib/AlgoPTG.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ open Result
open AlgoGeneric
(* open Statistics *)
open State
open DefaultHashTable

(* Notation and shorthands *)
let (>>) f g x = g(f(x))
Expand All @@ -46,32 +47,6 @@ type edge = {state: state_index; action: Automaton.action_index; transition: Sta

type edge_status = BackpropLosing | BackpropWinning | Unexplored
type backtrack_type = Winning | Losing
module type Default = sig
type elem
type key
val str_of_elem : elem -> string
val str_of_key : key -> string
val tbl : (key, elem) Hashtbl.t
val default : key -> elem
val model : AbstractModel.abstract_model option ref
end

module DefaultHashtbl (D : Default) = struct
let model = D.model
(* let to_seq () = Hashtbl.to_seq D.tbl *)
let tbl = D.tbl
let replace = Hashtbl.replace tbl
let find key =
try Hashtbl.find tbl key with
Not_found -> let d = D.default key in replace key d; d
let to_str () = "[" ^
Seq.fold_left
(fun acc (key, elem) -> Printf.sprintf "%s, %s -> %s\n" acc (D.str_of_key key) (D.str_of_elem elem))
("")
(Hashtbl.to_seq tbl)
^ "]"
end


let status_to_string = function
| Unexplored -> "EXPLORE"
Expand All @@ -95,29 +70,9 @@ let edge_list_to_str seq model state_space = "[" ^
seq)
^ "]"

module WinningZone = DefaultHashtbl (struct
let model = ref None
type elem = LinearConstraint.px_nnconvex_constraint * AlgoPTGStrategyGenerator.state_strategy ref
type key = state_index
let tbl = Hashtbl.create 100
let default = fun _ -> LinearConstraint.false_px_nnconvex_constraint(), ref []
let str_of_elem (zone, _) = match !model with
| Some model -> LinearConstraint.string_of_px_nnconvex_constraint model.variable_names zone
| None -> "No model provided"
let str_of_key = string_of_int >> (^) ("s")
end)

module LosingZone = DefaultHashtbl (struct
let model = ref None
type elem = LinearConstraint.px_nnconvex_constraint
type key = state_index
let tbl = Hashtbl.create 100
let default = fun _ -> LinearConstraint.false_px_nnconvex_constraint()
let str_of_elem zone = match !model with
| Some model -> LinearConstraint.string_of_px_nnconvex_constraint model.variable_names zone
| None -> "No model provided"
let str_of_key = string_of_int >> (^) ("s")
end)
class unionZoneMap =
[state_index, LinearConstraint.px_nnconvex_constraint] defaultHashTable
LinearConstraint.false_px_nnconvex_constraint

module EdgeSet = Set.Make(struct type t = edge let compare = Stdlib.compare end)

Expand All @@ -129,17 +84,10 @@ class edgeSet = object
method to_seq = EdgeSet.to_seq internal_set
end

module Depends = DefaultHashtbl (struct
let model = ref None
type elem = edgeSet
type key = state_index
let tbl = Hashtbl.create 100
let default = fun _ -> new edgeSet
let str_of_elem _ = match !model with
| Some _ -> "TODO"
| None -> "No model provided"
let str_of_key = string_of_int
end)
class dependsMap =
[state_index, edgeSet] defaultHashTable
(fun _ -> new edgeSet)




Expand Down Expand Up @@ -317,7 +265,10 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert

val mutable termination_status = Regular_termination


val winningZone = new unionZoneMap
val losingZone = new unionZoneMap
val depends = new dependsMap
val stateStrategy = new AlgoPTGStrategyGenerator.stateStrategyMap

method private constr_of_state_index state = (state_space#get_state state).px_constraint
method private get_global_location state = state_space#get_location (state_space#get_global_location_index state)
Expand All @@ -337,13 +288,6 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert

val init_losing_zone_changed = ref false
val init_winning_zone_changed = ref false
val init_winning_zone = fun _ -> WinningZone.find state_space_ptg#state_space#get_initial_state_index

(* Initialize the Winning and Depend tables with our model - only affects printing information in terminal *)
method private initialize_tables () =
WinningZone.model := Some model;
LosingZone.model := Some model;
Depends.model := Some model

(* Edges from a symbolic state *)
method private get_edges state =
Expand Down Expand Up @@ -432,15 +376,15 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
let coverage_pruning = ref false in
if self#matches_state_predicate state' then
begin
WinningZone.replace state' @@ ((self#constr_of_state_index >> nn) state', ref []);
winningZone#replace state' @@ (self#constr_of_state_index >> nn) state';
waiting #<- (e, BackpropWinning);
coverage_pruning := true
end;

if self#is_dead_lock state' then
begin
if options#ptg_propagate_losing_states then
(LosingZone.replace state' @@ (self#constr_of_state_index >> nn) state';
(losingZone#replace state' @@ (self#constr_of_state_index >> nn) state';
waiting #<- (e, BackpropLosing));
coverage_pruning := true
end;
Expand All @@ -452,7 +396,7 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
| true, _ -> print_PTG (Printf.sprintf "\n\tNot adding sucessors of state %d due to pruning (cumulative)" state')
| _, true -> print_PTG (Printf.sprintf "\n\tNot adding sucessors of state %d due to pruning (coverage)" state')
| _ ->
(Depends.find state')#add e;
(depends#find state')#add e;
waiting #<-- (self#get_edge_queue state');
print_PTG ("\n\tAdding successor edges to waiting list. New waiting list: " ^ edge_list_to_str waiting#to_list model state_space)
end;
Expand All @@ -469,7 +413,7 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
let winning_move = self#predecessor_nnconvex edge (zone_map state') in

let safe_timed_pred = self#safe_timed_pred winning_move bad_zone state in
let current_winning_zone = WinningZone.find state |> fst in
let current_winning_zone = winningZone#find state in


(* Intersect winning move with safe timed predecessors to remove unsafe parts *)
Expand All @@ -490,7 +434,7 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
prioritized_winning_zone = safe_timed_pred
}
in
let strategy = WinningZone.find state |> snd in
let strategy = stateStrategy#find state in
strategy := new_strategy_entry :: !strategy;
true
end
Expand All @@ -516,34 +460,34 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
match backtrack_type with
| Winning ->
print_PTG "\tWINNING ZONE PROPAGATION:";
let bad = get_pred_from_edges (bot ()) bad_edges (fun x -> self#negate_zone (fst @@ WinningZone.find x) x) in
let bad = get_pred_from_edges (bot ()) bad_edges (fun x -> self#negate_zone (winningZone#find x) x) in
let winning_zone_changed =
List.fold_left (||) false
(List.map(fun edge -> self#backtrack_single_controllable_edge edge bad (WinningZone.find >> fst)) good_edges) in
(List.map(fun edge -> self#backtrack_single_controllable_edge edge bad winningZone#find) good_edges) in
if winning_zone_changed then
begin
waiting #<-- (self#edge_set_to_queue_with_status (Depends.find state) BackpropWinning);
waiting #<-- (self#edge_set_to_queue_with_status (depends#find state) BackpropWinning);
if state = state_space#get_initial_state_index then init_winning_zone_changed := true
end
| Losing ->
print_PTG "\tLOSING ZONE PROPAGATION:";
let good = get_pred_from_edges (LinearConstraint.px_nnconvex_copy @@ LosingZone.find state) good_edges LosingZone.find in
let bad = get_pred_from_edges (bot ()) bad_edges (fun x -> self#negate_zone (LosingZone.find x) x) in
let good = get_pred_from_edges (LinearConstraint.px_nnconvex_copy @@ losingZone#find state) good_edges losingZone#find in
let bad = get_pred_from_edges (bot ()) bad_edges (fun x -> self#negate_zone (losingZone#find x) x) in
LinearConstraint.px_nnconvex_difference_assign bad good;
let new_zone = self#safe_timed_pred good bad state in
if (LosingZone.find state) #!= new_zone then
if (losingZone#find state) #!= new_zone then
begin
LosingZone.replace state new_zone;
waiting #<-- (self#edge_set_to_queue_with_status (Depends.find state) BackpropLosing);
losingZone#replace state new_zone;
waiting #<-- (self#edge_set_to_queue_with_status (depends#find state) BackpropLosing);
if state = state_space#get_initial_state_index then init_losing_zone_changed := true
end;
end;
(Depends.find state')#add e
(depends#find state')#add e

(* Initial state is lost if initial constraint is included in losing zone *)
method private init_is_lost init =
init_losing_zone_changed := false;
LinearConstraint.px_nnconvex_constraint_is_leq (self#initial_constraint ()) (LosingZone.find init)
LinearConstraint.px_nnconvex_constraint_is_leq (self#initial_constraint ()) (losingZone#find init)

(* Initial state is won if parameter valuations in its winning zone is non-empty *)
method private init_has_winning_witness =
Expand All @@ -555,7 +499,7 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
init_losing_zone_changed := false;
init_winning_zone_changed := false;
let init_zone_nn = nn @@ self#constr_of_state_index init in
let winning_and_losing_zone = LinearConstraint.px_nnconvex_copy @@ fst @@ WinningZone.find init ||| LosingZone.find init in
let winning_and_losing_zone = LinearConstraint.px_nnconvex_copy @@ winningZone#find init ||| losingZone#find init in
LinearConstraint.px_nnconvex_constraint_is_leq init_zone_nn winning_and_losing_zone

(* Returns true if the algorithm should terminate, depending on the criteria for termination *)
Expand All @@ -565,7 +509,7 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
let propagate_losing_states = options#ptg_propagate_losing_states in

if !init_winning_zone_changed then
synthesized_constraint <- project_params (self#initial_constraint () &&& fst @@ WinningZone.find init);
synthesized_constraint <- project_params (self#initial_constraint () &&& winningZone#find init);

let recompute_init_lost = propagate_losing_states && !init_losing_zone_changed in
let recompute_init_has_winning_witness = not complete_synthesis && !init_winning_zone_changed in
Expand All @@ -589,7 +533,6 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
(* Computes the parameters for which a winning strategy exists and saves the result in synthesized_constraint *)
method private compute_PTG =
let propagate_losing_states = options#ptg_propagate_losing_states in
self#initialize_tables();

(* === ALGORITHM INITIALIZATION === *)
let init = state_space#get_initial_state_index in
Expand All @@ -600,11 +543,11 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert

(* If goal is init then initial winning zone is it's own constraint*)
if self#matches_state_predicate init then
WinningZone.replace init @@ ((self#constr_of_state_index >> nn) init, ref []);
winningZone#replace init @@ (self#constr_of_state_index >> nn) init;

(* If init is deadlock then initial losing zone is it's own constraint*)
if self#matches_state_predicate init && propagate_losing_states then
(LosingZone.replace init @@ (self#constr_of_state_index >> nn) init; init_losing_zone_changed := true);
(losingZone#replace init @@ (self#constr_of_state_index >> nn) init; init_losing_zone_changed := true);

(* === ALGORITHM MAIN LOOP === *)
while (not @@ self#termination_criteria waiting init) do
Expand All @@ -626,12 +569,12 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert
self#backtrack e waiting Losing
done;
print_PTG "After running AlgoPTG I found these winning zones:";
print_PTG (WinningZone.to_str ());
(*print_PTG (winningZone#to_str ()); *)

if propagate_losing_states then
print_PTG (Printf.sprintf "And these losing zones: %s" (LosingZone.to_str()));
(* print_PTG (Printf.sprintf "And these losing zones: %s" (losingZone#to_str())); *)

let winning_parameters = project_params (self#initial_constraint () &&& fst @@ WinningZone.find init) in
let winning_parameters = project_params (self#initial_constraint () &&& winningZone#find init) in
synthesized_constraint <- winning_parameters

(*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*)
Expand All @@ -647,13 +590,12 @@ class algoPTG (model : AbstractModel.abstract_model) (property : AbstractPropert

AlgoPTGStrategyGenerator.print_strategy
model
~strategy:(fun state_index -> !(WinningZone.find state_index |> snd))
~state_indices:(List.of_seq @@ Hashtbl.to_seq_keys WinningZone.tbl)
~strategy:stateStrategy
~state_space:state_space;

(* Compute the strategy *)
(* if options#ptg_controller_mode != AbstractAlgorithm.No_Generation then
AlgoPTGStrategyGenerator.generate_controller model (fun x -> WinningZone.find x |> snd) state_space options; *)
AlgoPTGStrategyGenerator.generate_controller model (fun x -> winningZone#find x |> snd) state_space options; *)

(* Return the result *)
self#compute_result;
Expand Down
61 changes: 17 additions & 44 deletions src/lib/AlgoPTGStrategyGenerator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,7 @@ open StateSpace
open State
open ImitatorUtilities
open AlgoPTGStrategyGeneratorUtilities

class virtual ['a, 'b] hashTable = object (self)
val mutable internal_tbl : ('a, 'b) Hashtbl.t = Hashtbl.create 100
method replace key value = Hashtbl.replace internal_tbl key value
method find key = try Hashtbl.find internal_tbl key with
Not_found ->
let x = self#bot in Hashtbl.replace internal_tbl key x; x
method iter f = Hashtbl.iter f internal_tbl
method fold : 'c. ('a -> 'b -> 'c -> 'c) -> 'c -> 'c =
fun f init -> Hashtbl.fold f internal_tbl init
method is_empty = Hashtbl.length internal_tbl = 0
method virtual bot : 'b
end
open DefaultHashTable

class ['a] array (ls : 'a list) = object
val internal_array : 'a Array.t = Array.of_list ls
Expand All @@ -32,7 +20,9 @@ type strategy_entry = {

type state_strategy = strategy_entry list

type strategy = state_index -> state_strategy
class stateStrategyMap =
[state_index, state_strategy ref] defaultHashTable
(fun _ -> ref [])

let format_zone_string (string : string) =
let b = Buffer.create 10 in
Expand All @@ -54,45 +44,28 @@ let string_of_state_strategy (model : abstract_model) (state_strategy : state_st



let print_strategy (model : abstract_model) ~strategy ~state_indices ~state_space =
let relevant_states = List.filter (fun state_index -> List.length @@ strategy state_index != 0) state_indices in
let state_strategy_strings = List.map (fun state_index -> state_index, string_of_state_strategy model (strategy state_index)) relevant_states in

let print_strategy (model : abstract_model) ~strategy ~state_space =
let get_location_index state_index = Array.get (DiscreteState.get_locations ((state_space#get_state state_index).global_location)) 0 in
let get_location_name state_index = model.location_names 0 (get_location_index state_index) in

print_message Verbose_standard "Printing strategy that ensures controller win:";
List.iter (fun (state_index, str) ->
strategy#iter (fun state_index state_strategy ->
let str = string_of_state_strategy model !state_strategy in
print_message Verbose_standard @@ Printf.sprintf "%s -> \n%s\n" (get_location_name state_index) str
) state_strategy_strings
)

class winningMovesPerAction = object
inherit([action_index, LinearConstraint.px_nnconvex_constraint] hashTable)
method bot = LinearConstraint.false_px_nnconvex_constraint ()
end

class winningMovesPerState = object
inherit ([state_index, winningMovesPerAction] hashTable)
method bot = new winningMovesPerAction
end
class winningMovesPerAction = [action_index, LinearConstraint.px_nnconvex_constraint] defaultHashTable LinearConstraint.false_px_nnconvex_constraint

class transitionsPerAction = object
inherit ([action_index, transition_index list] hashTable)
method bot = []
end
class transitionsPerLocation = object
inherit ([location_index, transitionsPerAction] hashTable)
method bot = new transitionsPerAction
end
class actionsPerLocation = object
inherit ([location_index, stateIndexSet] hashTable)
method bot = new stateIndexSet
end
class winningMovesPerState = [state_index, winningMovesPerAction] defaultHashTable (fun _ -> new winningMovesPerAction)

class locationPerStateIndex = object
inherit ([state_index, location_index option] hashTable)
method bot = None
end
class transitionsPerAction = [action_index, transition_index list] defaultHashTable (fun _ -> [])

class transitionsPerLocation = [location_index, transitionsPerAction] defaultHashTable (fun _ -> new transitionsPerAction)

class actionsPerLocation = [location_index, stateIndexSet] defaultHashTable (fun _ -> new stateIndexSet)

class locationPerStateIndex = [state_index, location_index option] defaultHashTable (fun _ -> None)

type location_info = {
invariant : invariant;
Expand Down
Loading

0 comments on commit b797967

Please sign in to comment.