Skip to content

Commit

Permalink
Merge pull request #1520 from devitocodes/gpu-from-clusters-final
Browse files Browse the repository at this point in the history
GPU data streaming
  • Loading branch information
FabioLuporini authored Jan 6, 2021
2 parents aaf3598 + dbe8392 commit 24f2962
Show file tree
Hide file tree
Showing 62 changed files with 3,783 additions and 749 deletions.
8 changes: 2 additions & 6 deletions devito/core/arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def _specialize_iet(cls, graph, **kwargs):
hoist_prodders(graph)

# Symbol definitions
data_manager = DataManager(sregistry)
data_manager.place_definitions(graph)
data_manager.place_casts(graph)
DataManager(sregistry).process(graph)

return graph

Expand Down Expand Up @@ -69,8 +67,6 @@ def _specialize_iet(cls, graph, **kwargs):
hoist_prodders(graph)

# Symbol definitions
data_manager = DataManager(sregistry)
data_manager.place_definitions(graph)
data_manager.place_casts(graph)
DataManager(sregistry).process(graph)

return graph
92 changes: 69 additions & 23 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from devito.core.operator import OperatorCore
from devito.exceptions import InvalidOperator
from devito.passes.equations import collect_derivatives
from devito.logger import warning
from devito.passes.equations import buffering, collect_derivatives
from devito.passes.clusters import (Blocking, Lift, cire, cse, eliminate_arrays,
extract_increments, factorize, fuse, optimize_pows)
from devito.passes.iet import (DataManager, Ompizer, avoid_denormals, mpiize,
Expand Down Expand Up @@ -82,6 +83,9 @@ def _normalize_kwargs(cls, **kwargs):
o['openmp'] = oo.pop('openmp')
o['mpi'] = oo.pop('mpi')

# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Blocking
o['blockinner'] = oo.pop('blockinner', False)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand All @@ -108,7 +112,9 @@ def _normalize_kwargs(cls, **kwargs):
o['par-dynamic-work'] = oo.pop('par-dynamic-work', cls.PAR_DYNAMIC_WORK)
o['par-nested'] = oo.pop('par-nested', cls.PAR_NESTED)

o['gpu-direct'] = oo.pop('gpu-direct', False)
# Recognised but unused by the CPU backend
oo.pop('gpu-direct', None)
oo.pop('gpu-fit', None)

if oo:
raise InvalidOperator("Unrecognized optimization options: [%s]"
Expand All @@ -134,18 +140,16 @@ def _specialize_iet(cls, graph, **kwargs):
ompizer.make_parallel(graph)

# Symbol definitions
data_manager = DataManager(sregistry)
data_manager.place_definitions(graph)
data_manager.place_casts(graph)
DataManager(sregistry).process(graph)

return graph


class CPU64Operator(CPU64NoopOperator):

@classmethod
@timed_pass(name='specializing.Expressions')
def _specialize_exprs(cls, expressions, **kwargs):
@timed_pass(name='specializing.DSL')
def _specialize_dsl(cls, expressions, **kwargs):
expressions = collect_derivatives(expressions)

return expressions
Expand Down Expand Up @@ -209,9 +213,7 @@ def _specialize_iet(cls, graph, **kwargs):
hoist_prodders(graph)

# Symbol definitions
data_manager = DataManager(sregistry)
data_manager.place_definitions(graph)
data_manager.place_casts(graph)
DataManager(sregistry).process(graph)

return graph

Expand Down Expand Up @@ -247,23 +249,33 @@ def _specialize_iet(cls, graph, **kwargs):
hoist_prodders(graph)

# Symbol definitions
data_manager = DataManager(sregistry)
data_manager.place_definitions(graph)
data_manager.place_casts(graph)
DataManager(sregistry).process(graph)

return graph


class CustomOperator(CPU64Operator):

_known_passes = ('blocking', 'denormals', 'optcomms', 'openmp', 'mpi',
'simd', 'prodders', 'topofuse', 'fuse', 'factorize',
'cire-sops', 'cse', 'lift', 'opt-pows', 'collect-derivs')
@classmethod
def _make_dsl_passes_mapper(cls, **kwargs):
return {
'collect-derivs': collect_derivatives,
}

@classmethod
def _make_exprs_passes_mapper(cls, **kwargs):
options = kwargs['options']

# This callback simply mimics `is_on_device`, used in the device backends.
# It's used by `buffering` to replace `save!=None` TimeFunctions with buffers
def callback(f):
if f.is_TimeFunction and f.save is not None:
return [f.time_dim]
else:
return None

return {
'collect-derivs': collect_derivatives,
'buffering': lambda i: buffering(i, callback, options)
}

@classmethod
Expand Down Expand Up @@ -302,14 +314,50 @@ def _make_iet_passes_mapper(cls, **kwargs):
'prodders': hoist_prodders
}

_known_passes = (
# DSL
'collect-derivs',
# Expressions
'buffering',
# Clusters
'blocking', 'topofuse', 'fuse', 'factorize', 'cire-sops', 'cse',
'lift', 'opt-pows',
# IET
'denormals', 'optcomms', 'openmp', 'mpi', 'simd', 'prodders',
)
_known_passes_disabled = ('tasking', 'streaming', 'gpu-direct', 'openacc')
assert not (set(_known_passes) & set(_known_passes_disabled))

@classmethod
def _build(cls, expressions, **kwargs):
# Sanity check
passes = as_tuple(kwargs['mode'])
if any(i not in cls._known_passes for i in passes):
raise InvalidOperator("Unknown passes `%s`" % str(passes))
for i in passes:
if i not in cls._known_passes:
if i in cls._known_passes_disabled:
warning("Got explicit pass `%s`, but it's unsupported on an "
"Operator of type `%s`" % (i, str(cls)))
else:
raise InvalidOperator("Unknown pass `%s`" % i)

return super()._build(expressions, **kwargs)

return super(CustomOperator, cls)._build(expressions, **kwargs)
@classmethod
@timed_pass(name='specializing.DSL')
def _specialize_dsl(cls, expressions, **kwargs):
passes = as_tuple(kwargs['mode'])

# Fetch passes to be called
passes_mapper = cls._make_dsl_passes_mapper(**kwargs)

# Call passes
for i in passes:
try:
expressions = passes_mapper[i](expressions)
except KeyError:
pass

return expressions

@classmethod
@timed_pass(name='specializing.Expressions')
Expand Down Expand Up @@ -371,8 +419,6 @@ def _specialize_iet(cls, graph, **kwargs):
passes_mapper['openmp'](graph)

# Symbol definitions
data_manager = DataManager(sregistry)
data_manager.place_definitions(graph)
data_manager.place_casts(graph)
DataManager(sregistry).process(graph)

return graph
Loading

0 comments on commit 24f2962

Please sign in to comment.