diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py index ad523abc4d8..5c64b87d7ac 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py @@ -641,9 +641,6 @@ def out(self, hidden_states): fused_activation=None, ) - # TODO: bug in MM means these sizes need to be interleaved for now - if size == 4096: - output_mem_config = self.l1_interleaved_memory_config hidden_states = ttnn.experimental.operations.primary.matmul( hidden_states, self.parameters.to_out[0].weight, @@ -653,10 +650,7 @@ def out(self, hidden_states): output_dtype=ttnn.experimental.tensor.DataType.BFLOAT8_B, compute_kernel_config=self.compute_kernel_config, ) - if size == 4096: - hidden_states = self.reshard_to( - hidden_states, (5, 8), ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED - ) + return hidden_states def reshard_to(self, tensor, grid_size, layout, col_major=False): diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py index 4db9b9d709f..7a5601c4449 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention_down_block_2d.py @@ -114,5 +114,5 @@ def __call__( use_conv=True, ) hidden_states = ttnn.reallocate(hidden_states) - output_states += (hidden_states,) + output_states += (ttnn.to_memory_config(hidden_states, ttnn.DRAM_MEMORY_CONFIG),) return hidden_states, output_states diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py index 0e77f9d7d6c..6ad303eb2f1 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_resnetblock2d.py @@ -51,13 +51,13 @@ def ttnn_to_torch(input): } split_chunks = { - (320, 960, 64, 64): 2, + # (320, 960, 64, 64): 2, (640, 1920, 32, 32): 2, # (640, 1280, 32, 32): 2, # (640, 960, 32, 32): 2, (1280, 1920, 16, 16): 2, # (1280, 2560, 8, 8): 2, - (1280, 2560, 16, 16): 2, # TODO: Can remove with reallocation + # (1280, 2560, 16, 16): 2, }