Skip to content

Commit

Permalink
Make layernorm assert instead of silently giving garbage output when …
Browse files Browse the repository at this point in the history
…subblock_w too large (#14223)

* #14222: Add assert for subblock_w when fp32 not used as well

* #14223: Use TT_FATAL and handle dst_full_sync_en
  • Loading branch information
yieldthought authored Oct 28, 2024
1 parent 9b21ae3 commit 0911990
Showing 1 changed file with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded(
bool is_post_all_gather = distributed_norm_stage == DistributedLayerNormStage::POST_ALL_GATHER;

////////////////////////////////////////////////////////////////////////////
// Grayskull Device Setup
// Device Setup
////////////////////////////////////////////////////////////////////////////
Device *device = a.device();

Expand All @@ -422,8 +422,20 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded(
auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] =
get_compute_kernel_config_args(device->arch(), compute_kernel_config);

if (fp32_dest_acc_en) {
TT_ASSERT(subblock_wt <= 4, "subblock width must less than 4 in fp32 mode");
if (dst_full_sync_en == false) {
if (fp32_dest_acc_en) {
TT_FATAL(subblock_wt <= 4, "subblock_wt={}, but subblock width must less than 4 tiles in fp32 mode when dst_full_sync_en is false", subblock_wt);
}
else {
TT_FATAL(subblock_wt <= 8, "subblock_wt={}, but subblock width must less than 8 tiles when dst_full_sync_en is false", subblock_wt);
}
} else {
if (fp32_dest_acc_en) {
TT_FATAL(subblock_wt <= 8, "subblock_wt={}, but subblock width must less than 8 tiles in fp32 mode when dst_full_sync_en is true", subblock_wt);
}
else {
TT_FATAL(subblock_wt <= 16, "subblock_wt={}, but subblock width must less than 16 tiles when dst_full_sync_en is true", subblock_wt);
}
}

tt::DataFormat out_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype());
Expand Down

0 comments on commit 0911990

Please sign in to comment.