Skip to content

Commit

Permalink
Tile sharding handles size > 1472
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 6, 2023
1 parent a695c5e commit a052fb8
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tessellate_ipu/linalg/tile_linalg_hessenberg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import math
import os
from typing import Any, Tuple

Expand Down Expand Up @@ -79,10 +80,18 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr
"""
assert x.shape[0] == x.shape[1]
N = x.shape[0]
n_tiles = 1472
# Sharding R and Q
Q_tiles = tuple(range(0, N))
# R_tiles = tuple(range(N, 2 * N))
R_tiles = tuple(range(0, N))

n_per_tile = math.ceil(N / float(n_tiles))
full_tiles = N % n_tiles
if full_tiles == 0:
full_tiles = n_tiles

Q_tiles = [i for i in range(full_tiles) for _ in range(n_per_tile)] + [
i for i in range(full_tiles, n_tiles) for _ in range(n_per_tile - 1)
]
R_tiles = Q_tiles

# TODO: on-device construction of identity
Q = tile_put_sharded(np.identity(N, dtype=x.dtype), Q_tiles)
Expand Down

0 comments on commit a052fb8

Please sign in to comment.