Skip to content

Commit

Permalink
#4003: debugging whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 19, 2024
1 parent a974de4 commit d6c2d42
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 309 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,12 @@ def split_query_key_value_and_split_heads(config, fused_qkv):

query_states = torch.reshape(query_states, shape=(batch_size, seq_length, num_heads, head_size))
query_states = torch.permute(query_states, (0, 2, 1, 3))
query_states = query_states.contiguous()

key_states = torch.reshape(key_states, shape=(batch_size, seq_length, num_heads, head_size))
key_states = torch.permute(key_states, (0, 2, 1, 3))
key_states = key_states.contiguous()

value_states = torch.reshape(value_states, shape=(batch_size, seq_length, num_heads, head_size))
value_states = torch.permute(value_states, (0, 2, 1, 3))
value_states = value_states.contiguous()

return query_states, key_states, value_states

Expand All @@ -84,7 +81,6 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, *
query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias
query_states = torch.reshape(query_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size))
query_states = torch.permute(query_states, (0, 2, 1, 3))
query_states = query_states.contiguous()
key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters)
else:
query_states, key_states, value_states = calculate_query_key_values(
Expand All @@ -93,11 +89,11 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, *
query_states *= scaling

proj_shape = (bsz * config.encoder_attention_heads, -1, head_size)
query_states = torch.reshape(query_states, shape=proj_shape).contiguous()
key_states = torch.reshape(key_states, shape=proj_shape).contiguous()
value_states = torch.reshape(value_states, shape=proj_shape).contiguous()
query_states = torch.reshape(query_states, shape=proj_shape)
key_states = torch.reshape(key_states, shape=proj_shape)
value_states = torch.reshape(value_states, shape=proj_shape)

attn_weights = query_states @ torch.permute(key_states, (0, 2, 1)).contiguous()
attn_weights = query_states @ torch.permute(key_states, (0, 2, 1))
if attention_mask is not None:
bsz, _, tgt_len, src_len = attention_mask.size()
attn_weights = (
Expand All @@ -118,6 +114,7 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, *

def encoder_layer(config, hidden_states, *, parameters):
residual = hidden_states

hidden_states = F.layer_norm(
hidden_states,
(config.d_model,),
Expand Down Expand Up @@ -168,44 +165,6 @@ def encoder(config, inputs_embeds, *, parameters):
return hidden_states


def encoder_original(config, input_features, *, parameters):
inputs_embeds = gelu(
conv(
input_features,
weight=parameters.conv1.weight,
bias=parameters.conv1.bias,
padding=1,
)
)
inputs_embeds = gelu(
conv(
inputs_embeds,
weight=parameters.conv2.weight,
bias=parameters.conv2.bias,
stride=2,
padding=1,
)
)
inputs_embeds = inputs_embeds.permute(0, 2, 1)
hidden_states = inputs_embeds + parameters.embed_positions.weight
hidden_states = dropout(hidden_states, p=0, training=False)

for encoder_layer_parameter in parameters.layers:
hidden_states = encoder_layer(
config,
hidden_states,
parameters=encoder_layer_parameter,
)

hidden_states = F.layer_norm(
hidden_states,
(config.d_model,),
parameters.layer_norm.weight,
parameters.layer_norm.bias,
)
return hidden_states


def make_causal_mask(input_ids_shape, dtype):
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
Expand Down Expand Up @@ -326,35 +285,6 @@ def decoder(config, hidden_states, decoder_attention_mask, encoder_hidden_states
return hidden_states


def decoder_original(config, input_ids, attention_mask, encoder_hidden_states, parameters):
input_shape = input_ids.size()
input_ids = torch.reshape(input_ids, (-1, input_shape[-1]))
inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight)
attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds)
positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]

hidden_states = inputs_embeds + positions
hidden_states = dropout(hidden_states, p=0, training=False)

for decoder_layer_parameter in parameters.layers:
hidden_states = decoder_layer(
config,
hidden_states,
attention_mask,
encoder_hidden_states,
parameters=decoder_layer_parameter,
)

hidden_states = F.layer_norm(
hidden_states,
(config.d_model,),
parameters.layer_norm.weight,
parameters.layer_norm.bias,
)

return hidden_states


def preprocess_encoder_inputs(input_features, parameters):
inputs_embeds = gelu(
conv(
Expand Down Expand Up @@ -383,8 +313,8 @@ def preprocess_decoder_inputs(input_ids, attention_mask, *, parameters):
input_ids = torch.reshape(input_ids, (-1, input_shape[-1]))
inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight)
attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds)
positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]

positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]
decoder_hidden_states = inputs_embeds + positions

return decoder_hidden_states, attention_mask
Expand All @@ -404,17 +334,6 @@ def preprocess_inputs(
return input_embeds, decoder_hidden_states, attention_mask


def whisper_original(config, input_features, decoder_input_ids, attention_mask, *, parameters):
encoder_hidden_states = encoder_original(config, input_features, parameters=parameters.encoder)
return decoder_original(
config,
input_ids=decoder_input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
parameters=parameters.decoder,
)


def whisper(config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters):
encoder_hidden_states = encoder(config, input_embeds, parameters=parameters.encoder)
return decoder(
Expand All @@ -433,10 +352,10 @@ def custom_preprocessor(torch_model, name):
parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}}
preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0)
preprocessed_bias = torch.cat([torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0)
parameters["key_value"]["weight"] = preprocessed_weight.type(torch.bfloat16).T
parameters["key_value"]["bias"] = preprocessed_bias.type(dtype=torch.bfloat16)
parameters["q_proj"]["weight"] = torch_model.q_proj.weight.type(torch.bfloat16).T
parameters["q_proj"]["bias"] = torch_model.q_proj.bias.type(dtype=torch.bfloat16)
parameters["key_value"]["weight"] = preprocessed_weight.T.contiguous()
parameters["key_value"]["bias"] = preprocessed_bias
parameters["q_proj"]["weight"] = torch_model.q_proj.weight.T.contiguous()
parameters["q_proj"]["bias"] = torch_model.q_proj.bias
else:
parameters = {"query_key_value": {}, "out_proj": {}}
preprocessed_weight = torch.cat(
Expand All @@ -445,11 +364,11 @@ def custom_preprocessor(torch_model, name):
preprocessed_bias = torch.cat(
[torch_model.q_proj.bias, torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0
)
parameters["query_key_value"]["weight"] = preprocessed_weight.type(torch.bfloat16).T
parameters["query_key_value"]["bias"] = preprocessed_bias.type(dtype=torch.bfloat16)
parameters["query_key_value"]["weight"] = preprocessed_weight.T.contiguous()
parameters["query_key_value"]["bias"] = preprocessed_bias

parameters["out_proj"]["weight"] = torch_model.out_proj.weight.type(torch.bfloat16).T
parameters["out_proj"]["bias"] = torch_model.out_proj.bias.type(dtype=torch.bfloat16)
parameters["out_proj"]["weight"] = torch_model.out_proj.weight.T.contiguous()
parameters["out_proj"]["bias"] = torch_model.out_proj.bias
return parameters


Expand All @@ -466,7 +385,7 @@ def custom_preprocessor(torch_model, name):
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.type(torch.bfloat16)
input_features = inputs.input_features
decoder_input_ids = torch.ones(1, 1).type(torch.int32) * model.config.decoder_start_token_id

model_graph = draw_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,12 @@ def encoder(config, inputs_embeds, *, parameters):
return hidden_states


BIG_NEGATIVE_NUMBER = -1000


def make_causal_mask(input_ids_shape, dtype):
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask = torch.full((tgt_len, tgt_len), torch.tensor(BIG_NEGATIVE_NUMBER))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
Expand All @@ -176,7 +179,7 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] =

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), BIG_NEGATIVE_NUMBER)


def decoder_layer(config, hidden_states, attention_mask, encoder_hidden_states, *, parameters):
Expand Down Expand Up @@ -301,8 +304,7 @@ def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1):
)
)
input_embeds = input_embeds.permute(0, 2, 1)
input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16)
input_embeds = ttnn.to_device(input_embeds, device)
input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16, device=device)

return input_embeds

Expand All @@ -318,10 +320,8 @@ def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters,
positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]]
decoder_hidden_states = inputs_embeds + positions

decoder_hidden_states = ttnn.from_torch(decoder_hidden_states, dtype=ttnn.bfloat16)
decoder_hidden_states = ttnn.to_device(decoder_hidden_states, device)
attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16)
attention_mask = ttnn.to_device(attention_mask, device)
decoder_hidden_states = ttnn.from_torch(decoder_hidden_states, dtype=ttnn.bfloat16, device=device)
attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, device=device)

return decoder_hidden_states, attention_mask

Expand Down
Loading

0 comments on commit d6c2d42

Please sign in to comment.