Skip to content

Commit

Permalink
Lintrunner fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter McAughan committed Nov 2, 2023
1 parent ab0be1a commit 8f9912f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
9 changes: 6 additions & 3 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,10 @@ def find_past_seq_len_usage(subg: GraphProto):
nodes_to_remove.append(shape_node)
return tensor_names_to_rename, nodes_to_remove

def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0):

def replace_mha_with_gqa(
model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
):
past_seq_len = past_seq_len_input
if past_seq_len not in model.get_graphs_input_names():
# Add model input for past sequence length
Expand All @@ -1287,7 +1290,7 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
num_heads_mha = att.i
if window_size:
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
"GroupQueryAttention",
inputs=[
node.input[0], # query
node.input[1], # key
Expand Down Expand Up @@ -1331,7 +1334,7 @@ def update_decoder_subgraph_output_cross_attention(subg: GraphProto):
input_self_past_0 = 1
# w/wo attention mask, w/wo hidden_state
graph_input_names = [gi.name for gi in subg.input]
while input_self_past_0 replace_mha_with_gqa 3 and not graph_input_names[input_self_past_0].startswith("past"):
while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"):
input_self_past_0 += 1
output_self_present_0 = 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from onnx_model import OnnxModel
from optimizer import optimize_model
from packaging import version
from transformers import LlamaConfig, LlamaForCausalLM, PretrainedConfig, AutoConfig
from transformers import AutoConfig, LlamaConfig, LlamaForCausalLM, PretrainedConfig

from onnxruntime import quantization as ort_quantization
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
Expand Down Expand Up @@ -391,7 +391,7 @@ def run_torchscript_merged_export(


# Optimize the model as FP32
def optimize_export(config: LlamaConfig, input_path: str, output_path: str, remove_model = True):
def optimize_export(config: LlamaConfig, input_path: str, output_path: str, remove_model=True):
from fusion_options import FusionOptions

optimization_options = FusionOptions("gpt2")
Expand All @@ -407,7 +407,7 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str, remo
)
model_opt.save_model_to_file(output_path, use_external_data_format=True)
logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
if (remove_model):
if remove_model:
remove_existing_model(input_path)


Expand Down Expand Up @@ -438,7 +438,9 @@ def convert_to_float16(
return new_paths


def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0):
def use_group_query_attention(
config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1, window_size: int = 0
):
# Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes
fp16_model_opt = replace_mha_with_gqa(
fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size, window_size
Expand All @@ -447,6 +449,7 @@ def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, wo
fp16_model_opt.update_graph(allow_remove_graph_inputs=True)
return fp16_model_opt


def smooth_quant(
args: argparse.Namespace,
decoder_model_fp32_path: str,
Expand Down Expand Up @@ -539,17 +542,18 @@ def remove_existing_files(output_path: str):
os.remove(filepath)
logger.warning(f"Removed {filepath}")


def optimize_optimum(config: PretrainedConfig, args):
tmp_file = os.path.join(args.output, args.model_name+".tmp.onnx")
output_file = os.path.join(args.output, args.model_name+".onnx")
tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx")
output_file = os.path.join(args.output, args.model_name + ".onnx")
optimize_export(config, args.input, tmp_file, remove_model=False)
logger.info(f"Model successfully optimized to {tmp_file}")
opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True))
if args.precision == Precision.FLOAT16:
opt_model.convert_float_to_float16(keep_io_types=False)
window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window
window_size = 0 if not hasattr(config, "sliding_window") else config.sliding_window
opt_model = use_group_query_attention(config, opt_model, args.world_size, window_size)
logger.info(f"Model successfully fused and quantized to FP16!")
logger.info("Model successfully fused and quantized to FP16!")
opt_model.save_model_to_file(output_file, use_external_data_format=True)
logger.info(f"Output model successfully saved to {output_file}")
logger.info(f"Removing {tmp_file}")
Expand Down Expand Up @@ -745,7 +749,7 @@ def main():
# Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false
# in that scenario. It can be removed when torch v2.2.0 is released in stable.
logger.warning(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.")
#return
# return

args = get_args()
setup_logger(args.verbose)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def get_msft_sample_inputs(

# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: PretrainedConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
def get_past_kv_inputs(
config: PretrainedConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1
):
num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
Expand Down

0 comments on commit 8f9912f

Please sign in to comment.