From 97e6b00314c1a96b2acc674cdd4ebe529b7a23ad Mon Sep 17 00:00:00 2001 From: Eyon Date: Thu, 25 Jan 2024 21:59:14 +0000 Subject: [PATCH] #4899: fix the permute to operate on the intended shape --- .../tt/ttnn_functional_bert.py | 2 + .../tt/ttnn_functional_bloom.py | 5 +- .../functional_t5/tt/ttnn_functional_t5.py | 2 + .../tt/ttnn_optimized_functional_whisper.py | 38 ++++++++------ .../test_ttnn_optimized_functional_whisper.py | 1 + .../unit_tests/operations/test_permute.py | 14 +++++ ttnn/ttnn/core.py | 24 ++++++++- ttnn/ttnn/operations/data_movement.py | 51 ++++++++++++------- 8 files changed, 103 insertions(+), 34 deletions(-) diff --git a/models/experimental/functional_bert/tt/ttnn_functional_bert.py b/models/experimental/functional_bert/tt/ttnn_functional_bert.py index 53a97c2d420..621d1f2262a 100644 --- a/models/experimental/functional_bert/tt/ttnn_functional_bert.py +++ b/models/experimental/functional_bert/tt/ttnn_functional_bert.py @@ -42,7 +42,9 @@ def bert_attention( context_layer = attention_probs @ value context_layer = ttnn.permute(context_layer, (0, 2, 1, 3)) + context_layer = ttnn.to_layout(context_layer, ttnn.ROW_MAJOR_LAYOUT) context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size)) + context_layer = ttnn.to_layout(context_layer, ttnn.TILE_LAYOUT) self_output = context_layer self_output = self_output @ parameters.output.dense.weight diff --git a/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py b/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py index 38dd4b5b295..cf945e3c9c7 100644 --- a/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py +++ b/models/experimental/functional_bloom/tt/ttnn_functional_bloom.py @@ -140,7 +140,10 @@ def merge_heads(x: ttnn.Tensor) -> ttnn.Tensor: x = ttnn.permute(x, (0, 2, 1, 3)) # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim - return ttnn.reshape(x, shape=(batch_size, seq_length, num_heads * head_size)) + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.reshape(x, shape=(batch_size, seq_length, num_heads * head_size)) + x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) + return x def compute_context_layer(attention_probs, value_layer): diff --git a/models/experimental/functional_t5/tt/ttnn_functional_t5.py b/models/experimental/functional_t5/tt/ttnn_functional_t5.py index 6750db43d71..501fbe23b43 100644 --- a/models/experimental/functional_t5/tt/ttnn_functional_t5.py +++ b/models/experimental/functional_t5/tt/ttnn_functional_t5.py @@ -106,7 +106,9 @@ def shape(states, head_size, is_key=False): def unshape(states, hidden_size): """reshape""" states = ttnn.permute(states, (0, 2, 1, 3)) + states = ttnn.to_layout(states, ttnn.ROW_MAJOR_LAYOUT) states = ttnn.reshape(states, (batch_size, seq_length, hidden_size)) + states = ttnn.to_layout(states, ttnn.TILE_LAYOUT) return states def project(hidden_states, weight, is_key=False): diff --git a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py index 9e9248a96db..0f0309592d0 100644 --- a/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_optimized_functional_whisper.py @@ -51,15 +51,17 @@ def calculate_key_values(config, key_value_states, *, parameters): fused_qkv = ttnn.reshape(fused_qkv, shape=(bsz, tgt_len_padded, 2, config.encoder_attention_heads, head_size)) key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :] + key_states = ttnn.permute(key_states, (0, 2, 1, 3)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) + desired_shape = ttnn.Shape( - [bsz, tgt_len, config.encoder_attention_heads, head_size], - [bsz, tgt_len_padded, config.encoder_attention_heads, head_size], + [bsz, config.encoder_attention_heads, tgt_len, head_size], + [bsz, config.encoder_attention_heads, tgt_len_padded, head_size], ) key_states = ttnn.reshape(key_states, shape=desired_shape) - key_states = ttnn.permute(key_states, (0, 2, 1, 3)) - value_states = ttnn.reshape(value_states, shape=desired_shape) - value_states = ttnn.permute(value_states, (0, 2, 1, 3)) return key_states, value_states @@ -90,20 +92,21 @@ def split_query_key_value_and_split_heads( fused_qkv = ttnn.reshape(fused_qkv, shape=(batch_size, padded_seq_length, 3, encoder_attention_heads, head_size)) query_states, key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :] + query_states = ttnn.permute(query_states, (0, 2, 1, 3)) + key_states = ttnn.permute(key_states, (0, 2, 1, 3)) + value_states = ttnn.permute(value_states, (0, 2, 1, 3)) + + query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT) + key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT) + value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT) desired_shape = ttnn.Shape( - [batch_size, seq_length, encoder_attention_heads, head_size], - [batch_size, padded_seq_length, encoder_attention_heads, head_size], + [batch_size, encoder_attention_heads, seq_length, head_size], + [batch_size, encoder_attention_heads, padded_seq_length, head_size], ) query_states = ttnn.reshape(query_states, shape=desired_shape) - query_states = ttnn.permute(query_states, (0, 2, 1, 3)) - key_states = ttnn.reshape(key_states, shape=desired_shape) - key_states = ttnn.permute(key_states, (0, 2, 1, 3)) - value_states = ttnn.reshape(value_states, shape=desired_shape) - value_states = ttnn.permute(value_states, (0, 2, 1, 3)) - return query_states, key_states, value_states @@ -152,7 +155,8 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states=No key_states = ttnn.reshape(key_states, shape=proj_shape) value_states = ttnn.reshape(value_states, shape=proj_shape) - attn_weights = query_states @ ttnn.permute(key_states, (0, 2, 1)) + key_states = ttnn.permute(key_states, (0, 2, 1)) + attn_weights = query_states @ key_states if attention_mask is not None: bsz, *_, padded_tgt_len, src_len = attention_mask.shape.padded() @@ -187,10 +191,14 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states=No [bsz, config.encoder_attention_heads, padded_tgt_len, head_size], ), ) + attn_output = ttnn.permute(attn_output, (0, 2, 1, 3)) + attn_output = ttnn.to_layout(attn_output, ttnn.ROW_MAJOR_LAYOUT) attn_output = ttnn.reshape( - attn_output, shape=ttnn.Shape([bsz, tgt_len, config.d_model], [bsz, padded_tgt_len, config.d_model]) + attn_output, + shape=ttnn.Shape([bsz, tgt_len, config.d_model], [bsz, tgt_len, config.d_model]), ) + attn_output = ttnn.to_layout(attn_output, ttnn.TILE_LAYOUT) attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias return attn_output diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index 342c293f4af..d8aac05d135 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -286,6 +286,7 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): ttnn_decoder_input_ids = ttnn.to_device(ttnn_decoder_input_ids, device) ttnn_encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states, dtype=ttnn.bfloat16) + ttnn_encoder_hidden_states = ttnn.to_layout(ttnn_encoder_hidden_states, ttnn.TILE_LAYOUT) ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, device) (decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_decoder_inputs( diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 43455801cbe..6df41ebb36c 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -90,3 +90,17 @@ def test_permute_for_specific_case(device, b, s, h, w): output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.to_torch(output_tensor) assert torch.allclose(torch_output_tensor, output_tensor, atol=1e-1, rtol=1e-2) + + +def test_add_after_permute(device): + torch_a = torch.randn(2, 1280, 8, 8) + torch_b = torch.randn(1, 1, 2, 1280) + torch_b_permuted = torch.permute(torch_b, (2, 3, 0, 1)) + torch_output = torch_a + torch_b_permuted + + a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16) + b = ttnn.permute(b, (2, 3, 0, 1)) + output = a + b + output = ttnn.to_torch(output) + assert_with_pcc(torch_output, output, 0.9999) diff --git a/ttnn/ttnn/core.py b/ttnn/ttnn/core.py index 2fb1283ad41..a61e616cb72 100644 --- a/ttnn/ttnn/core.py +++ b/ttnn/ttnn/core.py @@ -59,13 +59,14 @@ def torch_getitem(tensor, slices): device = self.device else: device = None + layout = self.layout tensor = self tensor = to_torch(tensor) tensor = ttl.tensor.decorate_external_operation(torch_getitem, function_name="torch.Tensor.__getitem__")( tensor, slices ) - tensor = from_torch(tensor, dtype=self.dtype, device=device) + tensor = from_torch(tensor, dtype=self.dtype, layout=layout, device=device) return tensor def is_contiguous(self: "Shape") -> bool: @@ -205,6 +206,7 @@ def ttnn_reshape(tensor, shape): ttl_input_tensor = tensor.value return Tensor(ttl_input_tensor.reshape(shape.value)) + layout = input_tensor.layout ttnn_reshape = ttl.tensor.decorate_external_operation(ttnn_reshape, function_name="ttnn.reshape") if input_tensor.is_contiguous(): @@ -230,6 +232,8 @@ def ttnn_reshape(tensor, shape): ttl_output_tensor = ttl.tensor.reshape(ttl_input_tensor, w, z, y, x) output_tensor = Tensor(ttl_output_tensor) output_tensor = ttnn_reshape(output_tensor, shape) + # Unable to handle 5D tensors! See ttnn_optimized_functional_whisper. + # output_tensor = to_layout(output_tensor, layout) return output_tensor else: @@ -247,6 +251,8 @@ def torch_reshape(tensor, shape): tensor = to_torch(tensor) tensor = ttl.tensor.decorate_external_operation(torch_reshape, function_name="torch.reshape")(tensor, shape) tensor = from_torch(tensor, dtype=input_tensor.dtype, device=device) + # Unable to handle 5D tensors! See ttnn_optimized_functional_whisper. + # tensor = to_layout(tensor, layout) tensor = ttnn_reshape(tensor, shape) return tensor @@ -266,6 +272,22 @@ def unsqueeze_to_4D(tensor): return reshape(tensor, shape=Shape(shape, full_shape)) +def squeeze(tensor): + if len(tensor.shape) == 1: + # Tensor must be at least one dimension. Avoid exceptions when at all possible! + return tensor + if len(tensor.shape) > 4: + raise RuntimeError("Tensor cannot have more than 4 dimensions!") + + batch, *shape = tensor.shape + _, *full_shape = tensor.shape.padded() + if batch != 1: + # Do not throw an error. Avoid exceptions when at all possible! + return tensor + + return reshape(tensor, shape=Shape(shape, full_shape)) + + @decorate_operation() def from_torch( tensor: "torch.Tensor", diff --git a/ttnn/ttnn/operations/data_movement.py b/ttnn/ttnn/operations/data_movement.py index fbb741e301a..c5aa29c76e9 100644 --- a/ttnn/ttnn/operations/data_movement.py +++ b/ttnn/ttnn/operations/data_movement.py @@ -90,30 +90,49 @@ def permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...]) -> ttnn.Tensor: if not ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE): RuntimeError("input_tensor must be on device!") - ttl_input_tensor = input_tensor.value - if len(input_tensor.shape) != len(order): raise RuntimeError( "The number of dimensions in the tensor input does not match the length of the desired ordering" ) - original_shape = tuple(input_tensor.shape) - original_shape_padded = tuple(input_tensor.shape.padded()) - desired_shape = ttnn.Shape( - list([original_shape[i] for i in order]), list([original_shape_padded[i] for i in order]) - ) - if ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE) and len(input_tensor.shape) == 4: - output_tensor = ttnn.Tensor(ttl.tensor.permute(ttl_input_tensor, order)) - # permute is not currently keeping the original padding - return ttnn.reshape(output_tensor, desired_shape) - elif len(input_tensor.shape) < 4: + on_device = ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE) + device = input_tensor.device + layout = input_tensor.layout + dtype = input_tensor.dtype + rank = len(input_tensor.shape) + + if len(input_tensor.shape) < 4: input_tensor = ttnn.unsqueeze_to_4D(input_tensor) ttl_input_tensor = input_tensor.value adjusted_order_for_4D_tensor = order while len(adjusted_order_for_4D_tensor) < 4: adjusted_order_for_4D_tensor = (0,) + tuple(x + 1 for x in adjusted_order_for_4D_tensor) - output_tensor = ttnn.Tensor(ttl.tensor.permute(ttl_input_tensor, adjusted_order_for_4D_tensor)) - return ttnn.reshape(output_tensor, desired_shape) + order = adjusted_order_for_4D_tensor + + def has_padding(tensor): + if len(tensor.shape) > 1: + *_, h, w = tensor.shape + *_, h_padded, w_padded = tensor.shape.padded() + return h != h_padded or w != w_padded + return False + + if has_padding(input_tensor): + input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) + + ttl_input_tensor = input_tensor.value + + if ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE) and len(input_tensor.shape) == 4: + output_tensor = ttnn.Tensor(ttl.tensor.permute(ttl_input_tensor, order)) + output_tensor = ttnn.to_layout(output_tensor, layout) + rank_should_be_updated = len(output_tensor.shape) > rank + while rank_should_be_updated: + prior_rank = len(output_tensor.shape) + output_tensor = ttnn.squeeze(output_tensor) + rank_should_be_updated = prior_rank != len(output_tensor.shape) and len(output_tensor.shape) > rank + + if on_device and not ttnn.has_storage_type_of(output_tensor, ttnn.DEVICE_STORAGE_TYPE): + output_tensor = ttnn.to_device(output_tensor, device) + return output_tensor else: def torch_permute(tensor, order): @@ -121,9 +140,7 @@ def torch_permute(tensor, order): tensor = ttnn.to_torch(tensor) tensor = ttl.tensor.decorate_external_operation(torch_permute, function_name="torch.permute")(tensor, order) - tensor = ttnn.from_torch( - tensor, dtype=input_tensor.dtype, layout=input_tensor.layout, device=input_tensor.device - ) + tensor = ttnn.from_torch(tensor, dtype=dtype, layout=layout, device=device) return tensor