-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#8536: Allow in0 and output to be sharded/produced on different grids…
… for mcast 1D in0 - Fully decouple in0 sender (ie. has in0 data) and receiver (ie. cores that produce work) grids - This means user can now width shard in0 K and specify per_core_N that divides output width on arbitrary number of cores - See tests/ttnn/sweep_tests/sweeps/sweeps/matmul/short/matmul_user_program_config_mcast_1d.py for examples Changes: - Remove this assert: TT_FATAL(div_up(N, per_core_N) == input_tensor_a.shard_spec().value().grid.num_cores()); - Separate in0 sender/recv cores into 3 kernel quadrants so all new logic is compile time * in0_mcast_cores_with_work_and_in_receiver_grid * in0_mcast_cores_without_work_and_in_receiver_grid * in0_mcast_cores_without_work_and_not_in_receiver_grid - Only load compute and writer kernels onto cores that produce output work - Add new short matmul sweep to test mcast 1D matmul with different in0 and output grids - Fork tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp to *width_sharded.cpp - TODO: Merge these kernels back once mcast 2D matmul is uplifted to support this feature
- Loading branch information
1 parent
432117b
commit a237d77
Showing
4 changed files
with
2,276 additions
and
1,035 deletions.
There are no files selected for viewing
364 changes: 364 additions & 0 deletions
364
tests/ttnn/sweep_tests/sweeps/sweeps/matmul/short/matmul_user_program_config_mcast_1d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,364 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional, Tuple | ||
from loguru import logger | ||
import enum | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
from tests.ttnn.utils_for_testing import check_with_pcc | ||
from models.utility_functions import torch_random | ||
|
||
|
||
core_grid = ttnn.CoreCoord(8, 7) | ||
parameters = { | ||
"matmul_specs": [ | ||
# Matmul 1D mcast in0: in0 grid == output grid | ||
# loop along in0 shard width | ||
( | ||
(1,), | ||
(64, 32 * 64, 32 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 4), | ||
in0_block_w=1, | ||
out_subblock_h=1, | ||
out_subblock_w=1, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# no looping along in0 shard width | ||
( | ||
(1,), | ||
(64, 32 * 64, 32 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 4), | ||
in0_block_w=2, | ||
out_subblock_h=1, | ||
out_subblock_w=1, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 3))}), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# Matmul 1D mcast in0: in0 grid < output grid | ||
# loop along in0 shard width | ||
( | ||
(1,), | ||
(64, 28 * 64, 35 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 5), | ||
in0_block_w=1, | ||
out_subblock_h=1, | ||
out_subblock_w=3, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(28, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# no looping along in0 shard width | ||
( | ||
(1,), | ||
(64, 28 * 64, 35 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 5), | ||
in0_block_w=2, | ||
out_subblock_h=1, | ||
out_subblock_w=3, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(28, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# Matmul 1D mcast in0: in0 grid > output grid | ||
# loop along in0 shard width | ||
( | ||
(1,), | ||
(64, 35 * 64, 28 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 5), | ||
in0_block_w=1, | ||
out_subblock_h=1, | ||
out_subblock_w=3, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(35, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# no looping along in0 shard width | ||
( | ||
(1,), | ||
(64, 35 * 64, 28 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 5), | ||
in0_block_w=2, | ||
out_subblock_h=1, | ||
out_subblock_w=3, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(35, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# Matmul 1D mcast in0: in0 grid.y == output grid.y but in0 grid.x < output grid.x and output grid.x isn't full row; tests mcast logic for num_active_cores | ||
# loop along in0 shard width | ||
( | ||
(1,), | ||
(64, 28 * 64, 30 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 4), | ||
in0_block_w=1, | ||
out_subblock_h=1, | ||
out_subblock_w=1, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(28, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# no looping along in0 shard width | ||
( | ||
(1,), | ||
(64, 28 * 64, 30 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 4), | ||
in0_block_w=2, | ||
out_subblock_h=1, | ||
out_subblock_w=1, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(28, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# Matmul 1D mcast in0: in0 grid.y == output grid.y but in0 grid.x > output grid.x and in0 grid.x isn't full row; tests mcast logic for num_active_cores | ||
# loop along in0 shard width | ||
( | ||
(1,), | ||
(64, 30 * 64, 28 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 4), | ||
in0_block_w=1, | ||
out_subblock_h=1, | ||
out_subblock_w=1, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(30, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
# no looping along in0 shard width | ||
( | ||
(1,), | ||
(64, 30 * 64, 28 * 96), | ||
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( | ||
compute_with_storage_grid_size=(8, 4), | ||
in0_block_w=2, | ||
out_subblock_h=1, | ||
out_subblock_w=1, | ||
per_core_M=2, | ||
per_core_N=3, | ||
fuse_batch=True, | ||
fused_activation=None, | ||
mcast_in0=True, | ||
), | ||
ttnn.MemoryConfig( | ||
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, | ||
buffer_type=ttnn.BufferType.L1, | ||
shard_spec=ttnn.ShardSpec( | ||
ttnn.experimental.tensor.num_cores_to_core_range_set(30, core_grid, row_wise=True), | ||
(64, 64), | ||
ttnn.ShardOrientation.ROW_MAJOR, | ||
False, | ||
), | ||
), | ||
), | ||
], | ||
"batch_matrix_multiply": [False], | ||
"input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG], | ||
"output_memory_config": [ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG], | ||
"input_a_dtype": [ttnn.bfloat16], | ||
"input_b_dtype": [ttnn.bfloat8_b], | ||
"output_dtype": [ttnn.bfloat16], | ||
"input_layout": [ttnn.TILE_LAYOUT], | ||
"compute_kernel_config": [None], | ||
} | ||
|
||
|
||
def skip(**_) -> Tuple[bool, Optional[str]]: | ||
return False, None | ||
|
||
|
||
def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: | ||
return False, None | ||
|
||
|
||
def run( | ||
matmul_specs, | ||
batch_matrix_multiply, | ||
input_b_memory_config, | ||
output_memory_config, | ||
input_a_dtype, | ||
input_b_dtype, | ||
output_dtype, | ||
input_layout, | ||
compute_kernel_config, | ||
*, | ||
device, | ||
) -> Tuple[bool, Optional[str]]: | ||
( | ||
batch_sizes, | ||
input_shapes, | ||
program_config, | ||
input_a_memory_config, | ||
) = matmul_specs | ||
|
||
(m_size, k_size, n_size) = input_shapes | ||
input_shape_a = (*batch_sizes, m_size, k_size) | ||
input_shape_b = (k_size, n_size) | ||
if batch_matrix_multiply: | ||
input_shape_b = (*batch_sizes, k_size, n_size) | ||
|
||
input_a_layout = input_layout | ||
input_b_layout = input_layout | ||
|
||
torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.float32) | ||
torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.float32) | ||
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) | ||
|
||
input_tensor_a = ttnn.from_torch( | ||
torch_input_tensor_a, | ||
layout=input_a_layout, | ||
dtype=input_a_dtype, | ||
device=device, | ||
memory_config=input_a_memory_config, | ||
) | ||
input_tensor_b = ttnn.from_torch( | ||
torch_input_tensor_b, | ||
layout=input_b_layout, | ||
dtype=input_b_dtype, | ||
device=device, | ||
memory_config=input_b_memory_config, | ||
) | ||
|
||
output_tensor = ttnn.matmul( | ||
input_tensor_a, | ||
input_tensor_b, | ||
memory_config=output_memory_config, | ||
dtype=output_dtype, | ||
program_config=program_config, | ||
compute_kernel_config=compute_kernel_config, | ||
) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
expected_pcc = 0.99 | ||
return check_with_pcc(torch_output_tensor, output_tensor, expected_pcc) |
Oops, something went wrong.