Skip to content

Commit

Permalink
Refactor map_node_over_list function
Browse files Browse the repository at this point in the history
  • Loading branch information
guill committed Apr 22, 2024
1 parent fa48ad3 commit afa4c7b
Showing 1 changed file with 16 additions and 34 deletions.
50 changes: 16 additions & 34 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,59 +128,41 @@ def mark_missing():

def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
input_is_list = obj.INPUT_IS_LIST
input_is_list = getattr(obj, "INPUT_IS_LIST", False)

if len(input_data_all) == 0:
max_len_input = 0
else:
max_len_input = max([len(x) for x in input_data_all.values()])
max_len_input = max(len(x) for x in input_data_all.values())

# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
return {k: v[i if len(v) > i else -1] for k, v in d.items()}

results = []
if input_is_list:
def process_inputs(inputs, index=None):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
for k, v in input_data_all.items():
for input in v:
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break

for k, v in inputs.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb else v
break
if execution_block is None:
if pre_execute_cb is not None:
pre_execute_cb(0)
results.append(getattr(obj, func)(**input_data_all))
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
results.append(getattr(obj, func)(**inputs))
else:
results.append(execution_block)

if input_is_list:
process_inputs(input_data_all, 0)
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)())
process_inputs({})
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
input_dict = slice_dict(input_data_all, i)
execution_block = None
for k, v in input_dict.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break
if execution_block is None:
if pre_execute_cb is not None:
pre_execute_cb(i)
results.append(getattr(obj, func)(**input_dict))
else:
results.append(execution_block)
process_inputs(input_dict, i)
return results

def merge_result_data(results, obj):
Expand Down

0 comments on commit afa4c7b

Please sign in to comment.