Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Feb 2, 2024
1 parent 83d7d4b commit cf8aa62
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 237 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------
import logging
import os

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'os' is not used.

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

from pathlib import Path

import onnx

Expand All @@ -14,31 +13,29 @@ class DynamoOnnxHelper:
Helper class for processing ONNX models exported by torch Dynamo.
"""

def __init__(self):
pass
def __init__(self, model: onnx.ModelProto):
self.model = model

def update_edges(self, model: onnx.ModelProto, edge_mapping: dict) -> onnx.ModelProto:
def update_edges(self, edge_mapping: dict) -> None:
"""
Updates the edges in the model according to the given mapping.
"""
for node in model.graph.node:
for node in self.model.graph.node:
for i in range(len(node.input)):
if node.input[i] in edge_mapping:
node.input[i] = edge_mapping[node.input[i]]
for i in range(len(node.output)):
if node.output[i] in edge_mapping:
node.output[i] = edge_mapping[node.output[i]]

for graph_input in model.graph.input:
for graph_input in self.model.graph.input:
if graph_input.name in edge_mapping:
graph_input.name = edge_mapping[graph_input.name]
for graph_output in model.graph.output:
for graph_output in self.model.graph.output:
if graph_output.name in edge_mapping:
graph_output.name = edge_mapping[graph_output.name]

return model

def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelProto:
def unroll_function(self, func_name: str) -> None:
"""
Unrolls the function with the given name in the model.
"""
Expand All @@ -47,13 +44,13 @@ def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelP
nodes_to_add = []
edges_to_remove = []
edges_to_add = []
for node in model.graph.node:
for node in self.model.graph.node:
if node.op_type == func_name:
nodes_to_remove.append(node)
edges_to_remove.extend(list(node.input) + list(node.output))

func_to_remove = None
for f in model.functions:
for f in self.model.functions:
if f.name == func_name:
nodes_to_add.extend(list(f.node))
edges_to_add.extend(list(f.input) + list(f.output))
Expand All @@ -62,11 +59,11 @@ def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelP
assert len(edges_to_remove) == len(edges_to_add)

for node in nodes_to_remove:
model.graph.node.remove(node)
self.model.graph.node.remove(node)
for node in nodes_to_add:
model.graph.node.append(node)
self.model.graph.node.append(node)
if func_to_remove is not None:
model.functions.remove(func_to_remove)
self.model.functions.remove(func_to_remove)

edge_mapping = {}
for i in range(len(edges_to_remove)):
Expand All @@ -75,40 +72,22 @@ def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelP
if k != v:
edge_mapping[k] = v

return self.update_edges(model, edge_mapping)
return self.update_edges(edge_mapping)

def remove_dropout_layer(self, model: onnx.ModelProto) -> onnx.ModelProto:
def remove_dropout_layer(self) -> None:
"""
Removes the dropout layer in the model.
"""
logging.info("Removing dropout layer...")
edge_mapping = {}
nodes_to_remove = []
for node in model.graph.node:
for node in self.model.graph.node:
if node.op_type.find("Dropout") != -1:
assert len(node.input) == 1
assert len(node.output) == 1
edge_mapping[node.output[0]] = node.input[0]
nodes_to_remove.append(node)
for node in nodes_to_remove:
model.graph.node.remove(node)

return self.update_edges(model, edge_mapping)

def erase_onnx_model(self, onnx_path: str) -> None:
assert onnx_path.endswith(".onnx")
if not os.path.exists(onnx_path):
return
self.model.graph.node.remove(node)

model = onnx.load_model(onnx_path, load_external_data=False)
onnx_data_path = None
for initializer in model.graph.initializer:
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
onnx_data_path = "./" + initializer.external_data[0].value
break
logging.info(f"Erasing {onnx_path}...")
os.remove(onnx_path)
if onnx_data_path is not None:
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
logging.info(f"Erasing {onnx_data_path}...")
os.remove(onnx_data_path)
self.update_edges(edge_mapping)
15 changes: 15 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from argparse import ArgumentParser
from enum import Enum


class AttentionMaskFormat:
Expand All @@ -19,6 +20,15 @@ class AttentionMaskFormat:
NoMask = 3


class AttentionOpType(Enum):
Attention = "Attention"
MultiHeadAttention = "MultiHeadAttention"
GroupQueryAttention = "GroupQueryAttention"

def __str__(self):
return self.value


class FusionOptions:
"""Options of fusion in graph optimization"""

Expand Down Expand Up @@ -57,6 +67,8 @@ def __init__(self, model_type):
elif model_type == "vit":
self.attention_mask_format = AttentionMaskFormat.NoMask

self.attention_op_type = None

# options for stable diffusion
if model_type in ["unet", "vae", "clip"]:
self.enable_nhwc_conv = True
Expand All @@ -76,6 +88,9 @@ def use_raw_attention_mask(self, use_raw_mask=True):
def disable_attention_mask(self):
self.attention_mask_format = AttentionMaskFormat.NoMask

def set_attention_op_type(self, attn_op_type: AttentionOpType):
self.attention_op_type = attn_op_type

@staticmethod
def parse(args):
options = FusionOptions(args.model_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,4 @@ python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x -
The inference example currently supports all models running on CUDA.

## Limitations
- It's a known issue that symbolic shape inference fails. It can be ignored at the moment as it won't affect the optimized model's inference.
- Torch dynamo export only support Linux. The model export cannot be run on Windows as of now.
Loading

0 comments on commit cf8aa62

Please sign in to comment.