-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
211 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
# -------------------------------------------------------------------------- | ||
import logging | ||
import os | ||
Check notice Code scanning / CodeQL Unused import Note
Import of 'os' is not used.
Check warning Code scanning / lintrunner RUFF/F401 Warning
os imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import |
||
from pathlib import Path | ||
|
||
import onnx | ||
|
||
|
@@ -14,31 +13,29 @@ class DynamoOnnxHelper: | |
Helper class for processing ONNX models exported by torch Dynamo. | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
def __init__(self, model: onnx.ModelProto): | ||
self.model = model | ||
|
||
def update_edges(self, model: onnx.ModelProto, edge_mapping: dict) -> onnx.ModelProto: | ||
def update_edges(self, edge_mapping: dict) -> None: | ||
""" | ||
Updates the edges in the model according to the given mapping. | ||
""" | ||
for node in model.graph.node: | ||
for node in self.model.graph.node: | ||
for i in range(len(node.input)): | ||
if node.input[i] in edge_mapping: | ||
node.input[i] = edge_mapping[node.input[i]] | ||
for i in range(len(node.output)): | ||
if node.output[i] in edge_mapping: | ||
node.output[i] = edge_mapping[node.output[i]] | ||
|
||
for graph_input in model.graph.input: | ||
for graph_input in self.model.graph.input: | ||
if graph_input.name in edge_mapping: | ||
graph_input.name = edge_mapping[graph_input.name] | ||
for graph_output in model.graph.output: | ||
for graph_output in self.model.graph.output: | ||
if graph_output.name in edge_mapping: | ||
graph_output.name = edge_mapping[graph_output.name] | ||
|
||
return model | ||
|
||
def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelProto: | ||
def unroll_function(self, func_name: str) -> None: | ||
""" | ||
Unrolls the function with the given name in the model. | ||
""" | ||
|
@@ -47,13 +44,13 @@ def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelP | |
nodes_to_add = [] | ||
edges_to_remove = [] | ||
edges_to_add = [] | ||
for node in model.graph.node: | ||
for node in self.model.graph.node: | ||
if node.op_type == func_name: | ||
nodes_to_remove.append(node) | ||
edges_to_remove.extend(list(node.input) + list(node.output)) | ||
|
||
func_to_remove = None | ||
for f in model.functions: | ||
for f in self.model.functions: | ||
if f.name == func_name: | ||
nodes_to_add.extend(list(f.node)) | ||
edges_to_add.extend(list(f.input) + list(f.output)) | ||
|
@@ -62,11 +59,11 @@ def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelP | |
assert len(edges_to_remove) == len(edges_to_add) | ||
|
||
for node in nodes_to_remove: | ||
model.graph.node.remove(node) | ||
self.model.graph.node.remove(node) | ||
for node in nodes_to_add: | ||
model.graph.node.append(node) | ||
self.model.graph.node.append(node) | ||
if func_to_remove is not None: | ||
model.functions.remove(func_to_remove) | ||
self.model.functions.remove(func_to_remove) | ||
|
||
edge_mapping = {} | ||
for i in range(len(edges_to_remove)): | ||
|
@@ -75,40 +72,22 @@ def unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelP | |
if k != v: | ||
edge_mapping[k] = v | ||
|
||
return self.update_edges(model, edge_mapping) | ||
return self.update_edges(edge_mapping) | ||
|
||
def remove_dropout_layer(self, model: onnx.ModelProto) -> onnx.ModelProto: | ||
def remove_dropout_layer(self) -> None: | ||
""" | ||
Removes the dropout layer in the model. | ||
""" | ||
logging.info("Removing dropout layer...") | ||
edge_mapping = {} | ||
nodes_to_remove = [] | ||
for node in model.graph.node: | ||
for node in self.model.graph.node: | ||
if node.op_type.find("Dropout") != -1: | ||
assert len(node.input) == 1 | ||
assert len(node.output) == 1 | ||
edge_mapping[node.output[0]] = node.input[0] | ||
nodes_to_remove.append(node) | ||
for node in nodes_to_remove: | ||
model.graph.node.remove(node) | ||
|
||
return self.update_edges(model, edge_mapping) | ||
|
||
def erase_onnx_model(self, onnx_path: str) -> None: | ||
assert onnx_path.endswith(".onnx") | ||
if not os.path.exists(onnx_path): | ||
return | ||
self.model.graph.node.remove(node) | ||
|
||
model = onnx.load_model(onnx_path, load_external_data=False) | ||
onnx_data_path = None | ||
for initializer in model.graph.initializer: | ||
if initializer.data_location == 1 and initializer.external_data[0].key == "location": | ||
onnx_data_path = "./" + initializer.external_data[0].value | ||
break | ||
logging.info(f"Erasing {onnx_path}...") | ||
os.remove(onnx_path) | ||
if onnx_data_path is not None: | ||
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path) | ||
logging.info(f"Erasing {onnx_data_path}...") | ||
os.remove(onnx_data_path) | ||
self.update_edges(edge_mapping) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.