diff --git a/tessellate_ipu/lax/tile_lax_array.py b/tessellate_ipu/lax/tile_lax_array.py index 01a7987..b0cc561 100644 --- a/tessellate_ipu/lax/tile_lax_array.py +++ b/tessellate_ipu/lax/tile_lax_array.py @@ -216,8 +216,10 @@ def tile_sharded_identity(dtype: DTypeLike, tiles: Tuple[int, ...]) -> TileShard # Build zero matrix + update diagonal entries. arr = tile_fill((N,), 0, dtype=dtype, tiles=tiles) # Requiring constants for indices + updates. Something more efficient?s - indices = tile_constant_sharded(np.arange(0, N, dtype=np.uint32).reshape(N, 1, 1), tiles=tiles) - updates = tile_constant_replicated(np.array([1], dtype=dtype), tiles=tiles) + with jax.named_scope("indices"): + indices = tile_constant_sharded(np.arange(0, N, dtype=np.uint32).reshape(N, 1, 1), tiles=tiles) + with jax.named_scope("updates"): + updates = tile_constant_replicated(np.array([1], dtype=dtype), tiles=tiles) # Not the simplest way ever of updating diagonal terms! scatter_dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,) diff --git a/tessellate_ipu/lib/tessellate_ipu_ops_jax.cpp b/tessellate_ipu/lib/tessellate_ipu_ops_jax.cpp index 96e1ca2..77c0611 100644 --- a/tessellate_ipu/lib/tessellate_ipu_ops_jax.cpp +++ b/tessellate_ipu/lib/tessellate_ipu_ops_jax.cpp @@ -42,7 +42,8 @@ class TilePutShardedPrimitive : public TilePutBase { poplar::Graph& graph, const std::vector& inputs, std::vector& outputs, const std::string& attributes, const std::string& debug_prefix) { - const auto debug_context = poplar::DebugContext(debug_prefix); + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_put_sharded")); // Passing the tile array as attributes. const auto tile_array = extractTileArray(attributes); return lowerTilePutShardedToPoplar(graph, inputs, outputs, tile_array, @@ -54,12 +55,15 @@ class TilePutShardedPrimitive : public TilePutBase { poplar::Type type, const std::string& attributes, const std::string& debug_prefix) { + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_put_sharded")); const auto tile_array = extractTileArray(attributes); const auto item_shape = poplar::ArrayRef(shape.data() + 1, shape.size() - 1); // If not allocated => already pre-allocate input with proper tile mapping. // TODO: fix (unnecessary) on-tile-copy when doing that? - return createShardedVariable(graph, type, item_shape, tile_array); + return createShardedVariable(graph, type, item_shape, tile_array, + debug_context); } }; @@ -83,7 +87,8 @@ class TilePutReplicatedPrimitive : public TilePutBase { poplar::Graph& graph, const std::vector& inputs, std::vector& outputs, const std::string& attributes, const std::string& debug_prefix) { - const auto debug_context = poplar::DebugContext(debug_prefix); + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_put_replicated")); const auto tile_array = extractTileArray(attributes); return lowerTilePutReplicatedToPoplar(graph, inputs, outputs, tile_array, debug_context); @@ -109,7 +114,8 @@ class TileGatherPrimitive : public jax::ipu::PrimitiveInterface { poplar::Graph& graph, const std::vector& inputs, std::vector& outputs, const std::string& attributes, const std::string& debug_prefix) { - const auto debug_context = poplar::DebugContext(debug_prefix); + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_gather")); // Tile gather parameters. const auto params = ipu::from_json_str(attributes); return lowerTileGatherToPoplar(graph, inputs, outputs, params, @@ -138,7 +144,8 @@ class TileDataBarrierPrimitive : public jax::ipu::PrimitiveInterface { poplar::Graph& graph, const std::vector& inputs, std::vector& outputs, const std::string& attributes, const std::string& debug_prefix) { - const auto debug_context = poplar::DebugContext(debug_prefix); + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_data_barrier")); // Tile barrier parameters (with tile sharding). const auto params = ipu::from_json_str(attributes); return lowerTileDataBarrierToPoplar(graph, inputs, outputs, params, @@ -165,7 +172,8 @@ class TileConstantReplicatedPrimitive : public jax::ipu::PrimitiveInterface { poplar::Graph& graph, const std::vector& inputs, std::vector& outputs, const std::string& attributes, const std::string& debug_prefix) { - const auto debug_context = poplar::DebugContext(debug_prefix); + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_constant_replicated")); const auto params = ipu::from_json_str(attributes); return lowerTileConstantReplicatedToPoplar(graph, inputs, outputs, params, debug_context); @@ -191,7 +199,8 @@ class TileConstantShardedPrimitive : public jax::ipu::PrimitiveInterface { poplar::Graph& graph, const std::vector& inputs, std::vector& outputs, const std::string& attributes, const std::string& debug_prefix) { - const auto debug_context = poplar::DebugContext(debug_prefix); + const auto debug_context = poplar::DebugContext( + makeTileOpDebugPrefix(debug_prefix, "tile_constant_sharded")); const auto params = ipu::from_json_str(attributes); return lowerTileConstantShardedToPoplar(graph, inputs, outputs, params, debug_context); diff --git a/tessellate_ipu/lib/tile_array_ops.cpp b/tessellate_ipu/lib/tile_array_ops.cpp index 826bf62..34683eb 100644 --- a/tessellate_ipu/lib/tile_array_ops.cpp +++ b/tessellate_ipu/lib/tile_array_ops.cpp @@ -12,6 +12,25 @@ namespace ipu { +std::string makeTileOpDebugPrefix(const std::string& raw_debug_prefix, + const std::string& basename) { + const auto format_debug_prefix = [&raw_debug_prefix, + &basename](std::size_t idx) { + const std::string debug_prefix = + fmt::format("{}{}", raw_debug_prefix.substr(0, idx), basename); + return debug_prefix; + }; + std::string::size_type idx; + // A bit of ugly string pattern matching to remove the metadata, but keep + // the existing namespace. + idx = raw_debug_prefix.rfind(basename + "["); + if (idx != std::string::npos) { + return format_debug_prefix(idx); + } + // Not found => keep the same debug prefix. + return raw_debug_prefix; +} + poplar::Tensor tileBarrierReinterpretTensor(const poplar::Tensor& t, bool is_half_accurate) { // 8 bits data types. @@ -69,8 +88,8 @@ poplar::program::Program lowerTilePutShardedToPoplar( // Create output tensor, with proper tile mapping. // TODO: link to Slack discussion on VarRegion contiguity. - auto output = createShardedVariable(graph, input.elementType(), - input[0].shape(), tile_array); + auto output = createShardedVariable( + graph, input.elementType(), input[0].shape(), tile_array, debug_context); // Copy data tensor into the output. auto prog = poplar::program::Copy(input, output); outputs.push_back(output); @@ -91,7 +110,7 @@ poplar::program::Program lowerTilePutReplicatedToPoplar( // Create output tensor, with proper tile mapping. auto input_broadcasted = input.expand({0}).broadcast(tile_array.size(), 0); auto output = createShardedVariable(graph, input.elementType(), input.shape(), - tile_array); + tile_array, debug_context); // Copy data tensor into the output. auto prog = poplar::program::Copy(input_broadcasted, output, false); outputs.push_back(output); diff --git a/tessellate_ipu/lib/tile_array_ops.hpp b/tessellate_ipu/lib/tile_array_ops.hpp index 46f1c22..260a768 100644 --- a/tessellate_ipu/lib/tile_array_ops.hpp +++ b/tessellate_ipu/lib/tile_array_ops.hpp @@ -6,6 +6,14 @@ #include "base_types.hpp" namespace ipu { + +/** + * @brief Make a (readable/clean) tile op debug prefix. + * Help having a more readable naming in PopVision profile. + */ +std::string makeTileOpDebugPrefix(const std::string& raw_debug_prefix, + const std::string& basename); + /** * @brief IPU tile gather op parameters. */ diff --git a/tessellate_ipu/lib/tile_array_utils.cpp b/tessellate_ipu/lib/tile_array_utils.cpp index 680b9cb..32a435e 100644 --- a/tessellate_ipu/lib/tile_array_utils.cpp +++ b/tessellate_ipu/lib/tile_array_utils.cpp @@ -55,18 +55,23 @@ poplar::Tensor createReplicatedConstantTensor( poplar::ArrayRef tiles, const poplar::DebugContext& debug_context) { // TODO: check raw_values, dtype and shape are consistent. - // TODO: get it working with FP16! - // Expanded shape (used in concat). - const auto expand_shape = shapePrependAxis(1, shape); - // Create Poplar constant per tile. Should I create a single one? - std::vector tensor_list; + // Replicating raw values on the host. Should never be >1GB (worse case!). + // Allows creating a single constant tensor, which is better for Popvision + // profile. + std::vector replicated_raw_values(raw_values.size() * tiles.size()); + auto it = replicated_raw_values.begin(); for (size_t idx = 0; idx < tiles.size(); ++idx) { - auto t = createConstantTensor(graph, ipu_type, expand_shape, raw_values, - debug_context); - graph.setTileMapping(t, tiles[idx]); - tensor_list.push_back(t); + it = std::copy(raw_values.begin(), raw_values.end(), it); } - return poplar::concat(tensor_list, 0); + // Build the full constant tensor at once. + // TODO: make sure it works with FP16? + const auto replicated_shape = shapePrependAxis(tiles.size(), shape); + auto t = createConstantTensor(graph, ipu_type, replicated_shape, + replicated_raw_values, debug_context); + for (size_t idx = 0; idx < tiles.size(); ++idx) { + graph.setTileMapping(t[idx], tiles[idx]); + } + return t; } poplar::Tensor createShardedConstantTensor( @@ -74,25 +79,14 @@ poplar::Tensor createShardedConstantTensor( poplar::ArrayRef shape, poplar::ArrayRef raw_values, poplar::ArrayRef tiles, const poplar::DebugContext& debug_context) { - // TODO: check consistent raw values size. - // Expanded shape on every tile. - const auto expand_shape = - shapePrependAxis(1, arraySlice(shape, 1, shape.size())); - const auto dtype_size = ipuTypeSize(ipu_type); - const std::size_t bytes_size = sizeFromShape(expand_shape) * dtype_size; - auto poplar_type = toPoplar(ipu_type); - // Create Poplar constant per tile. Should I create a single one? - std::vector tensor_list; + // TODO: check raw_values, dtype and shape are consistent. + // Creating a single tensor, to avoid Popvision profile bloating. + auto t = + createConstantTensor(graph, ipu_type, shape, raw_values, debug_context); for (size_t idx = 0; idx < tiles.size(); ++idx) { - // Slicing the raw data corresponding to the tile. - auto raw_values_tile = - arraySlice(raw_values, idx * bytes_size, (idx + 1) * bytes_size); - auto t = createConstantTensor(graph, ipu_type, expand_shape, - raw_values_tile, debug_context); - graph.setTileMapping(t, tiles[idx]); - tensor_list.push_back(t); + graph.setTileMapping(t[idx], tiles[idx]); } - return poplar::concat(tensor_list, 0); + return t; } } // namespace ipu diff --git a/tessellate_ipu/lib/tile_map_ops.cpp b/tessellate_ipu/lib/tile_map_ops.cpp index baf9f98..b13b581 100644 --- a/tessellate_ipu/lib/tile_map_ops.cpp +++ b/tessellate_ipu/lib/tile_map_ops.cpp @@ -14,8 +14,6 @@ std::string makeTileMapCallDebugPrefix(const std::string& raw_debug_prefix, const std::string& primitive_name) { const auto format_debug_prefix = [&raw_debug_prefix, &primitive_name](std::size_t idx) { - // const std::string debug_prefix = raw_debug_prefix.substr(0, idx) + - // "tile_map"; const std::string debug_prefix = fmt::format("{}{}[{}]", raw_debug_prefix.substr(0, idx), "tile_map", primitive_name);