Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Resolves] Support ingesting PyTorch models #9

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ enable =
expression-not-assigned,
confusing-with-statement,
unnecessary-lambda,
assign-to-new-keyword,
redeclared-assigned-name,
pointless-statement,
pointless-string-statement,
Expand Down Expand Up @@ -123,7 +122,6 @@ enable =
invalid-length-returned,
protected-access,
attribute-defined-outside-init,
no-init,
abstract-method,
invalid-overridden-method,
arguments-differ,
Expand Down Expand Up @@ -165,9 +163,7 @@ enable =
### format
# Line length, indentation, whitespace:
bad-indentation,
mixed-indentation,
unnecessary-semicolon,
bad-whitespace,
missing-final-newline,
line-too-long,
mixed-line-endings,
Expand All @@ -187,7 +183,6 @@ enable =
import-self,
preferred-module,
reimported,
relative-import,
deprecated-module,
wildcard-import,
misplaced-future,
Expand Down Expand Up @@ -282,12 +277,6 @@ indent-string = ' '
# black doesn't always obey its own limit. See pyproject.toml.
max-line-length = 100

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check =

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt = no
Expand Down
42 changes: 22 additions & 20 deletions examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import csv
from collections import Counter, defaultdict
from tqdm import tqdm
from digest.model_class.digest_model import (
NodeShapeCounts,
NodeTypeCounts,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
)
from digest.model_class.digest_onnx_model import DigestOnnxModel
from utils.onnx_utils import (
get_dynamic_input_dims,
load_onnx,
DigestOnnxModel,
save_node_shape_counts_csv_report,
save_node_type_counts_csv_report,
NodeTypeCounts,
NodeShapeCounts,
)

GLOBAL_MODEL_HEADERS = [
Expand Down Expand Up @@ -71,7 +73,7 @@ def main(onnx_files: str, output_dir: str):
print(f"dim: {dynamic_shape}")

digest_model = DigestOnnxModel(
model_proto, onnx_filepath=onnx_file, model_name=model_name
model_proto, onnx_file_path=onnx_file, model_name=model_name
)

# Update the global model dictionary
Expand All @@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str):

global_model_data[model_name] = {
"opset": digest_model.opset,
"parameters": digest_model.model_parameters,
"flops": digest_model.model_flops,
"parameters": digest_model.parameters,
"flops": digest_model.flops,
}

# Model summary text report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt")
digest_model.save_txt_report(summary_filepath)
digest_model.save_text_report(summary_filepath)

# Model summary yaml report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml")
digest_model.save_yaml_report(summary_filepath)

# Save csv containing node-level information
nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)

# Save csv containing node type counter
node_type_counter = digest_model.get_node_type_counts()
node_type_filepath = os.path.join(
output_dir, f"{model_name}_node_type_counts.csv"
)
if node_type_counter:
save_node_type_counts_csv_report(node_type_counter, node_type_filepath)

digest_model.save_node_type_counts_csv_report(node_type_filepath)

# Update global data structure for node type counter
global_node_type_counter.update(node_type_counter)
global_node_type_counter.update(digest_model.node_type_counts)

# Save csv containing node shape counts per op_type
node_shape_counts = digest_model.get_node_shape_counts()
node_shape_filepath = os.path.join(
output_dir, f"{model_name}_node_shape_counts.csv"
)
save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath)
digest_model.save_node_shape_counts_csv_report(node_shape_filepath)

# Update global data structure for node shape counter
for node_type, shape_counts in node_shape_counts.items():
for node_type, shape_counts in digest_model.get_node_shape_counts().items():
global_node_shape_counter[node_type].update(shape_counts)

if len(onnx_file_list) > 1:
global_filepath = os.path.join(output_dir, "global_node_type_counts.csv")
global_node_type_counter = NodeTypeCounts(
global_node_type_counter.most_common()
)
save_node_type_counts_csv_report(global_node_type_counter, global_filepath)
global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common())
save_node_type_counts_csv_report(global_node_type_counts, global_filepath)

global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv")
save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath)
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.0.0",
version="1.2.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand All @@ -25,6 +25,8 @@
"platformdirs>=4.2.2",
"pyyaml>=6.0.1",
"psutil>=6.0.0",
"torch",
"transformers",
],
classifiers=[],
entry_points={"console_scripts": ["digest = digest.main:main"]},
Expand Down
14 changes: 12 additions & 2 deletions src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,23 @@ class WarnDialog(QDialog):

def __init__(self, warning_message: str, parent=None):
super().__init__(parent)
self.setWindowTitle("Warning Message")

self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg"))

self.setWindowTitle("Warning Message")
self.setWindowFlags(Qt.WindowType.Dialog)
self.setMinimumWidth(300)

self.setWindowModality(Qt.WindowModality.WindowModal)

layout = QVBoxLayout()

# Application Version
layout.addWidget(QLabel("<b>Something went wrong</b>"))
layout.addWidget(QLabel("<b>Warning</b>"))
layout.addWidget(QLabel(warning_message))

ok_button = QPushButton("OK")
ok_button.clicked.connect(self.accept) # Close dialog when clicked
layout.addWidget(ok_button)

self.setLayout(layout)
2 changes: 1 addition & 1 deletion src/digest/gui_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# For EXE releases we can block certain features e.g. to customers

modules:
huggingface: false
huggingface: true
6 changes: 3 additions & 3 deletions src/digest/histogramchartwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs):
super(StackedHistogramWidget, self).__init__(*args, **kwargs)

self.plot_widget = pg.PlotWidget()
self.plot_widget.setMaximumHeight(150)
self.plot_widget.setMaximumHeight(200)
plot_item = self.plot_widget.getPlotItem()
if plot_item:
plot_item.setContentsMargins(0, 0, 0, 0)
Expand All @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs):
self.bar_spacing = 25

def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False):

title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)"
self.plot_widget.setLabel(
"left",
Expand All @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal
x_positions = list(range(len(op_count)))
total_count = sum(op_count)
width = 0.6
self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing)
self.plot_widget.setFixedWidth(500)

for count, x_pos, tick in zip(op_count, x_positions, op_names):
x0 = x_pos - width / 2
y0 = 0
Expand Down
Loading
Loading