Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
BarrySlyDelgado committed Oct 23, 2024
1 parent 750255c commit c5b3fa3
Showing 1 changed file with 66 additions and 69 deletions.
135 changes: 66 additions & 69 deletions taskvine/src/bindings/python3/ndcctools/taskvine/stem.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import sys
import copy
import os
import select
import time
import pickle
import cloudpickle
import uuid
import ndcctools.taskvine as vine
from multiprocessing import Pipe
from multiprocessing.connection import wait


class StemObject():
def __init__(self):
def __init__(self):
self._item_id = str(uuid.uuid1())
self._link_id = None
self._chain_id = None
Expand All @@ -22,7 +21,6 @@ def __init__(self):
self._range = "independent"
self.full_map = False


def map(self, from_domain="all", to_range="all"):
if from_domain == "all" or to_range == "all":
self._domain = "all"
Expand All @@ -33,15 +31,16 @@ def map(self, from_domain="all", to_range="all"):
self._range = to_range
return self


class Chain(StemObject):
"""
A chain is group of tasks or groups that are executed to be executed sequentially.
by default, running a singular seed eventually executes chain(group(seed))
_chain - list of Seed and Group stem objects to execute sequentially
_current_link - current StemObject (Seed or Group) currently being executed
_current_items - mapping of current item_ids to item objects (Not yet sent, to a manager)
_waiting_items - mapping of waiting item_ids to item objects (sent to manager, waiting for result)
_current_items - mapping of current item_ids to item objects (Not yet sent, to a manager)
_waiting_items - mapping of waiting item_ids to item objects (sent to manager, waiting for result)
_pending_chains - mapping of pending chains waiting for sub items to be complete
_pending_chain_items - mapping from chain_ids to the chains pending items in the queue
_manager_links - list of read connection fds in which the Stem checks for messages
Expand All @@ -54,9 +53,8 @@ class Chain(StemObject):
_item_mapping - mapping of items within a group to their original indicie
_previous_results - list of results generated from previous group executed
_current_results - list of current results generated from activelty running group.
"""
def __init__ (self, *args):
def __init__(self, *args):
self._chain = []
self._current_link = None
self._current_items = {}
Expand All @@ -70,13 +68,12 @@ def __init__ (self, *args):
self._item_mapping = {}
self._previous_results = []
self._current_results = []

# remove item from parent chain's waiting items
super().__init__()
self._chain_mapping = {self._item_id:self}

self._chain_mapping = {self._item_id: self}

# Add Stem Objects to the Chain.
# Add Stem Objects to the Chain.
# NOTE: the order of the objects determine execution order.
# NOTE: Chains can not bee added to Chains. i.e. (Chain(Chain())) is invalid. However, Chain(Group(Chain)) is valid.
count = 0
Expand All @@ -94,7 +91,7 @@ def __init__ (self, *args):
count += 1
else:
raise TypeError
except:
except Exception:
if isinstance(arg, Group):
self._chain.append(arg)
arg._link_id = count
Expand All @@ -105,25 +102,25 @@ def __init__ (self, *args):
count += 1
else:
raise TypeError

# When deleting a master chain we send messages to kill all managers.
#def __del__(self):
# for manager in self._managers:
# self._managers[manager]["write"].send("kill")
def __del__(self):
for manager in self._managers:
self._managers[manager]["write"].send("kill")

# Set _currrent_link to the next available link. Returns False if there are no more links
# Set _currrent_link to the next available link. Returns False if there are no more links
def pop_link(self):
if self._chain:
self._current_link = self._chain.pop(0)
return self._current_link
else:
return None

# Execute Stem objects within a chain in order
def run(self):
def run(self):
# When a Chain is called wirh run() it becomes the master chain.
# The master chain maintains mappings of results from the previous link that has been executed
# Additionally, the current links results are kept. This is used when mapping outputs to inputs between links
# Additionally, the current links results are kept. This is used when mapping outputs to inputs between links
while self.pop_link():
link = self._current_link
# Execution of a Seed object: Convert seed to group and queue at top.
Expand All @@ -147,7 +144,7 @@ def exec_group(self, group):
# Queue inital items and create mapping for value
self.set_group(group)
# Exceute current link as a group.
while self._current_items or self._waiting_items:
while self._current_items or self._waiting_items:
# Queue current tasks
for item in list(self._current_items.values()):
if isinstance(item, Seed):
Expand All @@ -157,7 +154,7 @@ def exec_group(self, group):
else:
del self._current_items[item._item_id]
# check for results from managers
self.check_results()
self.check_results()

self._previous_results = []
# expand results to a continous list and move to previous results
Expand All @@ -180,25 +177,25 @@ def exec_sub_seed(self, seed):
manager = seed._manager
if manager not in self._managers:
read, write = run_manager(manager)
self._managers[manager] = {"read":read, "write":write}
self._managers[manager] = {"read": read, "write": write}
self._manager_links.append(read)
self._managers[manager]["write"].send(seed)
self._waiting_items[seed._item_id] = seed
del self._current_items[seed._item_id]

def exec_sub_chain(self, chain):
# get next avialable link from chain
chain_link = chain.pop_link()
if chain._item_id not in self._chain_mapping:
self._chain_mapping[chain._item_id] = chain
if isinstance(chain_link, Seed):
if isinstance(chain_link, Seed):
seed = chain_link
grouped_seed = Group(seed)
grouped_seed.map(seed._domain, seed._range)
grouped_seed._link_id = seed._link_id
grouped_seed._link_id = seed._link_id
chain._chain.insert(0, grouped_seed)
elif isinstance(chain_link, Group):
# set chains results for mapping
# set chains results for mapping
self.set_chain_results(chain)
group = chain_link
chain.set_group(group)
Expand All @@ -216,7 +213,7 @@ def exec_sub_chain(self, chain):
elif chain_link is None:
# TODO: deep copy probably
chain._previous_results = []
chain.expand_results(chain._current_results)
chain.expand_results(chain._current_results)
chain._parent_chain._item_mapping[chain._item_id].append(chain._previous_results)
del self._current_items[chain._item_id]
del chain._parent_chain._waiting_items[chain._item_id]
Expand All @@ -233,7 +230,7 @@ def set_group(self, group):
self._item_index = count
item._chain_id = self._item_id
item._parent_chain = self
count += 1
count += 1
if group._link_id == 0:
pass
elif group._domain != "idependent" and group._range != "independent":
Expand All @@ -252,22 +249,22 @@ def set_chain_results(self, chain):
if chain._parent_chain._full_map:
chain._previous_results = copy.deepcopy(chain._parent_chain._previous_results)
else:
domain_index = chain._parent_chain._item_index//chain._range
start_index = domain_index*chain._domain
stop_index - domain_index*chain._domain+chain._domain
chain._previous_results = copy.deepcopy(chain._parent_chain._previous_results[start_index:stop_index])
domain_index = chain._parent_chain._item_index // chain._range
start_index = domain_index * chain._domain
stop_index = domain_index * chain._domain + chain._domain
chain._previous_results = copy.deepcopy(chain._parent_chain._previous_results[start_index:stop_index])

elif chain._current_link is not None:
chain._current_results = []

def map_chain_seed(self, seed):
chain = seed._parent_chain
if chain._full_map:
seed.update_args(chain._previous_results)
else:
domain_index = chain._item_index//seed._range
start_index = domain_index*seed._domain
stop_index = domain_index*seed._domain+seed._domain
domain_index = chain._item_index // seed._range
start_index = domain_index * seed._domain
stop_index = domain_index * seed._domain + seed._domain
seed.update_args(chain._previous_results[start_index:stop_index])

def map_frontier_seed(self, seed):
Expand All @@ -276,9 +273,9 @@ def map_frontier_seed(self, seed):
seed.update_args(self._previous_results)
else:
# Map specific arguments to the to the seed arguments
domain_index = self._item_index//seed._range
start_index = domain_index*seed._domain
stop_index = domain_index*seed._domain+seed._domain
domain_index = self._item_index // seed._range
start_index = domain_index * seed._domain
stop_index = domain_index * seed._domain + seed._domain
seed.update_args(self._previous_results[start_index:stop_index])

def check_results(self):
Expand Down Expand Up @@ -316,16 +313,16 @@ def unlink_from_chain(self, item):
# add chain to master's current items if no pending items and chain is not the master chain
if not chain._current_items and not chain._waiting_items and chain._item_id != self._item_id:
chain._previous_results = []
chain.expand_results(chain._current_results)
chain.expand_results(chain._current_results)
self._current_items[chain._item_id] = chain

def expand_results(self, results):
for result in results:
if isinstance(result, list):
self.expand_results(result)
else:
self._previous_results.append(result)

def handle_bloom(self, item, bloomed_item):
chain = self._chain_mapping[item._chain_id]
del chain._waiting_items[item._item_id]
Expand All @@ -348,22 +345,23 @@ def handle_bloom(self, item, bloomed_item):
group_item._chain_id = chain._item_id
group_item._parent_chain = chain

chain._item_mapping[item._item_id].append([])
chain._item_mapping[group_item._item_id] = chain._item_mapping[item._item_id][count]
chain._item_mapping[item._item_id].append([])
chain._item_mapping[group_item._item_id] = chain._item_mapping[item._item_id][count]
self._current_items[group_item._item_id] = group_item

if chain._item_id != self._item_id:
chain._waiting_items[item._item_id] = group_item

count += 1
else:
raise TypeError



class Group(StemObject):
"""
A group is a collection of seeds and chains that can execute in parallel:
"""
def __init__ (self, *args):
def __init__(self, *args):
self._group = []
super().__init__()
for arg in iter(args):
Expand All @@ -376,22 +374,22 @@ def __init__ (self, *args):
self._group.append(iarg)
else:
raise TypeError
except:
except Exception:
if isinstance(arg, Chain):
self._group.append(arg)
elif isinstance(arg, Seed):
self._group.append(arg)
else:
raise TypeError

def run(self):
chain = Chain(self)
chain.run()


class Seed(StemObject):
"""
A Seed onject reflects the base object to be executed.
A Seed onject reflects the base object to be executed.
"""
def __init__(self, func, *args, **kwargs):
# TODO: This keeps The connection object from complaining when sending certain objectis via the Commuincation Pipe
Expand All @@ -410,17 +408,17 @@ def set_manager(self, manager):
def run(self):
group = Group(self)
group.run()

def update_args(self, new_args):
func, args, kwargs = cloudpickle.loads(self._srl)
args = tuple(new_args)
self._srl = cloudpickle.dumps((func, args, kwargs))

def set_result(self, result):
self._result = result

def set(self, attr, *args, **kwargs):
self._attr_list[attr] = {"args":args, "kwargs":kwargs}
self._attr_list[attr] = {"args": args, "kwargs": kwargs}
return self

def print(self):
Expand All @@ -430,7 +428,7 @@ def print(self):
class Bloom():
"""
A Bloom is a oject that is returned from a task that contains a Seed, Group, or Chain
that replaces a seed
that replaces a seed
"""
def __init__(self, item):
if isinstance(item, Seed) or isinstance(item, Group):
Expand All @@ -439,16 +437,16 @@ def __init__(self, item):
print("Error: A Bloom object must contain a Seed or Group!", file=sys.stderr)
raise TypeError


def run_manager(name):

p_read, c_write = Pipe()
c_read, p_write = Pipe()
c_read, p_write = Pipe()
pid = os.fork()

# Stem
if pid:
return p_read, p_write


# Manager
else:
Expand All @@ -458,29 +456,29 @@ def exec_func(srl, options):

read = c_read
write = c_write
time.sleep(1)
time.sleep(1)
tasks = {}

m = vine.Manager(port=[9123,9143], name=name)
while(True):
m = vine.Manager(port=[9123, 9143], name=name)
while True:
while read.poll():
try:
item = read.recv()
item = read.recv()
if isinstance(item, Seed):
task = vine.PythonTask(exec_func, item._srl, None)
try:
for attr in item._attr_list:
func = getattr(task, attr)
func(*item._attr_list[attr]["args"], **item._attr_list[attr]["kwargs"])
# TODO error handling
except:
except Exception:
pass
task.set_cores(1)
m.submit(task)
tasks[task.id] = item
elif isinstance(item, str) and item == "kill":
exit(1)
except:
except Exception:
raise RuntimeError
exit(1)

Expand All @@ -492,4 +490,3 @@ def exec_func(srl, options):
item.set_result(task.output)
write.send(item)
del tasks[task.id]

0 comments on commit c5b3fa3

Please sign in to comment.