From d6c2d42900f052b53066e885b16ae91b1e133d5d Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Wed, 17 Jan 2024 09:21:09 +0000 Subject: [PATCH] #4003: debugging whisper --- .../reference/torch_functional_whisper.py | 111 +++--------------- .../tt/ttnn_functional_whisper.py | 16 +-- .../t5/test_ttnn_functional_t5.py | 83 +++++-------- .../t5/test_ttnn_optimized_functional_t5.py | 101 ++++++---------- .../whisper/test_performance.py | 5 +- .../whisper/test_torch_functional_whisper.py | 58 ++++----- .../whisper/test_ttnn_functional_whisper.py | 47 ++++---- .../test_ttnn_optimized_functional_whisper.py | 39 +++--- 8 files changed, 151 insertions(+), 309 deletions(-) diff --git a/models/experimental/functional_whisper/reference/torch_functional_whisper.py b/models/experimental/functional_whisper/reference/torch_functional_whisper.py index d5ffefd347c..52f8bd70121 100644 --- a/models/experimental/functional_whisper/reference/torch_functional_whisper.py +++ b/models/experimental/functional_whisper/reference/torch_functional_whisper.py @@ -56,15 +56,12 @@ def split_query_key_value_and_split_heads(config, fused_qkv): query_states = torch.reshape(query_states, shape=(batch_size, seq_length, num_heads, head_size)) query_states = torch.permute(query_states, (0, 2, 1, 3)) - query_states = query_states.contiguous() key_states = torch.reshape(key_states, shape=(batch_size, seq_length, num_heads, head_size)) key_states = torch.permute(key_states, (0, 2, 1, 3)) - key_states = key_states.contiguous() value_states = torch.reshape(value_states, shape=(batch_size, seq_length, num_heads, head_size)) value_states = torch.permute(value_states, (0, 2, 1, 3)) - value_states = value_states.contiguous() return query_states, key_states, value_states @@ -84,7 +81,6 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, * query_states = hidden_states @ parameters.q_proj.weight + parameters.q_proj.bias query_states = torch.reshape(query_states, shape=(bsz, tgt_len, config.encoder_attention_heads, head_size)) query_states = torch.permute(query_states, (0, 2, 1, 3)) - query_states = query_states.contiguous() key_states, value_states = calculate_key_values(config, key_value_states, parameters=parameters) else: query_states, key_states, value_states = calculate_query_key_values( @@ -93,11 +89,11 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, * query_states *= scaling proj_shape = (bsz * config.encoder_attention_heads, -1, head_size) - query_states = torch.reshape(query_states, shape=proj_shape).contiguous() - key_states = torch.reshape(key_states, shape=proj_shape).contiguous() - value_states = torch.reshape(value_states, shape=proj_shape).contiguous() + query_states = torch.reshape(query_states, shape=proj_shape) + key_states = torch.reshape(key_states, shape=proj_shape) + value_states = torch.reshape(value_states, shape=proj_shape) - attn_weights = query_states @ torch.permute(key_states, (0, 2, 1)).contiguous() + attn_weights = query_states @ torch.permute(key_states, (0, 2, 1)) if attention_mask is not None: bsz, _, tgt_len, src_len = attention_mask.size() attn_weights = ( @@ -118,6 +114,7 @@ def whisper_attention(config, hidden_states, attention_mask, key_value_states, * def encoder_layer(config, hidden_states, *, parameters): residual = hidden_states + hidden_states = F.layer_norm( hidden_states, (config.d_model,), @@ -168,44 +165,6 @@ def encoder(config, inputs_embeds, *, parameters): return hidden_states -def encoder_original(config, input_features, *, parameters): - inputs_embeds = gelu( - conv( - input_features, - weight=parameters.conv1.weight, - bias=parameters.conv1.bias, - padding=1, - ) - ) - inputs_embeds = gelu( - conv( - inputs_embeds, - weight=parameters.conv2.weight, - bias=parameters.conv2.bias, - stride=2, - padding=1, - ) - ) - inputs_embeds = inputs_embeds.permute(0, 2, 1) - hidden_states = inputs_embeds + parameters.embed_positions.weight - hidden_states = dropout(hidden_states, p=0, training=False) - - for encoder_layer_parameter in parameters.layers: - hidden_states = encoder_layer( - config, - hidden_states, - parameters=encoder_layer_parameter, - ) - - hidden_states = F.layer_norm( - hidden_states, - (config.d_model,), - parameters.layer_norm.weight, - parameters.layer_norm.bias, - ) - return hidden_states - - def make_causal_mask(input_ids_shape, dtype): bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) @@ -326,35 +285,6 @@ def decoder(config, hidden_states, decoder_attention_mask, encoder_hidden_states return hidden_states -def decoder_original(config, input_ids, attention_mask, encoder_hidden_states, parameters): - input_shape = input_ids.size() - input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) - inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) - attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) - positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] - - hidden_states = inputs_embeds + positions - hidden_states = dropout(hidden_states, p=0, training=False) - - for decoder_layer_parameter in parameters.layers: - hidden_states = decoder_layer( - config, - hidden_states, - attention_mask, - encoder_hidden_states, - parameters=decoder_layer_parameter, - ) - - hidden_states = F.layer_norm( - hidden_states, - (config.d_model,), - parameters.layer_norm.weight, - parameters.layer_norm.bias, - ) - - return hidden_states - - def preprocess_encoder_inputs(input_features, parameters): inputs_embeds = gelu( conv( @@ -383,8 +313,8 @@ def preprocess_decoder_inputs(input_ids, attention_mask, *, parameters): input_ids = torch.reshape(input_ids, (-1, input_shape[-1])) inputs_embeds = F.embedding(input_ids, parameters.embed_tokens.weight) attention_mask = prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds) - positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] + positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] decoder_hidden_states = inputs_embeds + positions return decoder_hidden_states, attention_mask @@ -404,17 +334,6 @@ def preprocess_inputs( return input_embeds, decoder_hidden_states, attention_mask -def whisper_original(config, input_features, decoder_input_ids, attention_mask, *, parameters): - encoder_hidden_states = encoder_original(config, input_features, parameters=parameters.encoder) - return decoder_original( - config, - input_ids=decoder_input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - parameters=parameters.decoder, - ) - - def whisper(config, input_embeds, decoder_hidden_states, decoder_attention_mask, *, parameters): encoder_hidden_states = encoder(config, input_embeds, parameters=parameters.encoder) return decoder( @@ -433,10 +352,10 @@ def custom_preprocessor(torch_model, name): parameters = {"key_value": {}, "q_proj": {}, "out_proj": {}} preprocessed_weight = torch.cat([torch_model.k_proj.weight, torch_model.v_proj.weight], dim=0) preprocessed_bias = torch.cat([torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0) - parameters["key_value"]["weight"] = preprocessed_weight.type(torch.bfloat16).T - parameters["key_value"]["bias"] = preprocessed_bias.type(dtype=torch.bfloat16) - parameters["q_proj"]["weight"] = torch_model.q_proj.weight.type(torch.bfloat16).T - parameters["q_proj"]["bias"] = torch_model.q_proj.bias.type(dtype=torch.bfloat16) + parameters["key_value"]["weight"] = preprocessed_weight.T.contiguous() + parameters["key_value"]["bias"] = preprocessed_bias + parameters["q_proj"]["weight"] = torch_model.q_proj.weight.T.contiguous() + parameters["q_proj"]["bias"] = torch_model.q_proj.bias else: parameters = {"query_key_value": {}, "out_proj": {}} preprocessed_weight = torch.cat( @@ -445,11 +364,11 @@ def custom_preprocessor(torch_model, name): preprocessed_bias = torch.cat( [torch_model.q_proj.bias, torch.zeros_like(torch_model.v_proj.bias), torch_model.v_proj.bias], dim=0 ) - parameters["query_key_value"]["weight"] = preprocessed_weight.type(torch.bfloat16).T - parameters["query_key_value"]["bias"] = preprocessed_bias.type(dtype=torch.bfloat16) + parameters["query_key_value"]["weight"] = preprocessed_weight.T.contiguous() + parameters["query_key_value"]["bias"] = preprocessed_bias - parameters["out_proj"]["weight"] = torch_model.out_proj.weight.type(torch.bfloat16).T - parameters["out_proj"]["bias"] = torch_model.out_proj.bias.type(dtype=torch.bfloat16) + parameters["out_proj"]["weight"] = torch_model.out_proj.weight.T.contiguous() + parameters["out_proj"]["bias"] = torch_model.out_proj.bias return parameters @@ -466,7 +385,7 @@ def custom_preprocessor(torch_model, name): feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - input_features = inputs.input_features.type(torch.bfloat16) + input_features = inputs.input_features decoder_input_ids = torch.ones(1, 1).type(torch.int32) * model.config.decoder_start_token_id model_graph = draw_graph( diff --git a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py index 9ecadcf9081..8359bd1d5f6 100644 --- a/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py +++ b/models/experimental/functional_whisper/tt/ttnn_functional_whisper.py @@ -155,9 +155,12 @@ def encoder(config, inputs_embeds, *, parameters): return hidden_states +BIG_NEGATIVE_NUMBER = -1000 + + def make_causal_mask(input_ids_shape, dtype): bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(BIG_NEGATIVE_NUMBER)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -176,7 +179,7 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), BIG_NEGATIVE_NUMBER) def decoder_layer(config, hidden_states, attention_mask, encoder_hidden_states, *, parameters): @@ -301,8 +304,7 @@ def conv(input, weight, bias, stride=1, padding=1, dilation=1, groups=1): ) ) input_embeds = input_embeds.permute(0, 2, 1) - input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16) - input_embeds = ttnn.to_device(input_embeds, device) + input_embeds = ttnn.from_torch(input_embeds, dtype=ttnn.bfloat16, device=device) return input_embeds @@ -318,10 +320,8 @@ def preprocess_decoder_inputs(config, input_ids, attention_mask, *, parameters, positions = parameters.embed_positions.weight[0 : input_ids.shape[-1]] decoder_hidden_states = inputs_embeds + positions - decoder_hidden_states = ttnn.from_torch(decoder_hidden_states, dtype=ttnn.bfloat16) - decoder_hidden_states = ttnn.to_device(decoder_hidden_states, device) - attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16) - attention_mask = ttnn.to_device(attention_mask, device) + decoder_hidden_states = ttnn.from_torch(decoder_hidden_states, dtype=ttnn.bfloat16, device=device) + attention_mask = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, device=device) return decoder_hidden_states, attention_mask diff --git a/tests/ttnn/integration_tests/t5/test_ttnn_functional_t5.py b/tests/ttnn/integration_tests/t5/test_ttnn_functional_t5.py index 798ed4556a2..eb769373936 100644 --- a/tests/ttnn/integration_tests/t5/test_ttnn_functional_t5.py +++ b/tests/ttnn/integration_tests/t5/test_ttnn_functional_t5.py @@ -23,17 +23,15 @@ def test_t5_layer_norm(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerNorm(config.d_model).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_layer_norm(config, hidden_states, weight=parameters.weight) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -51,17 +49,15 @@ def test_t5_dense_act_dense(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5DenseActDense(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_dense_act_dense(config, hidden_states, parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -79,17 +75,15 @@ def test_t5_dense_gated_act_dense(device, model_name, batch_size, sequence_size) config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5DenseGatedActDense(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_dense_gated_act_dense(config, hidden_states, parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -107,17 +101,15 @@ def test_t5_layer_ff(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerFF(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_layer_ff(config, hidden_states, parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -135,17 +127,15 @@ def test_t5_attention(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5Attention(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_attention(config, hidden_states, parameters=parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -163,17 +153,15 @@ def test_t5_layer_self_attention(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerSelfAttention(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_layer_self_attention(config, hidden_states, parameters=parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -191,20 +179,17 @@ def test_t5_layer_cross_attention(device, model_name, batch_size, sequence_size) config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerCrossAttention(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) - torch_key_value_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) + torch_key_value_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states, torch_key_value_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) - key_value_states = ttnn.from_torch(torch_key_value_states) - key_value_states = ttnn.to_device(key_value_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) + key_value_states = ttnn.from_torch(torch_key_value_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_layer_cross_attention(config, hidden_states, key_value_states, parameters=parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -222,17 +207,15 @@ def test_t5_block_encoder(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5Block(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_block(config, hidden_states, parameters=parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -251,11 +234,10 @@ def test_t5_block_decoder(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) config.is_decoder = True model = transformers.models.t5.modeling_t5.T5Block(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_encoder_hidden_states = torch_random( - (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16 + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) torch_output, *_ = model(torch_hidden_states, encoder_hidden_states=torch_encoder_hidden_states) @@ -263,10 +245,8 @@ def test_t5_block_decoder(device, model_name, batch_size, sequence_size): initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_device(hidden_states, device) - encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states) - encoder_hidden_states = ttnn.to_device(encoder_hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, device=device) + encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_block( config, hidden_states, @@ -291,7 +271,6 @@ def test_t5_stack_encoder(device, model_name, batch_size, sequence_size): config.use_cache = False # Can't use cache when running as encoder shared_embedding = torch.nn.Embedding(config.vocab_size, config.d_model) model = transformers.models.t5.modeling_t5.T5Stack(config, shared_embedding).eval() - model = model.to(torch.bfloat16) torch_input_ids = torch_random((batch_size, sequence_size), 0, config.vocab_size, dtype=torch.int64) torch_output = model(torch_input_ids).last_hidden_state @@ -301,8 +280,7 @@ def test_t5_stack_encoder(device, model_name, batch_size, sequence_size): ) shared_embedding = preprocess_model_parameters(initialize_model=lambda: shared_embedding, device=device) - input_ids = ttnn.from_torch(torch_input_ids) - input_ids = ttnn.to_device(input_ids, device) + input_ids = ttnn.from_torch(torch_input_ids, device=device) output = functional_t5.t5_stack( config, input_ids, @@ -327,11 +305,10 @@ def test_t5_stack_decoder(device, model_name, batch_size, sequence_size): config.is_decoder = True shared_embedding = torch.nn.Embedding(config.vocab_size, config.d_model) model = transformers.models.t5.modeling_t5.T5Stack(config, shared_embedding).eval() - model = model.to(torch.bfloat16) torch_input_ids = torch_random((batch_size, sequence_size), 0, config.vocab_size, dtype=torch.int64) torch_encoder_hidden_states = torch_random( - (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16 + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) torch_output = model(torch_input_ids, encoder_hidden_states=torch_encoder_hidden_states).last_hidden_state @@ -340,10 +317,8 @@ def test_t5_stack_decoder(device, model_name, batch_size, sequence_size): ) shared_embedding = preprocess_model_parameters(initialize_model=lambda: shared_embedding, device=device) - input_ids = ttnn.from_torch(torch_input_ids) - input_ids = ttnn.to_device(input_ids, device) - encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states) - encoder_hidden_states = ttnn.to_device(encoder_hidden_states, device) + input_ids = ttnn.from_torch(torch_input_ids, device=device) + encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states, dtype=ttnn.bfloat16, device=device) output = functional_t5.t5_stack( config, input_ids, @@ -379,10 +354,8 @@ def test_t5_for_conditional_generation(device, model_name, batch_size, sequence_ device=device, ) - input_ids = ttnn.from_torch(torch_input_ids) - input_ids = ttnn.to_device(input_ids, device) - decoder_input_ids = ttnn.from_torch(torch_decoder_input_ids) - decoder_input_ids = ttnn.to_device(decoder_input_ids, device) + input_ids = ttnn.from_torch(torch_input_ids, device=device) + decoder_input_ids = ttnn.from_torch(torch_decoder_input_ids, device=device) output, *_ = functional_t5.t5_for_conditional_generation( config, input_ids, diff --git a/tests/ttnn/integration_tests/t5/test_ttnn_optimized_functional_t5.py b/tests/ttnn/integration_tests/t5/test_ttnn_optimized_functional_t5.py index e2248ddb4fb..31f0bb910a0 100644 --- a/tests/ttnn/integration_tests/t5/test_ttnn_optimized_functional_t5.py +++ b/tests/ttnn/integration_tests/t5/test_ttnn_optimized_functional_t5.py @@ -23,18 +23,15 @@ def test_t5_layer_norm(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerNorm(config.d_model).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_layer_norm(config, hidden_states, weight=parameters.weight) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -52,18 +49,15 @@ def test_t5_dense_act_dense(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5DenseActDense(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_dense_act_dense(config, hidden_states, parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -81,18 +75,15 @@ def test_t5_dense_gated_act_dense(device, model_name, batch_size, sequence_size) config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5DenseGatedActDense(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_dense_gated_act_dense(config, hidden_states, parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -110,18 +101,15 @@ def test_t5_layer_ff(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerFF(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_layer_ff(config, hidden_states, parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -139,9 +127,8 @@ def test_t5_attention(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5Attention(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states) parameters = preprocess_model_parameters( @@ -150,9 +137,7 @@ def test_t5_attention(device, model_name, batch_size, sequence_size): device=device, ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_attention(config, hidden_states, parameters=parameters) output = ttnn.from_device(output) output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT) @@ -170,18 +155,15 @@ def test_t5_layer_self_attention(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerSelfAttention(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_layer_self_attention( config, hidden_states, @@ -203,10 +185,9 @@ def test_t5_layer_cross_attention(device, model_name, batch_size, sequence_size) config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5LayerCrossAttention(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) - torch_key_value_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) + torch_key_value_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states, torch_key_value_states) parameters = preprocess_model_parameters( @@ -216,12 +197,10 @@ def test_t5_layer_cross_attention(device, model_name, batch_size, sequence_size) prefix="EncDecAttention", ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) - key_value_states = ttnn.from_torch(torch_key_value_states) - key_value_states = ttnn.to_layout(key_value_states, ttnn.TILE_LAYOUT) - key_value_states = ttnn.to_device(key_value_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + key_value_states = ttnn.from_torch( + torch_key_value_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) output = functional_t5.t5_layer_cross_attention( config, hidden_states, @@ -244,18 +223,15 @@ def test_t5_block_encoder(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) model = transformers.models.t5.modeling_t5.T5Block(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_output, *_ = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) output = functional_t5.t5_block( config, hidden_states, @@ -278,11 +254,10 @@ def test_t5_block_decoder(device, model_name, batch_size, sequence_size): config = transformers.T5Config.from_pretrained(model_name) config.is_decoder = True model = transformers.models.t5.modeling_t5.T5Block(config).eval() - model = model.to(torch.bfloat16) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) torch_encoder_hidden_states = torch_random( - (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16 + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) torch_output, *_ = model(torch_hidden_states, encoder_hidden_states=torch_encoder_hidden_states) @@ -290,12 +265,10 @@ def test_t5_block_decoder(device, model_name, batch_size, sequence_size): initialize_model=lambda: model, custom_preprocessor=functional_t5.custom_preprocessor, device=device ) - hidden_states = ttnn.from_torch(torch_hidden_states) - hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT) - hidden_states = ttnn.to_device(hidden_states, device) - encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states) - encoder_hidden_states = ttnn.to_layout(encoder_hidden_states, ttnn.TILE_LAYOUT) - encoder_hidden_states = ttnn.to_device(encoder_hidden_states, device) + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + encoder_hidden_states = ttnn.from_torch( + torch_encoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) output = functional_t5.t5_block( config, hidden_states, @@ -320,7 +293,6 @@ def test_t5_stack_encoder(device, model_name, batch_size, sequence_size): config.use_cache = False # Can't use cache when running as encoder shared_embedding = torch.nn.Embedding(config.vocab_size, config.d_model) model = transformers.models.t5.modeling_t5.T5Stack(config, shared_embedding).eval() - model = model.to(torch.bfloat16) torch_input_ids = torch_random((batch_size, sequence_size), 0, config.vocab_size, dtype=torch.int64) torch_output = model(torch_input_ids).last_hidden_state @@ -330,8 +302,7 @@ def test_t5_stack_encoder(device, model_name, batch_size, sequence_size): ) shared_embedding = preprocess_model_parameters(initialize_model=lambda: shared_embedding, device=device) - input_ids = ttnn.from_torch(torch_input_ids) - input_ids = ttnn.to_device(input_ids, device) + input_ids = ttnn.from_torch(torch_input_ids, device=device) output = functional_t5.t5_stack( config, input_ids, @@ -356,11 +327,10 @@ def test_t5_stack_decoder(device, model_name, batch_size, sequence_size): config.is_decoder = True shared_embedding = torch.nn.Embedding(config.vocab_size, config.d_model) model = transformers.models.t5.modeling_t5.T5Stack(config, shared_embedding).eval() - model = model.to(torch.bfloat16) torch_input_ids = torch_random((batch_size, sequence_size), 0, 1, dtype=torch.int64) torch_encoder_hidden_states = torch_random( - (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16 + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) torch_output = model(torch_input_ids, encoder_hidden_states=torch_encoder_hidden_states).last_hidden_state @@ -369,11 +339,10 @@ def test_t5_stack_decoder(device, model_name, batch_size, sequence_size): ) shared_embedding = preprocess_model_parameters(initialize_model=lambda: shared_embedding, device=device) - input_ids = ttnn.from_torch(torch_input_ids) - input_ids = ttnn.to_device(input_ids, device) - encoder_hidden_states = ttnn.from_torch(torch_encoder_hidden_states) - encoder_hidden_states = ttnn.to_layout(encoder_hidden_states, ttnn.TILE_LAYOUT) - encoder_hidden_states = ttnn.to_device(encoder_hidden_states, device) + input_ids = ttnn.from_torch(torch_input_ids, device=device) + encoder_hidden_states = ttnn.from_torch( + torch_encoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) output = functional_t5.t5_stack( config, input_ids, @@ -409,10 +378,8 @@ def test_t5_for_conditional_generation(device, model_name, batch_size, sequence_ device=device, ) - input_ids = ttnn.from_torch(torch_input_ids) - input_ids = ttnn.to_device(input_ids, device) - decoder_input_ids = ttnn.from_torch(torch_decoder_input_ids) - decoder_input_ids = ttnn.to_device(decoder_input_ids, device) + input_ids = ttnn.from_torch(torch_input_ids, device=device) + decoder_input_ids = ttnn.from_torch(torch_decoder_input_ids, device=device) output, *_ = functional_t5.t5_for_conditional_generation( config, input_ids, diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index e4b7ee3c360..95cf4df62bb 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -44,15 +44,14 @@ def test_performance(device, use_program_cache, model_name, batch_size, sequence feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) + input_features = inputs.input_features decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id attention_mask = None parameters = preprocess_model_parameters( tt_model_name, - initialize_model=lambda: WhisperModel.from_pretrained(model_name).to(dtype_to_use).eval(), + initialize_model=lambda: WhisperModel.from_pretrained(model_name).eval(), convert_to_ttnn=functional_whisper.convert_to_ttnn, custom_preprocessor=functional_whisper.custom_preprocessor, device=device, diff --git a/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py index 7ad6e009240..9e648c83a03 100644 --- a/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_torch_functional_whisper.py @@ -10,7 +10,7 @@ import torch from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import torch_random -from ttnn.model_preprocessing import preprocess_model_parameters, make_parameter_dict +from ttnn.model_preprocessing import preprocess_model_parameters # MODEL_NAME = "openai/whisper-base" MODEL_NAME = "openai/whisper-tiny.en" @@ -23,16 +23,12 @@ def test_whisper_attention(model_name, batch_size, sequence_size, use_key_value_states): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = ( - transformers.models.whisper.modeling_whisper.WhisperAttention( - embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout - ) - .to(torch.bfloat16) - .eval() - ) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + model = transformers.models.whisper.modeling_whisper.WhisperAttention( + embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout + ).eval() + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) if use_key_value_states: - key_value_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + key_value_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) else: key_value_states = None torch_output = model(torch_hidden_states, key_value_states=key_value_states) @@ -61,10 +57,10 @@ def test_whisper_attention(model_name, batch_size, sequence_size, use_key_value_ def test_encoder_layer(model_name, batch_size, sequence_size): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).to(torch.bfloat16).eval() + model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).eval() embed_dim = config.d_model - torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) attention_mask = None layer_head_mask = None @@ -87,13 +83,11 @@ def test_encoder_layer(model_name, batch_size, sequence_size): def test_encoder(model_name, batch_size, feature_size, sequence_length): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).to(torch.bfloat16).eval() + model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).eval() - torch_hidden_states = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.float32) - attention_mask = None - head_mask = None - torch_output = model(torch_hidden_states, attention_mask, head_mask) + torch_output = model(torch_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -117,18 +111,16 @@ def test_encoder(model_name, batch_size, feature_size, sequence_length): def test_decoder_layer(model_name, batch_size, sequence_size): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = transformers.models.whisper.modeling_whisper.WhisperDecoderLayer(config).to(torch.bfloat16).eval() - model = model.to(torch.bfloat16) + model = transformers.models.whisper.modeling_whisper.WhisperDecoderLayer(config).eval() + model = model embed_dim = config.d_model - torch_hidden_states = torch_random((batch_size, 32, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, 32, embed_dim), -0.1, 0.1, dtype=torch.float32) - torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - attention_mask = torch_random((batch_size, 1, 32, 32), -0.1, 0.1, dtype=torch.bfloat16) - layer_head_mask = None - cross_attn_layer_head_mask = None - torch_output = model(torch_hidden_states, attention_mask, layer_head_mask, cross_attn_layer_head_mask) + attention_mask = None + torch_output = model(torch_hidden_states, attention_mask, torch_encoder_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -149,20 +141,16 @@ def test_decoder_layer(model_name, batch_size, sequence_size): def test_decoder(model_name, batch_size, sequence_size): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).to(torch.bfloat16).eval() + model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() embed_dim = config.d_model - torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id + torch_decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id attention_mask = None - head_mask = None - cross_attn_layer_head_mask = None - torch_output = model( - decoder_input_ids, attention_mask, torch_encoder_hidden_states, head_mask, cross_attn_layer_head_mask - ) + torch_output = model(torch_decoder_input_ids, attention_mask, torch_encoder_hidden_states) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -171,7 +159,7 @@ def test_decoder(model_name, batch_size, sequence_size): ) (decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_decoder_inputs( - decoder_input_ids, attention_mask, parameters=parameters + torch_decoder_input_ids, attention_mask, parameters=parameters ) output = torch_functional_whisper.decoder( @@ -192,7 +180,7 @@ def test_torch_whisper(): feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 + dtype_to_use = torch.float32 input_features = inputs.input_features.type(dtype_to_use) decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py index 77de5dce7c1..1e01db8cae7 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_functional_whisper.py @@ -28,19 +28,15 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = ( - transformers.models.whisper.modeling_whisper.WhisperAttention( - embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout - ) - .to(torch.bfloat16) - .eval() - ) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + model = transformers.models.whisper.modeling_whisper.WhisperAttention( + embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout + ).eval() + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) ttnn_hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16) ttnn_hidden_states = ttnn.to_device(ttnn_hidden_states, device) if use_key_value_states: torch_key_value_states = torch_random( - (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16 + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) ttnn_key_value_states = ttnn.from_torch(torch_key_value_states, dtype=ttnn.bfloat16) ttnn_key_value_states = ttnn.to_device(ttnn_key_value_states, device) @@ -95,10 +91,10 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).eval() - model = model.to(torch.bfloat16) + model = model embed_dim = config.d_model - torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -132,9 +128,9 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).eval() - model = model.to(torch.bfloat16) + model = model - torch_input_features = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.bfloat16) + torch_input_features = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.float32) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -184,15 +180,15 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoderLayer(config).eval() - model = model.to(torch.bfloat16) + model = model num_heads = config.encoder_attention_heads embed_dim = config.d_model - torch_hidden_states = torch_random((batch_size, 2, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, 2, embed_dim), -0.1, 0.1, dtype=torch.float32) - torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - attention_mask = torch_random((batch_size, 1, 2, 2), -0.1, 0.1, dtype=torch.bfloat16) + attention_mask = torch_random((batch_size, 1, 2, 2), -0.1, 0.1, dtype=torch.float32) # Putting 8 in the channel because the add does not support broadcasting outside of the h and w dimensions. attention_mask = attention_mask.expand(-1, num_heads, -1, -1) @@ -244,11 +240,11 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() - model = model.to(torch.bfloat16) + model = model embed_dim = config.d_model - torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) # decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id @@ -316,13 +312,12 @@ def test_ttnn_whisper(device, ttnn_model): feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) + input_features = inputs.input_features decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id attention_mask = None - model = WhisperModel.from_pretrained(model_name).to(dtype_to_use).eval() + model = WhisperModel.from_pretrained(model_name).eval() parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -330,6 +325,9 @@ def test_ttnn_whisper(device, ttnn_model): custom_preprocessor=torch_functional_whisper.custom_preprocessor, ) + parameters.encoder.layers = parameters.encoder.layers[:1] + parameters.decoder.layers = parameters.decoder.layers[:1] + (encoder_hidden_states, decoder_hidden_states, decoder_attention_mask) = torch_functional_whisper.preprocess_inputs( input_features=input_features, input_ids=decoder_input_ids, @@ -352,6 +350,9 @@ def test_ttnn_whisper(device, ttnn_model): device=device, ) + ttnn_parameters.encoder.layers = ttnn_parameters.encoder.layers[:1] + ttnn_parameters.decoder.layers = ttnn_parameters.decoder.layers[:1] + (input_embeds, decoder_hidden_states, decoder_attention_mask) = ttnn_model.preprocess_inputs( config=config, input_features=input_features, @@ -371,4 +372,4 @@ def test_ttnn_whisper(device, ttnn_model): last_hidden_state = ttnn.from_device(last_hidden_state) last_hidden_state = ttnn.to_torch(last_hidden_state) - assert_with_pcc(expected_last_hidden_state, last_hidden_state, pcc=0.99) + assert_with_pcc(expected_last_hidden_state, last_hidden_state, 0.9912) diff --git a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py index eb36ac4acee..342c293f4af 100644 --- a/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py +++ b/tests/ttnn/integration_tests/whisper/test_ttnn_optimized_functional_whisper.py @@ -27,20 +27,16 @@ def test_whisper_attention(device, ttnn_model, model_name, batch_size, sequence_size, use_key_value_states): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) - model = ( - transformers.models.whisper.modeling_whisper.WhisperAttention( - embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout - ) - .to(torch.bfloat16) - .eval() - ) - torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16) + model = transformers.models.whisper.modeling_whisper.WhisperAttention( + embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout + ).eval() + torch_hidden_states = torch_random((batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32) ttnn_hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16) ttnn_hidden_states = ttnn.to_layout(ttnn_hidden_states, ttnn.TILE_LAYOUT) ttnn_hidden_states = ttnn.to_device(ttnn_hidden_states, device) if use_key_value_states: torch_key_value_states = torch_random( - (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.bfloat16 + (batch_size, sequence_size, config.d_model), -0.1, 0.1, dtype=torch.float32 ) ttnn_key_value_states = ttnn.from_torch(torch_key_value_states, dtype=ttnn.bfloat16) ttnn_key_value_states = ttnn.to_layout(ttnn_key_value_states, ttnn.TILE_LAYOUT) @@ -96,10 +92,10 @@ def test_encoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoderLayer(config).eval() - model = model.to(torch.bfloat16) + model = model embed_dim = config.d_model - torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -135,9 +131,9 @@ def test_encoder(device, ttnn_model, model_name, batch_size, feature_size, seque torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperEncoder(config).eval() - model = model.to(torch.bfloat16) + model = model - torch_input_features = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.bfloat16) + torch_input_features = torch_random((batch_size, feature_size, sequence_length), -0.1, 0.1, dtype=torch.float32) parameters = preprocess_model_parameters( initialize_model=lambda: model, @@ -188,15 +184,15 @@ def test_decoder_layer(device, ttnn_model, model_name, batch_size, sequence_size torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoderLayer(config).eval() - model = model.to(torch.bfloat16) + model = model num_heads = config.encoder_attention_heads embed_dim = config.d_model - torch_hidden_states = torch_random((batch_size, 2, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_hidden_states = torch_random((batch_size, 2, embed_dim), -0.1, 0.1, dtype=torch.float32) - torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) - attention_mask = torch_random((batch_size, 1, 2, 2), -0.1, 0.1, dtype=torch.bfloat16) + attention_mask = torch_random((batch_size, 1, 2, 2), -0.1, 0.1, dtype=torch.float32) # Putting num_heads in the channel because the add does not support broadcasting outside of the h and w dimensions. attention_mask = attention_mask.expand(-1, num_heads, -1, -1) @@ -246,11 +242,11 @@ def test_decoder(device, ttnn_model, model_name, batch_size, sequence_size): torch.manual_seed(0) config = transformers.WhisperConfig.from_pretrained(model_name) model = transformers.models.whisper.modeling_whisper.WhisperDecoder(config).eval() - model = model.to(torch.bfloat16) + model = model embed_dim = config.d_model - torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.bfloat16) + torch_encoder_hidden_states = torch_random((batch_size, sequence_size, embed_dim), -0.1, 0.1, dtype=torch.float32) # decoder_input_ids = torch.ones(1, 32).type(torch.int32) * config.decoder_start_token_id decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id @@ -318,13 +314,12 @@ def test_ttnn_whisper(device, ttnn_model): feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") inputs = feature_extractor(ds[0]["audio"]["array"], sampling_rate=16000, return_tensors="pt") - dtype_to_use = torch.bfloat16 - input_features = inputs.input_features.type(dtype_to_use) + input_features = inputs.input_features decoder_input_ids = torch.tensor([[1, 1]]) * config.decoder_start_token_id attention_mask = None - model = WhisperModel.from_pretrained(model_name).to(dtype_to_use).eval() + model = WhisperModel.from_pretrained(model_name).eval() parameters = preprocess_model_parameters( initialize_model=lambda: model,