Skip to content

Commit

Permalink
restore randomized abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Jan 17, 2024
1 parent 6a831b0 commit 9e53703
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
11 changes: 8 additions & 3 deletions paynt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def setup_logger(log_path = None):
@click.option("--mdp-split-wrt-mdp", is_flag=True, default=False,
help="# if set, MDP abstraction scheduler will be used for splitting, otherwise game abstraction scheduler will be used")
@click.option("--mdp-discard-unreachable-actions", is_flag=True, default=False,
help="# if set, unreachable choices will be discarded from game abstraction scheduler")
help="# if set, unreachable choices will be discarded from the splitting scheduler")
@click.option("--mdp-use-randomized-abstraction", is_flag=True, default=False,
help="# if set, randomized abstraction guess-and-verify will be used instead of game abstraction;" +
" MDP abstraction scheduler will be used for splitting"
)

@click.option(
"--ce-generator", type=click.Choice(["dtmc", "mdp"]), default="dtmc", show_default=True,
Expand All @@ -138,7 +142,7 @@ def paynt_run(
use_storm_cutoffs, unfold_strategy_storm,
export_fsc_storm, export_fsc_paynt, export_evaluation,
all_in_one,
mdp_split_wrt_mdp, mdp_discard_unreachable_actions,
mdp_split_wrt_mdp, mdp_discard_unreachable_actions, mdp_use_randomized_abstraction,
ce_generator,
profiling
):
Expand All @@ -157,7 +161,8 @@ def paynt_run(
paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware

paynt.synthesizer.policy_tree.SynthesizerPolicyTree.split_wrt_mdp_scheduler = mdp_split_wrt_mdp
paynt.synthesizer.policy_tree.SynthesizerPolicyTree.discard_unreachable_actions_in_game_scheduler = mdp_discard_unreachable_actions
paynt.synthesizer.policy_tree.SynthesizerPolicyTree.discard_unreachable_actions = mdp_discard_unreachable_actions
paynt.synthesizer.policy_tree.SynthesizerPolicyTree.use_randomized_abstraction = mdp_use_randomized_abstraction

storm_control = None
if storm_pomdp:
Expand Down
69 changes: 56 additions & 13 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,10 @@ class SynthesizerPolicyTree(paynt.synthesizer.synthesizer.Synthesizer):
double_check_policy_tree_leaves = False
# if True, MDP abstraction scheduler will be used for splitting, otherwise game abstraction scheduler will be used
split_wrt_mdp_scheduler = False
# if True, unreachable choices will be discarded from game abstraction scheduler
discard_unreachable_actions_in_game_scheduler = False
# if True, unreachable choices will be discarded from the splitting scheduler
discard_unreachable_actions = False
# if True, randomized abstraction guess-and-verify will be used instead of game abstraction
use_randomized_abstraction = False

@property
def method_name(self):
Expand Down Expand Up @@ -543,24 +545,57 @@ def solve_game_abstraction(self, family, prop, game_solver):
game_policy = game_policy_fixed
return game_policy,game_sat

def parse_game_scheduler(self, game_solver):
state_values = game_solver.solution_state_values
state_to_choice = game_solver.solution_state_to_quotient_choice.copy()
def try_randomized_abstraction(self, family, prop):
# build randomized abstraction
choice_to_action = []
for choice in range(family.mdp.choices):
action = self.quotient.choice_to_action[family.mdp.quotient_choice_map[choice]]
choice_to_action.append(action)
state_action_choices = self.quotient.map_state_action_to_choices(family.mdp.model,self.quotient.num_actions,choice_to_action)
model,choice_to_action = payntbind.synthesis.randomize_action_variant(family.mdp.model, state_action_choices)

# model check
result = stormpy.model_checking(model, prop.formula, extract_scheduler=True, environment=Property.environment)
self.stat.iteration(model)
value = result.at(model.initial_states[0])
policy_sat = prop.satisfies_threshold(value) # does this value matter?

# extract policy for the quotient
scheduler = result.scheduler
policy = self.quotient.empty_policy()
for state in range(model.nr_states):
state_choice = scheduler.get_choice(state).get_deterministic_choice()
choice = model.transition_matrix.get_row_group_start(state) + state_choice
action = choice_to_action[choice]
quotient_state = family.mdp.quotient_state_map[state]
policy[quotient_state] = action

# apply policy and check if it is SAT for all MDPs in the family
policy_sat = self.verify_policy(family, prop, policy)

return policy,policy_sat

def state_to_choice_to_scheduler(self, state_to_choice):
# uncomment this to use only reachable choices of the game scheduler
if SynthesizerPolicyTree.discard_unreachable_actions_in_game_scheduler:
if SynthesizerPolicyTree.discard_unreachable_actions:
state_to_choice = self.quotient.keep_reachable_choices_of_scheduler(state_to_choice)
scheduler_choices = self.quotient.state_to_choice_to_choices(state_to_choice)
hole_selection = self.quotient.coloring.collectHoleOptions(scheduler_choices)
return scheduler_choices,hole_selection

def parse_game_scheduler(self, game_solver):
state_values = game_solver.solution_state_values
state_to_choice = game_solver.solution_state_to_quotient_choice.copy()
scheduler_choices,hole_selection = self.state_to_choice_to_scheduler(state_to_choice)
return scheduler_choices,state_values,hole_selection

def parse_mdp_scheduler(self, family, mdp_result):
state_to_choice = self.quotient.scheduler_to_state_to_choice(family.mdp, mdp_result.result.scheduler)
scheduler_choices = self.quotient.state_to_choice_to_choices(state_to_choice)
scheduler_choices,hole_selection = self.state_to_choice_to_scheduler(state_to_choice)
state_values = [0] * self.quotient.quotient_mdp.nr_states
for state in range(family.mdp.states):
quotient_state = family.mdp.quotient_state_map[state]
state_values[quotient_state] = mdp_result.result.at(state)
hole_selection = self.quotient.coloring.collectHoleOptions(scheduler_choices)
return scheduler_choices,state_values,hole_selection


Expand All @@ -573,10 +608,18 @@ def verify_family(self, family, game_solver, prop):
mdp_family_result.policy = self.solve_singleton(family,prop)
return mdp_family_result

if family.candidate_policy is None:
game_policy,game_sat = self.solve_game_abstraction(family,prop,game_solver)
if not SynthesizerPolicyTree.use_randomized_abstraction:
if family.candidate_policy is None:
game_policy,game_sat = self.solve_game_abstraction(family,prop,game_solver)
else:
game_policy = family.candidate_policy
game_sat = False
else:
game_policy = family.candidate_policy
randomization_policy,policy_sat = self.try_randomized_abstraction(family,prop)
if policy_sat:
mdp_family_result.policy = randomization_policy
return mdp_family_result
game_policy = None
game_sat = False

mdp_family_result.game_policy = game_policy
Expand All @@ -594,7 +637,7 @@ def verify_family(self, family, game_solver, prop):
return mdp_family_result

# undecided: choose scheduler choices to be used for splitting
if not SynthesizerPolicyTree.split_wrt_mdp_scheduler:
if not (SynthesizerPolicyTree.use_randomized_abstraction or SynthesizerPolicyTree.split_wrt_mdp_scheduler):
scheduler_choices,state_values,hole_selection = self.parse_game_scheduler(game_solver)
else:
scheduler_choices,state_values,hole_selection = self.parse_mdp_scheduler(family, mdp_result)
Expand Down Expand Up @@ -680,7 +723,7 @@ def split(self, family, prop, hole_selection, splitter, policy):
subfamily.candidate_policy = None
subfamilies.append(subfamily)

if not SynthesizerPolicyTree.split_wrt_mdp_scheduler and not SynthesizerPolicyTree.discard_unreachable_actions_in_game_scheduler:
if not (SynthesizerPolicyTree.use_randomized_abstraction or SynthesizerPolicyTree.split_wrt_mdp_scheduler) and not SynthesizerPolicyTree.discard_unreachable_actions:
self.assign_candidate_policy(subfamilies, hole_selection, splitter, policy)

return suboptions,subfamilies
Expand Down
1 change: 0 additions & 1 deletion paynt/synthesizer/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def print_mdp_family_table_entries(self):
sat_by_total_percentage = round(self.num_mdps_sat/self.num_mdps_total*100,2)
print(sat_by_total_percentage)


headers = [
"time","nodes","nodes (merged)","leaves","leaves (merged)","leaves (merged) / MDPs %",
"policies","policies (merged)","policies (merged) / SAT %","pp time","pp time %",
Expand Down

0 comments on commit 9e53703

Please sign in to comment.