Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAYNT refactoring #54

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 29 additions & 55 deletions paynt/quotient/storm_pomdp_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def get_storm_result(self):
if self.s_queue is not None:
self.s_queue.put((self.result_dict, self.storm_bounds))

def store_storm_result(self, result):
self.latest_storm_result = result
if self.quotient.specification.optimality.minimizing:
self.storm_bounds = self.latest_storm_result.upper_bound
else:
self.storm_bounds = self.latest_storm_result.lower_bound

# run Storm POMDP analysis for given model and specification
# TODO: discuss Storm options
def run_storm_analysis(self):
Expand Down Expand Up @@ -155,11 +162,7 @@ def run_storm_analysis(self):
print(f'-----------Storm----------- \
\nValue = {value} | Time elapsed = {round(storm_timer.read(),1)}s | FSC size = {size}\nFSC (dot) = {result.induced_mc_from_scheduler.to_dot()}\n', flush=True)

self.latest_storm_result = result
if self.quotient.specification.optimality.minimizing:
self.storm_bounds = self.latest_storm_result.upper_bound
else:
self.storm_bounds = self.latest_storm_result.lower_bound
self.store_storm_result(result)

# setup interactive Storm belief model checker
def interactive_storm_setup(self):
Expand Down Expand Up @@ -217,11 +220,7 @@ def interactive_run(self, belmc):
print(result.induced_mc_from_scheduler.to_dot(), file=text_file)
text_file.close()

self.latest_storm_result = result
if self.quotient.specification.optimality.minimizing:
self.storm_bounds = self.latest_storm_result.upper_bound
else:
self.storm_bounds = self.latest_storm_result.lower_bound
self.store_storm_result(result)
self.parse_results(self.quotient)
self.update_data()

Expand Down Expand Up @@ -270,11 +269,7 @@ def interactive_control(self, belmc, start, storm_timeout):
print(result.induced_mc_from_scheduler.to_dot(), file=text_file)
text_file.close()

self.latest_storm_result = result
if self.quotient.specification.optimality.minimizing:
self.storm_bounds = self.latest_storm_result.upper_bound
else:
self.storm_bounds = self.latest_storm_result.lower_bound
self.store_storm_result(result)
self.parse_results(self.quotient)
self.update_data()

Expand Down Expand Up @@ -380,7 +375,7 @@ def parse_storm_result(self, quotient):
# to make the code cleaner
get_choice_label = self.latest_storm_result.induced_mc_from_scheduler.choice_labeling.get_labels_of_choice

cutoff_epxloration = [x for x in range(len(self.latest_storm_result.cutoff_schedulers))]
cutoff_epxloration = list(range(len(self.latest_storm_result.cutoff_schedulers)))
finite_mem = False

result = {x:[] for x in range(quotient.observations)}
Expand All @@ -394,46 +389,31 @@ def parse_storm_result(self, quotient):
# parse non cut-off states
if 'cutoff' not in state.labels and 'clipping' not in state.labels:
for label in state.labels:
# observation based on prism observables
observation = None
if '[' in label:
# observation based on prism observables
observation = self.quotient.observation_labels.index(label)

index = -1

choice_label = list(get_choice_label(state.id))[0]
for i in range(len(quotient.action_labels_at_observation[int(observation)])):
if choice_label == quotient.action_labels_at_observation[int(observation)][i]:
index = i
break

if index >= 0 and index not in result[int(observation)]:
result[int(observation)].append(index)

if index >= 0 and index not in result_no_cutoffs[int(observation)]:
result_no_cutoffs[int(observation)].append(index)
# explicit observation index
elif 'obs_' in label:
_, observation = label.split('_')

index = -1
# explicit observation index
_,observation = label.split('_')
if observation is not None:
observation = int(observation)
choice_label = list(get_choice_label(state.id))[0]
for i in range(len(quotient.action_labels_at_observation[int(observation)])):
if choice_label == quotient.action_labels_at_observation[int(observation)][i]:
index = i
for index,action_label in enumerate(quotient.action_labels_at_observation[observation]):
if choice_label == action_label:
if index not in result[observation]:
result[observation].append(index)
if index not in result_no_cutoffs[observation]:
result_no_cutoffs[observation].append(index)
break

if index >= 0 and index not in result[int(observation)]:
result[int(observation)].append(index)

if index >= 0 and index not in result_no_cutoffs[int(observation)]:
result_no_cutoffs[int(observation)].append(index)


# parse cut-off states
else:
if 'finite_mem' in state.labels and not finite_mem:
finite_mem = True
self.parse_paynt_result(self.quotient)
for obs, actions in self.result_dict_paynt.items():
for obs,actions in self.result_dict_paynt.items():
for action in actions:
if action not in result_no_cutoffs[obs]:
result_no_cutoffs[obs].append(action)
Expand Down Expand Up @@ -462,7 +442,6 @@ def parse_storm_result(self, quotient):
for action in actions:
if action not in result[observation]:
result[observation].append(action)

cutoff_epxloration.remove(int(scheduler_index))

# removing unrestricted observations
Expand All @@ -484,19 +463,14 @@ def parse_choice_string(self, choice_string, probability_bound=0):
chars = '}{]['
for c in chars:
choice_string = choice_string.replace(c, '')

choice_string = choice_string.strip(', ')

choices = choice_string.split(',')

result = []

for choice in choices:
probability, action = choice.split(':')
# probability bound

action = int(action.strip())

result.append(action)

return result
Expand Down Expand Up @@ -593,14 +567,14 @@ def get_subfamilies(self, restrictions, family):

subfamilies = []

for i in range(len(restrictions)):
for i,restriction in enumerate(restrictions):
restricted_family = family.copy()

actions = [action for action in family.hole_options(restrictions[i]["hole"]) if action not in restrictions[i]["restriction"]]
actions = [action for action in family.hole_options(restriction["hole"]) if action not in restriction["restriction"]]
if len(actions) == 0:
actions = [family.hole_options(restrictions[i]["hole"])[0]]
actions = [family.hole_options(restriction["hole"])[0]]

restricted_family.hole_set_options(restrictions[i]['hole'],actions)
restricted_family.hole_set_options(restriction['hole'],actions)

for j in range(i):
restricted_family.hole_set_options(restrictions[j]['hole'],restrictions[j]["restriction"])
Expand Down
13 changes: 4 additions & 9 deletions paynt/synthesizer/synthesizer_ar_storm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,16 @@ def storm_split(self, families):
# split each family in the current buffer to main family and corresponding subfamilies
for family in families:
if self.storm_control.use_cutoffs:
main_p = self.storm_control.get_main_restricted_family(family, self.storm_control.result_dict)
result_dict = self.storm_control.result_dict
else:
main_p = self.storm_control.get_main_restricted_family(family, self.storm_control.result_dict_no_cutoffs)

result_dict = self.storm_control.result_dict_no_cutoffs
main_p = self.storm_control.get_main_restricted_family(family, result_dict)
if main_p is None:
subfamilies.append(family)
continue

main_families.append(main_p)

if self.storm_control.use_cutoffs:
subfamily_restrictions = self.storm_control.get_subfamilies_restrictions(family, self.storm_control.result_dict)
else:
subfamily_restrictions = self.storm_control.get_subfamilies_restrictions(family, self.storm_control.result_dict_no_cutoffs)

subfamily_restrictions = self.storm_control.get_subfamilies_restrictions(family, result_dict)
subfamilies_p = self.storm_control.get_subfamilies(subfamily_restrictions, family)
subfamilies.extend(subfamilies_p)

Expand Down
Loading