Skip to content

Commit

Permalink
Merge pull request #517 from OP2/remove-flatten
Browse files Browse the repository at this point in the history
Remove the flatten option of op2.Arg
  • Loading branch information
miklos1 authored Dec 13, 2016
2 parents 6ff1148 + 404f587 commit f09cefb
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 261 deletions.
31 changes: 8 additions & 23 deletions pyop2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class Arg(object):
Instead, use the call syntax on the :class:`DataCarrier`.
"""

def __init__(self, data=None, map=None, idx=None, access=None, flatten=False):
def __init__(self, data=None, map=None, idx=None, access=None):
"""
:param data: A data-carrying object, either :class:`Dat` or class:`Mat`
:param map: A :class:`Map` to access this :class:`Arg` or the default
Expand All @@ -269,9 +269,6 @@ def __init__(self, data=None, map=None, idx=None, access=None, flatten=False):
given component of the mapping or the default to use all
components of the mapping.
:param access: An access descriptor of type :class:`Access`
:param flatten: Treat the data dimensions of this :class:`Arg` as flat
s.t. the kernel is passed a flat vector of length
``map.arity * data.dataset.cdim``.
Checks that:
Expand All @@ -284,7 +281,6 @@ def __init__(self, data=None, map=None, idx=None, access=None, flatten=False):
self._map = map
self._idx = idx
self._access = access
self._flatten = flatten
self._in_flight = False # some kind of comms in flight for this arg

# Check arguments for consistency
Expand All @@ -300,18 +296,10 @@ def __init__(self, data=None, map=None, idx=None, access=None, flatten=False):
"To set of %s doesn't match the set of %s." % (map, data))

# Determine the iteration space extents, if any
if self._is_mat and flatten:
rdims = tuple(d.cdim for d in data.sparsity.dsets[0])
cdims = tuple(d.cdim for d in data.sparsity.dsets[1])
self._block_shape = tuple(tuple((mr.arity * dr, mc.arity * dc)
for mc, dc in zip(map[1], cdims))
for mr, dr in zip(map[0], rdims))
elif self._is_mat:
if self._is_mat:
self._block_shape = tuple(tuple((mr.arity, mc.arity)
for mc in map[1])
for mr in map[0])
elif self._uses_itspace and flatten:
self._block_shape = tuple(((m.arity * d.cdim,),) for m, d in zip(map, data))
elif self._uses_itspace:
self._block_shape = tuple(((m.arity,),) for m in map)
else:
Expand Down Expand Up @@ -1814,13 +1802,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None,
self._recv_buf = {}

@validate_in(('access', _modes, ModeValueError))
def __call__(self, access, path=None, flatten=False):
def __call__(self, access, path=None):
if isinstance(path, _MapArg):
return _make_object('Arg', data=self, map=path.map, idx=path.idx,
access=access, flatten=flatten)
access=access)
if configuration["type_check"] and path and path.toset != self.dataset.set:
raise MapValueError("To Set of Map does not match Set of Dat.")
return _make_object('Arg', data=self, map=path, access=access, flatten=flatten)
return _make_object('Arg', data=self, map=path, access=access)

def __getitem__(self, idx):
"""Return self if ``idx`` is 0, raise an error otherwise."""
Expand Down Expand Up @@ -2634,10 +2622,7 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
Global._globalcount += 1

@validate_in(('access', _modes, ModeValueError))
def __call__(self, access, path=None, flatten=False):
"""Note that the flatten argument is only passed in order to
have the same interface as :class:`Dat`. Its value is
ignored."""
def __call__(self, access, path=None):
return _make_object('Arg', data=self, access=access)

def __iter__(self):
Expand Down Expand Up @@ -3636,14 +3621,14 @@ def __init__(self, sparsity, dtype=None, name=None):
Mat._globalcount += 1

@validate_in(('access', _modes, ModeValueError))
def __call__(self, access, path, flatten=False):
def __call__(self, access, path):
path = as_tuple(path, _MapArg, 2)
path_maps = tuple(arg and arg.map for arg in path)
path_idxs = tuple(arg and arg.idx for arg in path)
if configuration["type_check"] and tuple(path_maps) not in self.sparsity:
raise MapValueError("Path maps not in sparsity maps")
return _make_object('Arg', data=self, map=path_maps, access=access,
idx=path_idxs, flatten=flatten)
idx=path_idxs)

def assemble(self):
"""Finalise this :class:`Mat` ready for use.
Expand Down
5 changes: 2 additions & 3 deletions pyop2/fusion/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, arg, gather=None, c_index=False):
:arg c_index: if True, will provide the kernel with the iteration index of this
Arg's set. Otherwise, code generation is unaffected.
"""
super(FusionArg, self).__init__(arg.data, arg.map, arg.idx, arg.access, arg._flatten)
super(FusionArg, self).__init__(arg.data, arg.map, arg.idx, arg.access)
self.gather = gather or arg.gather
self.c_index = c_index or arg.c_index

Expand All @@ -83,11 +83,10 @@ def c_map_name(self, i, j, fromvector=False):
def c_vec_dec(self, is_facet=False):
if self.gather == 'onlymap':
facet_mult = 2 if is_facet else 1
cdim = self.data.cdim if self._flatten else 1
return "%(type)s %(vec_name)s[%(arity)s];\n" % \
{'type': self.ctype,
'vec_name': self.c_vec_name(),
'arity': self.map.arity * cdim * facet_mult}
'arity': self.map.arity * facet_mult}
else:
return super(FusionArg, self).c_vec_dec(is_facet)

Expand Down
8 changes: 4 additions & 4 deletions pyop2/petsc_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,23 +778,23 @@ def _init_global_block(self):
mat = _DatMat(self.sparsity)
self.handle = mat

def __call__(self, access, path, flatten=False):
def __call__(self, access, path):
"""Override the parent __call__ method in order to special-case global
blocks in matrices."""
try:
# Usual case
return super(Mat, self).__call__(access, path, flatten)
return super(Mat, self).__call__(access, path)
except TypeError:
# One of the path entries was not an Arg.
if path == (None, None):
return _make_object('Arg',
data=self.handle.getPythonContext().global_,
access=access, flatten=flatten)
access=access)
elif None in path:
thispath = path[0] or path[1]
return _make_object('Arg', data=self.handle.getPythonContext().dat,
map=thispath.map, idx=thispath.idx,
access=access, flatten=flatten)
access=access)
else:
raise

Expand Down
2 changes: 0 additions & 2 deletions pyop2/pyparloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def _compute(self, part, *arglist):
else:
arg.data._data[arg.map.values_with_halo[idx, arg.idx:arg.idx+1]] = tmp[:]
elif arg._is_mat:
if arg._flatten:
raise NotImplementedError # Need to sort out the permutation.
if arg.access is base.INC:
arg.data.addto_values(arg.map[0].values_with_halo[idx],
arg.map[1].values_with_halo[idx],
Expand Down
Loading

0 comments on commit f09cefb

Please sign in to comment.