Skip to content

Commit

Permalink
compute Q-values for POMDPxFSC
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Jul 19, 2024
1 parent ced7920 commit 7c65972
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
43 changes: 43 additions & 0 deletions paynt/quotient/pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,46 @@ def assignment_to_fsc(self, assignment):
fsc.check(observation_to_actions)

return fsc


def compute_qvalues(self, assignment):
'''
Given an MDP obtained after applying an FSC to a POMDP, compute for each state s, (reachable) memory node n
the Q-value Q(s,n).
:param assignment hole assignment encoding an FSC; it is assumed the assignment is the one obtained
for the current unfolding
:note Q(s,n) may be None if (s,n) exists in the unfolded POMDP but is not reachable in the induced DTMC
'''
# model check
submdp = self.build_assignment(assignment)
prop = self.get_property()
result = submdp.model_check_property(prop)
state_submdp_to_value = result.result.get_values()

# map states of a sub-MDP to the states of the quotient MDP to the state-memory pairs of the POMDPxFSC
import collections
state_memory_value = collections.defaultdict(lambda: None)
for submdp_state,value in enumerate(state_submdp_to_value):
mdp_state = submdp.quotient_state_map[submdp_state]
pomdp_state = self.pomdp_manager.state_prototype[mdp_state]
memory_node = self.pomdp_manager.state_memory[mdp_state]
state_memory_value[ (pomdp_state,memory_node) ] = value

# make this mapping total
memory_size = 1 + max([memory for state,memory in state_memory_value.keys()])
state_memory_value_total = [[None for memory in range(memory_size)] for state in range(self.pomdp.nr_states)]
for state in range(self.pomdp.nr_states):
for memory in range(memory_size):
value = state_memory_value[(state,memory)]
if value is None:
obs = self.pomdp.observations[state]
if memory < self.observation_memory_size[obs]:
# case 1: (s,n) exists but is not reachable in the induced DTMC
value = None
else:
# case 2: (s,n) does not exist because n memory was not allocated for s
# i.e. (s,n) has the same value as (s,0)
value = state_memory_value[(state,0)]
state_memory_value_total[state][memory] = value

return state_memory_value_total
2 changes: 1 addition & 1 deletion paynt/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def choose_synthesizer(quotient, method, fsc_synthesis, storm_control):
exit(0)
# FSC synthesis for POMDPs
if isinstance(quotient, paynt.quotient.pomdp.PomdpQuotient) and fsc_synthesis:
return paynt.synthesizer.synthesizer_pomdp.SynthesizerPOMDP(quotient, method, storm_control)
return paynt.synthesizer.synthesizer_pomdp.SynthesizerPomdp(quotient, method, storm_control)
# FSC synthesis for Dec-POMDPs
if isinstance(quotient, paynt.quotient.decpomdp.DecPomdpQuotient) and fsc_synthesis:
return paynt.synthesizer.synthesizer_decpomdp.SynthesizerDecPomdp(quotient)
Expand Down
6 changes: 4 additions & 2 deletions paynt/synthesizer/synthesizer_pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)


class SynthesizerPOMDP:
class SynthesizerPomdp:

# If true explore only the main family
incomplete_exploration = False
Expand Down Expand Up @@ -53,7 +53,9 @@ def __init__(self, quotient, method, storm_control):
self.synthesizer.saynt_timer = self.saynt_timer
self.storm_control.saynt_timer = self.saynt_timer

def synthesize(self, family, print_stats=True):
def synthesize(self, family=None, print_stats=True):
if family is None:
family = self.quotient.design_space
synthesizer = self.synthesizer(self.quotient)
family.constraint_indices = self.quotient.design_space.constraint_indices
assignment = synthesizer.synthesize(family, keep_optimum=True, print_stats=print_stats)
Expand Down

0 comments on commit 7c65972

Please sign in to comment.