From b9c031aefb7ba73f16829ff4248ea598f67366e5 Mon Sep 17 00:00:00 2001
From: ksagiyam <k.sagiyama@imperial.ac.uk>
Date: Fri, 14 Jun 2024 15:12:15 +0100
Subject: [PATCH 1/3] codegen: refactor Map.indexed()

---
 pyop2/codegen/builder.py | 24 ++++++++++++------------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py
index 89cf31fcf..9de773a9f 100644
--- a/pyop2/codegen/builder.py
+++ b/pyop2/codegen/builder.py
@@ -75,7 +75,11 @@ def shape(self):
     def dtype(self):
         return self.values.dtype
 
-    def indexed(self, multiindex, layer=None, permute=lambda x: x):
+    @property
+    def _permute(self):
+        return lambda x: x
+
+    def indexed(self, multiindex, layer=None):
         n, i, f = multiindex
         if layer is not None and self.offset is not None:
             # For extruded mesh, prefetch the indirections for each map, so that they don't
@@ -84,7 +88,7 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
             base_key = None
             if base_key not in self.prefetch:
                 j = Index()
-                base = Indexed(self.values, (n, permute(j)))
+                base = Indexed(self.values, (n, self._permute(j)))
                 self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))
 
             base = self.prefetch[base_key]
@@ -122,17 +126,17 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
             return Indexed(self.prefetch[key], (f, i)), (f, i)
         else:
             assert f.extent == 1 or f.extent is None
-            base = Indexed(self.values, (n, permute(i)))
+            base = Indexed(self.values, (n, self._permute(i)))
             return base, (f, i)
 
-    def indexed_vector(self, n, shape, layer=None, permute=lambda x: x):
+    def indexed_vector(self, n, shape, layer=None):
         shape = self.shape[1:] + shape
         if self.interior_horizontal:
             shape = (2, ) + shape
         else:
             shape = (1, ) + shape
         f, i, j = (Index(e) for e in shape)
-        base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute)
+        base, (f, i) = self.indexed((n, i, f), layer=layer)
         init = Sum(Product(base, Literal(numpy.int32(j.extent))), j)
         pack = Materialise(PackInst(), init, MultiIndex(f, i, j))
         multiindex = tuple(Index(e) for e in pack.shape)
@@ -168,13 +172,9 @@ def __init__(self, map_, permutation):
         self.offset_quotient = map_.offset_quotient
         self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")
 
-    def indexed(self, multiindex, layer=None):
-        permute = lambda x: Indexed(self.permutation, (x,))
-        return super().indexed(multiindex, layer=layer, permute=permute)
-
-    def indexed_vector(self, n, shape, layer=None):
-        permute = lambda x: Indexed(self.permutation, (x,))
-        return super().indexed_vector(n, shape, layer=layer, permute=permute)
+    @property
+    def _permute(self):
+        return lambda x: Indexed(self.permutation, (x,))
 
 
 class CMap(Map):

From 941007b4be080afc8c06995ad074969deb768613 Mon Sep 17 00:00:00 2001
From: ksagiyam <46749170+ksagiyam@users.noreply.github.com>
Date: Fri, 14 Jun 2024 20:59:59 +0100
Subject: [PATCH 2/3] Update pyop2/codegen/builder.py

Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
---
 pyop2/codegen/builder.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py
index 9de773a9f..382325ba3 100644
--- a/pyop2/codegen/builder.py
+++ b/pyop2/codegen/builder.py
@@ -75,9 +75,8 @@ def shape(self):
     def dtype(self):
         return self.values.dtype
 
-    @property
-    def _permute(self):
-        return lambda x: x
+    def _permute(self, x):
+        return x
 
     def indexed(self, multiindex, layer=None):
         n, i, f = multiindex

From e55af6ea9875790b3b0aa4d66ff01d02c83669b8 Mon Sep 17 00:00:00 2001
From: ksagiyam <46749170+ksagiyam@users.noreply.github.com>
Date: Fri, 14 Jun 2024 21:00:33 +0100
Subject: [PATCH 3/3] Update pyop2/codegen/builder.py

Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
---
 pyop2/codegen/builder.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py
index 382325ba3..505dc5d2b 100644
--- a/pyop2/codegen/builder.py
+++ b/pyop2/codegen/builder.py
@@ -171,9 +171,8 @@ def __init__(self, map_, permutation):
         self.offset_quotient = map_.offset_quotient
         self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")
 
-    @property
-    def _permute(self):
-        return lambda x: Indexed(self.permutation, (x,))
+    def _permute(self, x):
+        return Indexed(self.permutation, (x,))
 
 
 class CMap(Map):