diff --git a/models/mdp/maze/sketch.templ b/models/mdp/maze/sketch.templ index 97ef72f4..ee55b658 100644 --- a/models/mdp/maze/sketch.templ +++ b/models/mdp/maze/sketch.templ @@ -78,8 +78,6 @@ endmodule // rewards -label "what" = mod(x,2)=0; - rewards "steps" clk=1: 1; endrewards diff --git a/models/mdp/obstacles/sketch.props b/models/mdp/obstacles/sketch.props new file mode 100755 index 00000000..54d63b18 --- /dev/null +++ b/models/mdp/obstacles/sketch.props @@ -0,0 +1,2 @@ +//Pmax=? [ "notbad" U "goal" ] +R{"steps"}min=? [ F "goal" ] \ No newline at end of file diff --git a/models/mdp/obstacles/sketch.templ b/models/mdp/obstacles/sketch.templ new file mode 100755 index 00000000..a0985de1 --- /dev/null +++ b/models/mdp/obstacles/sketch.templ @@ -0,0 +1,65 @@ +mdp + +const int N = 10; +const int gMIN = 1; +const int gMAX = N; + +const int o1x = 2; +const int o1y = 2; + +const int o2x = 2; +const int o2y = 6; + +const int o3x = 4; +const int o3y = 3; + +const int o4x = 3; +const int o4y = 6; + +const int o5x = 5; +const int o5y = 5; + +const int o6x = 8; +const int o6y = 8; + +formula at1 = (x = o1x & y = o1y); +formula at2 = (x = o2x & y = o2y); +formula at3 = (x = o3x & y = o3y); +formula at4 = (x = o4x & y = o4y); +formula at5 = (x = o5x & y = o5y); +formula at6 = (x = o6x & y = o6y); + +formula crash = at1 | at2 | at3 | at4 | at5 | at6; +formula goal = (x=gMAX & y=gMAX); + +label "notbad" = !crash; +label "goal" = goal; + + +const double slip = 0.2; + +formula al = min(max(x-1,gMIN),gMAX); +formula all = min(max(x-2,gMIN),gMAX); +formula ar = min(max(x+1,gMIN),gMAX); +formula arr = min(max(x+2,gMIN),gMAX); +formula au = min(max(y-1,gMIN),gMAX); +formula auu = min(max(y-2,gMIN),gMAX); +formula ad = min(max(y+1,gMIN),gMAX); +formula add = min(max(y+2,gMIN),gMAX); + +module agent + x : [gMIN..gMAX] init gMIN; + y : [gMIN..gMAX] init gMIN; + + [le] !crash -> 1-slip : (x'=al) + slip : (x'=all); + [ri] !crash -> 1-slip : (x'=ar) + slip : (x'=arr); + [up] !crash -> 1-slip : (y'=au) + slip : (y'=auu); + [do] !crash -> 1-slip : (y'=ad) + slip : (y'=add); +endmodule + +rewards "steps" + [le] true: 1; + [ri] true: 1; + [up] true: 1; + [do] true: 1; +endrewards \ No newline at end of file diff --git a/models/mdp/zeroconf/sketch.props b/models/mdp/zeroconf/sketch.props new file mode 100644 index 00000000..26954192 --- /dev/null +++ b/models/mdp/zeroconf/sketch.props @@ -0,0 +1,4 @@ +// Maximum probability of configuring correctly +Pmax=? [ F (l=4 & ip=1) ]; +// Minimum probability of configuring correctly +//Pmin=? [ F (l=4 & ip=1) ]; diff --git a/models/mdp/zeroconf/sketch.templ b/models/mdp/zeroconf/sketch.templ new file mode 100644 index 00000000..383efdd4 --- /dev/null +++ b/models/mdp/zeroconf/sketch.templ @@ -0,0 +1,258 @@ +// IPv4: PTA model with digitial clocks +// one concrete host attempting to choose an ip address +// when a number of (abstract) hosts have already got ip addresses +// gxn/dxp/jzs 02/05/03 + +//------------------------------------------------------------- + +// we suppose that +// - the abstract hosts have already picked their addresses +// and always defend their addresses +// - the concrete host never picks the same ip address twice +// (this can happen only with a verys small probability) + +// under these assumptions we do not need message types because: +// 1) since messages to the concrete host will never be a probe, +// this host will react to all messages in the same way +// 2) since the abstract hosts always defend their addresses, +// all messages from the host will get an arp reply if the ip matches + +// following from the above assumptions we require only three abstract IP addresses +// (0,1 and 2) which correspond to the following sets of IP addresses: + +// 0 - the IP addresses of the abstract hosts which the concrete host +// previously tried to configure +// 1 - an IP address of an abstract host which the concrete host is +// currently trying to configure +// 2 - a fresh IP address which the concrete host is currently trying to configure + +// if the host picks an address that is being used it may end up picking another ip address +// in which case there may still be messages corresponding to the old ip address +// to be sent both from and to the host which the host should now disregard +// (since it will never pick the same ip address) + +// to deal with this situation: when a host picks a new ip address we reconfigure the +// messages that are still be be sent or are being sent by changing the ip address to 0 +// (an old ip address of the host) + +// all the messages from the abstract hosts for the 'old' address (in fact the +// set of old addresses since it may have started again more than once) +// can arrive in any order since they are equivalent to the host - it ignores then all + +// also the messages for the old and new address will come from different hosts +// (the ones with that ip address) which we model by allowing them to arrive in any order +// i.e. not neccessarily in the order they where sent + +//------------------------------------------------------------- +// model is an mdp +mdp + +//------------------------------------------------------------- +// VARIABLES +// reset or noreset model +const bool reset; + +const int N; // number of abstract hosts +const int K; // number of probes to send +const double loss = 0.1; // probability of message loss + +// PROBABILITIES +const double old = N/65024; // probability pick an ip address being used +const double new = (1-old); // probability pick a new ip address + +// TIMING CONSTANTS +const int CONSEC = 2; // time interval between sending consecutive probles +const int TRANSTIME = 1; // upper bound on transmission time delay +const int LONGWAIT = 60; // minimum time delay after a high number of address collisions +const int DEFEND = 10; + +const int TIME_MAX_X = 60; // max value of clock x +const int TIME_MAX_Y = 10; // max value of clock y +const int TIME_MAX_Z = 1; // max value of clock z + +// OTHER CONSTANTS +const int MAXCOLL = 10; // maximum number of collisions before long wait +// size of buffers for other hosts +const int B0 = 20; // buffer size for one abstract host +const int B1 = 8; // buffer sizes for all abstract hosts + +//------------------------------------------------------------- +// ENVIRONMENT - models: medium, output buffer of concrete host and all other hosts +module environment + + // buffer of concrete host + b_ip7 : [0..2]; // ip address of message in buffer position 8 + b_ip6 : [0..2]; // ip address of message in buffer position 7 + b_ip5 : [0..2]; // ip address of message in buffer position 6 + b_ip4 : [0..2]; // ip address of message in buffer position 5 + b_ip3 : [0..2]; // ip address of message in buffer position 4 + b_ip2 : [0..2]; // ip address of message in buffer position 3 + b_ip1 : [0..2]; // ip address of message in buffer position 2 + b_ip0 : [0..2]; // ip address of message in buffer position 1 + n : [0..8]; // number of places in the buffer used (from host) + + // messages to be sent from abstract hosts to concrete host + n0 : [0..B0]; // number of messages which do not have the host's current ip address + n1 : [0..B1]; // number of messages which have the host's current ip address + + b : [0..2]; // local state + // 0 - idle + // 1 - sending message from concrete host + // 2 - sending message from abstract host + + z : [0..1]; // clock of environment (needed for the time to send a message) + + ip_mess : [0..2]; // ip in the current message being sent + // 0 - different from concrete host + // 1 - same as the concrete host and in use + // 2 - same as the concrete host and not in use + + // RESET/RECONFIG: when host is about to choose new ip address + // suppose that the host cannot choose the same ip address + // (since happens with very small probability). + // Therefore all messages will have a different ip address, + // i.e. all n1 messages become n0 ones. + // Note this include any message currently being sent (ip is set to zero 0) + [reset] true -> (n1'=0) & (n0'=min(B0,n0+n1)) // abstract buffers + & (ip_mess'=0) // message being set + & (n'=(reset)?0:n) // concrete buffer (remove this update to get NO_RESET model) + & (b_ip7'=0) + & (b_ip6'=0) + & (b_ip5'=0) + & (b_ip4'=0) + & (b_ip3'=0) + & (b_ip2'=0) + & (b_ip1'=0) + & (b_ip0'=0); + // note: prevent anything else from happening when reconfiguration needs to take place + + // time passage (only if no messages to send or sending a message) + [time] l>0 & b=0 & n=0 & n0=0 & n1=0 -> (b'=b); // cannot send a message + [time] l>0 & b>0 & z<1 -> (z'=min(z+1,TIME_MAX_Z)); // sending a message + + // get messages to be sent (so message has same ip address as host) + [send] l>0 & n=0 -> (b_ip0'=ip) & (n'=n+1); + [send] l>0 & n=1 -> (b_ip1'=ip) & (n'=n+1); + [send] l>0 & n=2 -> (b_ip2'=ip) & (n'=n+1); + [send] l>0 & n=3 -> (b_ip3'=ip) & (n'=n+1); + [send] l>0 & n=4 -> (b_ip4'=ip) & (n'=n+1); + [send] l>0 & n=5 -> (b_ip5'=ip) & (n'=n+1); + [send] l>0 & n=6 -> (b_ip6'=ip) & (n'=n+1); + [send] l>0 & n=7 -> (b_ip7'=ip) & (n'=n+1); + [send] l>0 & n=8 -> (n'=n); // buffer full so lose message + + // start sending message from host + [] l>0 & b=0 & n>0 -> (1-loss) : (b'=1) & (ip_mess'=b_ip0) + & (n'=n-1) + & (b_ip7'=0) + & (b_ip6'=b_ip7) + & (b_ip5'=b_ip6) + & (b_ip4'=b_ip5) + & (b_ip3'=b_ip4) + & (b_ip2'=b_ip3) + & (b_ip1'=b_ip2) + & (b_ip0'=b_ip1) // send message + + loss : (n'=n-1) + & (b_ip7'=0) + & (b_ip6'=b_ip7) + & (b_ip5'=b_ip6) + & (b_ip4'=b_ip5) + & (b_ip3'=b_ip4) + & (b_ip2'=b_ip3) + & (b_ip1'=b_ip2) + & (b_ip0'=b_ip1); // lose message + + // start sending message to host + [] l>0 & b=0 & n0>0 -> (1-loss) : (b'=2) & (ip_mess'=0) & (n0'=n0-1) + loss : (n0'=n0-1); // different ip + [] l>0 & b=0 & n1>0 -> (1-loss) : (b'=2) & (ip_mess'=1) & (n1'=n1-1) + loss : (n1'=n1-1); // same ip + + // finish sending message from host + [] l>0 & b=1 & ip_mess=0 -> (b'=0) & (z'=0) & (n0'=min(n0+1,B0)) & (ip_mess'=0); + [] l>0 & b=1 & ip_mess=1 -> (b'=0) & (z'=0) & (n1'=min(n1+1,B1)) & (ip_mess'=0); + [] l>0 & b=1 & ip_mess=2 -> (b'=0) & (z'=0) & (ip_mess'=0); + + // finish sending message to host + [rec] l>0 & b=2 -> (b'=0) & (z'=0) & (ip_mess'=0); + +endmodule + +//------------------------------------------------------------- +// CONCRETE HOST +module host0 + + x : [0..TIME_MAX_X]; // first clock of the host + y : [0..TIME_MAX_Y]; // second clock of the host + + coll : [0..MAXCOLL]; // number of address collisions + probes : [0..K]; // counter (number of probes sent) + mess : [0..1]; // need to send a message or not + defend : [0..1]; // defend (if =1, try to defend IP address) + + ip : [1..2]; // ip address (1 - in use & 2 - fresh) + + l : [0..4] init 1; // location + // 0 : RECONFIGURE + // 1 : RANDOM + // 2 : WAITSP + // 3 : WAITSG + // 4 : USE + + // RECONFIGURE + [reset] l=0 -> (l'=1); + + // RANDOM (choose IP address) + [rec] (l=1) -> true; // get message (ignore since have no ip address) + // small number of collisions (choose straight away) + [] l=1 & coll 1/3*old : (l'=2) & (ip'=1) & (x'=0) + + 1/3*old : (l'=2) & (ip'=1) & (x'=1) + + 1/3*old : (l'=2) & (ip'=1) & (x'=2) + + 1/3*new : (l'=2) & (ip'=2) & (x'=0) + + 1/3*new : (l'=2) & (ip'=2) & (x'=1) + + 1/3*new : (l'=2) & (ip'=2) & (x'=2); + // large number of collisions: (wait for LONGWAIT) + [time] l=1 & coll=MAXCOLL & x (x'=min(x+1,TIME_MAX_X)); + [] l=1 & coll=MAXCOLL & x=LONGWAIT -> 1/3*old : (l'=2) & (ip'=1) & (x'=0) + + 1/3*old : (l'=2) & (ip'=1) & (x'=1) + + 1/3*old : (l'=2) & (ip'=1) & (x'=2) + + 1/3*new : (l'=2) & (ip'=2) & (x'=0) + + 1/3*new : (l'=2) & (ip'=2) & (x'=1) + + 1/3*new : (l'=2) & (ip'=2) & (x'=2); + + // WAITSP + // let time pass + [time] l=2 & x<2 -> (x'=min(x+1,2)); + // send probe + [send] l=2 & x=2 & probes (x'=0) & (probes'=probes+1); + // sent K probes and waited 2 seconds + [] l=2 & x=2 & probes=K -> (l'=3) & (probes'=0) & (coll'=0) & (x'=0); + // get message and ip does not match: ignore + [rec] l=2 & ip_mess!=ip -> (l'=l); + // get a message with matching ip: reconfigure + [rec] l=2 & ip_mess=ip -> (l'=0) & (coll'=min(coll+1,MAXCOLL)) & (x'=0) & (probes'=0); + + // WAITSG (sends two gratuitious arp probes) + // time passage + [time] l=3 & mess=0 & defend=0 & x (x'=min(x+1,TIME_MAX_X)); + [time] l=3 & mess=0 & defend=1 & x (x'=min(x+1,TIME_MAX_X)) & (y'=min(y+1,DEFEND)); + + // receive message and same ip: defend + [rec] l=3 & mess=0 & ip_mess=ip & (defend=0 | y>=DEFEND) -> (defend'=1) & (mess'=1) & (y'=0); + // receive message and same ip: defer + [rec] l=3 & mess=0 & ip_mess=ip & (defend=0 | y (l'=0) & (probes'=0) & (defend'=0) & (x'=0) & (y'=0); + // receive message and different ip + [rec] l=3 & mess=0 & ip_mess!=ip -> (l'=l); + + + // send probe reply or message for defence + [send] l=3 & mess=1 -> (mess'=0); + // send first gratuitous arp message + [send] l=3 & mess=0 & x=CONSEC & probes<1 -> (x'=0) & (probes'=probes+1); + // send second gratuitous arp message (move to use) + [send] l=3 & mess=0 & x=CONSEC & probes=1 -> (l'=4) & (x'=0) & (y'=0) & (probes'=0); + + // USE (only interested in reaching this state so do not need to add anything here) + [] l=4 -> true; + +endmodule + + diff --git a/paynt/quotient/mdp.py b/paynt/quotient/mdp.py index ae8f14d2..6196bb91 100644 --- a/paynt/quotient/mdp.py +++ b/paynt/quotient/mdp.py @@ -2,6 +2,7 @@ import stormpy import payntbind +import json import logging logger = logging.getLogger(__name__) @@ -9,40 +10,28 @@ class Variable: - def __init__(self, variable, model): - assert variable.has_boolean_type() or variable.has_integer_type(), \ - f"variable {variable.name} is neither integer nor boolean" - self.variable = variable + def __init__(self, name, model): + self.name = name assert model.has_state_valuations(), "model has no state valuations" - if self.has_integer_type: - value_getter = model.state_valuations.get_integer_value - else: - value_getter = model.state_valuations.get_boolean_value domain = set() for state in range(model.nr_states): - value = value_getter(state,self.variable) + valuation = json.loads(str(model.state_valuations.get_json(state))) + value = valuation[name] domain.add(value) domain = list(domain) # conversion of boolean variables to integers - if self.has_boolean_type: - domain = [1 if value else 0 for value in domain] + domain_new = [] + for value in domain: + if value is True: + value = 1 + elif value is False: + value = 0 + domain_new.append(value) + domain = domain_new domain = sorted(domain) self.domain = domain - - @property - def name(self): - return self.variable.name - - @property - def has_integer_type(self): - return self.variable.has_integer_type() - - @property - def has_boolean_type(self): - return self.variable.has_boolean_type() - @property def domain_min(self): return self.domain[0] @@ -64,8 +53,10 @@ def __str__(self): return f"{self.name}:{domain}" @classmethod - def from_model(cls, model, program_variables): - variables = [Variable(v,model) for v in program_variables] + def from_model(cls, model): + assert model.has_state_valuations(), "model has no state valuations" + valuation = json.loads(str(model.state_valuations.get_json(0))) + variables = [Variable(var,model) for var in valuation] variables = [v for v in variables if len(v.domain) > 1] return variables @@ -98,8 +89,7 @@ def set_variable(self, variable_index:int): self.child_true = DecisionTreeNode(self) def set_variable_by_name(self, variable_name:str, decision_tree): - name_to_variable_index = {var.name:index for index,var in enumerate(decision_tree.variables)} - variable_index = name_to_variable_index[variable_name] + variable_index = [variable.name for variable in decision_tree.variables].index(variable_name) self.set_variable(variable_index) def create_hole(self, family, action_labels, variables): @@ -113,7 +103,6 @@ def create_hole(self, family, action_labels, variables): prefix = "A" option_labels = action_labels #+ ["__dont_care__"] else: - var = variables[self.variable_index] prefix = variables[self.variable_index].name option_labels = variables[self.variable_index].hole_domain hole_name = f"{prefix}_{self.hole}" @@ -137,10 +126,10 @@ def collect_bounds(self): class DecisionTree: - def __init__(self, model, program_variables): + def __init__(self, model): self.model = model - self.variables = Variable.from_model(model,program_variables) + self.variables = Variable.from_model(model) logger.debug(f"found the following variables: {[str(v) for v in self.variables]}") self.num_nodes = 0 self.root = DecisionTreeNode(None) @@ -173,29 +162,41 @@ def create_family(self, action_labels): node.create_hole(family, action_labels, self.variables) return family - -def custom_decision_tree(mdp, program_variables): - dt = DecisionTree(mdp, program_variables) - +def custom_decision_tree(mdp): + dt = DecisionTree(mdp) decide = lambda node,var_name : node.set_variable_by_name(var_name,dt) - decide(dt.root,"clk") - main = dt.root.child_false - - decide(dt.root.child_false, "y") - decide(dt.root.child_false.child_true, "x") - decide(dt.root.child_false.child_true.child_true, "x") - - # decide(main,"y") - # decide(main.child_false,"x") - # decide(main.child_true,"x") - # decide(main.child_true.child_true,"x") - - # decide(main, "y") - # decide(main.child_false, "x") - # decide(main.child_false.child_true, "x") - # decide(main.child_true, "x") - # decide(main.child_true.child_true, "x") + # model = "maze" + model = "obstacles" + + if model == "maze": + decide(dt.root,"clk") + main = dt.root.child_false + + # decide(decide, "y") + # decide(decide.child_true, "x") + # decide(decide.child_true.child_true, "x") + + # decide(main,"y") + # decide(main.child_false,"x") + # decide(main.child_true,"x") + # decide(main.child_true.child_true,"x") + + decide(main, "y") + decide(main.child_false, "x") + decide(main.child_false.child_true, "x") + decide(main.child_true, "x") + decide(main.child_true.child_true, "x") + + if model == "obstacles": + decide(dt.root, "x") + # decide(dt.root.child_true, "x") + # decide(dt.root.child_true.child_true, "y") + # decide(dt.root.child_true.child_false, "y") + decide(dt.root.child_false, "x") + # decide(dt.root.child_false.child_true, "y") + # decide(dt.root.child_false.child_false, "y") + return dt @@ -205,17 +206,14 @@ class MdpQuotient(paynt.quotient.quotient.Quotient): def __init__(self, mdp, specification): super().__init__(specification=specification) - # get variables before choice origins are lost - assert mdp.has_choice_origins(), "model has no choice origins" - program_variables = mdp.choice_origins.program.variables - - target_states = self.identify_target_states(mdp,self.get_property()) - mdp = payntbind.synthesis.restoreActionsInTargetStates(mdp,target_states) + mdp = payntbind.synthesis.restoreActionsInAbsorbingStates(mdp) self.quotient_mdp = mdp + paynt_mdp = paynt.models.models.Mdp(mdp) + logger.info(f"optimal scheduler has value: {paynt_mdp.model_check_property(self.get_property())}") self.choice_destinations = payntbind.synthesis.computeChoiceDestinations(self.quotient_mdp) self.action_labels,self.choice_to_action = payntbind.synthesis.extractActionLabels(mdp) - decision_tree = custom_decision_tree(mdp, program_variables) + decision_tree = custom_decision_tree(mdp) family = decision_tree.create_family(self.action_labels) print("family = ", family) @@ -224,21 +222,20 @@ def __init__(self, mdp, specification): hole_bounds[node.hole] = node.collect_bounds() # print("hole bounds = ", hole_bounds) - hole_variable = [len(decision_tree.variables) for _ in range(family.num_holes)] + sv = mdp.state_valuations + hole_variable = ["" for _ in range(family.num_holes)] hole_domain = [[] for h in range(family.num_holes)] for node in decision_tree.collect_nonterminals(): - hole_variable[node.hole] = node.variable_index + hole_variable[node.hole] = decision_tree.variables[node.variable_index].name hole_domain[node.hole] = family.hole_to_option_labels[node.hole] # print("hole variables = ", hole_variable) # print("hole domain = ", hole_domain) - stormpy_variables = [v.variable for v in decision_tree.variables] self.decision_tree = decision_tree - self.hole_variable = hole_variable - self.is_action_hole = [var == len(self.decision_tree.variables) for var in self.hole_variable] + self.is_action_hole = [var == "" for var in hole_variable] self.coloring = payntbind.synthesis.ColoringSmt( mdp.nondeterministic_choice_indices, self.choice_to_action, - mdp.state_valuations, stormpy_variables, + mdp.state_valuations, hole_variable, hole_bounds, family.family, hole_domain ) @@ -281,8 +278,7 @@ def scheduler_is_consistent(self, mdp, prop, result): state_to_choice = self.scheduler_to_state_to_choice(mdp, scheduler) choices = self.state_to_choice_to_choices(state_to_choice) consistent,hole_selection = self.areChoicesConsistent(choices, mdp) - if mdp.is_deterministic: - assert consistent, "obtained a DTMC, but the scheduler is not consistent" + # print(consistent, hole_selection) # convert selection to actual hole options for hole,values in enumerate(hole_selection): @@ -351,8 +347,9 @@ def split(self, family, incomplete_search): assert len(hole_assignments[splitter]) == 1 splitter_option = hole_assignments[splitter][0] index = family.hole_options(splitter).index(splitter_option) + assert index < family.hole_num_options(splitter)-1 options = mdp.design_space.hole_options(splitter) - core_suboptions = [options[:index], options[index:]] + core_suboptions = [options[:index+1], options[index+1:]] other_suboptions = [] new_design_space, suboptions = self.discard(mdp, hole_assignments, core_suboptions, other_suboptions, incomplete_search) diff --git a/paynt/synthesizer/synthesizer_ar.py b/paynt/synthesizer/synthesizer_ar.py index e708fa55..af92b830 100644 --- a/paynt/synthesizer/synthesizer_ar.py +++ b/paynt/synthesizer/synthesizer_ar.py @@ -12,7 +12,7 @@ def method_name(self): return "AR" def verify_family(self, family): - self.stat.iteration_smt() + # self.stat.iteration_smt() self.quotient.build(family) if family.mdp is None: return diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.cpp b/payntbind/src/synthesis/quotient/ColoringSmt.cpp index cd94d676..4826667d 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.cpp +++ b/payntbind/src/synthesis/quotient/ColoringSmt.cpp @@ -11,24 +11,39 @@ ColoringSmt::ColoringSmt( std::vector const& row_groups, std::vector const& choice_to_action, storm::storage::sparse::StateValuations const& state_valuations, - std::vector const& variables, - std::vector hole_variable, + std::vector const& hole_to_variable_name, std::vector,std::vector>> hole_bounds, synthesis::Family const& family, std::vector> hole_domain ) : choice_to_action(choice_to_action), row_groups(row_groups), family(family), hole_domain(hole_domain), solver(context) { num_actions = 1 + *max_element(choice_to_action.begin(),choice_to_action.end()); + std::vector variables; + auto const& valuation = state_valuations.at(0); + for(auto x = valuation.begin(); x != valuation.end(); ++x) { + variables.push_back(x.getVariable()); + } + std::vector hole_variable(family.numHoles(),variables.size()); + hole_corresponds_to_program_variable = storm::storage::BitVector(family.numHoles()); + // create solver variables for each hole for(uint64_t hole = 0; hole < family.numHoles(); ++hole) { - uint64_t var = hole_variable[hole]; - bool corresponds_to_program_variable = (var < variables.size()); + std::string const& var_name = hole_to_variable_name[hole]; + bool corresponds_to_program_variable = (var_name != ""); hole_corresponds_to_program_variable.set(hole,corresponds_to_program_variable); std::string name; if(corresponds_to_program_variable) { - name = variables[var].getName() + "_" + std::to_string(hole); + name = var_name + "_" + std::to_string(hole); + uint64_t var_index; + for(var_index = 0; var_index < variables.size(); ++var_index) { + if(variables[var_index].getName() == var_name) { + hole_variable[hole] = var_index; + break; + } + } + STORM_LOG_THROW(var_index < variables.size(), storm::exceptions::InvalidArgumentException, "unexpected variable name"); } else { name = "A_" + std::to_string(hole); } diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.h b/payntbind/src/synthesis/quotient/ColoringSmt.h index f4e22ad8..66c8fff5 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.h +++ b/payntbind/src/synthesis/quotient/ColoringSmt.h @@ -24,8 +24,7 @@ class ColoringSmt { std::vector const& row_groups, std::vector const& choice_to_action, storm::storage::sparse::StateValuations const& state_valuations, - std::vector const& variables, - std::vector hole_variable, + std::vector const& hole_to_variable_name, std::vector,std::vector>> hole_bounds, Family const& family, std::vector> hole_domain diff --git a/payntbind/src/synthesis/quotient/bindings.cpp b/payntbind/src/synthesis/quotient/bindings.cpp index 5d388ac7..b8e371ff 100644 --- a/payntbind/src/synthesis/quotient/bindings.cpp +++ b/payntbind/src/synthesis/quotient/bindings.cpp @@ -301,8 +301,7 @@ void bindings_coloring(py::module& m) { std::vector const&, std::vector const&, storm::storage::sparse::StateValuations const&, - std::vector const&, - std::vector, + std::vector const&, std::vector,std::vector>>, synthesis::Family const&, std::vector> diff --git a/payntbind/src/synthesis/translation/bindings.cpp b/payntbind/src/synthesis/translation/bindings.cpp index 0f5df6ec..6254650c 100644 --- a/payntbind/src/synthesis/translation/bindings.cpp +++ b/payntbind/src/synthesis/translation/bindings.cpp @@ -167,18 +167,37 @@ std::shared_ptr> removeAction( } /** - * Given an MDP, for any state in the set \p target_states, mark any unlabeled action, explicitly add all availabled - * actions and subsequently removed unlabeled actions. + * Given an MDP, for any absorbing state with an unlabeled action, explicitly add all availabled actions and + * subsequently remove these unlabeled actions. */ template -std::shared_ptr> restoreActionsInTargetStates( - storm::models::sparse::Model const& model, - storm::storage::BitVector const& target_states +std::shared_ptr> restoreActionsInAbsorbingStates( + storm::models::sparse::Model const& model ) { auto model_canonic = synthesis::addMissingChoiceLabels(model); - auto model_target_enabled = synthesis::enableAllActions(*model_canonic, target_states); const std::string NO_ACTION_LABEL = "__no_label__"; - auto model_target_fixed = synthesis::removeAction(*model_target_enabled, NO_ACTION_LABEL, target_states); + storm::storage::BitVector const& no_action_label_choices = model_canonic->getChoiceLabeling().getChoices(NO_ACTION_LABEL); + storm::storage::BitVector absorbing_states(model.getNumberOfStates(),true); + for(uint64_t state = 0; state < model.getNumberOfStates(); ++state) { + bool state_is_absorbing = true; + for(uint64_t choice: model_canonic->getTransitionMatrix().getRowGroupIndices(state)) { + if(not no_action_label_choices[choice]) { + absorbing_states.set(state,false); + break; + } + for(auto const& entry: model_canonic->getTransitionMatrix().getRow(choice)) { + if(entry.getColumn() != state) { + absorbing_states.set(state,false); + break; + } + } + if(not absorbing_states[state]) { + break; + } + } + } + auto model_target_enabled = synthesis::enableAllActions(*model_canonic, absorbing_states); + auto model_target_fixed = synthesis::removeAction(*model_target_enabled, NO_ACTION_LABEL, absorbing_states); return model_target_fixed; } @@ -189,7 +208,7 @@ void bindings_translation(py::module& m) { m.def("addMissingChoiceLabels", &synthesis::addMissingChoiceLabels); m.def("extractActionLabels", &synthesis::extractActionLabels); m.def("enableAllActions", &synthesis::enableAllActions); - m.def("restoreActionsInTargetStates", &synthesis::restoreActionsInTargetStates); + m.def("restoreActionsInAbsorbingStates", &synthesis::restoreActionsInAbsorbingStates); py::class_, std::shared_ptr>>(m, "SubPomdpBuilder") .def(py::init const&>())