Skip to content

Commit

Permalink
tile mapping for N<736 and clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 9, 2023
1 parent a052fb8 commit f010c8b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 128 deletions.
30 changes: 0 additions & 30 deletions examples/hessenberg_example.py

This file was deleted.

54 changes: 3 additions & 51 deletions tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,9 @@ using namespace poplar;
*/
static constexpr size_t MIN_ALIGN = 8;

class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dIndexedVertex
: public MultiVertex {
public:
using T = float;
using T2 = float2;
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;


Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
x; // (N,) x vector
Input<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
y; // (N,) y vector
Input<Vector<int, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
start_idx;

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>>
worker_offsets; // (7,) number threads + 1.
Output<Vector<T, poplar::VectorLayout::ONE_PTR, MIN_ALIGN>> partials; // float result.

bool compute(unsigned wid) {
// Always assuming size % 2 == 0
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];
const IndexType wsize = wend - wstart;

const IndexType index = start_idx[0];

T2* ptr_tmp_partials_f2 = reinterpret_cast<T2*>(partials.data()) + wid;
// Nothing to do in this worker thread.
if (wstart == wend) {
ipu::store_postinc(&ptr_tmp_partials_f2, T2{0, 0}, 1);
return true;
}
// X and Y input pointers.
const T2* ptr_inxdata_f2 = reinterpret_cast<const T2*>(x.data()) + wstart;
const T2* ptr_inydata_f2 = reinterpret_cast<const T2*>(y.data()) + wstart;
T2 partial = T2{0, 0};

for (IndexType idx = 0; idx != wsize; ++idx) {
// TODO: use ld2x64pace + tapack instructions?
const T2 xin = ipu::load_postinc(&ptr_inxdata_f2, 1);
const T2 yin = ipu::load_postinc(&ptr_inydata_f2, 1);
// popc seems to recognize this pattern and optimize it.
// Using directly ipu::fma intrinsics leads to poor performance!?
partial += xin * yin;
}
ipu::store_postinc(&ptr_tmp_partials_f2, partial, 1);
return true;
}
};
/*
The code here is just a minor modification of tile_qr_vertex.cpp
*/

/**
* @brief Vertex computing the correction vector in the Hessenberg algorithm.
Expand Down
75 changes: 29 additions & 46 deletions tessellate_ipu/linalg/tile_linalg_hessenberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Tuple

import jax.lax
import numpy as np # used for np.float32, shouldn't we use jax types?
import numpy as np
from jax.core import ShapedArray

from tessellate_ipu import (
Expand All @@ -21,26 +21,13 @@

Array = Any

# The code here is heavily based on tile_linalg_qr.py


def get_hessenberg_vertex_gp_filename() -> str:
return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_hessenberg_vertex.cpp")


dot_product1d_indexed_p = create_ipu_tile_primitive(
"dot_product1d_indexed",
"DotProduct1dIndexedVertex",
inputs=["x", "y", "start_idx"],
outputs={"partials": ShapedArray((12,), dtype=np.float32)},
constants={
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
inavals[0].size, vector_size=2, num_workers=6, wdtype=np.uint16
)
},
# tmp_space=ShapedArray((12,), dtype=np.float32),
gp_filename=get_hessenberg_vertex_gp_filename(),
perf_estimate=1000,
)

"""Vertex computing Hessenberg correction vector.
"""
hessenberg_correction_vector_p = create_ipu_tile_primitive(
Expand Down Expand Up @@ -81,17 +68,21 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr
assert x.shape[0] == x.shape[1]
N = x.shape[0]
n_tiles = 1472
# Sharding R and Q

n_per_tile = math.ceil(N / float(n_tiles))
full_tiles = N % n_tiles
if full_tiles == 0:
full_tiles = n_tiles

Q_tiles = [i for i in range(full_tiles) for _ in range(n_per_tile)] + [
i for i in range(full_tiles, n_tiles) for _ in range(n_per_tile - 1)
]
R_tiles = Q_tiles
# Sharding R and Q
if N <= 736:
Q_tiles = list(range(N))
R_tiles = list(range(N, 2 * N))
else:
n_per_tile = math.ceil(N / float(n_tiles))
full_tiles = N % n_tiles
if full_tiles == 0:
full_tiles = n_tiles

Q_tiles = [i for i in range(full_tiles) for _ in range(n_per_tile)] + [
i for i in range(full_tiles, n_tiles) for _ in range(n_per_tile - 1)
]
R_tiles = Q_tiles

# TODO: on-device construction of identity
Q = tile_put_sharded(np.identity(N, dtype=x.dtype), Q_tiles)
Expand All @@ -101,20 +92,20 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr
return Q, R, sdiag_full


# Heavily based on ipu_qr_iterations in tile_linalg_qr.py
# The body of the for-loop computes
# v = Householder(R[i]) # v is chosen to annihilate the elements below the first lower diagonal
# R = R - 2 * v.reshape(-1, 1) @ (v.reshape(1, -1) @ R)
# R = R - 2 * (R @ v.reshape(-1, 1)) @ v.reshape(1, -1) # Not present in QR algorithm
# Q = Q - 2 * (Q @ v.reshape(-1, 1)) @ v.reshape(1, -1)


def ipu_hessenberg_body(
i: int, carry: Tuple[TileShardedArray, TileShardedArray, TileShardedArray]
) -> Tuple[TileShardedArray, TileShardedArray, TileShardedArray]:

Q, R, sdiag_full = carry

# Extract the i-th col of R and the i-th element of sdiag_full
# Using the gather_p primitive avoids inefficient general-case processing

dim_numbers = jax.lax.GatherDimensionNumbers(offset_dims=tuple(), collapsed_slice_dims=(0,), start_index_map=(0,))

i_rep = tile_put_replicated(jax.numpy.array([[i]], dtype=np.uint32), R.tiles)
Expand All @@ -131,6 +122,8 @@ def ipu_hessenberg_body(
fill_value=None,
) # => TileShardedArray() (Num_tiles, 1)

# This determines also where the computation of v (Householder correction vector) takes place
# For now, the tile is picked arbitrarily. Are there better choices? R.tiles[0]?
Rcol_replicated = tile_put_replicated(Rcol.array, tiles=[736]) # type:ignore

sdiag = tile_map(
Expand All @@ -147,24 +140,17 @@ def ipu_hessenberg_body(

sdiag_rep = tile_put_replicated(sdiag.array, Rcol_replicated.tiles) # type:ignore

# Smart-indexing
# start_idx = (i // 2) * 2
start_idx = 0

start_idxQ = tile_put_replicated(start_idx, Q.tiles)
start_idxR = tile_put_replicated(start_idx, R.tiles)

# Alternative: we pass the whole RT and sdiag; then we extract the result from the i-th tile

# Correction vector. Computed
# Correction vector. Computed on the tile where Rcol is located
v, vrescale = tile_map(
hessenberg_correction_vector_p, Rcol_replicated, sdiag_rep, tile_put_replicated(i + 1, Rcol_replicated.tiles)
) # type:ignore
# v, vrescale = tile_map(
# hessenberg_correction_vector_p, RT, sdiag_full, tile_put_replicated(i + 1, RT.tiles)
# ) # type:ignore

# This compiles
# vi = tile_gather(v.array[i], [0], [0]

# Replicate to all Q and R tiles.
vQ = tile_put_replicated(v.array, Q.tiles) # 0
Expand All @@ -173,13 +159,7 @@ def ipu_hessenberg_body(
vrescaleQ = tile_put_replicated(vrescale.array, Q.tiles) # 0
vrescaleR = tile_put_replicated(vrescale.array, R.tiles) # 0

# Alternative using tile_gather
# vQ = tile_gather(v, [i]*len(Q.tiles), list(Q.tiles), copy=False) # 0
# vR = tile_gather(v, [i]*len(RT.tiles), RT.tiles) # 0
# # v normalization factor to pass to householder update.
# vrescaleQ = tile_gather(vrescale, [i]*len(Q.tiles), Q.tiles) # 0
# vrescaleR = tile_gather(vrescale, [i]*len(RT.tiles), RT.tiles) #

# Transpose R so that we can use hessenberg_householder_row_update_p() to compute R @ ...
RT = tile_put_sharded(R.array.T, R.tiles)

# w = R^T @ v
Expand All @@ -195,6 +175,9 @@ def ipu_hessenberg_body(
hessenberg_householder_row_update_p, RT, vR, w, vrescaleR, start_idxR # type:ignore
)

# We compute the Q updates.
# It is done here and is followed by tile_data_barrier() because this induces the Poplar
# to schedule it in parallel to the RT updates, when RT and Q are mapped on disjoint tiles.
# w = Q @ v
# w = tile_map(dot_product1d_indexed_p, vQ, Q, start_idxQ)
w = tile_map(dot_product1d_p, vQ, Q)
Expand All @@ -205,7 +188,7 @@ def ipu_hessenberg_body(
)
RT, Q = tile_data_barrier(RT, Q)

# Transpose the RT matrix so that we can do the right product
# Transpose the RT matrix so that we can use hessenberg_householder_row_update_p() to compute ... @ R
R = tile_put_sharded(RT.array.T, RT.tiles)

# w = R^T @ v
Expand Down
2 changes: 1 addition & 1 deletion tests/linalg/test_tile_linalg_hessenberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ def hessenberg_decomposition_fn(x, xsdiag):

start, end = np.asarray(start)[0], np.asarray(end)[0]
hessenberg_cycle_count = end[0] - start[0]
assert hessenberg_cycle_count <= 105000
assert hessenberg_cycle_count <= 150000

0 comments on commit f010c8b

Please sign in to comment.