Skip to content

Commit

Permalink
#4899: fix the permute to operate on the intended shape
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jan 26, 2024
1 parent d34df49 commit 97e6b00
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def bert_attention(

context_layer = attention_probs @ value
context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
context_layer = ttnn.to_layout(context_layer, ttnn.ROW_MAJOR_LAYOUT)
context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))
context_layer = ttnn.to_layout(context_layer, ttnn.TILE_LAYOUT)

self_output = context_layer
self_output = self_output @ parameters.output.dense.weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def merge_heads(x: ttnn.Tensor) -> ttnn.Tensor:
x = ttnn.permute(x, (0, 2, 1, 3))

# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return ttnn.reshape(x, shape=(batch_size, seq_length, num_heads * head_size))
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.reshape(x, shape=(batch_size, seq_length, num_heads * head_size))
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
return x


def compute_context_layer(attention_probs, value_layer):
Expand Down
2 changes: 2 additions & 0 deletions models/experimental/functional_t5/tt/ttnn_functional_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def shape(states, head_size, is_key=False):
def unshape(states, hidden_size):
"""reshape"""
states = ttnn.permute(states, (0, 2, 1, 3))
states = ttnn.to_layout(states, ttnn.ROW_MAJOR_LAYOUT)
states = ttnn.reshape(states, (batch_size, seq_length, hidden_size))
states = ttnn.to_layout(states, ttnn.TILE_LAYOUT)
return states

def project(hidden_states, weight, is_key=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ def calculate_key_values(config, key_value_states, *, parameters):
fused_qkv = ttnn.reshape(fused_qkv, shape=(bsz, tgt_len_padded, 2, config.encoder_attention_heads, head_size))
key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :]

key_states = ttnn.permute(key_states, (0, 2, 1, 3))
value_states = ttnn.permute(value_states, (0, 2, 1, 3))
key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT)
value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT)

desired_shape = ttnn.Shape(
[bsz, tgt_len, config.encoder_attention_heads, head_size],
[bsz, tgt_len_padded, config.encoder_attention_heads, head_size],
[bsz, config.encoder_attention_heads, tgt_len, head_size],
[bsz, config.encoder_attention_heads, tgt_len_padded, head_size],
)
key_states = ttnn.reshape(key_states, shape=desired_shape)
key_states = ttnn.permute(key_states, (0, 2, 1, 3))

value_states = ttnn.reshape(value_states, shape=desired_shape)
value_states = ttnn.permute(value_states, (0, 2, 1, 3))

return key_states, value_states

Expand Down Expand Up @@ -90,20 +92,21 @@ def split_query_key_value_and_split_heads(

fused_qkv = ttnn.reshape(fused_qkv, shape=(batch_size, padded_seq_length, 3, encoder_attention_heads, head_size))
query_states, key_states, value_states = fused_qkv[..., 0, :, :], fused_qkv[..., 1, :, :], fused_qkv[..., 2, :, :]
query_states = ttnn.permute(query_states, (0, 2, 1, 3))
key_states = ttnn.permute(key_states, (0, 2, 1, 3))
value_states = ttnn.permute(value_states, (0, 2, 1, 3))

query_states = ttnn.to_layout(query_states, ttnn.TILE_LAYOUT)
key_states = ttnn.to_layout(key_states, ttnn.TILE_LAYOUT)
value_states = ttnn.to_layout(value_states, ttnn.TILE_LAYOUT)

desired_shape = ttnn.Shape(
[batch_size, seq_length, encoder_attention_heads, head_size],
[batch_size, padded_seq_length, encoder_attention_heads, head_size],
[batch_size, encoder_attention_heads, seq_length, head_size],
[batch_size, encoder_attention_heads, padded_seq_length, head_size],
)
query_states = ttnn.reshape(query_states, shape=desired_shape)
query_states = ttnn.permute(query_states, (0, 2, 1, 3))

key_states = ttnn.reshape(key_states, shape=desired_shape)
key_states = ttnn.permute(key_states, (0, 2, 1, 3))

value_states = ttnn.reshape(value_states, shape=desired_shape)
value_states = ttnn.permute(value_states, (0, 2, 1, 3))

return query_states, key_states, value_states


Expand Down Expand Up @@ -152,7 +155,8 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states=No
key_states = ttnn.reshape(key_states, shape=proj_shape)
value_states = ttnn.reshape(value_states, shape=proj_shape)

attn_weights = query_states @ ttnn.permute(key_states, (0, 2, 1))
key_states = ttnn.permute(key_states, (0, 2, 1))
attn_weights = query_states @ key_states

if attention_mask is not None:
bsz, *_, padded_tgt_len, src_len = attention_mask.shape.padded()
Expand Down Expand Up @@ -187,10 +191,14 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states=No
[bsz, config.encoder_attention_heads, padded_tgt_len, head_size],
),
)

attn_output = ttnn.permute(attn_output, (0, 2, 1, 3))
attn_output = ttnn.to_layout(attn_output, ttnn.ROW_MAJOR_LAYOUT)
attn_output = ttnn.reshape(
attn_output, shape=ttnn.Shape([bsz, tgt_len, config.d_model], [bsz, padded_tgt_len, config.d_model])
attn_output,
shape=ttnn.Shape([bsz, tgt_len, config.d_model], [bsz, tgt_len, config.d_model]),
)
attn_output = ttnn.to_layout(attn_output, ttnn.TILE_LAYOUT)
attn_output = attn_output @ parameters.out_proj.weight + parameters.out_proj.bias
return attn_output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size):
ttnn_decoder_input_ids = ttnn.to_device(ttnn_decoder_input_ids, device)

ttnn_encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states, dtype=ttnn.bfloat16)
ttnn_encoder_hidden_states = ttnn.to_layout(ttnn_encoder_hidden_states, ttnn.TILE_LAYOUT)
ttnn_encoder_hidden_states = ttnn.to_device(ttnn_encoder_hidden_states, device)

(decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_decoder_inputs(
Expand Down
14 changes: 14 additions & 0 deletions tests/ttnn/unit_tests/operations/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,17 @@ def test_permute_for_specific_case(device, b, s, h, w):
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.to_torch(output_tensor)
assert torch.allclose(torch_output_tensor, output_tensor, atol=1e-1, rtol=1e-2)


def test_add_after_permute(device):
torch_a = torch.randn(2, 1280, 8, 8)
torch_b = torch.randn(1, 1, 2, 1280)
torch_b_permuted = torch.permute(torch_b, (2, 3, 0, 1))
torch_output = torch_a + torch_b_permuted

a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
b = ttnn.permute(b, (2, 3, 0, 1))
output = a + b
output = ttnn.to_torch(output)
assert_with_pcc(torch_output, output, 0.9999)
24 changes: 23 additions & 1 deletion ttnn/ttnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def torch_getitem(tensor, slices):
device = self.device
else:
device = None
layout = self.layout

tensor = self
tensor = to_torch(tensor)
tensor = ttl.tensor.decorate_external_operation(torch_getitem, function_name="torch.Tensor.__getitem__")(
tensor, slices
)
tensor = from_torch(tensor, dtype=self.dtype, device=device)
tensor = from_torch(tensor, dtype=self.dtype, layout=layout, device=device)
return tensor

def is_contiguous(self: "Shape") -> bool:
Expand Down Expand Up @@ -205,6 +206,7 @@ def ttnn_reshape(tensor, shape):
ttl_input_tensor = tensor.value
return Tensor(ttl_input_tensor.reshape(shape.value))

layout = input_tensor.layout
ttnn_reshape = ttl.tensor.decorate_external_operation(ttnn_reshape, function_name="ttnn.reshape")

if input_tensor.is_contiguous():
Expand All @@ -230,6 +232,8 @@ def ttnn_reshape(tensor, shape):
ttl_output_tensor = ttl.tensor.reshape(ttl_input_tensor, w, z, y, x)
output_tensor = Tensor(ttl_output_tensor)
output_tensor = ttnn_reshape(output_tensor, shape)
# Unable to handle 5D tensors! See ttnn_optimized_functional_whisper.
# output_tensor = to_layout(output_tensor, layout)
return output_tensor
else:

Expand All @@ -247,6 +251,8 @@ def torch_reshape(tensor, shape):
tensor = to_torch(tensor)
tensor = ttl.tensor.decorate_external_operation(torch_reshape, function_name="torch.reshape")(tensor, shape)
tensor = from_torch(tensor, dtype=input_tensor.dtype, device=device)
# Unable to handle 5D tensors! See ttnn_optimized_functional_whisper.
# tensor = to_layout(tensor, layout)
tensor = ttnn_reshape(tensor, shape)

return tensor
Expand All @@ -266,6 +272,22 @@ def unsqueeze_to_4D(tensor):
return reshape(tensor, shape=Shape(shape, full_shape))


def squeeze(tensor):
if len(tensor.shape) == 1:
# Tensor must be at least one dimension. Avoid exceptions when at all possible!
return tensor
if len(tensor.shape) > 4:
raise RuntimeError("Tensor cannot have more than 4 dimensions!")

batch, *shape = tensor.shape
_, *full_shape = tensor.shape.padded()
if batch != 1:
# Do not throw an error. Avoid exceptions when at all possible!
return tensor

return reshape(tensor, shape=Shape(shape, full_shape))


@decorate_operation()
def from_torch(
tensor: "torch.Tensor",
Expand Down
51 changes: 34 additions & 17 deletions ttnn/ttnn/operations/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,40 +90,57 @@ def permute(input_tensor: ttnn.Tensor, order: Tuple[int, ...]) -> ttnn.Tensor:
if not ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE):
RuntimeError("input_tensor must be on device!")

ttl_input_tensor = input_tensor.value

if len(input_tensor.shape) != len(order):
raise RuntimeError(
"The number of dimensions in the tensor input does not match the length of the desired ordering"
)

original_shape = tuple(input_tensor.shape)
original_shape_padded = tuple(input_tensor.shape.padded())
desired_shape = ttnn.Shape(
list([original_shape[i] for i in order]), list([original_shape_padded[i] for i in order])
)
if ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE) and len(input_tensor.shape) == 4:
output_tensor = ttnn.Tensor(ttl.tensor.permute(ttl_input_tensor, order))
# permute is not currently keeping the original padding
return ttnn.reshape(output_tensor, desired_shape)
elif len(input_tensor.shape) < 4:
on_device = ttnn.has_storage_type_of(input_tensor, ttnn.DEVICE_STORAGE_TYPE)
device = input_tensor.device
layout = input_tensor.layout
dtype = input_tensor.dtype
rank = len(input_tensor.shape)

if len(input_tensor.shape) < 4:
input_tensor = ttnn.unsqueeze_to_4D(input_tensor)
ttl_input_tensor = input_tensor.value
adjusted_order_for_4D_tensor = order
while len(adjusted_order_for_4D_tensor) < 4:
adjusted_order_for_4D_tensor = (0,) + tuple(x + 1 for x in adjusted_order_for_4D_tensor)
output_tensor = ttnn.Tensor(ttl.tensor.permute(ttl_input_tensor, adjusted_order_for_4D_tensor))
return ttnn.reshape(output_tensor, desired_shape)
order = adjusted_order_for_4D_tensor

def has_padding(tensor):
if len(tensor.shape) > 1:
*_, h, w = tensor.shape
*_, h_padded, w_padded = tensor.shape.padded()
return h != h_padded or w != w_padded
return False

if has_padding(input_tensor):
input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)

ttl_input_tensor = input_tensor.value

if ttnn.has_storage_type_of(input_tensor, ttl.tensor.StorageType.DEVICE) and len(input_tensor.shape) == 4:
output_tensor = ttnn.Tensor(ttl.tensor.permute(ttl_input_tensor, order))
output_tensor = ttnn.to_layout(output_tensor, layout)
rank_should_be_updated = len(output_tensor.shape) > rank
while rank_should_be_updated:
prior_rank = len(output_tensor.shape)
output_tensor = ttnn.squeeze(output_tensor)
rank_should_be_updated = prior_rank != len(output_tensor.shape) and len(output_tensor.shape) > rank

if on_device and not ttnn.has_storage_type_of(output_tensor, ttnn.DEVICE_STORAGE_TYPE):
output_tensor = ttnn.to_device(output_tensor, device)
return output_tensor
else:

def torch_permute(tensor, order):
return tensor.permute(order).contiguous().clone()

tensor = ttnn.to_torch(tensor)
tensor = ttl.tensor.decorate_external_operation(torch_permute, function_name="torch.permute")(tensor, order)
tensor = ttnn.from_torch(
tensor, dtype=input_tensor.dtype, layout=input_tensor.layout, device=input_tensor.device
)
tensor = ttnn.from_torch(tensor, dtype=dtype, layout=layout, device=device)
return tensor


Expand Down

0 comments on commit 97e6b00

Please sign in to comment.