Skip to content

Commit

Permalink
#5337: MLP matmuls 8x8 grid
Browse files Browse the repository at this point in the history
  • Loading branch information
sraizada-tt authored and mtairum committed May 29, 2024
1 parent 16c8fce commit d0cf198
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtypes):
self.w3 = ttnn.to_device(self.w3, device_mesh)

self.w1_prg_cfg = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(6, 7),
in0_block_w=4, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size
compute_with_storage_grid_size=(8, 8),
in0_block_w=2, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size
out_subblock_h=1, # Must be divisible by per_core_M
out_subblock_w=1, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4
per_core_M=1, # M / TILE_HEIGHT = 32 / 32
per_core_N=11, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size, N = 4096 for num_device=8
per_core_N=7, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size, N = 4096 for num_device=8
fuse_batch=True,
fused_activation=ttnn.experimental.tensor.FusibleActivation.SILU,
mcast_in0=True,
)
self.w3_prg_cfg = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(6, 7),
compute_with_storage_grid_size=(8, 8),
in0_block_w=4, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size
out_subblock_h=1, # Must be divisible by per_core_M
out_subblock_w=1, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4
Expand All @@ -67,14 +67,13 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtypes):
fused_activation=None,
mcast_in0=True,
)

self.w2_prg_cfg = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(6, 7),
in0_block_w=8, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size
compute_with_storage_grid_size=(8, 8),
in0_block_w=7, # K = 8192 / TILE_WIDTH=32 / Grid_Size is based on compute_with_storage_grid_size
out_subblock_h=1, # Must be divisible by per_core_M
out_subblock_w=4, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4
out_subblock_w=2, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4
per_core_M=1, # M / TILE_HEIGHT = 32 / 32
per_core_N=4, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size, N = 4096 for num_device=8
per_core_N=2, # N / TILE_WIDTH / Grid_Size is based on compute_with_storage_grid_size, N = 4096 for num_device=8
fuse_batch=True,
fused_activation=None,
mcast_in0=True,
Expand Down

0 comments on commit d0cf198

Please sign in to comment.