Skip to content

Commit

Permalink
Fix gather and scatter bug in IPU hardware
Browse files Browse the repository at this point in the history
IPU model is not fully replicating IPU hardware in the case of
gather and scatter vertices, where the `splitSingleRegion` seems to
be ignored on the IPU model. Setting back `splitSingleRegion=False`
solves the issue.

One still needs to investigate which configuration of these vertices is
the most optimal.
  • Loading branch information
balancap committed Sep 25, 2023
1 parent 49dea55 commit ec37ad7
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tessellate_ipu/lax/tile_lax_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def ipu_gather_primitive_translation(
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)),
regionSize=1, # TODO: understand?
splitSingleRegion=True, # Split regions between threads?
splitSingleRegion=False, # Split regions between threads? TODO: understand!
)
# TODO: should we use `split offsets` between threads?
# For now: need to do it manually at the Python `tile_map` level.
Expand Down
2 changes: 1 addition & 1 deletion tessellate_ipu/lax/tile_lax_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def ipu_scatter_primitive_translation(
maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)),
regionSize=1, # TODO: understand?
indicesAreSorted=False,
splitSingleRegion=True,
splitSingleRegion=False, # Split regions between threads? TODO: understand!
)
# For now: need to do it manually at the Python `tile_map` level.
ipu_prim_info = IpuTileMapEquation(
Expand Down
2 changes: 2 additions & 0 deletions tests/lax/test_tile_lax_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import jax
import numpy as np
import numpy.testing as npt
import pytest
from absl.testing import parameterized

from tessellate_ipu import tile_map, tile_put_replicated
from tessellate_ipu.lax import gather_p


@pytest.mark.ipu_hardware
class IpuTilePrimitivesLaxGather(chex.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
Expand Down

0 comments on commit ec37ad7

Please sign in to comment.