diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp index 7973e5174e0..d9e387165e4 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/multi_core/layernorm_op_multi_core.cpp @@ -412,7 +412,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( bool is_post_all_gather = distributed_norm_stage == DistributedLayerNormStage::POST_ALL_GATHER; //////////////////////////////////////////////////////////////////////////// - // Grayskull Device Setup + // Device Setup //////////////////////////////////////////////////////////////////////////// Device *device = a.device(); @@ -422,8 +422,20 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); - if (fp32_dest_acc_en) { - TT_ASSERT(subblock_wt <= 4, "subblock width must less than 4 in fp32 mode"); + if (dst_full_sync_en == false) { + if (fp32_dest_acc_en) { + TT_FATAL(subblock_wt <= 4, "subblock_wt={}, but subblock width must less than 4 tiles in fp32 mode when dst_full_sync_en is false", subblock_wt); + } + else { + TT_FATAL(subblock_wt <= 8, "subblock_wt={}, but subblock width must less than 8 tiles when dst_full_sync_en is false", subblock_wt); + } + } else { + if (fp32_dest_acc_en) { + TT_FATAL(subblock_wt <= 8, "subblock_wt={}, but subblock width must less than 8 tiles in fp32 mode when dst_full_sync_en is true", subblock_wt); + } + else { + TT_FATAL(subblock_wt <= 16, "subblock_wt={}, but subblock width must less than 16 tiles when dst_full_sync_en is true", subblock_wt); + } } tt::DataFormat out_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype());