Skip to content

Commit

Permalink
#4003: implemented functional t5 model
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Dec 7, 2023
1 parent 69b8a52 commit e25a170
Show file tree
Hide file tree
Showing 4 changed files with 710 additions and 1 deletion.
375 changes: 375 additions & 0 deletions models/experimental/functional_t5/reference/torch_functional_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch


def t5_layer_norm(hidden_states, *, weight, eps=1e-6):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32

variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)

# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(weight.dtype)

return weight * hidden_states


def t5_dense_gated_act_dense(hidden_states, parameters):
hidden_gelu = torch.nn.functional.gelu(hidden_states @ parameters.wi_0.weight)
hidden_linear = hidden_states @ parameters.wi_1.weight
hidden_states = hidden_gelu * hidden_linear

hidden_states = hidden_states @ parameters.wo.weight
return hidden_states


def t5_layer_ff(hidden_states, parameters):
forwarded_states = t5_layer_norm(hidden_states, weight=parameters.layer_norm.weight, eps=1e-6)
forwarded_states = t5_dense_gated_act_dense(forwarded_states, parameters.DenseReluDense)
hidden_states = hidden_states + forwarded_states
return hidden_states


def t5_attention(
hidden_states,
key_value_states=None,
mask=None,
layer_head_mask=None,
*,
parameters,
num_heads,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length, _ = hidden_states.shape

def shape(states, num_heads, head_size):
"""projection"""
return states.view(batch_size, -1, num_heads, head_size).transpose(1, 2)

def unshape(states, hidden_size):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, hidden_size)

def project(hidden_states, weight):
hidden_size = weight.shape[-1]
head_size = hidden_size // num_heads
"""projects hidden states correctly to key/query states"""
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(hidden_states @ weight, num_heads, head_size)
return hidden_states

# get query states
hidden_size = parameters.q.weight.shape[-1]
query_states = project(hidden_states, parameters.q.weight) # (batch_size, n_heads, seq_length, dim_per_head)

# get key/value states
key_states = project(
hidden_states if key_value_states is None else key_value_states,
parameters.k.weight,
)
value_states = project(
hidden_states if key_value_states is None else key_value_states,
parameters.v.weight,
)

# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if mask is not None:
scores += mask

attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)

# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask

attn_output = unshape(torch.matmul(attn_weights, value_states), hidden_size) # (batch_size, seq_length, dim)
attn_output = attn_output @ parameters.o.weight

return attn_output


def t5_layer_self_attention(
hidden_states,
attention_mask=None,
*,
parameters,
num_heads,
):
normed_hidden_states = t5_layer_norm(hidden_states, weight=parameters.layer_norm.weight, eps=1e-06)
attention_output = t5_attention(
normed_hidden_states,
mask=attention_mask,
parameters=parameters.SelfAttention,
num_heads=num_heads,
)
hidden_states = hidden_states + attention_output
return hidden_states


def t5_layer_cross_attention(hidden_states, key_value_states, attention_mask=None, *, parameters, num_heads):
normed_hidden_states = t5_layer_norm(hidden_states, weight=parameters.layer_norm.weight, eps=1e-06)
attention_output = t5_attention(
normed_hidden_states,
key_value_states,
mask=attention_mask,
parameters=parameters.EncDecAttention,
num_heads=num_heads,
)
layer_output = hidden_states + attention_output
return layer_output


def t5_block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
*,
parameters,
num_heads,
):
hidden_states = t5_layer_self_attention(
hidden_states,
attention_mask=attention_mask,
parameters=parameters.layer[0],
num_heads=num_heads,
)

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

do_cross_attention = encoder_hidden_states is not None
if do_cross_attention:
hidden_states = t5_layer_cross_attention(
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
parameters=parameters.layer[1],
num_heads=num_heads,
)

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

# Apply Feed Forward layer
hidden_states = t5_layer_ff(hidden_states, parameters.layer[-1])

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

return hidden_states # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)


def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
device = attention_mask.device
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)

if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)

extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask


def get_extended_attention_mask(
attention_mask: torch.Tensor, input_shape: Tuple[int], dtype: torch.float = torch.bfloat16, *, is_decoder
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
extended_attention_mask = create_extended_attention_mask_for_decoder(input_shape, attention_mask)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask


def invert_attention_mask(encoder_attention_mask: torch.Tensor, dtype=torch.bfloat16) -> torch.Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`torch.Tensor`): An attention mask.
Returns:
`torch.Tensor`: The inverted attention mask.
"""
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(dtype).min

return encoder_extended_attention_mask


def t5_stack(
input_ids,
shared_embedding_weight,
encoder_hidden_states=None,
*,
parameters,
num_heads,
):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])

hidden_states = torch.nn.functional.embedding(input_ids, shared_embedding_weight)

batch_size, seq_length = input_shape

# required mask seq length can be calculated via length of past
mask_seq_length = seq_length

attention_mask = torch.ones(batch_size, mask_seq_length, device=hidden_states.device)
if encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long
)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = get_extended_attention_mask(
attention_mask, input_shape, is_decoder=encoder_hidden_states is not None
)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device)
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

for block_parameters in parameters.block:
hidden_states = t5_block(
hidden_states,
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
parameters=block_parameters,
num_heads=num_heads,
)

hidden_states = t5_layer_norm(hidden_states, weight=parameters.final_layer_norm.weight, eps=1e-06)

return hidden_states


def t5_for_conditional_generation(
input_ids: Optional[torch.LongTensor],
decoder_input_ids: Optional[torch.LongTensor],
parameters,
*,
num_heads,
) -> torch.FloatTensor:
# Encode
hidden_states = t5_stack(
input_ids=input_ids,
shared_embedding_weight=parameters.shared.weight,
parameters=parameters.encoder,
num_heads=num_heads,
)

# Decode
sequence_output = t5_stack(
input_ids=decoder_input_ids,
encoder_hidden_states=hidden_states,
shared_embedding_weight=parameters.shared.weight,
parameters=parameters.decoder,
num_heads=num_heads,
)

lm_logits = sequence_output @ parameters.lm_head.weight

return lm_logits
2 changes: 2 additions & 0 deletions models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def float_to_bits(x):


def torch_random(shape, low, high, dtype):
if dtype == torch.int64:
return torch.randint(low, high, shape, dtype=dtype)
return torch.zeros(shape, dtype=dtype).uniform_(low, high)


Expand Down
Loading

0 comments on commit e25a170

Please sign in to comment.