diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index d22af394cf0..4b10f62a6ad 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -90,11 +90,11 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.device_mesh, - mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=-2), + mesh_mapper=ReplicateTensorToMesh(self.device_mesh), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], - cache_file_name=cache_name(f"wo_multidevice4d"), + cache_file_name=cache_name(f"wo_multidevice4d_H"), ) cache_k = torch.zeros( @@ -129,17 +129,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): self.scale = self.head_dim**-0.5 - reduce_mask_torch = torch.zeros(1, 1, self.max_batch_size, self.max_batch_size * 8) - for i in range(self.max_batch_size): - reduce_mask_torch[:, :, i, range(i, self.max_batch_size * 8, self.max_batch_size)] = 1 - self.reduce_mask = ttnn.from_torch( - reduce_mask_torch, - device=self.device_mesh, - mesh_mapper=ReplicateTensorToMesh(self.device_mesh), - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - ) - self.compute_kernel = self.model_args.get_compute_kernel_config() self.compute_kernel_attn = self.model_args.get_compute_kernel_attn_config() @@ -300,16 +289,19 @@ def forward( ) attn_output_1B4D.deallocate(True) - # attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved( - # attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG - # ) + attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved( + attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG + ) ### # Output matmul ### + # All gather + dense_outputs_11BH_gathered = ttnn.all_gather(attn_output_11BH, dim=3, num_links=1) - dense_out_11BH = ttnn.experimental.operations.primary.matmul( - attn_output_11BH, + # return the sum of the outputs + dense_outputs_11BH = ttnn.experimental.operations.primary.matmul( + dense_outputs_11BH_gathered, wo, output_mem_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"], # compute_with_storage_grid_size=(8, 8), @@ -317,10 +309,6 @@ def forward( compute_kernel_config=self.compute_kernel, output_dtype=ttnn.bfloat8_b, ) - attn_output_11BH.deallocate(True) - # All gather - dense_outputs_11BH = ttnn.all_gather(dense_out_11BH, dim=2, num_links=1) - # return the sum of the outputs - dense_outputs_11BH = ttnn.experimental.operations.primary.matmul(self.reduce_mask, dense_outputs_11BH) + dense_outputs_11BH_gathered.deallocate(True) return dense_outputs_11BH