Skip to content

Commit

Permalink
Lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Jan 24, 2024
1 parent 124dcf8 commit abc239c
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
root_input = output
break

graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])
graph_input_names = set([node.name for node in self.model.graph().input])
graph_output_names = set([node.name for node in self.model.graph().output])

v_nodes = self.model.match_parent_path(
matmul_qkv,
Expand All @@ -152,13 +152,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if v_nodes is not None:
(transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
# For initial pass through encoder-decoder_with_past to get starting past values (beam search)
#present_v = add_v.output[0]

add_v_children = self.model.get_children(add_v)
for child in add_v_children:
if child.op_type == "Reshape":
#if child.output[0] in graph_output_names:
#present_v = child.output[0]
reshape_v_children = self.model.get_children(child)
for reshape_child in reshape_v_children:
if reshape_child.op_type == "Transpose":
Expand Down Expand Up @@ -205,9 +202,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
present_v = present_v if present_v in graph_output_names else ""

qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]
)
qk_nodes_2 = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
if qk_nodes_1 is not None:
_, matmul_qk = qk_nodes_1
qk_nodes = qk_nodes_1
Expand Down Expand Up @@ -256,13 +251,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
k_nodes = k_nodes_with_bias
present_k = matmul_k.output[0]
mat_k_out_tmp = matmul_k.output[0] + "_temp"
#matmul_k.output[0] = matmul_k.output[0] + "_temp"

matmul_k_children = self.model.get_children(matmul_k)
for child in matmul_k_children:
if child.op_type == "Reshape":
#if child.output[0] in graph_output_names:
# present_k = child.output[0]
reshape_k_children = self.model.get_children(child)
for reshape_child in reshape_k_children:
if reshape_child.op_type == "Transpose":
Expand All @@ -285,8 +277,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if reshape_parent.op_type == "Transpose":
if reshape_parent.input[0] in graph_input_names:
past_k = reshape_parent.input[0]
#else:
# matmul_k.output[0] = mat_k_out_tmp


elif k_nodes_no_bias is not None:
Expand Down Expand Up @@ -328,7 +318,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
add_name = self.model.create_node_name("Add")
add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name)

'''
"""
if not past_k and not self.check_runtime_shape_path(
reshape_qkv_2,
reshape_qkv_1,
Expand All @@ -338,7 +328,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
root_input,
):
return
'''
"""

three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals()
one_root_input = (
Expand Down Expand Up @@ -381,7 +371,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
)
if mask_nodes_whisper is not None:
pass
#mask_index = mask_nodes_whisper[0].output[-1]
# mask_index = mask_nodes_whisper[0].output[-1]
elif mask_nodes_bart is not None:
mask_index = mask_nodes_bart[0].output[-1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,8 @@ def export_onnx_models(
device = torch.device("cuda:0" if use_gpu else "cpu")

models = WhisperHelper.load_model(
model_name_or_path,
model_impl,
cache_dir,
device,
merge_encoder_and_decoder_init,
state_dict_path
model_name_or_path, model_impl, cache_dir, device,
merge_encoder_and_decoder_init, state_dict_path
)
config = models["decoder"].config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(
class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with past key values"""

def __init__(self, decoder, config, model_impl: str = 'hf', model=None):
def __init__(self, decoder, config, model_impl: str = "hf", model=None):
super().__init__()
self.decoder = decoder
self.config = config
Expand All @@ -83,9 +83,11 @@ def forward(self, decoder_input_ids, *past):
encoder_outputs["hidden_states"] = dummy_encoder_hidden_states
encoder_outputs["attentions"] = None

if self.model_impl == 'openai':
if self.model_impl == "openai":
dummy_encoder_hidden_states.unsqueeze(0)
dec_out, present = self.whisper_decoder_openai_init(decoder_input_ids, dummy_encoder_hidden_states, past=past)
dec_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, dummy_encoder_hidden_states, past=past
)
return dec_out, present

if len(past) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class WhisperEncoder(torch.nn.Module):
"""Whisper encoder outputs only the last hidden state"""

def __init__(self, encoder, config: WhisperConfig, model_impl: str = 'hf'):
def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"):
super().__init__()
self.encoder = encoder
self.config = config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# license information.
# --------------------------------------------------------------------------

import copy
import logging
import os
import tempfile
import copy
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
decoder: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: Optional[int] = None,
model_impl: str = 'hf',
model_impl: str = "hf",
model: torch.nn.Module = None,
):
super().__init__()
Expand All @@ -55,7 +55,7 @@ def forward(
):
encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
# Decoder out: (logits, past_key_values, encoder_hidden_state)
if self.model_impl == 'openai':
if self.model_impl == "openai":
encoder_hidden_states.unsqueeze(0)
decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states)
return decinit_out, encoder_hidden_states, present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,22 @@
# license information.
# --------------------------------------------------------------------------

import io
import logging
import os
import io
import sys
from pathlib import Path
from typing import Dict, Tuple, Union

import numpy as np
import torch
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
from whisper import _MODELS, _ALIGNMENT_HEADS, _download
from whisper.model import Whisper, ModelDimensions
from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit
from whisper_encoder import WhisperEncoder, WhisperEncoderHelper
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper

from whisper.model import Whisper, ModelDimensions
from whisper import _MODELS, _ALIGNMENT_HEADS
from whisper import _download

from onnxruntime import InferenceSession

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
Expand Down Expand Up @@ -94,16 +92,14 @@ def load_model_openai(

in_memory = False

model_name = model_name_or_path.split('/')[-1][8:]
model_name = model_name_or_path.split("/")[-1][8:]
checkpoint_file = None
if model_name in _MODELS:
checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory)
alignment_heads = _ALIGNMENT_HEADS[model_name]


with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
with io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") as fp:

Check warning

Code scanning / CodeQL

File is not always closed Warning

File is opened but is not closed.
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

class WhisperDecoderInitOpenai(torch.nn.Module):
"""WhisperDecoderInit for Openai."""

def __init__(
self,
model: torch.nn.Module,
Expand All @@ -43,13 +44,12 @@ def forward(
audio_features,
past=None,
):

# Create a kv_cache for past_values
past_kv_cache = dict()
if past is not None:
# Convert past values from 4D to 3D
past = [torch.transpose(val, 1, 2) for val in past]
past = [val.reshape(val.shape[:2] + (-1, )) for val in past]
past = [val.reshape(val.shape[:2] + (-1,)) for val in past]
half_idx = len(past) // 2
for idx, block in enumerate(self.whisper_decoder.blocks):
past_kv_cache[block.attn.key] = past[2 * idx]
Expand All @@ -65,8 +65,12 @@ def forward(
# Add concat node for past values
if past is not None:
for idx, block in enumerate(self.whisper_decoder.blocks):

Check warning

Code scanning / lintrunner

RUFF/B007 Warning

Loop control variable idx not used within loop body.
See https://docs.astral.sh/ruff/rules/unused-loop-control-variable
self.kv_cache[block.attn.key] = torch.cat([past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1).detach()
self.kv_cache[block.attn.value] = torch.cat([past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1).detach()
self.kv_cache[block.attn.key] = torch.cat(
[past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1
).detach()
self.kv_cache[block.attn.value] = torch.cat(
[past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1
).detach()

present_self, present_cross = [], []
# Group self and cross values
Expand All @@ -79,7 +83,7 @@ def forward(

present_self = present_self + present_cross
# Add reshape and transpose ops to convert from 3D to 4D
present_self = [present_val.reshape(
present_val.shape[:2] + (-1, 64)
).transpose(1, 2) for present_val in present_self]
present_self = [
present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self
]
return logits, present_self
5 changes: 1 addition & 4 deletions onnxruntime/python/tools/transformers/onnx_model_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
self.attention_mask = AttentionMask(self)
if model_impl == "openai":
self.attention_fusion = FusionBartAttentionOpenai(
self,
self.hidden_size,
self.num_heads,
self.attention_mask
self, self.hidden_size, self.num_heads, self.attention_mask
)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute attention_fusion, which was previously defined in superclass
BertOnnxModel
.
else:
self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute attention_fusion, which was previously defined in superclass
BertOnnxModel
.
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/python/tools/transformers/onnx_model_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo

if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention) and not isinstance(self.attention_fusion, FusionBartAttentionOpenai):
if (
options.use_multi_head_attention
and not isinstance(self.attention_fusion, FusionBartAttention)
and not isinstance(self.attention_fusion, FusionBartAttentionOpenai)
):
self.attention_fusion = FusionAttention(
self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention
)
Expand Down

0 comments on commit abc239c

Please sign in to comment.