Skip to content

Commit

Permalink
Initial commit for feature request #2.
Browse files Browse the repository at this point in the history
- Offers solution for steps 1-3
  • Loading branch information
Philip Colangelo committed Dec 6, 2024
1 parent 6df7d05 commit e33c35c
Show file tree
Hide file tree
Showing 15 changed files with 1,604 additions and 1,076 deletions.
4 changes: 4 additions & 0 deletions examples/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def main(onnx_files: str, output_dir: str):
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt")
digest_model.save_txt_report(summary_filepath)

# Model summary yaml report
summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml")
digest_model.save_yaml_report(summary_filepath)

# Save csv containing node-level information
nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv")
digest_model.save_nodes_csv_report(nodes_filepath)
Expand Down
2 changes: 1 addition & 1 deletion src/digest/gui_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# For EXE releases we can block certain features e.g. to customers

modules:
huggingface: false
huggingface: true
286 changes: 273 additions & 13 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from digest.modelsummary import modelSummary
from digest.node_summary import NodeSummary
from digest.qt_utils import apply_dark_style_sheet
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_model import save_node_type_counts_csv_report
from utils import onnx_utils

GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml")
Expand Down Expand Up @@ -161,7 +163,7 @@ def __init__(self, model_file: Optional[str] = None):
self.status_dialog = None
self.err_open_dialog = None
self.temp_dir = tempfile.TemporaryDirectory()
self.digest_models: Dict[str, onnx_utils.DigestOnnxModel] = {}
self.digest_models: Dict[str, DigestOnnxModel] = {}

# QThread containers
self.model_nodes_stats_thread: Dict[str, StatsThread] = {}
Expand Down Expand Up @@ -243,10 +245,10 @@ def uncheck_ingest_buttons(self):
def tab_focused(self, index):
widget = self.ui.tabWidget.widget(index)
if isinstance(widget, modelSummary):
model_id = widget.digest_model.unique_id
unique_id = widget.digest_model.unique_id
if (
self.stats_save_button_flag[model_id]
and self.similarity_save_button_flag[model_id]
self.stats_save_button_flag[unique_id]
and self.similarity_save_button_flag[unique_id]
):
self.ui.saveBtn.setEnabled(True)
else:
Expand All @@ -273,17 +275,28 @@ def closeTab(self, index):

def openFile(self):
filename, _ = QFileDialog.getOpenFileName(
self, "Open File", "", "ONNX Files (*.onnx)"
self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)"
)

if (
filename and os.path.splitext(filename)[-1] == ".onnx"
): # Only if user selects a file and clicks OK
if not filename:
return

file_ext = os.path.splitext(filename)[-1]

if file_ext == ".onnx":
self.load_onnx(filename)
elif file_ext == ".yaml":
self.load_report(filename)
else:
bad_ext_dialog = StatusDialog(
f"Digest does not support files with the extension {file_ext}",
parent=self,
)
bad_ext_dialog.show()

def update_flops_label(
self,
digest_model: onnx_utils.DigestOnnxModel,
digest_model: DigestOnnxModel,
unique_id: str,
):
self.digest_models[unique_id].model_flops = digest_model.model_flops
Expand Down Expand Up @@ -432,7 +445,7 @@ def load_onnx(self, filepath: str):
basename = os.path.splitext(os.path.basename(filepath))
model_name = basename[0]

digest_model = onnx_utils.DigestOnnxModel(
digest_model = DigestOnnxModel(
onnx_model=model, model_name=model_name, save_proto=False
)
model_id = digest_model.unique_id
Expand Down Expand Up @@ -652,6 +665,251 @@ def load_onnx(self, filepath: str):
except FileNotFoundError as e:
print(f"File not found: {e.filename}")

def load_report(self, filepath: str):

# Ensure the filepath follows a standard formatting:
filepath = os.path.normpath(filepath)

if not os.path.exists(filepath):
return

# Every time a report is loaded we should emulate a model summary button click
self.summary_clicked()

# Before opening the file, check to see if it is already opened.
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
if isinstance(widget, modelSummary) and filepath == widget.file:
self.ui.tabWidget.setCurrentIndex(index)
return

try:

progress = ProgressDialog("Loading Digest Report File...", 8, self)
QApplication.processEvents() # Process pending events

with open(filepath, "r", encoding="utf-8") as yaml_f:
report_data = yaml.safe_load(yaml_f)
model_name = report_data["model_name"]

model_id = digest_model.unique_id

# There is no sense in offering to save the report
self.stats_save_button_flag[model_id] = False
self.similarity_save_button_flag[model_id] = False

self.digest_models[model_id] = digest_model

# We must set the proto for the model_summary freeze_inputs
self.digest_models[model_id].model_proto = opt_model

model_summary = modelSummary(self.digest_models[model_id])
model_summary.freeze_inputs.complete_signal.connect(self.load_onnx)

dynamic_input_dims = onnx_utils.get_dynamic_input_dims(opt_model)
if dynamic_input_dims:
model_summary.ui.freezeButton.setVisible(True)
model_summary.ui.warningLabel.setText(
"⚠️ Some model details are unavailable due to dynamic input dimensions. "
"See section Input Tensor(s) Information below for more details."
)
model_summary.ui.warningLabel.show()

elif not opt_passed:
model_summary.ui.warningLabel.setText(
"⚠️ The model could not be optimized either due to an ONNX Runtime "
"session error or it did not pass the ONNX checker."
)
model_summary.ui.warningLabel.show()

progress.step()
progress.setLabelText("Checking for dynamic Inputs")

self.ui.tabWidget.addTab(model_summary, "")
model_summary.ui.flops.setText("Loading...")

# Hide some of the components
model_summary.ui.similarityCorrelation.hide()
model_summary.ui.similarityCorrelationStatic.hide()

model_summary.file = filepath
model_summary.setObjectName(model_name)
model_summary.ui.modelName.setText(model_name)
model_summary.ui.modelFilename.setText(filepath)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))

self.digest_models[model_id].model_name = model_name
self.digest_models[model_id].filepath = filepath

self.digest_models[model_id].model_inputs = (
onnx_utils.get_model_input_shapes_types(opt_model)
)
self.digest_models[model_id].model_outputs = (
onnx_utils.get_model_output_shapes_types(opt_model)
)

progress.step()
progress.setLabelText("Calculating Parameter Count")

parameter_count = onnx_utils.get_parameter_count(opt_model)
model_summary.ui.parameters.setText(format(parameter_count, ","))

# Kick off model stats thread
self.model_nodes_stats_thread[model_id] = StatsThread()
self.model_nodes_stats_thread[model_id].completed.connect(
self.update_flops_label
)

self.model_nodes_stats_thread[model_id].model = opt_model
self.model_nodes_stats_thread[model_id].tab_name = model_name
self.model_nodes_stats_thread[model_id].unique_id = model_id
self.model_nodes_stats_thread[model_id].start()

progress.step()
progress.setLabelText("Calculating Node Type Counts")

node_type_counts = onnx_utils.get_node_type_counts(opt_model)
if len(node_type_counts) < 15:
bar_spacing = 40
else:
bar_spacing = 20
model_summary.ui.opHistogramChart.bar_spacing = bar_spacing
model_summary.ui.opHistogramChart.set_data(node_type_counts)
model_summary.ui.nodes.setText(str(sum(node_type_counts.values())))
self.digest_models[model_id].node_type_counts = node_type_counts

progress.step()
progress.setLabelText("Gathering Model Inputs and Outputs")

# Inputs Table
model_summary.ui.inputsTable.setRowCount(
len(self.digest_models[model_id].model_inputs)
)

for row_idx, (input_name, input_info) in enumerate(
self.digest_models[model_id].model_inputs.items()
):
model_summary.ui.inputsTable.setItem(
row_idx, 0, QTableWidgetItem(input_name)
)
model_summary.ui.inputsTable.setItem(
row_idx, 1, QTableWidgetItem(str(input_info.shape))
)
model_summary.ui.inputsTable.setItem(
row_idx, 2, QTableWidgetItem(str(input_info.dtype))
)
model_summary.ui.inputsTable.setItem(
row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes))
)

model_summary.ui.inputsTable.resizeColumnsToContents()
model_summary.ui.inputsTable.resizeRowsToContents()

# Outputs Table
model_summary.ui.outputsTable.setRowCount(
len(self.digest_models[model_id].model_outputs)
)
for row_idx, (output_name, output_info) in enumerate(
self.digest_models[model_id].model_outputs.items()
):
model_summary.ui.outputsTable.setItem(
row_idx, 0, QTableWidgetItem(output_name)
)
model_summary.ui.outputsTable.setItem(
row_idx, 1, QTableWidgetItem(str(output_info.shape))
)
model_summary.ui.outputsTable.setItem(
row_idx, 2, QTableWidgetItem(str(output_info.dtype))
)
model_summary.ui.outputsTable.setItem(
row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes))
)

model_summary.ui.outputsTable.resizeColumnsToContents()
model_summary.ui.outputsTable.resizeRowsToContents()

progress.step()
progress.setLabelText("Gathering Model Proto Data")

# ModelProto Info
model_summary.ui.modelProtoTable.setItem(
0, 1, QTableWidgetItem(str(opt_model.model_version))
)
self.digest_models[model_id].model_version = opt_model.model_version

model_summary.ui.modelProtoTable.setItem(
1, 1, QTableWidgetItem(str(opt_model.graph.name))
)
self.digest_models[model_id].graph_name = opt_model.graph.name

producer_txt = f"{opt_model.producer_name} {opt_model.producer_version}"
model_summary.ui.modelProtoTable.setItem(
2, 1, QTableWidgetItem(producer_txt)
)
self.digest_models[model_id].producer_name = opt_model.producer_name
self.digest_models[model_id].producer_version = opt_model.producer_version

model_summary.ui.modelProtoTable.setItem(
3, 1, QTableWidgetItem(str(opt_model.ir_version))
)
self.digest_models[model_id].ir_version = opt_model.ir_version

for imp in opt_model.opset_import:
row_idx = model_summary.ui.importsTable.rowCount()
model_summary.ui.importsTable.insertRow(row_idx)
if imp.domain == "" or imp.domain == "ai.onnx":
model_summary.ui.opsetVersion.setText(str(imp.version))
domain = "ai.onnx"
self.digest_models[model_id].opset = imp.version
else:
domain = imp.domain
model_summary.ui.importsTable.setItem(
row_idx, 0, QTableWidgetItem(str(domain))
)
model_summary.ui.importsTable.setItem(
row_idx, 1, QTableWidgetItem(str(imp.version))
)
row_idx += 1

self.digest_models[model_id].imports[imp.domain] = imp.version

progress.step()
progress.setLabelText("Wrapping Up Model Analysis")

model_summary.ui.importsTable.resizeColumnsToContents()
model_summary.ui.modelProtoTable.resizeColumnsToContents()
model_summary.setObjectName(model_name)
new_tab_idx = self.ui.tabWidget.count() - 1
self.ui.tabWidget.setTabText(new_tab_idx, "".join(model_name))
self.ui.tabWidget.setCurrentIndex(new_tab_idx)
self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY)
self.ui.singleModelWidget.show()
progress.step()

movie = QMovie(":/assets/gifs/load.gif")
model_summary.ui.similarityImg.setMovie(movie)
movie.start()

# Start similarity Analysis
# Note: Should only be started after the model tab has been created
png_tmp_path = os.path.join(self.temp_dir.name, model_id)
os.makedirs(png_tmp_path, exist_ok=True)
self.model_similarity_thread[model_id] = SimilarityThread()
self.model_similarity_thread[model_id].completed_successfully.connect(
self.update_similarity_widget
)
self.model_similarity_thread[model_id].model_filepath = filepath
self.model_similarity_thread[model_id].png_filepath = os.path.join(
png_tmp_path, f"heatmap_{model_name}.png"
)
self.model_similarity_thread[model_id].model_id = model_id
self.model_similarity_thread[model_id].start()

progress.close()

except FileNotFoundError as e:
print(f"File not found: {e.filename}")

def dragEnterEvent(self, event: QDragEnterEvent):
if event.mimeData().hasUrls():
event.acceptProposedAction()
Expand Down Expand Up @@ -740,9 +998,7 @@ def save_reports(self):
)
node_counter = digest_model.get_node_type_counts()
if node_counter:
onnx_utils.save_node_type_counts_csv_report(
node_counter, node_type_filepath
)
save_node_type_counts_csv_report(node_counter, node_type_filepath)

# Save the similarity image
similarity_png = self.model_similarity_report[digest_model.unique_id].grab()
Expand All @@ -754,6 +1010,10 @@ def save_reports(self):
txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt")
digest_model.save_txt_report(txt_report_filepath)

# Save the yaml report
yaml_report_filepath = os.path.join(save_directory, f"{model_name}_report.yaml")
digest_model.save_yaml_report(yaml_report_filepath)

# Save the node list
nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv")
self.save_nodes_csv(nodes_report_filepath, False)
Expand Down
Loading

0 comments on commit e33c35c

Please sign in to comment.