From b9c031aefb7ba73f16829ff4248ea598f67366e5 Mon Sep 17 00:00:00 2001 From: ksagiyam <k.sagiyama@imperial.ac.uk> Date: Fri, 14 Jun 2024 15:12:15 +0100 Subject: [PATCH 1/3] codegen: refactor Map.indexed() --- pyop2/codegen/builder.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index 89cf31fcf..9de773a9f 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -75,7 +75,11 @@ def shape(self): def dtype(self): return self.values.dtype - def indexed(self, multiindex, layer=None, permute=lambda x: x): + @property + def _permute(self): + return lambda x: x + + def indexed(self, multiindex, layer=None): n, i, f = multiindex if layer is not None and self.offset is not None: # For extruded mesh, prefetch the indirections for each map, so that they don't @@ -84,7 +88,7 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x): base_key = None if base_key not in self.prefetch: j = Index() - base = Indexed(self.values, (n, permute(j))) + base = Indexed(self.values, (n, self._permute(j))) self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j)) base = self.prefetch[base_key] @@ -122,17 +126,17 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x): return Indexed(self.prefetch[key], (f, i)), (f, i) else: assert f.extent == 1 or f.extent is None - base = Indexed(self.values, (n, permute(i))) + base = Indexed(self.values, (n, self._permute(i))) return base, (f, i) - def indexed_vector(self, n, shape, layer=None, permute=lambda x: x): + def indexed_vector(self, n, shape, layer=None): shape = self.shape[1:] + shape if self.interior_horizontal: shape = (2, ) + shape else: shape = (1, ) + shape f, i, j = (Index(e) for e in shape) - base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute) + base, (f, i) = self.indexed((n, i, f), layer=layer) init = Sum(Product(base, Literal(numpy.int32(j.extent))), j) pack = Materialise(PackInst(), init, MultiIndex(f, i, j)) multiindex = tuple(Index(e) for e in pack.shape) @@ -168,13 +172,9 @@ def __init__(self, map_, permutation): self.offset_quotient = map_.offset_quotient self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}") - def indexed(self, multiindex, layer=None): - permute = lambda x: Indexed(self.permutation, (x,)) - return super().indexed(multiindex, layer=layer, permute=permute) - - def indexed_vector(self, n, shape, layer=None): - permute = lambda x: Indexed(self.permutation, (x,)) - return super().indexed_vector(n, shape, layer=layer, permute=permute) + @property + def _permute(self): + return lambda x: Indexed(self.permutation, (x,)) class CMap(Map): From 941007b4be080afc8c06995ad074969deb768613 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Fri, 14 Jun 2024 20:59:59 +0100 Subject: [PATCH 2/3] Update pyop2/codegen/builder.py Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk> --- pyop2/codegen/builder.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index 9de773a9f..382325ba3 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -75,9 +75,8 @@ def shape(self): def dtype(self): return self.values.dtype - @property - def _permute(self): - return lambda x: x + def _permute(self, x): + return x def indexed(self, multiindex, layer=None): n, i, f = multiindex From e55af6ea9875790b3b0aa4d66ff01d02c83669b8 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Fri, 14 Jun 2024 21:00:33 +0100 Subject: [PATCH 3/3] Update pyop2/codegen/builder.py Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk> --- pyop2/codegen/builder.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index 382325ba3..505dc5d2b 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -171,9 +171,8 @@ def __init__(self, map_, permutation): self.offset_quotient = map_.offset_quotient self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}") - @property - def _permute(self): - return lambda x: Indexed(self.permutation, (x,)) + def _permute(self, x): + return Indexed(self.permutation, (x,)) class CMap(Map):