diff --git a/paynt/family/family.py b/paynt/family/family.py index 6d274246..03c4b1c9 100644 --- a/paynt/family/family.py +++ b/paynt/family/family.py @@ -22,7 +22,6 @@ def __init__(self): self.selected_choices = None self.constraint_indices = None self.refinement_depth = None - self.splitter = None class Family: @@ -44,7 +43,6 @@ def __init__(self, other=None): self.selected_choices = None self.mdp = None self.analysis_result = None - self.splitter = None self.encoding = None def add_parent_info(self, parent_info): @@ -108,12 +106,27 @@ def __str__(self): def copy(self): return Family(self) + def assume_hole_options_copy(self, hole, options): + ''' + Create a copy and assume suboptions for a given hole. + @note this does not check whether @options are actually suboptions of this hole. + ''' + subfamily = self.copy() + subfamily.hole_set_options(hole,options) + return subfamily + def assume_options_copy(self, hole_options): - ''' Create a copy and assume suboptions for each hole. ''' - holes_copy = self.copy() + ''' + Create a copy and assume suboptions for each hole. + @note this does not check whether suboptions are actually suboptions of any given hole. + ''' + subfamily = self.copy() for hole,options in enumerate(hole_options): - holes_copy.hole_set_options(hole,options) - return holes_copy + subfamily.hole_set_options(hole,options) + return subfamily + + def split(self, splitter, suboptions): + return [self.assume_hole_options_copy(splitter,options) for options in suboptions] def pick_any(self): hole_options = [[self.hole_options(hole)[0]] for hole in range(self.num_holes)] @@ -140,24 +153,12 @@ def construct_assignment(self, combination): assignment = self.assume_options_copy(suboptions) return assignment - def subholes(self, hole_index, options): - ''' - Construct a semi-shallow copy of self with only one modified hole - @hole_index having selected @options - :note this is a performance/memory optimization associated with creating - subfamilies wrt one splitter having restricted options - ''' - shallow_copy = self.copy() - shallow_copy.hole_set_options(hole_index,options) - return shallow_copy - def collect_parent_info(self, specification): pi = ParentInfo() pi.selected_choices = self.selected_choices pi.refinement_depth = self.refinement_depth cr = self.analysis_result.constraints_result pi.constraint_indices = cr.undecided_constraints if cr is not None else [] - pi.splitter = self.splitter return pi def encode(self, smt_solver): diff --git a/paynt/quotient/mdp.py b/paynt/quotient/mdp.py index 53f4b24f..ca919208 100644 --- a/paynt/quotient/mdp.py +++ b/paynt/quotient/mdp.py @@ -475,23 +475,17 @@ def split(self, family): for options in core_suboptions: assert len(options) > 0 other_suboptions = [] - new_family = mdp.family.copy() if len(other_suboptions) == 0: suboptions = core_suboptions else: suboptions = [other_suboptions] + core_suboptions # DFS solves core first # construct corresponding subfamilies - subfamilies = [] - family.splitter = splitter parent_info = family.collect_parent_info(self.specification) parent_info.analysis_result = family.analysis_result parent_info.scheduler_choices = family.scheduler_choices parent_info.unsat_core_hint = self.coloring.unsat_core.copy() - for suboption in suboptions: - subfamily = new_family.subholes(splitter, suboption) + subfamilies = family.split(splitter,suboptions) + for subfamily in subfamilies: subfamily.add_parent_info(parent_info) - subfamily.hole_set_options(splitter, suboption) - subfamilies.append(subfamily) - return subfamilies diff --git a/paynt/quotient/quotient.py b/paynt/quotient/quotient.py index 0b669030..a648a5cd 100644 --- a/paynt/quotient/quotient.py +++ b/paynt/quotient/quotient.py @@ -292,25 +292,16 @@ def split(self, family): other_suboptions = [] # print(mdp.family[splitter], core_suboptions, other_suboptions) - new_family = mdp.family.copy() if len(other_suboptions) == 0: suboptions = core_suboptions else: suboptions = [other_suboptions] + core_suboptions # DFS solves core first - # construct corresponding design subspaces - design_subspaces = [] - # construct corresponding subfamilies - subfamilies = [] - family.splitter = splitter parent_info = family.collect_parent_info(self.specification) - for suboption in suboptions: - subfamily = new_family.subholes(splitter, suboption) + subfamilies = family.split(splitter,suboptions) + for subfamily in subfamilies: subfamily.add_parent_info(parent_info) - subfamily.hole_set_options(splitter, suboption) - subfamilies.append(subfamily) - return subfamilies diff --git a/paynt/synthesizer/policy_tree.py b/paynt/synthesizer/policy_tree.py index d8235527..83e2a30c 100644 --- a/paynt/synthesizer/policy_tree.py +++ b/paynt/synthesizer/policy_tree.py @@ -660,15 +660,9 @@ def split(self, family, prop, hole_selection, splitter, policy): half = len(options) // 2 suboptions = [options[:half], options[half:]] - # construct corresponding design subspaces - subfamilies = [] - family.splitter = splitter - new_family = family.copy() - for suboption in suboptions: - subfamily = new_family.subholes(splitter, suboption) - subfamily.hole_set_options(splitter, suboption) + subfamilies = family.split(splitter,suboptions) + for subfamily in subfamilies: subfamily.candidate_policy = None - subfamilies.append(subfamily) if not SynthesizerPolicyTree.discard_unreachable_choices: self.assign_candidate_policy(subfamilies, hole_selection, splitter, policy)