-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#4003: implemented functional t5 model
- Loading branch information
Showing
4 changed files
with
710 additions
and
1 deletion.
There are no files selected for viewing
375 changes: 375 additions & 0 deletions
375
models/experimental/functional_t5/reference/torch_functional_t5.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,375 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
def t5_layer_norm(hidden_states, *, weight, eps=1e-6): | ||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean | ||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated | ||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for | ||
# half-precision inputs is done in fp32 | ||
|
||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||
hidden_states = hidden_states * torch.rsqrt(variance + eps) | ||
|
||
# convert into half-precision if necessary | ||
if weight.dtype in [torch.float16, torch.bfloat16]: | ||
hidden_states = hidden_states.to(weight.dtype) | ||
|
||
return weight * hidden_states | ||
|
||
|
||
def t5_dense_gated_act_dense(hidden_states, parameters): | ||
hidden_gelu = torch.nn.functional.gelu(hidden_states @ parameters.wi_0.weight) | ||
hidden_linear = hidden_states @ parameters.wi_1.weight | ||
hidden_states = hidden_gelu * hidden_linear | ||
|
||
hidden_states = hidden_states @ parameters.wo.weight | ||
return hidden_states | ||
|
||
|
||
def t5_layer_ff(hidden_states, parameters): | ||
forwarded_states = t5_layer_norm(hidden_states, weight=parameters.layer_norm.weight, eps=1e-6) | ||
forwarded_states = t5_dense_gated_act_dense(forwarded_states, parameters.DenseReluDense) | ||
hidden_states = hidden_states + forwarded_states | ||
return hidden_states | ||
|
||
|
||
def t5_attention( | ||
hidden_states, | ||
key_value_states=None, | ||
mask=None, | ||
layer_head_mask=None, | ||
*, | ||
parameters, | ||
num_heads, | ||
): | ||
""" | ||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). | ||
""" | ||
# Input is (batch_size, seq_length, dim) | ||
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) | ||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) | ||
batch_size, seq_length, _ = hidden_states.shape | ||
|
||
def shape(states, num_heads, head_size): | ||
"""projection""" | ||
return states.view(batch_size, -1, num_heads, head_size).transpose(1, 2) | ||
|
||
def unshape(states, hidden_size): | ||
"""reshape""" | ||
return states.transpose(1, 2).contiguous().view(batch_size, -1, hidden_size) | ||
|
||
def project(hidden_states, weight): | ||
hidden_size = weight.shape[-1] | ||
head_size = hidden_size // num_heads | ||
"""projects hidden states correctly to key/query states""" | ||
# self-attn | ||
# (batch_size, n_heads, seq_length, dim_per_head) | ||
hidden_states = shape(hidden_states @ weight, num_heads, head_size) | ||
return hidden_states | ||
|
||
# get query states | ||
hidden_size = parameters.q.weight.shape[-1] | ||
query_states = project(hidden_states, parameters.q.weight) # (batch_size, n_heads, seq_length, dim_per_head) | ||
|
||
# get key/value states | ||
key_states = project( | ||
hidden_states if key_value_states is None else key_value_states, | ||
parameters.k.weight, | ||
) | ||
value_states = project( | ||
hidden_states if key_value_states is None else key_value_states, | ||
parameters.v.weight, | ||
) | ||
|
||
# compute scores | ||
scores = torch.matmul( | ||
query_states, key_states.transpose(3, 2) | ||
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 | ||
if mask is not None: | ||
scores += mask | ||
|
||
attn_weights = torch.nn.functional.softmax(scores.float(), dim=-1).type_as( | ||
scores | ||
) # (batch_size, n_heads, seq_length, key_length) | ||
|
||
# Mask heads if we want to | ||
if layer_head_mask is not None: | ||
attn_weights = attn_weights * layer_head_mask | ||
|
||
attn_output = unshape(torch.matmul(attn_weights, value_states), hidden_size) # (batch_size, seq_length, dim) | ||
attn_output = attn_output @ parameters.o.weight | ||
|
||
return attn_output | ||
|
||
|
||
def t5_layer_self_attention( | ||
hidden_states, | ||
attention_mask=None, | ||
*, | ||
parameters, | ||
num_heads, | ||
): | ||
normed_hidden_states = t5_layer_norm(hidden_states, weight=parameters.layer_norm.weight, eps=1e-06) | ||
attention_output = t5_attention( | ||
normed_hidden_states, | ||
mask=attention_mask, | ||
parameters=parameters.SelfAttention, | ||
num_heads=num_heads, | ||
) | ||
hidden_states = hidden_states + attention_output | ||
return hidden_states | ||
|
||
|
||
def t5_layer_cross_attention(hidden_states, key_value_states, attention_mask=None, *, parameters, num_heads): | ||
normed_hidden_states = t5_layer_norm(hidden_states, weight=parameters.layer_norm.weight, eps=1e-06) | ||
attention_output = t5_attention( | ||
normed_hidden_states, | ||
key_value_states, | ||
mask=attention_mask, | ||
parameters=parameters.EncDecAttention, | ||
num_heads=num_heads, | ||
) | ||
layer_output = hidden_states + attention_output | ||
return layer_output | ||
|
||
|
||
def t5_block( | ||
hidden_states, | ||
attention_mask=None, | ||
encoder_hidden_states=None, | ||
encoder_attention_mask=None, | ||
*, | ||
parameters, | ||
num_heads, | ||
): | ||
hidden_states = t5_layer_self_attention( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
parameters=parameters.layer[0], | ||
num_heads=num_heads, | ||
) | ||
|
||
# clamp inf values to enable fp16 training | ||
if hidden_states.dtype == torch.float16: | ||
clamp_value = torch.where( | ||
torch.isinf(hidden_states).any(), | ||
torch.finfo(hidden_states.dtype).max - 1000, | ||
torch.finfo(hidden_states.dtype).max, | ||
) | ||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | ||
|
||
do_cross_attention = encoder_hidden_states is not None | ||
if do_cross_attention: | ||
hidden_states = t5_layer_cross_attention( | ||
hidden_states, | ||
key_value_states=encoder_hidden_states, | ||
attention_mask=encoder_attention_mask, | ||
parameters=parameters.layer[1], | ||
num_heads=num_heads, | ||
) | ||
|
||
# clamp inf values to enable fp16 training | ||
if hidden_states.dtype == torch.float16: | ||
clamp_value = torch.where( | ||
torch.isinf(hidden_states).any(), | ||
torch.finfo(hidden_states.dtype).max - 1000, | ||
torch.finfo(hidden_states.dtype).max, | ||
) | ||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | ||
|
||
# Apply Feed Forward layer | ||
hidden_states = t5_layer_ff(hidden_states, parameters.layer[-1]) | ||
|
||
# clamp inf values to enable fp16 training | ||
if hidden_states.dtype == torch.float16: | ||
clamp_value = torch.where( | ||
torch.isinf(hidden_states).any(), | ||
torch.finfo(hidden_states.dtype).max - 1000, | ||
torch.finfo(hidden_states.dtype).max, | ||
) | ||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | ||
|
||
return hidden_states # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) | ||
|
||
|
||
def create_extended_attention_mask_for_decoder(input_shape, attention_mask): | ||
device = attention_mask.device | ||
batch_size, seq_length = input_shape | ||
seq_ids = torch.arange(seq_length, device=device) | ||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | ||
# in case past_key_values are used we need to add a prefix ones mask to the causal mask | ||
# causal and attention masks must have same type with pytorch version < 1.3 | ||
causal_mask = causal_mask.to(attention_mask.dtype) | ||
|
||
if causal_mask.shape[1] < attention_mask.shape[1]: | ||
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] | ||
causal_mask = torch.cat( | ||
[ | ||
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), | ||
causal_mask, | ||
], | ||
axis=-1, | ||
) | ||
|
||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] | ||
return extended_attention_mask | ||
|
||
|
||
def get_extended_attention_mask( | ||
attention_mask: torch.Tensor, input_shape: Tuple[int], dtype: torch.float = torch.bfloat16, *, is_decoder | ||
) -> torch.Tensor: | ||
""" | ||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | ||
Arguments: | ||
attention_mask (`torch.Tensor`): | ||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore. | ||
input_shape (`Tuple[int]`): | ||
The shape of the input to the model. | ||
Returns: | ||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. | ||
""" | ||
|
||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
# ourselves in which case we just need to make it broadcastable to all heads. | ||
if attention_mask.dim() == 3: | ||
extended_attention_mask = attention_mask[:, None, :, :] | ||
elif attention_mask.dim() == 2: | ||
# Provided a padding mask of dimensions [batch_size, seq_length] | ||
# - if the model is a decoder, apply a causal mask in addition to the padding mask | ||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
if is_decoder: | ||
extended_attention_mask = create_extended_attention_mask_for_decoder(input_shape, attention_mask) | ||
else: | ||
extended_attention_mask = attention_mask[:, None, None, :] | ||
else: | ||
raise ValueError( | ||
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" | ||
) | ||
|
||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | ||
# masked positions, this operation will create a tensor which is 0.0 for | ||
# positions we want to attend and the dtype's smallest value for masked positions. | ||
# Since we are adding it to the raw scores before the softmax, this is | ||
# effectively the same as removing these entirely. | ||
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility | ||
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min | ||
return extended_attention_mask | ||
|
||
|
||
def invert_attention_mask(encoder_attention_mask: torch.Tensor, dtype=torch.bfloat16) -> torch.Tensor: | ||
""" | ||
Invert an attention mask (e.g., switches 0. and 1.). | ||
Args: | ||
encoder_attention_mask (`torch.Tensor`): An attention mask. | ||
Returns: | ||
`torch.Tensor`: The inverted attention mask. | ||
""" | ||
if encoder_attention_mask.dim() == 3: | ||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] | ||
if encoder_attention_mask.dim() == 2: | ||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] | ||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition | ||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow | ||
# /transformer/transformer_layers.py#L270 | ||
# encoder_extended_attention_mask = (encoder_extended_attention_mask == | ||
# encoder_extended_attention_mask.transpose(-1, -2)) | ||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=dtype) # fp16 compatibility | ||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(dtype).min | ||
|
||
return encoder_extended_attention_mask | ||
|
||
|
||
def t5_stack( | ||
input_ids, | ||
shared_embedding_weight, | ||
encoder_hidden_states=None, | ||
*, | ||
parameters, | ||
num_heads, | ||
): | ||
input_shape = input_ids.size() | ||
input_ids = input_ids.view(-1, input_shape[-1]) | ||
|
||
hidden_states = torch.nn.functional.embedding(input_ids, shared_embedding_weight) | ||
|
||
batch_size, seq_length = input_shape | ||
|
||
# required mask seq length can be calculated via length of past | ||
mask_seq_length = seq_length | ||
|
||
attention_mask = torch.ones(batch_size, mask_seq_length, device=hidden_states.device) | ||
if encoder_hidden_states is not None: | ||
encoder_seq_length = encoder_hidden_states.shape[1] | ||
encoder_attention_mask = torch.ones( | ||
batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long | ||
) | ||
|
||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
# ourselves in which case we just need to make it broadcastable to all heads. | ||
extended_attention_mask = get_extended_attention_mask( | ||
attention_mask, input_shape, is_decoder=encoder_hidden_states is not None | ||
) | ||
|
||
# If a 2D or 3D attention mask is provided for the cross-attention | ||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||
if encoder_hidden_states is not None: | ||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() | ||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | ||
if encoder_attention_mask is None: | ||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device) | ||
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask) | ||
else: | ||
encoder_extended_attention_mask = None | ||
|
||
for block_parameters in parameters.block: | ||
hidden_states = t5_block( | ||
hidden_states, | ||
attention_mask=extended_attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_extended_attention_mask, | ||
parameters=block_parameters, | ||
num_heads=num_heads, | ||
) | ||
|
||
hidden_states = t5_layer_norm(hidden_states, weight=parameters.final_layer_norm.weight, eps=1e-06) | ||
|
||
return hidden_states | ||
|
||
|
||
def t5_for_conditional_generation( | ||
input_ids: Optional[torch.LongTensor], | ||
decoder_input_ids: Optional[torch.LongTensor], | ||
parameters, | ||
*, | ||
num_heads, | ||
) -> torch.FloatTensor: | ||
# Encode | ||
hidden_states = t5_stack( | ||
input_ids=input_ids, | ||
shared_embedding_weight=parameters.shared.weight, | ||
parameters=parameters.encoder, | ||
num_heads=num_heads, | ||
) | ||
|
||
# Decode | ||
sequence_output = t5_stack( | ||
input_ids=decoder_input_ids, | ||
encoder_hidden_states=hidden_states, | ||
shared_embedding_weight=parameters.shared.weight, | ||
parameters=parameters.decoder, | ||
num_heads=num_heads, | ||
) | ||
|
||
lm_logits = sequence_output @ parameters.lm_head.weight | ||
|
||
return lm_logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.