Skip to content

Commit

Permalink
fixed value of parameters merged in workflows
Browse files Browse the repository at this point in the history
a temporary value could be erased
  • Loading branch information
denisri committed Jul 28, 2023
1 parent e70f5b7 commit 725741a
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions capsul/execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def executable_requirements(self, executable):


class CapsulWorkflow(Controller):
parameters: DictWithProxy
# parameters: DictWithProxy
jobs: dict

def __init__(self, executable, debug=False):
Expand All @@ -118,6 +118,7 @@ def __init__(self, executable, debug=False):
top_parameters.content.update(job_parameters.content)
self.parameters_values = top_parameters.proxy_values
self.parameters_dict = top_parameters.content
# self.parameters = top_parameters

# Set jobs chronology based on processes chronology
for after_process, before_processes in process_chronology.items():
Expand Down Expand Up @@ -160,31 +161,33 @@ def __init__(self, executable, debug=False):
job['wait_for'] = wait_for
for waited in wait_for:
self.jobs[waited].setdefault('waited_by',[]).append(job_id)

parameters = top_parameters
for index in job['parameters_location']:
if index.isnumeric():
index = int(index)
parameters = parameters[index]
def no_proxy(i):
if DictWithProxy.is_proxy(i):
return no_proxy(i[1])
v = parameters.proxy_values[i]
if DictWithProxy.is_proxy(v):
return no_proxy(v[1])
return i
parameters_index = {}
stack = list((k, v[1]) for k, v in parameters.content.items() if k != 'nodes')
while stack:
k, i = stack.pop()
i = no_proxy(i)
i = self._no_proxy(parameters, i)
v = parameters.proxy_values[i]
if isinstance(v, list) and v and DictWithProxy.is_proxy(v[0]):
parameters_index[k] = [no_proxy(i) for i in v]
parameters_index[k] = [self._no_proxy(parameters, i)
for i in v]
else:
parameters_index[k] = i
job['parameters_index'] = parameters_index


@staticmethod
def _no_proxy(parameters, i):
if DictWithProxy.is_proxy(i):
return CapsulWorkflow._no_proxy(parameters, i[1])
v = parameters.proxy_values[i]
if DictWithProxy.is_proxy(v):
return CapsulWorkflow._no_proxy(parameters, v[1])
return i

def _create_jobs(self,
top_parameters,
Expand Down Expand Up @@ -233,7 +236,13 @@ def _create_jobs(self,
in_sub_pipelines=False):
if dest_node in disabled_nodes:
continue
parameters.content[field.name] = nodes_dict.get(dest_node.name, {}).get(plug_name)
if field.metadata('write', False) \
and field.name in parameters.content:
nodes_dict.get(dest_node.name, {})[plug_name] \
= parameters.content[field.name]
else:
parameters.content[field.name] \
= nodes_dict.get(dest_node.name, {}).get(plug_name)
break
if field.is_output():
for dest_node_name, dest_plug_name, dest_node, dest_plug, is_weak in process.plugs[field.name].links_to:
Expand Down Expand Up @@ -273,7 +282,14 @@ def _create_jobs(self,
first = tmp
first_index = first[1]
second_index = second[1]
v1 = parameters.proxy_values[
self._no_proxy(parameters, first_index)]
v2 = parameters.proxy_values[
self._no_proxy(parameters, second_index)]
parameters.proxy_values[second_index] = first
if v1 is None and v2 is not None:
# move former dest value to source (temporary)
parameters.proxy_values[first_index] = v2
elif isinstance(process, ProcessIteration):
parameters['_iterations'] = []
iteration_index = 0
Expand Down Expand Up @@ -396,9 +412,9 @@ def find_temporary_to_generate(executable):
for node in nodes:
# print('!temporaries! initialize node', node.full_name)
for field in node.user_fields():
if (field.output or
not field.metadata('write', False) or
not node.plugs[field.name].activated):
if (field.output or
not field.metadata('write', False) or
not node.plugs[field.name].activated):
field.generate_temporary = False
else:
field.generate_temporary = True
Expand Down

0 comments on commit 725741a

Please sign in to comment.