Skip to content

Commit

Permalink
#4003: deleted split_heads and split_key_value_and_split_heads
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 9, 2024
1 parent 3a5fcae commit d5d6a1f
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def bert_attention(
) = ttnn.transformer.split_query_key_value_and_split_heads(
query_key_value_output,
memory_config=ttnn.L1_MEMORY_CONFIG,
core_grid=(batch_size, num_cores_x),
num_heads=num_heads,
)
ttnn.deallocate(query_key_value_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
def split_query_key_value_and_split_heads(
query_key_value: torch.Tensor, num_heads: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, *_ = query_key_value.shape
output = ttnn.transformer.split_query_key_value_and_split_heads(
query_key_value, core_grid=(batch_size, 12), memory_config=BLOOM_MEMORY_CONFIG, num_heads=num_heads
query_key_value, memory_config=BLOOM_MEMORY_CONFIG, num_heads=num_heads
)
return output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def t5_attention(
) = ttnn.transformer.split_query_key_value_and_split_heads(
query_key_value_output,
memory_config=ttnn.L1_MEMORY_CONFIG,
core_grid=(batch_size, num_cores_x),
num_heads=config.num_heads,
)
ttnn.deallocate(query_key_value_output)
Expand All @@ -134,12 +133,6 @@ def t5_attention(
# dtype=ttnn.bfloat8_b,
core_grid=(batch_size, num_cores_x),
)
query = ttnn.transformer.split_heads(
query_proj,
num_heads=config.num_heads,
order=(0, 2, 1, 3),
)
ttnn.deallocate(query_proj)

key_value_proj = ttnn.linear(
key_value_states,
Expand All @@ -148,7 +141,10 @@ def t5_attention(
# dtype=ttnn.bfloat8_b,
core_grid=(batch_size, num_cores_x),
)
key, value = ttnn.transformer.split_key_value_and_split_heads(key_value_proj, num_heads=config.num_heads)
query, key, value = ttnn.transformer.split_query_key_value_and_split_heads(
query_proj, key_value_proj, num_heads=config.num_heads
)
ttnn.deallocate(query_proj)
ttnn.deallocate(key_value_proj)

attention_scores = ttnn.matmul(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ def calculate_key_values(config, key_value_states, *, parameters):
return key_states, value_states


# def split_fused_qkv_and_split_heads(
# fused_qkv: torch.Tensor, head_size
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# batch_size, *_ = fused_qkv.shape
# output = ttnn.transformer.split_fused_qkv_and_split_heads(
# fused_qkv, core_grid=(batch_size, 12), memory_config=WHISPER_MEMORY_CONFIG
# )
# return output


def split_fused_qkv_and_split_heads(config, fused_qkv: ttnn.Tensor) -> Tuple[ttnn.Tensor, ttnn.Tensor, ttnn.Tensor]:
head_size = config.d_model // config.encoder_attention_heads
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
Expand Down
174 changes: 14 additions & 160 deletions ttnn/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,63 +18,7 @@
from ttnn.decorators import decorate_operation


def _torch_split_heads(input_tensor: Tensor, *, num_heads, order):
import ttnn
import torch

input_tensor = ttnn.from_device(input_tensor)
input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
input_tensor = ttnn.to_torch(input_tensor)

batch_size, sequence_size, hidden_size = input_tensor.shape
head_size = hidden_size // num_heads

output_tensor = torch.reshape(input_tensor, (batch_size, sequence_size, num_heads, head_size)).contiguous().clone()
output_tensor = torch.permute(output_tensor, order).contiguous().clone()
return output_tensor


@decorate_operation(torch_function=_torch_split_heads)
def split_heads(input_tensor: Tensor, *, num_heads: int, order: Tuple[int]) -> Tensor:
if len(input_tensor.shape) != 3:
raise RuntimeError("Input Tensor must have strictly 3 dimensions!")

if input_tensor.layout != TILE_LAYOUT:
raise RuntimeError("Input Tensor must be in a TILE_LAYOUT!")

if not has_storage_type_of(input_tensor, DEVICE_STORAGE_TYPE):
raise RuntimeError("input_tensor must be on device!")

import ttnn
import torch

def impl(tensor):
tensor = torch.reshape(tensor, (batch_size, sequence_size, num_heads, head_size)).contiguous().clone()
tensor = torch.permute(tensor, order).contiguous().clone()
return tensor

impl = ttl.tensor.decorate_external_operation(impl, function_name="ttnn.transformer.split_heads")

device = input_tensor.value.device()
input_dtype = input_tensor.dtype

batch_size, sequence_size, hidden_size = input_tensor.shape
head_size = hidden_size // num_heads

tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
tensor = ttnn.from_device(tensor)
tensor = ttnn.to_torch(tensor)

tensor = impl(tensor)

tensor = ttnn.from_torch(tensor, input_dtype)
tensor = ttnn.to_layout(tensor, ttnn.TILE_LAYOUT)
tensor = ttnn.to_device(tensor, device)

return tensor


def _torch_split_query_key_value_and_split_heads(input_tensor: Tensor, *, num_heads=16, **_):
def _torch_split_query_key_value_and_split_heads(input_tensor: Tensor, *, num_heads, **_):
import ttnn
import torch

Expand Down Expand Up @@ -108,9 +52,9 @@ def _torch_split_query_key_value_and_split_heads(input_tensor: Tensor, *, num_he
@decorate_operation(torch_function=_torch_split_query_key_value_and_split_heads)
def split_query_key_value_and_split_heads(
input_tensor: Tensor,
kv_input_tensor: Optional[Tensor] = None,
*,
num_heads: int,
core_grid: Tuple[int, int],
memory_config: MemoryConfig = DRAM_MEMORY_CONFIG,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Expand Down Expand Up @@ -149,28 +93,36 @@ def split_query_key_value_and_split_heads(
if not has_storage_type_of(input_tensor, DEVICE_STORAGE_TYPE):
raise RuntimeError("input_tensor must be on device!")

batch_size, sequence_size, three_times_hidden_size = input_tensor.shape
batch_size, *_ = input_tensor.shape
if input_tensor.shape == (batch_size, 384, 1024 * 3):
batch_size, sequence_size, three_times_hidden_size = input_tensor.shape
input_tensor = reshape(input_tensor, (batch_size, 1, sequence_size, three_times_hidden_size))

ttl_input_tensor = input_tensor.value

core_y, core_x = core_grid
query_key_value = ttl.operations.primary.transformers.split_query_key_value_and_split_heads(
ttl_input_tensor,
ttl.tensor.CoreCoord(core_x, core_y),
ttl_input_tensor.device().compute_with_storage_grid_size(),
memory_config,
)
query_key_value = (Tensor(ttl_tensor) for ttl_tensor in query_key_value)
query, key, value = query_key_value
return query, key, value
else:
input_tensor = reshape(input_tensor, (batch_size, 1, sequence_size, three_times_hidden_size))
batch_size, sequence_size, hidden_size = input_tensor.shape

input_tensor = reshape(input_tensor, (batch_size, 1, sequence_size, hidden_size))
ttl_input_tensor = input_tensor.value

if kv_input_tensor is not None:
kv_input_tensor = reshape(kv_input_tensor, (batch_size, 1, sequence_size, hidden_size * 2))
ttl_kv_input_tensor = kv_input_tensor.value
else:
ttl_kv_input_tensor = None

query_key_value = ttl.tensor.nlp_create_qkv_heads(
ttl_input_tensor,
ttl_kv_input_tensor,
num_heads=num_heads,
output_mem_config=memory_config,
)
Expand All @@ -179,104 +131,6 @@ def split_query_key_value_and_split_heads(
return query, key, value


def _torch_split_key_value_and_split_heads(input_tensor: Tensor, *, num_heads, **_):
import ttnn
import torch

input_tensor = ttnn.from_device(input_tensor)
input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
input_tensor = ttnn.to_torch(input_tensor)

batch_size, sequence_size, two_times_hidden_size = input_tensor.shape
hidden_size = two_times_hidden_size // 2
head_size = hidden_size // num_heads

tensor = torch.reshape(input_tensor, (batch_size, sequence_size, 2, num_heads, head_size))
key_layer, value_layer = (
tensor[..., 0, :, :],
tensor[..., 1, :, :],
)

key_layer = torch.reshape(key_layer, (batch_size, sequence_size, num_heads, head_size))
key_layer = torch.permute(key_layer, (0, 2, 3, 1)).contiguous().clone()

value_layer = torch.reshape(value_layer, (batch_size, sequence_size, num_heads, head_size))
value_layer = torch.permute(value_layer, (0, 2, 1, 3)).contiguous().clone()

return key_layer, value_layer


@decorate_operation(torch_function=_torch_split_key_value_and_split_heads)
def split_key_value_and_split_heads(
input_tensor: Tensor,
*,
num_heads: int,
) -> Tuple[Tensor, Tensor]:
"""
split_key_value_and_split_heads(input_tensor: ttnn.Tensor, *, core_grid: Tuple[int, int], memory_config: MemoryConfig = DRAM_MEMORY_CONFIG) -> Tuple[Tensor, Tensor, Tensor]
Splits tensor of shape [batch_size, sequence_size, 2 * hidden_size] into 2 tensors (Key, Value) of shape [batch_size, sequence_size, hidden_size]. Then, reshapes and permutes them, to make them ready for computing attention scores
Args:
* :attr:`input_tensor`: Input Tensor
* :attr:`num_heads`: num heads to split into
"""
if len(input_tensor.shape) != 3:
raise RuntimeError("Input Tensor must have strictly 3 dimensions!")

if input_tensor.layout != TILE_LAYOUT:
raise RuntimeError("Input Tensor must be in a TILE_LAYOUT!")

if not has_storage_type_of(input_tensor, DEVICE_STORAGE_TYPE):
raise RuntimeError("input_tensor must be on device!")

import ttnn
import torch

device = input_tensor.value.device()
input_dtype = input_tensor.dtype

def impl(tensor):
batch_size, sequence_size, two_times_hidden_size = tensor.shape
hidden_size = two_times_hidden_size // 2
head_size = hidden_size // num_heads

tensor = torch.reshape(tensor, (batch_size, sequence_size, 2, num_heads, head_size))
key_layer, value_layer = (
tensor[..., 0, :, :],
tensor[..., 1, :, :],
)

key_layer = torch.reshape(key_layer, (batch_size, sequence_size, num_heads, head_size))
key_layer = torch.permute(key_layer, (0, 2, 3, 1)).contiguous().clone()

value_layer = torch.reshape(value_layer, (batch_size, sequence_size, num_heads, head_size))
value_layer = torch.permute(value_layer, (0, 2, 1, 3)).contiguous().clone()

return key_layer, value_layer

impl = ttl.tensor.decorate_external_operation(
impl, function_name="ttnn.transformer.split_key_value_and_split_heads"
)

input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT)
input_tensor = ttnn.from_device(input_tensor)
input_tensor = ttnn.to_torch(input_tensor)

key_layer, value_layer = impl(input_tensor)

key_layer = ttnn.from_torch(key_layer, input_dtype)
key_layer = ttnn.to_layout(key_layer, ttnn.TILE_LAYOUT)
key_layer = ttnn.to_device(key_layer, device)

value_layer = ttnn.from_torch(value_layer, input_dtype)
value_layer = ttnn.to_layout(value_layer, ttnn.TILE_LAYOUT)
value_layer = ttnn.to_device(value_layer, device)

return key_layer, value_layer


def _torch_attention_softmax(input_tensor: Tensor, *, head_size: int, attention_mask, **_):
import ttnn
import torch
Expand Down
1 change: 0 additions & 1 deletion ttnn/tutorials/003.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@
" ) = ttnn.transformer.split_query_key_value_and_split_heads(\n",
" fused_qkv_output,\n",
" memory_config=ttnn.L1_MEMORY_CONFIG,\n",
" core_grid=(batch_size, num_cores_x),\n",
" num_heads=num_heads,\n",
" )\n",
" ttnn.deallocate(fused_qkv_output)\n",
Expand Down

0 comments on commit d5d6a1f

Please sign in to comment.