diff --git a/taskvine/src/bindings/python3/ndcctools/taskvine/stem.py b/taskvine/src/bindings/python3/ndcctools/taskvine/stem.py index 6702d2abbf..a3c0226da4 100644 --- a/taskvine/src/bindings/python3/ndcctools/taskvine/stem.py +++ b/taskvine/src/bindings/python3/ndcctools/taskvine/stem.py @@ -1,17 +1,16 @@ import sys import copy import os -import select import time -import pickle import cloudpickle import uuid import ndcctools.taskvine as vine from multiprocessing import Pipe from multiprocessing.connection import wait + class StemObject(): - def __init__(self): + def __init__(self): self._item_id = str(uuid.uuid1()) self._link_id = None self._chain_id = None @@ -22,7 +21,6 @@ def __init__(self): self._range = "independent" self.full_map = False - def map(self, from_domain="all", to_range="all"): if from_domain == "all" or to_range == "all": self._domain = "all" @@ -33,15 +31,16 @@ def map(self, from_domain="all", to_range="all"): self._range = to_range return self + class Chain(StemObject): """ A chain is group of tasks or groups that are executed to be executed sequentially. by default, running a singular seed eventually executes chain(group(seed)) - + _chain - list of Seed and Group stem objects to execute sequentially _current_link - current StemObject (Seed or Group) currently being executed - _current_items - mapping of current item_ids to item objects (Not yet sent, to a manager) - _waiting_items - mapping of waiting item_ids to item objects (sent to manager, waiting for result) + _current_items - mapping of current item_ids to item objects (Not yet sent, to a manager) + _waiting_items - mapping of waiting item_ids to item objects (sent to manager, waiting for result) _pending_chains - mapping of pending chains waiting for sub items to be complete _pending_chain_items - mapping from chain_ids to the chains pending items in the queue _manager_links - list of read connection fds in which the Stem checks for messages @@ -54,9 +53,8 @@ class Chain(StemObject): _item_mapping - mapping of items within a group to their original indicie _previous_results - list of results generated from previous group executed _current_results - list of current results generated from activelty running group. - """ - def __init__ (self, *args): + def __init__(self, *args): self._chain = [] self._current_link = None self._current_items = {} @@ -70,13 +68,12 @@ def __init__ (self, *args): self._item_mapping = {} self._previous_results = [] self._current_results = [] - + # remove item from parent chain's waiting items super().__init__() - self._chain_mapping = {self._item_id:self} - + self._chain_mapping = {self._item_id: self} - # Add Stem Objects to the Chain. + # Add Stem Objects to the Chain. # NOTE: the order of the objects determine execution order. # NOTE: Chains can not bee added to Chains. i.e. (Chain(Chain())) is invalid. However, Chain(Group(Chain)) is valid. count = 0 @@ -94,7 +91,7 @@ def __init__ (self, *args): count += 1 else: raise TypeError - except: + except Exception: if isinstance(arg, Group): self._chain.append(arg) arg._link_id = count @@ -105,25 +102,25 @@ def __init__ (self, *args): count += 1 else: raise TypeError - + # When deleting a master chain we send messages to kill all managers. - #def __del__(self): - # for manager in self._managers: - # self._managers[manager]["write"].send("kill") + def __del__(self): + for manager in self._managers: + self._managers[manager]["write"].send("kill") - # Set _currrent_link to the next available link. Returns False if there are no more links + # Set _currrent_link to the next available link. Returns False if there are no more links def pop_link(self): if self._chain: self._current_link = self._chain.pop(0) return self._current_link else: return None - + # Execute Stem objects within a chain in order - def run(self): + def run(self): # When a Chain is called wirh run() it becomes the master chain. # The master chain maintains mappings of results from the previous link that has been executed - # Additionally, the current links results are kept. This is used when mapping outputs to inputs between links + # Additionally, the current links results are kept. This is used when mapping outputs to inputs between links while self.pop_link(): link = self._current_link # Execution of a Seed object: Convert seed to group and queue at top. @@ -147,7 +144,7 @@ def exec_group(self, group): # Queue inital items and create mapping for value self.set_group(group) # Exceute current link as a group. - while self._current_items or self._waiting_items: + while self._current_items or self._waiting_items: # Queue current tasks for item in list(self._current_items.values()): if isinstance(item, Seed): @@ -157,7 +154,7 @@ def exec_group(self, group): else: del self._current_items[item._item_id] # check for results from managers - self.check_results() + self.check_results() self._previous_results = [] # expand results to a continous list and move to previous results @@ -180,25 +177,25 @@ def exec_sub_seed(self, seed): manager = seed._manager if manager not in self._managers: read, write = run_manager(manager) - self._managers[manager] = {"read":read, "write":write} + self._managers[manager] = {"read": read, "write": write} self._manager_links.append(read) self._managers[manager]["write"].send(seed) self._waiting_items[seed._item_id] = seed del self._current_items[seed._item_id] - + def exec_sub_chain(self, chain): # get next avialable link from chain chain_link = chain.pop_link() if chain._item_id not in self._chain_mapping: self._chain_mapping[chain._item_id] = chain - if isinstance(chain_link, Seed): + if isinstance(chain_link, Seed): seed = chain_link grouped_seed = Group(seed) grouped_seed.map(seed._domain, seed._range) - grouped_seed._link_id = seed._link_id + grouped_seed._link_id = seed._link_id chain._chain.insert(0, grouped_seed) elif isinstance(chain_link, Group): - # set chains results for mapping + # set chains results for mapping self.set_chain_results(chain) group = chain_link chain.set_group(group) @@ -216,7 +213,7 @@ def exec_sub_chain(self, chain): elif chain_link is None: # TODO: deep copy probably chain._previous_results = [] - chain.expand_results(chain._current_results) + chain.expand_results(chain._current_results) chain._parent_chain._item_mapping[chain._item_id].append(chain._previous_results) del self._current_items[chain._item_id] del chain._parent_chain._waiting_items[chain._item_id] @@ -233,7 +230,7 @@ def set_group(self, group): self._item_index = count item._chain_id = self._item_id item._parent_chain = self - count += 1 + count += 1 if group._link_id == 0: pass elif group._domain != "idependent" and group._range != "independent": @@ -252,22 +249,22 @@ def set_chain_results(self, chain): if chain._parent_chain._full_map: chain._previous_results = copy.deepcopy(chain._parent_chain._previous_results) else: - domain_index = chain._parent_chain._item_index//chain._range - start_index = domain_index*chain._domain - stop_index - domain_index*chain._domain+chain._domain - chain._previous_results = copy.deepcopy(chain._parent_chain._previous_results[start_index:stop_index]) + domain_index = chain._parent_chain._item_index // chain._range + start_index = domain_index * chain._domain + stop_index = domain_index * chain._domain + chain._domain + chain._previous_results = copy.deepcopy(chain._parent_chain._previous_results[start_index:stop_index]) elif chain._current_link is not None: chain._current_results = [] - + def map_chain_seed(self, seed): chain = seed._parent_chain if chain._full_map: seed.update_args(chain._previous_results) else: - domain_index = chain._item_index//seed._range - start_index = domain_index*seed._domain - stop_index = domain_index*seed._domain+seed._domain + domain_index = chain._item_index // seed._range + start_index = domain_index * seed._domain + stop_index = domain_index * seed._domain + seed._domain seed.update_args(chain._previous_results[start_index:stop_index]) def map_frontier_seed(self, seed): @@ -276,9 +273,9 @@ def map_frontier_seed(self, seed): seed.update_args(self._previous_results) else: # Map specific arguments to the to the seed arguments - domain_index = self._item_index//seed._range - start_index = domain_index*seed._domain - stop_index = domain_index*seed._domain+seed._domain + domain_index = self._item_index // seed._range + start_index = domain_index * seed._domain + stop_index = domain_index * seed._domain + seed._domain seed.update_args(self._previous_results[start_index:stop_index]) def check_results(self): @@ -316,16 +313,16 @@ def unlink_from_chain(self, item): # add chain to master's current items if no pending items and chain is not the master chain if not chain._current_items and not chain._waiting_items and chain._item_id != self._item_id: chain._previous_results = [] - chain.expand_results(chain._current_results) + chain.expand_results(chain._current_results) self._current_items[chain._item_id] = chain - + def expand_results(self, results): for result in results: if isinstance(result, list): self.expand_results(result) else: self._previous_results.append(result) - + def handle_bloom(self, item, bloomed_item): chain = self._chain_mapping[item._chain_id] del chain._waiting_items[item._item_id] @@ -348,22 +345,23 @@ def handle_bloom(self, item, bloomed_item): group_item._chain_id = chain._item_id group_item._parent_chain = chain - chain._item_mapping[item._item_id].append([]) - chain._item_mapping[group_item._item_id] = chain._item_mapping[item._item_id][count] + chain._item_mapping[item._item_id].append([]) + chain._item_mapping[group_item._item_id] = chain._item_mapping[item._item_id][count] self._current_items[group_item._item_id] = group_item - + if chain._item_id != self._item_id: chain._waiting_items[item._item_id] = group_item count += 1 else: raise TypeError - + + class Group(StemObject): """ A group is a collection of seeds and chains that can execute in parallel: """ - def __init__ (self, *args): + def __init__(self, *args): self._group = [] super().__init__() for arg in iter(args): @@ -376,22 +374,22 @@ def __init__ (self, *args): self._group.append(iarg) else: raise TypeError - except: + except Exception: if isinstance(arg, Chain): self._group.append(arg) elif isinstance(arg, Seed): self._group.append(arg) else: raise TypeError - + def run(self): chain = Chain(self) chain.run() - + class Seed(StemObject): """ - A Seed onject reflects the base object to be executed. + A Seed onject reflects the base object to be executed. """ def __init__(self, func, *args, **kwargs): # TODO: This keeps The connection object from complaining when sending certain objectis via the Commuincation Pipe @@ -410,17 +408,17 @@ def set_manager(self, manager): def run(self): group = Group(self) group.run() - + def update_args(self, new_args): func, args, kwargs = cloudpickle.loads(self._srl) args = tuple(new_args) self._srl = cloudpickle.dumps((func, args, kwargs)) - + def set_result(self, result): self._result = result def set(self, attr, *args, **kwargs): - self._attr_list[attr] = {"args":args, "kwargs":kwargs} + self._attr_list[attr] = {"args": args, "kwargs": kwargs} return self def print(self): @@ -430,7 +428,7 @@ def print(self): class Bloom(): """ A Bloom is a oject that is returned from a task that contains a Seed, Group, or Chain - that replaces a seed + that replaces a seed """ def __init__(self, item): if isinstance(item, Seed) or isinstance(item, Group): @@ -439,16 +437,16 @@ def __init__(self, item): print("Error: A Bloom object must contain a Seed or Group!", file=sys.stderr) raise TypeError + def run_manager(name): - + p_read, c_write = Pipe() - c_read, p_write = Pipe() + c_read, p_write = Pipe() pid = os.fork() - + # Stem if pid: return p_read, p_write - # Manager else: @@ -458,14 +456,14 @@ def exec_func(srl, options): read = c_read write = c_write - time.sleep(1) + time.sleep(1) tasks = {} - m = vine.Manager(port=[9123,9143], name=name) - while(True): + m = vine.Manager(port=[9123, 9143], name=name) + while True: while read.poll(): try: - item = read.recv() + item = read.recv() if isinstance(item, Seed): task = vine.PythonTask(exec_func, item._srl, None) try: @@ -473,14 +471,14 @@ def exec_func(srl, options): func = getattr(task, attr) func(*item._attr_list[attr]["args"], **item._attr_list[attr]["kwargs"]) # TODO error handling - except: + except Exception: pass task.set_cores(1) m.submit(task) tasks[task.id] = item elif isinstance(item, str) and item == "kill": exit(1) - except: + except Exception: raise RuntimeError exit(1) @@ -492,4 +490,3 @@ def exec_func(srl, options): item.set_result(task.output) write.send(item) del tasks[task.id] -