Skip to content

Commit

Permalink
Implement MultipleInputRequirements workflow Requirement.
Browse files Browse the repository at this point in the history
Provide an alternative implementation of "collection info" object to tool execution if non-collection map-over needs to happen. This is a bit sloppy still - I need to:

- Rename "collection_info" everywhere - maybe map_over_info.
- Build an interface that tool execution environment can consume.
- Implement a blended approach that allows mapping over collections and individual inputs.
- Rename "implicit_inputs" property on these to "implicit_input_collections".
  • Loading branch information
jmchilton committed Jul 20, 2017
1 parent 2b20e2b commit 56e3866
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 91 deletions.
4 changes: 4 additions & 0 deletions lib/galaxy/dataset_collections/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def structure( self ):
effective_structure = effective_structure.multiply( linked_structure )
return None if effective_structure.is_leaf else effective_structure

@property
def implicit_inputs( self ):
return list( self.collection_info.collections.items() )

@staticmethod
def for_collections( collections_to_match, collection_type_descriptions ):
if not collections_to_match.has_collections():
Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/managers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def create( self, trans, parent, name, collection_type, element_identifiers=None
name=name,
)
if implicit_collection_info:
for input_name, input_collection in implicit_collection_info[ "implicit_inputs" ]:
dataset_collection_instance.add_implicit_input_collection( input_name, input_collection )
implicit_inputs = implicit_collection_info[ "implicit_inputs" ]
if implicit_inputs:
for input_name, input_collection in implicit_inputs:
dataset_collection_instance.add_implicit_input_collection( input_name, input_collection )
for output_dataset in implicit_collection_info.get( "outputs" ):
if output_dataset not in trans.sa_session:
output_dataset = trans.sa_session.query( type( output_dataset ) ).get( output_dataset.id )
Expand Down
72 changes: 44 additions & 28 deletions lib/galaxy/tools/cwl/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import six

from galaxy.tools.hash import build_tool_hash
from galaxy.util import safe_makedirs
from galaxy.util import listify, safe_makedirs
from galaxy.util.bunch import Bunch
from galaxy.util.odict import odict

Expand All @@ -41,6 +41,7 @@
"InlineJavascriptRequirement",
"ShellCommandRequirement",
"ScatterFeatureRequirement",
"MultipleInputFeatureRequirement",
]


Expand Down Expand Up @@ -484,19 +485,24 @@ def input_connections_by_step(self, step_proxies):
for cwl_input in cwl_inputs:
cwl_input_id = cwl_input["id"]
cwl_source_id = cwl_input["source"]
step_name, input_name = split_step_reference(cwl_input_id)
output_step_name, output_name = split_step_reference(cwl_source_id)
output_step_id = self.cwl_id + "#" + output_step_name
if output_step_id not in cwl_ids_to_index:
template = "Output [%s] does not appear in ID-to-index map [%s]."
msg = template % (output_step_id, cwl_ids_to_index)
raise AssertionError(msg)

input_connections_step[input_name] = {
"id": cwl_ids_to_index[output_step_id],
"output_name": output_name,
"input_type": "dataset"
}
step_name, input_name = split_step_references(cwl_input_id, multiple=False)
# Consider only allow multiple if MultipleInputFeatureRequirement is enabled
for (output_step_name, output_name) in split_step_references(cwl_source_id):
output_step_id = self.cwl_id + "#" + output_step_name
if output_step_id not in cwl_ids_to_index:
template = "Output [%s] does not appear in ID-to-index map [%s]."
msg = template % (output_step_id, cwl_ids_to_index)
raise AssertionError(msg)

if input_name not in input_connections_step:
input_connections_step[input_name] = []

input_connections_step[input_name].append({
"id": cwl_ids_to_index[output_step_id],
"output_name": output_name,
"input_type": "dataset"
})

input_connections_by_step.append(input_connections_step)

return input_connections_by_step
Expand Down Expand Up @@ -551,24 +557,34 @@ def cwl_object_to_annotation(self, cwl_obj):
return cwl_obj.get("doc", None)


def split_step_reference(step_reference):
def split_step_references(step_references, multiple=True):
"""Split a CWL step input or output reference into step id and name."""
# Trim off the workflow id part of the reference.
assert "#" in step_reference
cwl_workflow_id, step_reference = step_reference.split("#", 1)
step_references = listify(step_references)
split_references = []

for step_reference in step_references:
assert "#" in step_reference
cwl_workflow_id, step_reference = step_reference.split("#", 1)

# Now just grab the step name and input/output name.
assert "#" not in step_reference
if "/" in step_reference:
step_name, io_name = step_reference.split("/", 1)
# Now just grab the step name and input/output name.
assert "#" not in step_reference
if "/" in step_reference:
step_name, io_name = step_reference.split("/", 1)
else:
# Referencing an input, not a step.
# In Galaxy workflows input steps have an implicit output named
# "output" for consistency with tools - in cwl land
# just the input name is referenced.
step_name = step_reference
io_name = "output"
split_references.append((step_name, io_name))

if multiple:
return split_references
else:
# Referencing an input, not a step.
# In Galaxy workflows input steps have an implicit output named
# "output" for consistency with tools - in cwl land
# just the input name is referenced.
step_name = step_reference
io_name = "output"
return (step_name, io_name)
assert len(split_references) == 1
return split_references[0]


class StepProxy(object):
Expand Down
21 changes: 12 additions & 9 deletions lib/galaxy/tools/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,21 @@ def create_output_collections( self, trans, history, params ):

structure = self.collection_info.structure

# params is just one sample tool param execution with parallelized
# collection replaced with a specific dataset. Need to replace this
# with the collection and wrap everything up so can evaluate output
# label.
params.update( self.collection_info.collections ) # Replace datasets with source collections for labelling outputs.

collection_names = ["collection %d" % c.hid for c in self.collection_info.collections.values()]
on_text = on_text_for_names( collection_names )
if hasattr( self.collection_info, "collections" ):
# params is just one sample tool param execution with parallelized
# collection replaced with a specific dataset. Need to replace this
# with the collection and wrap everything up so can evaluate output
# label.
params.update( self.collection_info.collections ) # Replace datasets with source collections for labelling outputs.

collection_names = ["collection %d" % c.hid for c in self.collection_info.collections.values()]
on_text = on_text_for_names( collection_names )
else:
on_text = "implicitly create collection for inputs"

collections = {}

implicit_inputs = list(self.collection_info.collections.items())
implicit_inputs = self.collection_info.implicit_inputs
for output_name, outputs in self.outputs_by_output_name.items():
if not len( structure ) == len( outputs ):
# Output does not have the same structure, if all jobs were
Expand Down
78 changes: 76 additions & 2 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
web
)
from galaxy.dataset_collections import matching
from galaxy.dataset_collections.structure import leaf, Tree
from galaxy.exceptions import ToolMissingException
from galaxy.jobs.actions.post import ActionBox
from galaxy.model import PostJobAction
Expand Down Expand Up @@ -58,6 +59,7 @@ class WorkflowModule( object ):

def __init__( self, trans, content_id=None, **kwds ):
self.trans = trans
self.app = trans.app
self.content_id = content_id
self.state = DefaultToolState()

Expand Down Expand Up @@ -859,6 +861,25 @@ def decode_runtime_state( self, runtime_state ):
else:
raise ToolMissingException( "Tool %s missing. Cannot recover runtime state." % self.tool_id )

def _check_for_scatters( self, step, tool, progress, tool_state ):
scatter_collector = ScatterOverCollector(
self.app
)

def callback( input, prefixed_name, **kwargs ):
replacement = progress.replacement_for_tool_input( step, input, prefixed_name )
log.info("replacement for %s is %s" % (prefixed_name, replacement))
if replacement:
if isinstance(replacement, ScatterOver):
scatter_collector.add_scatter(replacement)

return NO_REPLACEMENT

visit_input_values( tool.inputs, tool_state, callback, no_replacement_value=NO_REPLACEMENT )

# TODO: num slices is bad - what about empty arrays.
return None if scatter_collector.num_slices == 0 else scatter_collector

def execute( self, trans, progress, invocation, step ):
tool = trans.app.toolbox.get_tool( step.tool_id, tool_version=step.tool_version, tool_hash=step.tool_hash )
tool_state = step.state
Expand All @@ -869,10 +890,10 @@ def execute( self, trans, progress, invocation, step ):
collections_to_match = self._find_collections_to_match( tool, progress, step )
# Have implicit collections...
if collections_to_match.has_collections():
# Is a MatchingCollections
collection_info = self.trans.app.dataset_collections_service.match_collections( collections_to_match )
else:
collection_info = None

collection_info = self._check_for_scatters( step, tool, progress, make_dict_copy( tool_state.inputs ) )
param_combinations = []
if collection_info:
iteration_elements_iter = collection_info.slice_collections()
Expand Down Expand Up @@ -1120,6 +1141,59 @@ def load_module_sections( trans ):
return module_sections


class ScatterOverCollector(object):

def __init__(self, app):
self.inputs_per_name = {}
self.num_slices = 0
self.app = app

def add_scatter(self, scatter_over):
inputs = scatter_over.inputs
self.inputs_per_name[scatter_over.prefixed_name] = inputs
if self.num_slices > 0:
assert len(inputs) == self.num_slices
else:
self.num_slices = len(inputs)

def slice_collections(self):
slices = []
for i in range(self.num_slices):
this_slice = {}
for prefixed_name, inputs in self.inputs_per_name.items():
this_slice[prefixed_name] = SliceElement(inputs[i], str(i))
slices.append(this_slice)
return slices

@property
def structure(self):
collection_type_descriptions = self.app.dataset_collections_service.collection_type_descriptions
collection_type_description = collection_type_descriptions.for_collection_type("list")
children = []
for input in self.inputs_per_name.values()[0]:
children.append((input.element_identifier, leaf))

return Tree(children, collection_type_description)

@property
def implicit_inputs(self):
return None


class SliceElement(object):

def __init__(self, dataset_instance, element_identifier):
self.dataset_instance = dataset_instance
self.element_identifier = element_identifier


class ScatterOver(object):

def __init__(self, prefixed_name, inputs):
self.prefixed_name = prefixed_name
self.inputs = inputs


class DelayedWorkflowEvaluation(Exception):

def __init__(self, why=None):
Expand Down
16 changes: 15 additions & 1 deletion lib/galaxy/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,21 @@ def replacement_for_tool_input( self, step, input, prefixed_name ):
replacement = replacement[ 0 ]
else:
is_data = input.type in ["data", "data_collection"]
replacement = self.replacement_for_connection( connection[ 0 ], is_data=is_data )
if len( connection ) == 1:
replacement = self.replacement_for_connection( connection[ 0 ], is_data=is_data )
else:
# We've mapped multiple individual inputs to a single parameter,
# promote output to a collection.
inputs = []
for c in connection:
input_from_connection = self.replacement_for_connection( c, is_data=is_data )
inputs.append(input_from_connection)

replacement = modules.ScatterOver(
prefixed_name,
inputs,
)

return replacement

def replacement_for_connection( self, connection, is_data=True ):
Expand Down
Loading

0 comments on commit 56e3866

Please sign in to comment.