Skip to content

Commit

Permalink
#9059: adjust matmul parameters for rounding up in some scenarios (#9105
Browse files Browse the repository at this point in the history
)

* #9059: adjust matmul parameters for rounding up in some scenarios

* #9059: Adjust some matmul parameters to use div_up
  • Loading branch information
bbradelTT authored Jun 5, 2024
1 parent 6bd0fbf commit 4276e5c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <algorithm>
#include <cmath>
#include <optional>
#include <numeric>

#include "third_party/magic_enum/magic_enum.hpp"
#include "tt_dnn/op_library/run_operation.hpp"
Expand Down Expand Up @@ -368,7 +369,7 @@ tt::operations::primary::MatmulProgramConfig get_matmul_program_config(
mcast_in0 = true;
per_core_M = M;
per_core_N = div_up(N, input_tensor_a.shard_spec().value().grid.num_cores());
in0_block_w = shard_shape[1] / TILE_WIDTH;
in0_block_w = std::gcd(shard_shape[1] / TILE_WIDTH, K);
} else if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
mcast_in0 = false;
per_core_M = shard_shape[0] / TILE_HEIGHT;
Expand Down Expand Up @@ -413,14 +414,16 @@ tt::operations::primary::MatmulProgramConfig get_matmul_program_config(
auto shard_shape = input_tensor_a.shard_spec().value().shape;
uint32_t virtual_x = transpose_mcast ? grid_size.y : grid_size.x;
uint32_t virtual_y = transpose_mcast ? grid_size.x : grid_size.y;
bool cores_along_x_match_grid_size = virtual_x == (K / (shard_shape[1] / TILE_WIDTH));
bool cores_along_y_match_grid_size = virtual_y == (M / (shard_shape[0] / TILE_HEIGHT));
TT_FATAL(
virtual_y == (M / (shard_shape[0] / TILE_HEIGHT)), "Num cores along y must match provided grid size!");
cores_along_y_match_grid_size || virtual_y == div_up(M, (shard_shape[0] / TILE_HEIGHT)), "Num cores along y must match provided grid size!");
TT_FATAL(
virtual_x == (K / (shard_shape[1] / TILE_WIDTH)), "Num cores along x must match provided grid size!");
cores_along_x_match_grid_size || virtual_x == div_up(K, (shard_shape[1] / TILE_WIDTH)), "Num cores along x must match provided grid size!");

uint32_t per_core_M = M / virtual_y;
uint32_t per_core_N = N / virtual_x;
uint32_t in0_block_w = shard_shape[1] / TILE_WIDTH;
uint32_t per_core_M = (M < virtual_y) ? 1 : M / virtual_y;
uint32_t per_core_N = (N < virtual_x) ? 1 : N / virtual_x;
uint32_t in0_block_w = cores_along_x_match_grid_size ? shard_shape[1] / TILE_WIDTH : 1;

auto subblock_hw = get_matmul_subblock_params(
per_core_M, per_core_N, false, per_core_N_equals_subblock_w_constraint, fp32_dest_acc_en);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void kernel_main() {
in0_mcast_sender_semaphore_valid_addr_ptr[0] =
VALID; // Load const 1 to be used as semaphore valid value sent from sender to receivers

constexpr uint32_t num_remote_senders = num_blocks / num_blocks_per_shard;
constexpr uint32_t num_remote_senders = (num_blocks + num_blocks_per_shard - 1) / num_blocks_per_shard;
uint64_t remote_sender_noc_addrs[num_remote_senders];
if constexpr (transpose_mcast) {
uint32_t x = 0, y = 0;
Expand Down

0 comments on commit 4276e5c

Please sign in to comment.