diff --git a/.pylintrc b/.pylintrc index b831756..efb14e7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -81,7 +81,6 @@ enable = expression-not-assigned, confusing-with-statement, unnecessary-lambda, - assign-to-new-keyword, redeclared-assigned-name, pointless-statement, pointless-string-statement, @@ -123,7 +122,6 @@ enable = invalid-length-returned, protected-access, attribute-defined-outside-init, - no-init, abstract-method, invalid-overridden-method, arguments-differ, @@ -165,9 +163,7 @@ enable = ### format # Line length, indentation, whitespace: bad-indentation, - mixed-indentation, unnecessary-semicolon, - bad-whitespace, missing-final-newline, line-too-long, mixed-line-endings, @@ -187,7 +183,6 @@ enable = import-self, preferred-module, reimported, - relative-import, deprecated-module, wildcard-import, misplaced-future, @@ -282,12 +277,6 @@ indent-string = ' ' # black doesn't always obey its own limit. See pyproject.toml. max-line-length = 100 -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check = - # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt = no diff --git a/examples/analysis.py b/examples/analysis.py index e9b9c63..0910637 100644 --- a/examples/analysis.py +++ b/examples/analysis.py @@ -6,14 +6,16 @@ import csv from collections import Counter, defaultdict from tqdm import tqdm +from digest.model_class.digest_model import ( + NodeShapeCounts, + NodeTypeCounts, + save_node_shape_counts_csv_report, + save_node_type_counts_csv_report, +) +from digest.model_class.digest_onnx_model import DigestOnnxModel from utils.onnx_utils import ( get_dynamic_input_dims, load_onnx, - DigestOnnxModel, - save_node_shape_counts_csv_report, - save_node_type_counts_csv_report, - NodeTypeCounts, - NodeShapeCounts, ) GLOBAL_MODEL_HEADERS = [ @@ -71,7 +73,7 @@ def main(onnx_files: str, output_dir: str): print(f"dim: {dynamic_shape}") digest_model = DigestOnnxModel( - model_proto, onnx_filepath=onnx_file, model_name=model_name + model_proto, onnx_file_path=onnx_file, model_name=model_name ) # Update the global model dictionary @@ -82,46 +84,46 @@ def main(onnx_files: str, output_dir: str): global_model_data[model_name] = { "opset": digest_model.opset, - "parameters": digest_model.model_parameters, - "flops": digest_model.model_flops, + "parameters": digest_model.parameters, + "flops": digest_model.flops, } # Model summary text report summary_filepath = os.path.join(output_dir, f"{model_name}_summary.txt") - digest_model.save_txt_report(summary_filepath) + digest_model.save_text_report(summary_filepath) + + # Model summary yaml report + summary_filepath = os.path.join(output_dir, f"{model_name}_summary.yaml") + digest_model.save_yaml_report(summary_filepath) # Save csv containing node-level information nodes_filepath = os.path.join(output_dir, f"{model_name}_nodes.csv") digest_model.save_nodes_csv_report(nodes_filepath) # Save csv containing node type counter - node_type_counter = digest_model.get_node_type_counts() node_type_filepath = os.path.join( output_dir, f"{model_name}_node_type_counts.csv" ) - if node_type_counter: - save_node_type_counts_csv_report(node_type_counter, node_type_filepath) + + digest_model.save_node_type_counts_csv_report(node_type_filepath) # Update global data structure for node type counter - global_node_type_counter.update(node_type_counter) + global_node_type_counter.update(digest_model.node_type_counts) # Save csv containing node shape counts per op_type - node_shape_counts = digest_model.get_node_shape_counts() node_shape_filepath = os.path.join( output_dir, f"{model_name}_node_shape_counts.csv" ) - save_node_shape_counts_csv_report(node_shape_counts, node_shape_filepath) + digest_model.save_node_shape_counts_csv_report(node_shape_filepath) # Update global data structure for node shape counter - for node_type, shape_counts in node_shape_counts.items(): + for node_type, shape_counts in digest_model.get_node_shape_counts().items(): global_node_shape_counter[node_type].update(shape_counts) if len(onnx_file_list) > 1: global_filepath = os.path.join(output_dir, "global_node_type_counts.csv") - global_node_type_counter = NodeTypeCounts( - global_node_type_counter.most_common() - ) - save_node_type_counts_csv_report(global_node_type_counter, global_filepath) + global_node_type_counts = NodeTypeCounts(global_node_type_counter.most_common()) + save_node_type_counts_csv_report(global_node_type_counts, global_filepath) global_filepath = os.path.join(output_dir, "global_node_shape_counts.csv") save_node_shape_counts_csv_report(global_node_shape_counter, global_filepath) diff --git a/setup.py b/setup.py index ca21f4a..d6a16f7 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="digestai", - version="1.0.0", + version="1.2.0", description="Model analysis toolkit", author="Philip Colangelo, Daniel Holanda", packages=find_packages(where="src"), @@ -25,6 +25,8 @@ "platformdirs>=4.2.2", "pyyaml>=6.0.1", "psutil>=6.0.0", + "torch", + "transformers", ], classifiers=[], entry_points={"console_scripts": ["digest = digest.main:main"]}, diff --git a/src/digest/dialog.py b/src/digest/dialog.py index d2f834e..ae9986d 100644 --- a/src/digest/dialog.py +++ b/src/digest/dialog.py @@ -125,13 +125,23 @@ class WarnDialog(QDialog): def __init__(self, warning_message: str, parent=None): super().__init__(parent) - self.setWindowTitle("Warning Message") + self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) + + self.setWindowTitle("Warning Message") + self.setWindowFlags(Qt.WindowType.Dialog) self.setMinimumWidth(300) + self.setWindowModality(Qt.WindowModality.WindowModal) + layout = QVBoxLayout() # Application Version - layout.addWidget(QLabel("Something went wrong")) + layout.addWidget(QLabel("Warning")) layout.addWidget(QLabel(warning_message)) + + ok_button = QPushButton("OK") + ok_button.clicked.connect(self.accept) # Close dialog when clicked + layout.addWidget(ok_button) + self.setLayout(layout) diff --git a/src/digest/gui_config.yaml b/src/digest/gui_config.yaml index baffd47..dbd1c08 100644 --- a/src/digest/gui_config.yaml +++ b/src/digest/gui_config.yaml @@ -2,4 +2,4 @@ # For EXE releases we can block certain features e.g. to customers modules: - huggingface: false \ No newline at end of file + huggingface: true \ No newline at end of file diff --git a/src/digest/histogramchartwidget.py b/src/digest/histogramchartwidget.py index 97d5f16..f72befb 100644 --- a/src/digest/histogramchartwidget.py +++ b/src/digest/histogramchartwidget.py @@ -140,7 +140,7 @@ def __init__(self, *args, **kwargs): super(StackedHistogramWidget, self).__init__(*args, **kwargs) self.plot_widget = pg.PlotWidget() - self.plot_widget.setMaximumHeight(150) + self.plot_widget.setMaximumHeight(200) plot_item = self.plot_widget.getPlotItem() if plot_item: plot_item.setContentsMargins(0, 0, 0, 0) @@ -157,7 +157,6 @@ def __init__(self, *args, **kwargs): self.bar_spacing = 25 def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=False): - title_color = "rgb(0,0,0)" if set_ticks else "rgb(200,200,200)" self.plot_widget.setLabel( "left", @@ -173,7 +172,8 @@ def set_data(self, data: OrderedDict, model_name, y_max, title="", set_ticks=Fal x_positions = list(range(len(op_count))) total_count = sum(op_count) width = 0.6 - self.plot_widget.setFixedWidth(len(op_names) * self.bar_spacing) + self.plot_widget.setFixedWidth(500) + for count, x_pos, tick in zip(op_count, x_positions, op_names): x0 = x_pos - width / 2 y0 = 0 diff --git a/src/digest/main.py b/src/digest/main.py index 08c401a..70ab66a 100644 --- a/src/digest/main.py +++ b/src/digest/main.py @@ -3,11 +3,13 @@ import os import sys +import shutil import argparse from datetime import datetime -from typing import Dict, Tuple, Optional +from typing import Dict, Tuple, Optional, Union import tempfile from enum import IntEnum +import pandas as pd import yaml # This is a temporary workaround since the Qt designer generated files @@ -33,17 +35,22 @@ QMenu, ) from PySide6.QtGui import QDragEnterEvent, QDropEvent, QPixmap, QMovie, QIcon, QFont -from PySide6.QtCore import Qt, QDir +from PySide6.QtCore import Qt, QSize from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog -from digest.thread import StatsThread, SimilarityThread -from digest.popup_window import PopupWindow +from digest.thread import StatsThread, SimilarityThread, post_process +from digest.popup_window import PopupWindow, PopupDialog from digest.huggingface_page import HuggingfacePage +from digest.pytorch_ingest import PyTorchIngest from digest.multi_model_selection_page import MultiModelSelectionPage from digest.ui.mainwindow_ui import Ui_MainWindow from digest.modelsummary import modelSummary from digest.node_summary import NodeSummary from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_model import DigestModel +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel +from digest.model_class.digest_pytorch_model import DigestPyTorchModel from utils import onnx_utils GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml") @@ -161,11 +168,16 @@ def __init__(self, model_file: Optional[str] = None): self.status_dialog = None self.err_open_dialog = None self.temp_dir = tempfile.TemporaryDirectory() - self.digest_models: Dict[str, onnx_utils.DigestOnnxModel] = {} + self.digest_models: Dict[ + str, Union[DigestOnnxModel, DigestReportModel, DigestPyTorchModel] + ] = {} + + self.pytorch_ingest_window: Optional[PopupDialog] = None # QThread containers self.model_nodes_stats_thread: Dict[str, StatsThread] = {} self.model_similarity_thread: Dict[str, SimilarityThread] = {} + self.model_similarity_report: Dict[str, SimilarityAnalysisReport] = {} self.ui.singleModelWidget.hide() @@ -209,7 +221,7 @@ def __init__(self, model_file: Optional[str] = None): # Set up the HUGGINGFACE Page huggingface_page = HuggingfacePage() - huggingface_page.model_signal.connect(self.load_onnx) + huggingface_page.model_signal.connect(self.load_model) self.ui.stackedWidget.insertWidget(self.Page.HUGGINGFACE, huggingface_page) # Set up the multi model page and relevant button @@ -217,15 +229,19 @@ def __init__(self, model_file: Optional[str] = None): self.ui.stackedWidget.insertWidget( self.Page.MULTIMODEL, self.multimodelselection_page ) - self.multimodelselection_page.model_signal.connect(self.load_onnx) + self.multimodelselection_page.model_signal.connect(self.load_model) + + # Set up the pyptorch ingest page + self.pytorch_ingest: Optional[PyTorchIngest] = None # Load model file if given as input to the executable if model_file: - if ( - os.path.exists(model_file) - and os.path.splitext(model_file)[-1] == ".onnx" - ): + exists = os.path.exists(model_file) + ext = os.path.splitext(model_file)[-1] + if exists and ext == ".onnx": self.load_onnx(model_file) + elif exists and ext == ".yaml": + self.load_report(model_file) else: self.err_open_dialog = StatusDialog( f"Could not open {model_file}", parent=self @@ -243,10 +259,11 @@ def uncheck_ingest_buttons(self): def tab_focused(self, index): widget = self.ui.tabWidget.widget(index) if isinstance(widget, modelSummary): - model_id = widget.digest_model.unique_id + unique_id = widget.digest_model.unique_id if ( - self.stats_save_button_flag[model_id] - and self.similarity_save_button_flag[model_id] + self.stats_save_button_flag[unique_id] + and self.similarity_save_button_flag[unique_id] + and not isinstance(widget.digest_model, DigestReportModel) ): self.ui.saveBtn.setEnabled(True) else: @@ -257,14 +274,20 @@ def closeTab(self, index): if isinstance(summary_widget, modelSummary): unique_id = summary_widget.digest_model.unique_id summary_widget.deleteLater() - tab_thread = self.model_nodes_stats_thread[unique_id] + + tab_thread = self.model_nodes_stats_thread.get(unique_id) if tab_thread: tab_thread.exit() + tab_thread.wait(5000) + if not tab_thread.isRunning(): del self.model_nodes_stats_thread[unique_id] + else: + print(f"Warning: Thread for {unique_id} did not finish in time") + # delete the digest model to free up used memory if unique_id in self.digest_models: - del self.digest_models[unique_id] + self.digest_models.pop(unique_id) self.ui.tabWidget.removeTab(index) if self.ui.tabWidget.count() == 0: @@ -272,40 +295,44 @@ def closeTab(self, index): self.ui.singleModelWidget.hide() def openFile(self): - filename, _ = QFileDialog.getOpenFileName( - self, "Open File", "", "ONNX Files (*.onnx)" + file_name, _ = QFileDialog.getOpenFileName( + self, + "Open File", + "", + "ONNX, PyTorch, and Report Files (*.onnx *.pt *.yaml)", ) - if ( - filename and os.path.splitext(filename)[-1] == ".onnx" - ): # Only if user selects a file and clicks OK - self.load_onnx(filename) + if not file_name: + return - def update_flops_label( + self.load_model(file_name) + + def update_cards( self, - digest_model: onnx_utils.DigestOnnxModel, + digest_model: DigestModel, unique_id: str, ): - self.digest_models[unique_id].model_flops = digest_model.model_flops + self.digest_models[unique_id].flops = digest_model.flops self.digest_models[unique_id].node_type_flops = digest_model.node_type_flops - self.digest_models[unique_id].model_parameters = digest_model.model_parameters + self.digest_models[unique_id].parameters = digest_model.parameters self.digest_models[unique_id].node_type_parameters = ( digest_model.node_type_parameters ) - self.digest_models[unique_id].per_node_info = digest_model.per_node_info + self.digest_models[unique_id].node_data = digest_model.node_data # We must iterate over the tabWidget and match to the tab_name because the user # may have switched the currentTab during the threads execution. + curr_index = -1 for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) if ( isinstance(widget, modelSummary) and widget.digest_model.unique_id == unique_id ): - if digest_model.model_flops is None: + if digest_model.flops is None: flops_str = "--" else: - flops_str = format(digest_model.model_flops, ",") + flops_str = format(digest_model.flops, ",") # Set up the pie chart pie_chart_labels, pie_chart_data = zip( @@ -328,11 +355,14 @@ def update_flops_label( pie_chart_labels, pie_chart_data, ) + curr_index = index break self.stats_save_button_flag[unique_id] = True - if self.ui.tabWidget.currentIndex() == index: - if self.similarity_save_button_flag[unique_id]: + if self.ui.tabWidget.currentIndex() == curr_index: + if self.similarity_save_button_flag[unique_id] and not isinstance( + digest_model, DigestReportModel + ): self.ui.saveBtn.setEnabled(True) def open_similarity_report(self, model_id: str, image_path, most_similar_models): @@ -346,10 +376,12 @@ def update_similarity_widget( completed_successfully: bool, model_id: str, most_similar: str, - png_filepath: str, + png_file_path: Optional[str] = None, + df_sorted: Optional[pd.DataFrame] = None, ): - widget = None + digest_model = None + curr_index = -1 for index in range(self.ui.tabWidget.count()): tab_widget = self.ui.tabWidget.widget(index) if ( @@ -357,83 +389,146 @@ def update_similarity_widget( and tab_widget.digest_model.unique_id == model_id ): widget = tab_widget + digest_model = tab_widget.digest_model + curr_index = index break - if completed_successfully and isinstance(widget, modelSummary): + # convert back to a List[str] + most_similar_list = most_similar.split(",") + + if ( + completed_successfully + and isinstance(widget, modelSummary) + and digest_model + and png_file_path + ): + + if df_sorted is not None: + post_process( + digest_model.model_name, most_similar_list, df_sorted, png_file_path + ) + + widget.load_gif.stop() + widget.ui.similarityImg.clear() + # We give the image a 10% haircut to fit it more aesthetically widget_width = widget.ui.similarityImg.width() - widget.ui.similarityImg.setPixmap( - QPixmap(png_filepath).scaledToWidth(widget_width) + + pixmap = QPixmap(png_file_path) + aspect_ratio = pixmap.width() / pixmap.height() + target_height = int(widget_width / aspect_ratio) + pixmap_scaled = pixmap.scaled( + QSize(widget_width, target_height), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, ) + + widget.ui.similarityImg.setPixmap(pixmap_scaled) widget.ui.similarityImg.setText("") widget.ui.similarityImg.setCursor(Qt.CursorShape.PointingHandCursor) # Show most correlated models widget.ui.similarityCorrelation.show() widget.ui.similarityCorrelationStatic.show() - most_similar_models = most_similar.split(",") - text = ( - "\n" - f"{most_similar_models[0]}, {most_similar_models[1]}, and {most_similar_models[2]}." - "" - ) + + most_similar_list = most_similar_list[1:4] + if most_similar: + text = ( + "\n" + f"{most_similar_list[0]}, {most_similar_list[1]}, " + f"and {most_similar_list[2]}. " + "" + ) + else: + # currently the similarity widget expects the most_similar_models + # to allows contains 3 models. For now we will just send three empty + # strings but at some point we should handle an arbitrary case. + most_similar_list = ["", "", ""] + text = "NTD" # Create option to click to enlarge image widget.ui.similarityImg.mousePressEvent = ( lambda event: self.open_similarity_report( - model_id, png_filepath, most_similar_models + model_id, png_file_path, most_similar_list ) ) # Create option to click to enlarge image self.model_similarity_report[model_id] = SimilarityAnalysisReport( - png_filepath, most_similar_models + png_file_path, most_similar_list ) widget.ui.similarityCorrelation.setText(text) elif isinstance(widget, modelSummary): # Remove animation and set text to failing message - widget.ui.similarityImg.setMovie(QMovie(None)) + widget.load_gif.stop() + widget.ui.similarityImg.clear() widget.ui.similarityImg.setText("Failed to perform similarity analysis") else: - print("Tab widget is not of type modelSummary which is unexpected.") + print( + f"Tab widget is of type {type(widget)} and not of type modelSummary " + "which is unexpected." + ) - # self.similarity_save_button_flag[model_id] = True - if self.ui.tabWidget.currentIndex() == index: - if self.stats_save_button_flag[model_id]: + if self.ui.tabWidget.currentIndex() == curr_index: + if self.stats_save_button_flag[model_id] and not isinstance( + digest_model, DigestReportModel + ): self.ui.saveBtn.setEnabled(True) - def load_onnx(self, filepath: str): + def load_onnx(self, file_path: str): - # Ensure the filepath follows a standard formatting: - filepath = os.path.normpath(filepath) + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) - if not os.path.exists(filepath): + if not os.path.exists(file_path): return # Every time an onnx is loaded we should emulate a model summary button click self.summary_clicked() - # Before opening the file, check to see if it is already opened. + model_proto = None + + # Before opening the ONNX file, check to see if it is already opened. for index in range(self.ui.tabWidget.count()): widget = self.ui.tabWidget.widget(index) - if isinstance(widget, modelSummary) and filepath == widget.file: - self.ui.tabWidget.setCurrentIndex(index) - return + if ( + isinstance(widget, modelSummary) + and isinstance(widget.digest_model, DigestOnnxModel) + and file_path == widget.file + ): + # Check if the model proto is different + if widget.digest_model.model_proto: + model_proto = onnx_utils.load_onnx( + file_path, load_external_data=False + ) + # If they are equivalent, set the GUI to show the existing + # report and return + if model_proto == widget.digest_model.model_proto: + self.ui.tabWidget.setCurrentIndex(index) + return + # If they aren't equivalent, then the proto has been modified. In this case, + # we close the tab associated with the stale model, remove from the model list, + # then go through the standard process of adding it to the tabWidget. In the + # future, it may be slightly better to have an update tab function. + else: + self.closeTab(index) try: progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self) QApplication.processEvents() # Process pending events - model = onnx_utils.load_onnx(filepath, load_external_data=False) - opt_model, opt_passed = onnx_utils.optimize_onnx_model(model) + if not model_proto: + model_proto = onnx_utils.load_onnx(file_path, load_external_data=False) + opt_model, opt_passed = onnx_utils.optimize_onnx_model(model_proto) progress.step() - basename = os.path.splitext(os.path.basename(filepath)) + basename = os.path.splitext(os.path.basename(file_path)) model_name = basename[0] - digest_model = onnx_utils.DigestOnnxModel( - onnx_model=model, model_name=model_name, save_proto=False + # Save the model proto so we can use the Freeze Inputs feature + digest_model = DigestOnnxModel( + onnx_model=opt_model, model_name=model_name, save_proto=True ) model_id = digest_model.unique_id @@ -442,11 +537,9 @@ def load_onnx(self, filepath: str): self.digest_models[model_id] = digest_model - # We must set the proto for the model_summary freeze_inputs - self.digest_models[model_id].model_proto = opt_model - - model_summary = modelSummary(self.digest_models[model_id]) - model_summary.freeze_inputs.complete_signal.connect(self.load_onnx) + model_summary = modelSummary(digest_model) + if model_summary.freeze_inputs: + model_summary.freeze_inputs.complete_signal.connect(self.load_onnx) dynamic_input_dims = onnx_utils.get_dynamic_input_dims(opt_model) if dynamic_input_dims: @@ -474,20 +567,19 @@ def load_onnx(self, filepath: str): model_summary.ui.similarityCorrelation.hide() model_summary.ui.similarityCorrelationStatic.hide() - model_summary.file = filepath + model_summary.file = file_path model_summary.setObjectName(model_name) model_summary.ui.modelName.setText(model_name) - model_summary.ui.modelFilename.setText(filepath) + model_summary.ui.modelFilename.setText(file_path) model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) - self.digest_models[model_id].model_name = model_name - self.digest_models[model_id].filepath = filepath - - self.digest_models[model_id].model_inputs = ( - onnx_utils.get_model_input_shapes_types(opt_model) + digest_model.model_name = model_name + digest_model.file_path = file_path + digest_model.model_inputs = onnx_utils.get_model_input_shapes_types( + opt_model ) - self.digest_models[model_id].model_outputs = ( - onnx_utils.get_model_output_shapes_types(opt_model) + digest_model.model_outputs = onnx_utils.get_model_output_shapes_types( + opt_model ) progress.step() @@ -498,9 +590,7 @@ def load_onnx(self, filepath: str): # Kick off model stats thread self.model_nodes_stats_thread[model_id] = StatsThread() - self.model_nodes_stats_thread[model_id].completed.connect( - self.update_flops_label - ) + self.model_nodes_stats_thread[model_id].completed.connect(self.update_cards) self.model_nodes_stats_thread[model_id].model = opt_model self.model_nodes_stats_thread[model_id].tab_name = model_name @@ -518,7 +608,7 @@ def load_onnx(self, filepath: str): model_summary.ui.opHistogramChart.bar_spacing = bar_spacing model_summary.ui.opHistogramChart.set_data(node_type_counts) model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) - self.digest_models[model_id].node_type_counts = node_type_counts + digest_model.node_type_counts = node_type_counts progress.step() progress.setLabelText("Gathering Model Inputs and Outputs") @@ -577,24 +667,24 @@ def load_onnx(self, filepath: str): model_summary.ui.modelProtoTable.setItem( 0, 1, QTableWidgetItem(str(opt_model.model_version)) ) - self.digest_models[model_id].model_version = opt_model.model_version + digest_model.model_version = opt_model.model_version model_summary.ui.modelProtoTable.setItem( 1, 1, QTableWidgetItem(str(opt_model.graph.name)) ) - self.digest_models[model_id].graph_name = opt_model.graph.name + digest_model.graph_name = opt_model.graph.name producer_txt = f"{opt_model.producer_name} {opt_model.producer_version}" model_summary.ui.modelProtoTable.setItem( 2, 1, QTableWidgetItem(producer_txt) ) - self.digest_models[model_id].producer_name = opt_model.producer_name - self.digest_models[model_id].producer_version = opt_model.producer_version + digest_model.producer_name = opt_model.producer_name + digest_model.producer_version = opt_model.producer_version model_summary.ui.modelProtoTable.setItem( 3, 1, QTableWidgetItem(str(opt_model.ir_version)) ) - self.digest_models[model_id].ir_version = opt_model.ir_version + digest_model.ir_version = opt_model.ir_version for imp in opt_model.opset_import: row_idx = model_summary.ui.importsTable.rowCount() @@ -602,7 +692,7 @@ def load_onnx(self, filepath: str): if imp.domain == "" or imp.domain == "ai.onnx": model_summary.ui.opsetVersion.setText(str(imp.version)) domain = "ai.onnx" - self.digest_models[model_id].opset = imp.version + digest_model.opset = imp.version else: domain = imp.domain model_summary.ui.importsTable.setItem( @@ -613,7 +703,7 @@ def load_onnx(self, filepath: str): ) row_idx += 1 - self.digest_models[model_id].imports[imp.domain] = imp.version + digest_model.imports[imp.domain] = imp.version progress.step() progress.setLabelText("Wrapping Up Model Analysis") @@ -628,20 +718,17 @@ def load_onnx(self, filepath: str): self.ui.singleModelWidget.show() progress.step() - movie = QMovie(":/assets/gifs/load.gif") - model_summary.ui.similarityImg.setMovie(movie) - movie.start() - # Start similarity Analysis # Note: Should only be started after the model tab has been created png_tmp_path = os.path.join(self.temp_dir.name, model_id) os.makedirs(png_tmp_path, exist_ok=True) + assert os.path.exists(png_tmp_path), f"Error with creating {png_tmp_path}" self.model_similarity_thread[model_id] = SimilarityThread() self.model_similarity_thread[model_id].completed_successfully.connect( self.update_similarity_widget ) - self.model_similarity_thread[model_id].model_filepath = filepath - self.model_similarity_thread[model_id].png_filepath = os.path.join( + self.model_similarity_thread[model_id].model_file_path = file_path + self.model_similarity_thread[model_id].png_file_path = os.path.join( png_tmp_path, f"heatmap_{model_name}.png" ) self.model_similarity_thread[model_id].model_id = model_id @@ -652,6 +739,243 @@ def load_onnx(self, filepath: str): except FileNotFoundError as e: print(f"File not found: {e.filename}") + def load_report(self, file_path: str): + + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) + + if not os.path.exists(file_path): + return + + # Every time a report is loaded we should emulate a model summary button click + self.summary_clicked() + + # Before opening the file, check to see if it is already opened. + for index in range(self.ui.tabWidget.count()): + widget = self.ui.tabWidget.widget(index) + if isinstance(widget, modelSummary) and file_path == widget.file: + self.ui.tabWidget.setCurrentIndex(index) + return + + try: + + progress = ProgressDialog("Loading Digest Report File...", 2, self) + QApplication.processEvents() # Process pending events + + digest_model = DigestReportModel(file_path) + + if not digest_model.is_valid: + progress.close() + invalid_yaml_dialog = StatusDialog( + title="Warning", + status_message=f"YAML file {file_path} is not a valid digest report", + ) + invalid_yaml_dialog.show() + + return + + model_id = digest_model.unique_id + + # There is no sense in offering to save the report + self.stats_save_button_flag[model_id] = False + self.similarity_save_button_flag[model_id] = False + + self.digest_models[model_id] = digest_model + + model_summary = modelSummary(digest_model) + + self.ui.tabWidget.addTab(model_summary, "") + model_summary.ui.flops.setText("Loading...") + + # Hide some of the components + model_summary.ui.similarityCorrelation.hide() + model_summary.ui.similarityCorrelationStatic.hide() + + model_summary.file = file_path + model_summary.setObjectName(digest_model.model_name) + model_summary.ui.modelName.setText(digest_model.model_name) + model_summary.ui.modelFilename.setText(file_path) + model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y")) + + model_summary.ui.parameters.setText(format(digest_model.parameters, ",")) + + node_type_counts = digest_model.node_type_counts + if len(node_type_counts) < 15: + bar_spacing = 40 + else: + bar_spacing = 20 + + model_summary.ui.opHistogramChart.bar_spacing = bar_spacing + model_summary.ui.opHistogramChart.set_data(node_type_counts) + model_summary.ui.nodes.setText(str(sum(node_type_counts.values()))) + + progress.step() + progress.setLabelText("Gathering Model Inputs and Outputs") + + # Inputs Table + model_summary.ui.inputsTable.setRowCount( + len(self.digest_models[model_id].model_inputs) + ) + + for row_idx, (input_name, input_info) in enumerate( + self.digest_models[model_id].model_inputs.items() + ): + model_summary.ui.inputsTable.setItem( + row_idx, 0, QTableWidgetItem(input_name) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(input_info.shape)) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(input_info.dtype)) + ) + model_summary.ui.inputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes)) + ) + + model_summary.ui.inputsTable.resizeColumnsToContents() + model_summary.ui.inputsTable.resizeRowsToContents() + + # Outputs Table + model_summary.ui.outputsTable.setRowCount( + len(self.digest_models[model_id].model_outputs) + ) + for row_idx, (output_name, output_info) in enumerate( + self.digest_models[model_id].model_outputs.items() + ): + model_summary.ui.outputsTable.setItem( + row_idx, 0, QTableWidgetItem(output_name) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 1, QTableWidgetItem(str(output_info.shape)) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 2, QTableWidgetItem(str(output_info.dtype)) + ) + model_summary.ui.outputsTable.setItem( + row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes)) + ) + + model_summary.ui.outputsTable.resizeColumnsToContents() + model_summary.ui.outputsTable.resizeRowsToContents() + + progress.step() + progress.setLabelText("Gathering Model Proto Data") + + # ModelProto Info + model_summary.ui.modelProtoTable.setItem( + 0, 1, QTableWidgetItem(str(digest_model.model_data["model_version"])) + ) + + model_summary.ui.modelProtoTable.setItem( + 1, 1, QTableWidgetItem(str(digest_model.model_data["graph_name"])) + ) + + producer_txt = ( + f"{digest_model.model_data['producer_name']} " + f"{digest_model.model_data['producer_version']}" + ) + model_summary.ui.modelProtoTable.setItem( + 2, 1, QTableWidgetItem(producer_txt) + ) + + model_summary.ui.modelProtoTable.setItem( + 3, 1, QTableWidgetItem(str(digest_model.model_data["ir_version"])) + ) + + for domain, version in digest_model.model_data["import_list"].items(): + row_idx = model_summary.ui.importsTable.rowCount() + model_summary.ui.importsTable.insertRow(row_idx) + if domain == "" or domain == "ai.onnx": + model_summary.ui.opsetVersion.setText(str(version)) + domain = "ai.onnx" + + model_summary.ui.importsTable.setItem( + row_idx, 0, QTableWidgetItem(str(domain)) + ) + model_summary.ui.importsTable.setItem( + row_idx, 1, QTableWidgetItem(str(version)) + ) + row_idx += 1 + + progress.step() + progress.setLabelText("Wrapping Up Model Analysis") + + model_summary.ui.importsTable.resizeColumnsToContents() + model_summary.ui.modelProtoTable.resizeColumnsToContents() + model_summary.setObjectName(digest_model.model_name) + new_tab_idx = self.ui.tabWidget.count() - 1 + self.ui.tabWidget.setTabText(new_tab_idx, "".join(digest_model.model_name)) + self.ui.tabWidget.setCurrentIndex(new_tab_idx) + self.ui.stackedWidget.setCurrentIndex(self.Page.SUMMARY) + self.ui.singleModelWidget.show() + progress.step() + + self.update_cards(digest_model, digest_model.unique_id) + + movie = QMovie(":/assets/gifs/load.gif") + model_summary.ui.similarityImg.setMovie(movie) + movie.start() + + self.update_similarity_widget( + completed_successfully=bool(digest_model.similarity_heatmap_path), + model_id=digest_model.unique_id, + most_similar="", + png_file_path=digest_model.similarity_heatmap_path, + ) + + progress.close() + + except FileNotFoundError as e: + print(f"File not found: {e.filename}") + + def load_pytorch(self, file_path: str): + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) + + if not os.path.exists(file_path): + return + + basename = os.path.splitext(os.path.basename(file_path)) + model_name = basename[0] + + # The current support for PyTorch includes exporting it to ONNX. In this case, + # an ingest window will pop up giving the user options to export. This window + # will block the main GUI until the ingest window is closed + self.pytorch_ingest = PyTorchIngest(file_path, model_name) + self.pytorch_ingest_window = PopupDialog( + self.pytorch_ingest, "PyTorch Ingest", self + ) + self.pytorch_ingest_window.open() + + # The above code will block until the user has completed the pytorch ingest form + # The form will exit upon a successful export at which point the path will be set + if self.pytorch_ingest.digest_pytorch_model.onnx_file_path: + self.load_onnx(self.pytorch_ingest.digest_pytorch_model.onnx_file_path) + + def load_model(self, file_path: str): + + # Ensure the file_path follows a standard formatting: + file_path = os.path.normpath(file_path) + + if not os.path.exists(file_path): + return + + file_ext = os.path.splitext(file_path)[-1] + + if file_ext == ".onnx": + self.load_onnx(file_path) + elif file_ext == ".yaml": + self.load_report(file_path) + elif file_ext == ".pt" or file_ext == ".pth": + self.load_pytorch(file_path) + else: + bad_ext_dialog = StatusDialog( + f"Digest does not support files with the extension {file_ext}", + parent=self, + ) + bad_ext_dialog.show() + def dragEnterEvent(self, event: QDragEnterEvent): if event.mimeData().hasUrls(): event.acceptProposedAction() @@ -660,9 +984,7 @@ def dropEvent(self, event: QDropEvent): if event.mimeData().hasUrls(): for url in event.mimeData().urls(): file_path = url.toLocalFile() - if file_path.endswith(".onnx"): - self.load_onnx(file_path) - break + self.load_model(file_path) ## functions for changing menu page def logo_clicked(self): @@ -710,14 +1032,8 @@ def save_reports(self): self, "Select Directory" ) - if not save_directory: - return - - # Create a QDir object - directory = QDir(save_directory) - # Check if the directory exists and is writable - if not directory.exists() and directory.isWritable(): # type: ignore + if not os.path.exists(save_directory) or not os.access(save_directory, os.W_OK): self.show_warning_dialog( f"The directory {save_directory} is not valid or writable." ) @@ -726,43 +1042,57 @@ def save_reports(self): save_directory, str(digest_model.model_name) + "_reports" ) - os.makedirs(save_directory, exist_ok=True) - - # Save the node histogram image - node_histogram = current_tab.ui.opHistogramChart.grab() - node_histogram.save( - os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG" - ) + try: + os.makedirs(save_directory, exist_ok=True) - # Save csv of node type counts - node_type_filepath = os.path.join( - save_directory, f"{model_name}_node_type_counts.csv" - ) - node_counter = digest_model.get_node_type_counts() - if node_counter: - onnx_utils.save_node_type_counts_csv_report( - node_counter, node_type_filepath + # Save the node histogram image + node_histogram = current_tab.ui.opHistogramChart.grab() + node_histogram.save( + os.path.join(save_directory, f"{model_name}_histogram.png"), "PNG" ) - # Save the similarity image - similarity_png = self.model_similarity_report[digest_model.unique_id].grab() - similarity_png.save( - os.path.join(save_directory, f"{model_name}_heatmap.png"), "PNG" - ) + # Save csv of node type counts + node_type_file_path = os.path.join( + save_directory, f"{model_name}_node_type_counts.csv" + ) + digest_model.save_node_type_counts_csv_report(node_type_file_path) + + # Save (copy) the similarity image + png_file_path = self.model_similarity_thread[ + digest_model.unique_id + ].png_file_path + png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png") + if png_file_path and os.path.exists(png_file_path): + shutil.copy(png_file_path, png_save_path) + + # Save the text report + txt_report_file_path = os.path.join( + save_directory, f"{model_name}_report.txt" + ) + digest_model.save_text_report(txt_report_file_path) - # Save the text report - txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt") - digest_model.save_txt_report(txt_report_filepath) + # Save the yaml report + yaml_report_file_path = os.path.join( + save_directory, f"{model_name}_report.yaml" + ) + digest_model.save_yaml_report(yaml_report_file_path) - # Save the node list - nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv") - self.save_nodes_csv(nodes_report_filepath, False) + # Save the node list + nodes_report_file_path = os.path.join( + save_directory, f"{model_name}_nodes.csv" + ) + self.save_nodes_csv(nodes_report_file_path, False) - self.status_dialog = StatusDialog( - f"Saved reports to: \n{os.path.abspath(save_directory)}", - "Successfully saved reports!", - ) - self.status_dialog.show() + self.save_nodes_csv(nodes_report_file_path, False) + except Exception as exception: # pylint: disable=broad-exception-caught + self.status_dialog = StatusDialog(f"{exception}") + self.status_dialog.show() + else: + self.status_dialog = StatusDialog( + f"Saved reports to: \n{os.path.abspath(save_directory)}", + "Successfully saved reports!", + ) + self.status_dialog.show() def on_dialog_closed(self): self.infoDialog = None @@ -792,20 +1122,20 @@ def save_file_dialog( ) return path, filter_type - def save_parameters_csv(self, filepath: str, open_dialog: bool = True): - self.save_nodes_csv(filepath, open_dialog) + def save_parameters_csv(self, file_path: str, open_dialog: bool = True): + self.save_nodes_csv(file_path, open_dialog) - def save_flops_csv(self, filepath: str, open_dialog: bool = True): - self.save_nodes_csv(filepath, open_dialog) + def save_flops_csv(self, file_path: str, open_dialog: bool = True): + self.save_nodes_csv(file_path, open_dialog) - def save_nodes_csv(self, csv_filepath: Optional[str], open_dialog: bool = True): + def save_nodes_csv(self, csv_file_path: Optional[str], open_dialog: bool = True): if open_dialog: - csv_filepath, _ = self.save_file_dialog() - if not csv_filepath: - raise ValueError("A filepath must be given.") + csv_file_path, _ = self.save_file_dialog() + if not csv_file_path: + raise ValueError("A file_path must be given.") current_tab = self.ui.tabWidget.currentWidget() if isinstance(current_tab, modelSummary): - current_tab.digest_model.save_nodes_csv_report(csv_filepath) + current_tab.digest_model.save_nodes_csv_report(csv_file_path) def save_chart(self, chart_view): path, _ = self.save_file_dialog("Save PNG", "PNG(*.png)") @@ -829,7 +1159,7 @@ def open_node_summary(self): digest_models = self.digest_models[model_id] node_summary = NodeSummary( - model_name=model_name, node_data=digest_models.per_node_info + model_name=model_name, node_data=digest_models.node_data ) self.nodes_window[model_id] = PopupWindow( diff --git a/src/digest/model_class/digest_model.py b/src/digest/model_class/digest_model.py new file mode 100644 index 0000000..49b4e52 --- /dev/null +++ b/src/digest/model_class/digest_model.py @@ -0,0 +1,235 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +import csv +from enum import Enum +from dataclasses import dataclass, field +from uuid import uuid4 +from abc import ABC, abstractmethod +from collections import Counter, OrderedDict, defaultdict +from typing import List, Dict, Optional, Any, Union + + +class SupportedModelTypes(Enum): + ONNX = "onnx" + REPORT = "report" + PYTORCH = "pytorch" + + +class NodeParsingException(Exception): + pass + + +# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias +class NodeShapeCounts(defaultdict[str, Counter]): + def __init__(self): + super().__init__(Counter) # Initialize with the Counter factory + + +class NodeTypeCounts(Dict[str, int]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@dataclass +class TensorInfo: + "Used to store node input and output tensor information" + dtype: Optional[str] = None + dtype_bytes: Optional[int] = None + size_kbytes: Optional[float] = None + shape: List[Union[int, str]] = field(default_factory=list) + + +class TensorData(OrderedDict[str, TensorInfo]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class NodeInfo: + def __init__(self) -> None: + self.flops: Optional[int] = None + self.parameters: int = 0 # TODO: should we make this Optional[int] = None? + self.node_type: Optional[str] = None + self.attributes: OrderedDict[str, Any] = OrderedDict() + # We use an ordered dictionary because the order in which + # the inputs and outputs are listed in the node matter. + self.inputs = TensorData() + self.outputs = TensorData() + + def get_input(self, index: int) -> TensorInfo: + return list(self.inputs.values())[index] + + def get_output(self, index: int) -> TensorInfo: + return list(self.outputs.values())[index] + + def __str__(self): + """Provides a human-readable string representation of NodeInfo.""" + output = [ + f"Node Type: {self.node_type}", + f"FLOPs: {self.flops if self.flops is not None else 'N/A'}", + f"Parameters: {self.parameters}", + ] + + if self.attributes: + output.append("Attributes:") + for key, value in self.attributes.items(): + output.append(f" - {key}: {value}") + + if self.inputs: + output.append("Inputs:") + for name, tensor in self.inputs.items(): + output.append(f" - {name}: {tensor}") + + if self.outputs: + output.append("Outputs:") + for name, tensor in self.outputs.items(): + output.append(f" - {name}: {tensor}") + + return "\n".join(output) + + +# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias +class NodeData(OrderedDict[str, NodeInfo]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class DigestModel(ABC): + def __init__( + self, file_path: str, model_name: str, model_type: SupportedModelTypes + ): + # Public members exposed to the API + self.unique_id: str = str(uuid4()) + self.file_path: Optional[str] = os.path.abspath(file_path) + self.model_name: str = model_name + self.model_type: SupportedModelTypes = model_type + self.node_type_counts: NodeTypeCounts = NodeTypeCounts() + self.flops: Optional[int] = None + self.parameters: int = 0 + self.node_type_flops: Dict[str, int] = {} + self.node_type_parameters: Dict[str, int] = {} + self.node_data = NodeData() + self.model_inputs = TensorData() + self.model_outputs = TensorData() + + def get_node_shape_counts(self) -> NodeShapeCounts: + tensor_shape_counter = NodeShapeCounts() + for _, info in self.node_data.items(): + shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()]) + if info.node_type: + tensor_shape_counter[info.node_type][shape_hash] += 1 + return tensor_shape_counter + + @abstractmethod + def parse_model_nodes(self, *args, **kwargs) -> None: + pass + + @abstractmethod + def save_yaml_report(self, file_path: str) -> None: + pass + + @abstractmethod + def save_text_report(self, file_path: str) -> None: + pass + + def save_nodes_csv_report(self, file_path: str) -> None: + save_nodes_csv_report(self.node_data, file_path) + + def save_node_type_counts_csv_report(self, file_path: str) -> None: + if self.node_type_counts: + save_node_type_counts_csv_report(self.node_type_counts, file_path) + + def save_node_shape_counts_csv_report(self, file_path: str) -> None: + save_node_shape_counts_csv_report(self.get_node_shape_counts(), file_path) + + +def save_nodes_csv_report(node_data: NodeData, file_path: str) -> None: + + parent_dir = os.path.dirname(os.path.abspath(file_path)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + flattened_data = [] + fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"] + input_fieldnames = [] + output_fieldnames = [] + for name, node_info in node_data.items(): + row = OrderedDict() + row["Node Name"] = name + row["Node Type"] = str(node_info.node_type) + row["Parameters"] = str(node_info.parameters) + row["FLOPs"] = str(node_info.flops) + if node_info.attributes: + row["Attributes"] = str({k: v for k, v in node_info.attributes.items()}) + else: + row["Attributes"] = "" + + for i, (input_name, input_info) in enumerate(node_info.inputs.items()): + column_name = f"Input{i+1} (Shape, Dtype, Size (kB))" + row[column_name] = ( + f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})" + ) + + # Dynamically add input column names to fieldnames if not already present + if column_name not in input_fieldnames: + input_fieldnames.append(column_name) + + for i, (output_name, output_info) in enumerate(node_info.outputs.items()): + column_name = f"Output{i+1} (Shape, Dtype, Size (kB))" + row[column_name] = ( + f"{output_name} ({output_info.shape}, " + f"{output_info.dtype}, {output_info.size_kbytes})" + ) + + # Dynamically add input column names to fieldnames if not already present + if column_name not in output_fieldnames: + output_fieldnames.append(column_name) + + flattened_data.append(row) + + fieldnames = fieldnames + input_fieldnames + output_fieldnames + try: + with open(file_path, "w", encoding="utf-8", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") + writer.writeheader() + writer.writerows(flattened_data) + except PermissionError as exception: + raise PermissionError( + f"Saving reports to {file_path} failed with error {exception}" + ) + + +def save_node_type_counts_csv_report( + node_type_counts: NodeTypeCounts, file_path: str +) -> None: + + parent_dir = os.path.dirname(os.path.abspath(file_path)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + header = ["Node Type", "Count"] + + with open(file_path, "w", encoding="utf-8", newline="") as csvfile: + writer = csv.writer(csvfile, lineterminator="\n") + writer.writerow(header) + for node_type, node_count in node_type_counts.items(): + writer.writerow([node_type, node_count]) + + +def save_node_shape_counts_csv_report( + node_shape_counts: NodeShapeCounts, file_path: str +) -> None: + + parent_dir = os.path.dirname(os.path.abspath(file_path)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + header = ["Node Type", "Input Tensors Shapes", "Count"] + + with open(file_path, "w", encoding="utf-8", newline="") as csvfile: + writer = csv.writer(csvfile, dialect="excel", lineterminator="\n") + writer.writerow(header) + for node_type, node_info in node_shape_counts.items(): + info_iter = iter(node_info.items()) + for shape, count in info_iter: + writer.writerow([node_type, shape, count]) diff --git a/src/digest/model_class/digest_onnx_model.py b/src/digest/model_class/digest_onnx_model.py new file mode 100644 index 0000000..17d1201 --- /dev/null +++ b/src/digest/model_class/digest_onnx_model.py @@ -0,0 +1,654 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +from typing import List, Dict, Optional, Tuple, cast +from datetime import datetime +import importlib.metadata +from collections import OrderedDict +import yaml +import numpy as np +import onnx +from prettytable import PrettyTable +from digest.model_class.digest_model import ( + DigestModel, + SupportedModelTypes, + NodeInfo, + TensorData, + TensorInfo, +) +import utils.onnx_utils as onnx_utils + + +class DigestOnnxModel(DigestModel): + def __init__( + self, + onnx_model: onnx.ModelProto, + onnx_file_path: str = "", + model_name: str = "", + save_proto: bool = True, + ) -> None: + super().__init__(onnx_file_path, model_name, SupportedModelTypes.ONNX) + + # Public members exposed to the API + self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None + self.model_version: Optional[int] = None + self.graph_name: Optional[str] = None + self.producer_name: Optional[str] = None + self.producer_version: Optional[str] = None + self.ir_version: Optional[int] = None + self.opset: Optional[int] = None + self.imports: OrderedDict[str, int] = OrderedDict() + + # Private members not intended to be exposed + self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {} + self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {} + self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {} + self.init_tensors_: Dict[str, onnx.TensorProto] = {} + + self.update_state(onnx_model) + + def update_state(self, model_proto: onnx.ModelProto) -> None: + self.model_version = model_proto.model_version + self.graph_name = model_proto.graph.name + self.producer_name = model_proto.producer_name + self.producer_version = model_proto.producer_version + self.ir_version = model_proto.ir_version + self.opset = onnx_utils.get_opset(model_proto) + self.imports = OrderedDict( + sorted( + (import_.domain, import_.version) + for import_ in model_proto.opset_import + ) + ) + + self.model_inputs = onnx_utils.get_model_input_shapes_types(model_proto) + self.model_outputs = onnx_utils.get_model_output_shapes_types(model_proto) + + self.node_type_counts = onnx_utils.get_node_type_counts(model_proto) + self.parse_model_nodes(model_proto) + + def get_node_tensor_info_( + self, onnx_node: onnx.NodeProto + ) -> Tuple[TensorData, TensorData]: + """ + This function is set to private because it is not intended to be used + outside of the DigestOnnxModel class. + """ + + input_tensor_info = TensorData() + for node_input in onnx_node.input: + input_tensor_info[node_input] = TensorInfo() + if ( + node_input in self.input_tensors_ + or node_input in self.value_tensors_ + or node_input in self.output_tensors_ + ): + tensor = ( + self.input_tensors_.get(node_input) + or self.value_tensors_.get(node_input) + or self.output_tensors_.get(node_input) + ) + if tensor: + for dim in tensor.type.tensor_type.shape.dim: + if dim.HasField("dim_value"): + input_tensor_info[node_input].shape.append(dim.dim_value) + elif dim.HasField("dim_param"): + input_tensor_info[node_input].shape.append(dim.dim_param) + + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + tensor.type.tensor_type.elem_type + ) + elif node_input in self.init_tensors_: + input_tensor_info[node_input].shape.extend( + [dim for dim in self.init_tensors_[node_input].dims] + ) + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + self.init_tensors_[node_input].data_type + ) + else: + dtype_str = None + dtype_bytes = None + + input_tensor_info[node_input].dtype = dtype_str + input_tensor_info[node_input].dtype_bytes = dtype_bytes + + if ( + all(isinstance(s, int) for s in input_tensor_info[node_input].shape) + and dtype_bytes + ): + tensor_size = float( + np.prod(np.array(input_tensor_info[node_input].shape)) + ) + input_tensor_info[node_input].size_kbytes = ( + tensor_size * float(dtype_bytes) / 1024.0 + ) + + output_tensor_info = TensorData() + for node_output in onnx_node.output: + output_tensor_info[node_output] = TensorInfo() + if ( + node_output in self.input_tensors_ + or node_output in self.value_tensors_ + or node_output in self.output_tensors_ + ): + tensor = ( + self.input_tensors_.get(node_output) + or self.value_tensors_.get(node_output) + or self.output_tensors_.get(node_output) + ) + if tensor: + output_tensor_info[node_output].shape.extend( + [ + int(dim.dim_value) + for dim in tensor.type.tensor_type.shape.dim + ] + ) + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + tensor.type.tensor_type.elem_type + ) + elif node_output in self.init_tensors_: + output_tensor_info[node_output].shape.extend( + [dim for dim in self.init_tensors_[node_output].dims] + ) + dtype_str, dtype_bytes = onnx_utils.tensor_type_to_str_and_size( + self.init_tensors_[node_output].data_type + ) + + else: + dtype_str = None + dtype_bytes = None + + output_tensor_info[node_output].dtype = dtype_str + output_tensor_info[node_output].dtype_bytes = dtype_bytes + + if ( + all(isinstance(s, int) for s in output_tensor_info[node_output].shape) + and dtype_bytes + ): + tensor_size = float( + np.prod(np.array(output_tensor_info[node_output].shape)) + ) + output_tensor_info[node_output].size_kbytes = ( + tensor_size * float(dtype_bytes) / 1024.0 + ) + + return input_tensor_info, output_tensor_info + + def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None: + """ + Calculate total number of FLOPs found in the onnx model. + FLOP is defined as one floating-point operation. This distinguishes + from multiply-accumulates (MACs) where FLOPs == 2 * MACs. + """ + + # Initialze to zero so we can accumulate. Set to None during the + # model FLOPs calculation if it errors out. + self.flops = 0 + + # Check to see if the model inputs have any dynamic shapes + if onnx_utils.get_dynamic_input_dims(onnx_model): + self.flops = None + + try: + onnx_model, _ = onnx_utils.optimize_onnx_model(onnx_model) + + onnx_model = onnx.shape_inference.infer_shapes( + onnx_model, strict_mode=True, data_prop=True + ) + except Exception as e: # pylint: disable=broad-except + print(f"ONNX utils: {str(e)}") + self.flops = None + + # If the ONNX model contains one of the following unsupported ops, then this + # function will return None since the FLOP total is expected to be incorrect + unsupported_ops = [ + "Einsum", + "RNN", + "GRU", + "DeformConv", + ] + + if not self.input_tensors_: + self.input_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.input + } + + if not self.output_tensors_: + self.output_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.output + } + + if not self.value_tensors_: + self.value_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.value_info + } + + if not self.init_tensors_: + self.init_tensors_ = { + tensor.name: tensor for tensor in onnx_model.graph.initializer + } + + for node in onnx_model.graph.node: # pylint: disable=E1101 + + node_info = NodeInfo() + + # TODO: I have encountered models containing nodes with no name. It would be a good idea + # to have this type of model info fed back to the user through a warnings section. + if not node.name: + node.name = f"{node.op_type}_{len(self.node_data)}" + + node_info.node_type = node.op_type + input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node) + node_info.inputs = input_tensor_info + node_info.outputs = output_tensor_info + + # Check if this node has parameters through the init tensors + for input_name, input_tensor in node_info.inputs.items(): + if input_name in self.init_tensors_: + if all(isinstance(dim, int) for dim in input_tensor.shape): + input_parameters = int(np.prod(np.array(input_tensor.shape))) + node_info.parameters += input_parameters + self.parameters += input_parameters + self.node_type_parameters[node.op_type] = ( + self.node_type_parameters.get(node.op_type, 0) + + input_parameters + ) + else: + print(f"Tensor with params has unknown shape: {input_name}") + + for attribute in node.attribute: + node_info.attributes.update(onnx_utils.attribute_to_dict(attribute)) + + # if node.name in self.node_data: + # print(f"Node name {node.name} is a duplicate.") + + self.node_data[node.name] = node_info + + if node.op_type in unsupported_ops: + self.flops = None + node_info.flops = None + + try: + + if ( + node.op_type == "MatMul" + or node.op_type == "MatMulInteger" + or node.op_type == "QLinearMatMul" + ): + + input_a = node_info.get_input(0).shape + if node.op_type == "QLinearMatMul": + input_b = node_info.get_input(3).shape + else: + input_b = node_info.get_input(1).shape + + if not all( + isinstance(dim, int) for dim in input_a + ) or not isinstance(input_b[-1], int): + node_info.flops = None + self.flops = None + continue + + node_info.flops = int( + 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1] + ) + + elif ( + node.op_type == "Mul" + or node.op_type == "Div" + or node.op_type == "Add" + ): + input_a = node_info.get_input(0).shape + input_b = node_info.get_input(1).shape + + if not all(isinstance(dim, int) for dim in input_a) or not all( + isinstance(dim, int) for dim in input_b + ): + node_info.flops = None + self.flops = None + continue + + node_info.flops = int( + np.prod(np.array(input_a), dtype=np.int64) + ) + int(np.prod(np.array(input_b), dtype=np.int64)) + + elif node.op_type == "Gemm" or node.op_type == "QGemm": + x_shape = node_info.get_input(0).shape + if node.op_type == "Gemm": + w_shape = node_info.get_input(1).shape + else: + w_shape = node_info.get_input(3).shape + + if not all(isinstance(dim, int) for dim in x_shape) or not all( + isinstance(dim, int) for dim in w_shape + ): + node_info.flops = None + self.flops = None + continue + + mm_dims = [ + ( + x_shape[0] + if not node_info.attributes.get("transA", 0) + else x_shape[1] + ), + ( + x_shape[1] + if not node_info.attributes.get("transA", 0) + else x_shape[0] + ), + ( + w_shape[1] + if not node_info.attributes.get("transB", 0) + else w_shape[0] + ), + ] + + node_info.flops = int( + 2 * np.prod(np.array(mm_dims), dtype=np.int64) + ) + + if len(mm_dims) == 3: # if there is a bias input + bias_shape = node_info.get_input(2).shape + node_info.flops += int(np.prod(np.array(bias_shape))) + + elif ( + node.op_type == "Conv" + or node.op_type == "ConvInteger" + or node.op_type == "QLinearConv" + or node.op_type == "ConvTranspose" + ): + # N, C, d1, ..., dn + x_shape = node_info.get_input(0).shape + + # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose + if node.op_type == "QLinearConv": + w_shape = node_info.get_input(3).shape + else: + w_shape = node_info.get_input(1).shape + + if not all(isinstance(dim, int) for dim in x_shape): + node_info.flops = None + self.flops = None + continue + + x_shape_ints = cast(List[int], x_shape) + w_shape_ints = cast(List[int], w_shape) + + has_bias = False # Note, ConvInteger has no bias + if node.op_type == "Conv" and len(node_info.inputs) == 3: + has_bias = True + elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9: + has_bias = True + + num_dims = len(x_shape_ints) - 2 + strides = node_info.attributes.get( + "strides", [1] * num_dims + ) # type: List[int] + dilation = node_info.attributes.get( + "dilations", [1] * num_dims + ) # type: List[int] + kernel_shape = w_shape_ints[2:] + batch_size = x_shape_ints[0] + out_channels = w_shape_ints[0] + out_dims = [batch_size, out_channels] + output_shape = node_info.attributes.get( + "output_shape", [] + ) # type: List[int] + + # If output_shape is given then we do not need to compute it ourselves + # The output_shape attribute does not include batch_size or channels and + # is only valid for ConvTranspose + if output_shape: + out_dims.extend(output_shape) + else: + auto_pad = node_info.attributes.get( + "auto_pad", "NOTSET".encode() + ).decode() + # SAME expects padding so that the output_shape = CEIL(input_shape / stride) + if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": + out_dims.extend( + [x * s for x, s in zip(x_shape_ints[2:], strides)] + ) + else: + # NOTSET means just use pads attribute + if auto_pad == "NOTSET": + pads = node_info.attributes.get( + "pads", [0] * num_dims * 2 + ) + # VALID essentially means no padding + elif auto_pad == "VALID": + pads = [0] * num_dims * 2 + + for i in range(num_dims): + dim_in = x_shape_ints[i + 2] # type: int + + if node.op_type == "ConvTranspose": + out_dim = ( + strides[i] * (dim_in - 1) + + ((kernel_shape[i] - 1) * dilation[i] + 1) + - pads[i] + - pads[i + num_dims] + ) + else: + out_dim = ( + dim_in + + pads[i] + + pads[i + num_dims] + - dilation[i] * (kernel_shape[i] - 1) + - 1 + ) // strides[i] + 1 + + out_dims.append(out_dim) + + kernel_flops = int( + np.prod(np.array(kernel_shape)) * w_shape_ints[1] + ) + output_points = int(np.prod(np.array(out_dims))) + bias_ops = output_points if has_bias else int(0) + node_info.flops = 2 * kernel_flops * output_points + bias_ops + + elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM": + + x_shape = node_info.get_input( + 0 + ).shape # seq_length, batch_size, input_dim + + if not all(isinstance(dim, int) for dim in x_shape): + node_info.flops = None + self.flops = None + continue + + x_shape_ints = cast(List[int], x_shape) + hidden_size = node_info.attributes["hidden_size"] + direction = ( + 2 + if node_info.attributes.get("direction") + == "bidirectional".encode() + else 1 + ) + + has_bias = True if len(node_info.inputs) >= 4 else False + if has_bias: + bias_shape = node_info.get_input(3).shape + if isinstance(bias_shape[1], int): + bias_ops = bias_shape[1] + else: + bias_ops = 0 + else: + bias_ops = 0 + # seq_length, batch_size, input_dim = x_shape + if not isinstance(bias_ops, int): + bias_ops = int(0) + num_gates = int(4) + gate_input_flops = int(2 * x_shape_ints[2] * hidden_size) + gate_hid_flops = int(2 * hidden_size * hidden_size) + unit_flops = ( + num_gates * (gate_input_flops + gate_hid_flops) + bias_ops + ) + node_info.flops = ( + x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops + ) + # In this case we just hit an op that doesn't have FLOPs + else: + node_info.flops = None + + except IndexError as err: + print(f"Error parsing node {node.name}: {err}") + node_info.flops = None + self.flops = None + continue + + # Update the model level flops count + if node_info.flops is not None and self.flops is not None: + self.flops += node_info.flops + + # Update the node type flops count + self.node_type_flops[node.op_type] = ( + self.node_type_flops.get(node.op_type, 0) + node_info.flops + ) + + def save_yaml_report(self, file_path: str) -> None: + + parent_dir = os.path.dirname(os.path.abspath(file_path)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + report_date = datetime.now().strftime("%B %d, %Y") + + input_tensors = dict({k: vars(v) for k, v in self.model_inputs.items()}) + output_tensors = dict({k: vars(v) for k, v in self.model_outputs.items()}) + digest_version = importlib.metadata.version("digestai") + + yaml_data = { + "report_date": report_date, + "digest_version": digest_version, + "model_type": self.model_type.value, + "model_file": self.file_path, + "model_name": self.model_name, + "model_version": self.model_version, + "graph_name": self.graph_name, + "producer_name": self.producer_name, + "producer_version": self.producer_version, + "ir_version": self.ir_version, + "opset": self.opset, + "import_list": dict(self.imports), + "graph_nodes": sum(self.node_type_counts.values()), + "parameters": self.parameters, + "flops": self.flops, + "node_type_counts": dict(self.node_type_counts), + "node_type_flops": dict(self.node_type_flops), + "node_type_parameters": self.node_type_parameters, + "input_tensors": input_tensors, + "output_tensors": output_tensors, + } + + with open(file_path, "w", encoding="utf-8") as f_p: + yaml.dump(yaml_data, f_p, sort_keys=False) + + def save_text_report(self, file_path: str) -> None: + + parent_dir = os.path.dirname(os.path.abspath(file_path)) + if not os.path.exists(parent_dir): + raise FileNotFoundError(f"Directory {parent_dir} does not exist.") + + report_date = datetime.now().strftime("%B %d, %Y") + + digest_version = importlib.metadata.version("digestai") + + with open(file_path, "w", encoding="utf-8") as f_p: + f_p.write(f"Report created on {report_date}\n") + f_p.write(f"Digest version: {digest_version}\n") + f_p.write(f"Model type: {self.model_type.name}\n") + if self.file_path: + f_p.write(f"ONNX file: {self.file_path}\n") + f_p.write(f"Name of the model: {self.model_name}\n") + f_p.write(f"Model version: {self.model_version}\n") + f_p.write(f"Name of the graph: {self.graph_name}\n") + f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n") + f_p.write(f"Ir version: {self.ir_version}\n") + f_p.write(f"Opset: {self.opset}\n\n") + f_p.write("Import list\n") + for name, version in self.imports.items(): + f_p.write(f"\t{name}: {version}\n") + + f_p.write("\n") + f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n") + f_p.write(f"Number of parameters: {self.parameters}\n") + if self.flops: + f_p.write(f"Number of FLOPs: {self.flops}\n") + f_p.write("\n") + + table_op_intensity = PrettyTable() + table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"] + for op_type, count in self.node_type_flops.items(): + if count > 0: + table_op_intensity.add_row( + [ + op_type, + count, + 100.0 * float(count) / float(self.flops), + ] + ) + + f_p.write("Op intensity:\n") + f_p.write(table_op_intensity.get_string()) + f_p.write("\n\n") + + node_counts_table = PrettyTable() + node_counts_table.field_names = ["Node", "Occurrences"] + for op, count in self.node_type_counts.items(): + node_counts_table.add_row([op, count]) + f_p.write("Nodes and their occurrences:\n") + f_p.write(node_counts_table.get_string()) + f_p.write("\n\n") + + input_table = PrettyTable() + input_table.field_names = [ + "Input Name", + "Shape", + "Type", + "Tensor Size (KB)", + ] + for input_name, input_details in self.model_inputs.items(): + if input_details.size_kbytes: + kbytes = f"{input_details.size_kbytes:.2f}" + else: + kbytes = "" + + input_table.add_row( + [ + input_name, + input_details.shape, + input_details.dtype, + kbytes, + ] + ) + f_p.write("Input Tensor(s) Information:\n") + f_p.write(input_table.get_string()) + f_p.write("\n\n") + + output_table = PrettyTable() + output_table.field_names = [ + "Output Name", + "Shape", + "Type", + "Tensor Size (KB)", + ] + for output_name, output_details in self.model_outputs.items(): + if output_details.size_kbytes: + kbytes = f"{output_details.size_kbytes:.2f}" + else: + kbytes = "" + + output_table.add_row( + [ + output_name, + output_details.shape, + output_details.dtype, + kbytes, + ] + ) + f_p.write("Output Tensor(s) Information:\n") + f_p.write(output_table.get_string()) + f_p.write("\n\n") diff --git a/src/digest/model_class/digest_pytorch_model.py b/src/digest/model_class/digest_pytorch_model.py new file mode 100644 index 0000000..9f159e9 --- /dev/null +++ b/src/digest/model_class/digest_pytorch_model.py @@ -0,0 +1,107 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +from collections import OrderedDict +from typing import List, Tuple, Optional, Union +import inspect +import onnx +import torch +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_model import ( + DigestModel, + SupportedModelTypes, +) + + +class DigestPyTorchModel(DigestModel): + """The idea of this class is to first support PyTorch models by converting them to ONNX + Eventually, we will want to support a PyTorch specific interface that has a custom GUI. + To facilitate this process, it makes the most sense to use this class as helper class + to convert the PyTorch model to ONNX and store the ONNX info in a member DigestOnnxModel + object. We can also store various PyTorch specific details in this class as well. + """ + + def __init__( + self, + pytorch_file_path: str = "", + model_name: str = "", + ) -> None: + super().__init__(pytorch_file_path, model_name, SupportedModelTypes.PYTORCH) + + assert os.path.exists( + pytorch_file_path + ), f"PyTorch file {pytorch_file_path} does not exist." + + # Default opset value + self.opset = 17 + + # Input dictionary to contain the names and shapes + # required for exporting the ONNX model + self.input_tensor_info: OrderedDict[ + str, Tuple[torch.dtype, List[Union[str, int]]] + ] = OrderedDict() + + self.pytorch_model = torch.load(pytorch_file_path) + + # Data needed for exporting to ONNX + self.do_constant_folding = True + self.export_params = True + + self.onnx_file_path: Optional[str] = None + + self.digest_onnx_model: Optional[DigestOnnxModel] = None + + def parse_model_nodes(self) -> None: + """This will be done in the DigestOnnxModel""" + + def save_yaml_report(self, file_path: str) -> None: + """This will be done in the DigestOnnxModel""" + + def save_text_report(self, file_path: str) -> None: + """This will be done in the DigestOnnxModel""" + + def generate_random_tensor(self, dtype: torch.dtype, shape: List[Union[str, int]]): + static_shape = [dim if isinstance(dim, int) else 1 for dim in shape] + if dtype in (torch.float16, torch.float32, torch.float64): + return torch.rand(static_shape, dtype=dtype) + else: + return torch.randint(0, 100, static_shape, dtype=dtype) + + def export_to_onnx(self, output_onnx_path: str) -> Union[onnx.ModelProto, None]: + + dummy_input_names: List[str] = list(self.input_tensor_info.keys()) + dummy_inputs: List[torch.Tensor] = [] + + for dtype, shape in self.input_tensor_info.values(): + dummy_inputs.append(self.generate_random_tensor(dtype, shape)) + + dynamic_axes = { + name: {i: dim for i, dim in enumerate(shape) if isinstance(dim, str)} + for name, (_, shape) in self.input_tensor_info.items() + } + + try: + torch.onnx.export( + self.pytorch_model, + tuple(dummy_inputs), + output_onnx_path, + input_names=dummy_input_names, + do_constant_folding=self.do_constant_folding, + export_params=self.export_params, + opset_version=self.opset, + dynamic_axes=dynamic_axes, + verbose=False, + ) + + self.onnx_file_path = output_onnx_path + + return onnx.load(output_onnx_path) + + except (ValueError, TypeError, RuntimeError) as err: + print(f"Failed to export ONNX: {err}") + raise + + +def get_model_fwd_parameters(torch_file_path): + torch_model = torch.load(torch_file_path) + return inspect.signature(torch_model.forward).parameters diff --git a/src/digest/model_class/digest_report_model.py b/src/digest/model_class/digest_report_model.py new file mode 100644 index 0000000..04da8e3 --- /dev/null +++ b/src/digest/model_class/digest_report_model.py @@ -0,0 +1,242 @@ +import os +from collections import OrderedDict +import csv +import ast +import re +from typing import Tuple, Optional, List, Dict, Any, Union +import yaml +from digest.model_class.digest_model import ( + DigestModel, + SupportedModelTypes, + NodeData, + NodeInfo, + TensorData, + TensorInfo, +) + + +def parse_tensor_info( + csv_tensor_cell_value, +) -> Tuple[str, list, str, Optional[float]]: + """This is a helper function that expects the input to come from parsing + the nodes csv and extracting either an input or output tensor.""" + + # Use regex to split the string into name and details + match = re.match(r"(.*?)\s*\((.*)\)$", csv_tensor_cell_value) + if not match: + raise ValueError(f"Invalid format for tensor info: {csv_tensor_cell_value}") + + name, details = match.groups() + + # Split details, but keep the shape as a single item + match = re.match(r"(\[.*?\])\s*,\s*(.*?)\s*,\s*(.*)", details) + if not match: + raise ValueError(f"Invalid format for tensor details: {details}") + + shape_str, dtype, size = match.groups() + + # Ensure shape is stored as a list + shape = ast.literal_eval(shape_str) + if not isinstance(shape, list): + shape = list(shape) + + if size == "None": + size = None + else: + size = float(size.split()[0]) + + return name.strip(), shape, dtype.strip(), size + + +class DigestReportModel(DigestModel): + def __init__( + self, + report_file_path: str, + ) -> None: + + self.model_type = SupportedModelTypes.REPORT + + self.is_valid = validate_yaml(report_file_path) + + if not self.is_valid: + print(f"The yaml file {report_file_path} is not a valid digest report.") + return + + self.model_data = OrderedDict() + with open(report_file_path, "r", encoding="utf-8") as yaml_f: + self.model_data = yaml.safe_load(yaml_f) + + model_name = self.model_data["model_name"] + super().__init__(report_file_path, model_name, SupportedModelTypes.REPORT) + + self.similarity_heatmap_path: Optional[str] = None + self.node_data = NodeData() + + # Given the path to the digest report, let's check if its a complete cache + # and we can grab the nodes csv data and the similarity heatmap + cache_dir = os.path.dirname(os.path.abspath(report_file_path)) + expected_heatmap_file = os.path.join(cache_dir, f"{model_name}_heatmap.png") + if os.path.exists(expected_heatmap_file): + self.similarity_heatmap_path = expected_heatmap_file + + expected_nodes_file = os.path.join(cache_dir, f"{model_name}_nodes.csv") + if os.path.exists(expected_nodes_file): + with open(expected_nodes_file, "r", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + node_name = row["Node Name"] + node_info = NodeInfo() + node_info.node_type = row["Node Type"] + if row["Parameters"]: + node_info.parameters = int(row["Parameters"]) + if ast.literal_eval(row["FLOPs"]): + node_info.flops = int(row["FLOPs"]) + node_info.attributes = ( + OrderedDict(ast.literal_eval(row["Attributes"])) + if row["Attributes"] + else OrderedDict() + ) + + node_info.inputs = TensorData() + node_info.outputs = TensorData() + + # Process inputs and outputs + for key, value in row.items(): + if key.startswith("Input") and value: + input_name, shape, dtype, size = parse_tensor_info(value) + node_info.inputs[input_name] = TensorInfo() + node_info.inputs[input_name].shape = shape + node_info.inputs[input_name].dtype = dtype + node_info.inputs[input_name].size_kbytes = size + + elif key.startswith("Output") and value: + output_name, shape, dtype, size = parse_tensor_info(value) + node_info.outputs[output_name] = TensorInfo() + node_info.outputs[output_name].shape = shape + node_info.outputs[output_name].dtype = dtype + node_info.outputs[output_name].size_kbytes = size + + self.node_data[node_name] = node_info + + # Unpack the model type agnostic values + self.flops = self.model_data["flops"] + self.parameters = self.model_data["parameters"] + self.node_type_flops = self.model_data["node_type_flops"] + self.node_type_parameters = self.model_data["node_type_parameters"] + self.node_type_counts = self.model_data["node_type_counts"] + + self.model_inputs = TensorData( + { + key: TensorInfo(**val) + for key, val in self.model_data["input_tensors"].items() + } + ) + self.model_outputs = TensorData( + { + key: TensorInfo(**val) + for key, val in self.model_data["output_tensors"].items() + } + ) + + def parse_model_nodes(self) -> None: + """There are no model nodes to parse""" + + def save_yaml_report(self, file_path: str) -> None: + """Report models are not intended to be saved""" + + def save_text_report(self, file_path: str) -> None: + """Report models are not intended to be saved""" + + +def validate_yaml(report_file_path: str) -> bool: + """Check that the provided yaml file is indeed a Digest Report file.""" + expected_keys = [ + "report_date", + "model_file", + "model_type", + "model_name", + "flops", + "node_type_flops", + "node_type_parameters", + "node_type_counts", + "input_tensors", + "output_tensors", + ] + try: + with open(report_file_path, "r", encoding="utf-8") as file: + yaml_content = yaml.safe_load(file) + + if not isinstance(yaml_content, dict): + print("Error: YAML content is not a dictionary") + return False + + for key in expected_keys: + if key not in yaml_content: + # print(f"Error: Missing required key '{key}'") + return False + + return True + except yaml.YAMLError as _: + # print(f"Error parsing YAML file: {e}") + return False + except IOError as _: + # print(f"Error reading file: {e}") + return False + + +def compare_yaml_files( + file1: str, file2: str, skip_keys: Optional[List[str]] = None +) -> bool: + """ + Compare two YAML files, ignoring specified keys. + + :param file1: Path to the first YAML file + :param file2: Path to the second YAML file + :param skip_keys: List of keys to ignore in the comparison + :return: True if the files are equal (ignoring specified keys), False otherwise + """ + + def load_yaml(file_path: str) -> Dict[str, Any]: + with open(file_path, "r", encoding="utf-8") as file: + return yaml.safe_load(file) + + def compare_dicts( + dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = "" + ) -> List[str]: + differences = [] + all_keys = set(dict1.keys()) | set(dict2.keys()) + + for key in all_keys: + if skip_keys and key in skip_keys: + continue + + current_path = f"{path}.{key}" if path else key + + if key not in dict1: + differences.append(f"Key '{current_path}' is missing in the first file") + elif key not in dict2: + differences.append( + f"Key '{current_path}' is missing in the second file" + ) + elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + differences.extend(compare_dicts(dict1[key], dict2[key], current_path)) + elif dict1[key] != dict2[key]: + differences.append( + f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}" + ) + + return differences + + yaml1 = load_yaml(file1) + yaml2 = load_yaml(file2) + + differences = compare_dicts(yaml1, yaml2) + + if differences: + # print("Differences found:") + # for diff in differences: + # print(f"- {diff}") + return False + else: + # print("No differences found.") + return True diff --git a/src/digest/modelsummary.py b/src/digest/modelsummary.py index 1e3872e..a92b756 100644 --- a/src/digest/modelsummary.py +++ b/src/digest/modelsummary.py @@ -3,10 +3,12 @@ import os # pylint: disable=invalid-name -from typing import Optional +from typing import Optional, Union # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget +from PySide6.QtGui import QMovie +from PySide6.QtCore import QSize from onnx import ModelProto @@ -14,37 +16,54 @@ from digest.freeze_inputs import FreezeInputs from digest.popup_window import PopupWindow from digest.qt_utils import apply_dark_style_sheet -from utils import onnx_utils +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel + ROOT_FOLDER = os.path.dirname(os.path.abspath(__file__)) class modelSummary(QWidget): - def __init__(self, digest_model: onnx_utils.DigestOnnxModel, parent=None): + def __init__( + self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None + ): super().__init__(parent) self.ui = Ui_modelSummary() self.ui.setupUi(self) apply_dark_style_sheet(self) self.file: Optional[str] = None - self.ui.freezeButton.setVisible(False) - self.ui.freezeButton.clicked.connect(self.open_freeze_inputs) self.ui.warningLabel.hide() self.digest_model = digest_model - self.model_proto: ModelProto = ( - digest_model.model_proto if digest_model.model_proto else ModelProto() - ) + self.model_proto: Optional[ModelProto] = None model_name: str = digest_model.model_name if digest_model.model_name else "" - self.freeze_inputs = FreezeInputs(self.model_proto, model_name) - self.freeze_inputs.complete_signal.connect(self.close_freeze_window) + + self.load_gif = QMovie(":/assets/gifs/load.gif") + # We set the size of the GIF to half the original + self.load_gif.setScaledSize(QSize(214, 120)) + self.ui.similarityImg.setMovie(self.load_gif) + self.load_gif.start() + + # There is no freezing if the model is not ONNX + self.ui.freezeButton.setVisible(False) + self.freeze_inputs: Optional[FreezeInputs] = None self.freeze_window: Optional[QWidget] = None + if isinstance(digest_model, DigestOnnxModel): + self.model_proto = ( + digest_model.model_proto if digest_model.model_proto else ModelProto() + ) + self.freeze_inputs = FreezeInputs(self.model_proto, model_name) + self.ui.freezeButton.clicked.connect(self.open_freeze_inputs) + self.freeze_inputs.complete_signal.connect(self.close_freeze_window) + def open_freeze_inputs(self): - self.freeze_window = PopupWindow( - self.freeze_inputs, "Freeze Model Inputs", self - ) - self.freeze_window.open() + if self.freeze_inputs: + self.freeze_window = PopupWindow( + self.freeze_inputs, "Freeze Model Inputs", self + ) + self.freeze_window.open() def close_freeze_window(self): if self.freeze_window: diff --git a/src/digest/multi_model_analysis.py b/src/digest/multi_model_analysis.py index d7f6bab..1c09905 100644 --- a/src/digest/multi_model_analysis.py +++ b/src/digest/multi_model_analysis.py @@ -1,17 +1,27 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. import os +from datetime import datetime import csv from typing import List, Dict, Union from collections import Counter, defaultdict, OrderedDict # pylint: disable=no-name-in-module from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog +from PySide6.QtCore import Qt from digest.dialog import ProgressDialog, StatusDialog from digest.ui.multimodelanalysis_ui import Ui_multiModelAnalysis from digest.histogramchartwidget import StackedHistogramWidget from digest.qt_utils import apply_dark_style_sheet -from utils import onnx_utils +from digest.model_class.digest_model import ( + NodeTypeCounts, + NodeShapeCounts, + save_node_shape_counts_csv_report, + save_node_type_counts_csv_report, +) +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel +import utils.onnx_utils as onnx_utils ROOT_FOLDER = os.path.dirname(__file__) @@ -21,7 +31,7 @@ class MultiModelAnalysis(QWidget): def __init__( self, - model_list: List[onnx_utils.DigestOnnxModel], + model_list: List[Union[DigestOnnxModel, DigestReportModel]], parent=None, ): super().__init__(parent) @@ -34,6 +44,9 @@ def __init__( self.ui.individualCheckBox.stateChanged.connect(self.check_box_changed) self.ui.multiCheckBox.stateChanged.connect(self.check_box_changed) + # For some reason setting alignments in designer lead to bugs in *ui.py files + self.ui.opHistogramChart.layout().setAlignment(Qt.AlignmentFlag.AlignTop) + if not model_list: return @@ -41,41 +54,60 @@ def __init__( self.global_node_type_counter: Counter[str] = Counter() # Holds the data for node shape counts across all models - self.global_node_shape_counter: onnx_utils.NodeShapeCounts = defaultdict( - Counter - ) + self.global_node_shape_counter: NodeShapeCounts = defaultdict(Counter) # Holds the data for all models statistics - self.global_model_data: Dict[str, Dict[str, Union[int, None]]] = {} + self.global_model_data: Dict[str, Dict[str, Union[int, str, None]]] = {} progress = ProgressDialog("", len(model_list), self) - header_labels = ["Model", "Opset", "Total Nodes", "Parameters", "FLOPs"] + header_labels = [ + "Model Name", + "Model Type", + "Opset", + "Total Nodes", + "Parameters", + "FLOPs", + ] self.ui.dataTable.setRowCount(len(model_list)) self.ui.dataTable.setColumnCount(len(header_labels)) self.ui.dataTable.setHorizontalHeaderLabels(header_labels) self.ui.dataTable.setSortingEnabled(False) for row, model in enumerate(model_list): + item = QTableWidgetItem(str(model.model_name)) self.ui.dataTable.setItem(row, 0, item) - item = QTableWidgetItem(str(model.opset)) + item = QTableWidgetItem(str(model.model_type.name)) self.ui.dataTable.setItem(row, 1, item) - item = QTableWidgetItem(str(len(model.per_node_info))) + if isinstance(model, DigestOnnxModel): + item = QTableWidgetItem(str(model.opset)) + elif isinstance(model, DigestReportModel): + item = QTableWidgetItem(str(model.model_data.get("opset", ""))) + self.ui.dataTable.setItem(row, 2, item) - item = QTableWidgetItem(str(model.model_parameters)) + item = QTableWidgetItem(str(len(model.node_data))) self.ui.dataTable.setItem(row, 3, item) - item = QTableWidgetItem(str(model.model_flops)) + item = QTableWidgetItem(str(model.parameters)) self.ui.dataTable.setItem(row, 4, item) + item = QTableWidgetItem(str(model.flops)) + self.ui.dataTable.setItem(row, 5, item) + self.ui.dataTable.resizeColumnsToContents() self.ui.dataTable.resizeRowsToContents() - node_type_counter = {} + # Until we use the unique_id to represent the model contents we store + # the entire model as the key so that we can store models that happen to have + # the same name. There is a guarantee that the models will not be duplicates. + node_type_counter: Dict[ + Union[DigestOnnxModel, DigestReportModel], NodeTypeCounts + ] = {} + for i, digest_model in enumerate(model_list): progress.step() progress.setLabelText(f"Analyzing model {digest_model.model_name}") @@ -83,39 +115,52 @@ def __init__( if digest_model.model_name is None: digest_model.model_name = f"model_{i}" - if digest_model.model_proto: - dynamic_input_dims = onnx_utils.get_dynamic_input_dims( - digest_model.model_proto - ) - if dynamic_input_dims: - print( - "Found the following non-static input dims in your model. " - "It is recommended to make all dims static before generating reports." + if isinstance(digest_model, DigestOnnxModel): + opset = digest_model.opset + if digest_model.model_proto: + dynamic_input_dims = onnx_utils.get_dynamic_input_dims( + digest_model.model_proto ) - for dynamic_shape in dynamic_input_dims: - print(f"dim: {dynamic_shape}") + if dynamic_input_dims: + print( + "Found the following non-static input dims in your model. " + "It is recommended to make all dims static before generating reports." + ) + for dynamic_shape in dynamic_input_dims: + print(f"dim: {dynamic_shape}") + + elif isinstance(digest_model, DigestReportModel): + opset = digest_model.model_data.get("opset", "") # Update the global model dictionary - if digest_model.model_name in self.global_model_data: + if digest_model.unique_id in self.global_model_data: print( - f"Warning! {digest_model.model_name} has already been processed, " + f"Warning! {digest_model.model_name} with id " + f"{digest_model.unique_id} has already been processed, " "skipping the duplicate model." ) - - self.global_model_data[digest_model.model_name] = { - "opset": digest_model.opset, - "parameters": digest_model.model_parameters, - "flops": digest_model.model_flops, + continue + + self.global_model_data[digest_model.unique_id] = { + "model_name": digest_model.model_name, + "model_type": digest_model.model_type.name, + "opset": opset, + "parameters": digest_model.parameters, + "flops": digest_model.flops, } - node_type_counter[digest_model.model_name] = ( - digest_model.get_node_type_counts() - ) + if digest_model in node_type_counter: + print( + f"Warning! {digest_model.model_name} with model type " + f"{digest_model.model_type.value} and id {digest_model.unique_id} " + "has already been added to the stacked histogram, skipping." + ) + continue + + node_type_counter[digest_model] = digest_model.node_type_counts # Update global data structure for node type counter - self.global_node_type_counter.update( - node_type_counter[digest_model.model_name] - ) + self.global_node_type_counter.update(node_type_counter[digest_model]) node_shape_counts = digest_model.get_node_shape_counts() @@ -133,31 +178,31 @@ def __init__( # Create stacked op histograms max_count = 0 top_ops = [key for key, _ in self.global_node_type_counter.most_common(20)] - for model_name, _ in node_type_counter.items(): - max_local = Counter(node_type_counter[model_name]).most_common()[0][1] + for model, _ in node_type_counter.items(): + max_local = Counter(node_type_counter[model]).most_common()[0][1] if max_local > max_count: max_count = max_local - for idx, model_name in enumerate(node_type_counter): + for idx, model in enumerate(node_type_counter): stacked_histogram_widget = StackedHistogramWidget() ordered_dict = OrderedDict() - model_counter = Counter(node_type_counter[model_name]) + model_counter = Counter(node_type_counter[model]) for key in top_ops: ordered_dict[key] = model_counter.get(key, 0) title = "Stacked Op Histogram" if idx == 0 else "" stacked_histogram_widget.set_data( ordered_dict, - model_name=model_name, + model_name=model.model_name, y_max=max_count, title=title, set_ticks=False, ) frame_layout = self.ui.stackedHistogramFrame.layout() - frame_layout.addWidget(stacked_histogram_widget) + if frame_layout: + frame_layout.addWidget(stacked_histogram_widget) # Add a "ghost" histogram to allow us to set the x axis label vertically - model_name = list(node_type_counter.keys())[0] stacked_histogram_widget = StackedHistogramWidget() - ordered_dict = {key: 1 for key in top_ops} + ordered_dict = OrderedDict({key: 1 for key in top_ops}) stacked_histogram_widget.set_data( ordered_dict, model_name="_", @@ -165,18 +210,39 @@ def __init__( set_ticks=True, ) frame_layout = self.ui.stackedHistogramFrame.layout() - frame_layout.addWidget(stacked_histogram_widget) + if frame_layout: + frame_layout.addWidget(stacked_histogram_widget) self.model_list = model_list def save_reports(self): - # Model summary text report - save_directory = QFileDialog(self).getExistingDirectory( + """This function saves all available reports for the models that are opened + in the multi-model analysis page.""" + + base_directory = QFileDialog(self).getExistingDirectory( self, "Select Directory" ) - if not save_directory: - return + # Check if the directory exists and is writable + if not os.path.exists(base_directory) or not os.access(base_directory, os.W_OK): + bad_ext_dialog = StatusDialog( + f"The directory {base_directory} is not valid or writable.", + parent=self, + ) + bad_ext_dialog.show() + + # Append a subdirectory to the save_directory so that all reports are co-located + name_id = datetime.now().strftime("%Y%m%d%H%M%S") + sub_directory = f"multi_model_reports_{name_id}" + save_directory = os.path.join(base_directory, sub_directory) + try: + os.makedirs(save_directory) + except OSError as os_err: + bad_ext_dialog = StatusDialog( + f"Failed to create {save_directory} with error {os_err}", + parent=self, + ) + bad_ext_dialog.show() save_individual_reports = self.ui.individualCheckBox.isChecked() save_multi_reports = self.ui.multiCheckBox.isChecked() @@ -192,29 +258,21 @@ def save_reports(self): save_directory, f"{digest_model.model_name}_summary.txt" ) - digest_model.save_txt_report(summary_filepath) + digest_model.save_text_report(summary_filepath) # Save csv of node type counts node_type_filepath = os.path.join( save_directory, f"{digest_model.model_name}_node_type_counts.csv" ) - # Save csv containing node type counter - node_type_counter = digest_model.get_node_type_counts() - - if node_type_counter: - onnx_utils.save_node_type_counts_csv_report( - node_type_counter, node_type_filepath - ) + if digest_model.node_type_counts: + digest_model.save_node_type_counts_csv_report(node_type_filepath) # Save csv containing node shape counts per op_type - node_shape_counts = digest_model.get_node_shape_counts() node_shape_filepath = os.path.join( save_directory, f"{digest_model.model_name}_node_shape_counts.csv" ) - onnx_utils.save_node_shape_counts_csv_report( - node_shape_counts, node_shape_filepath - ) + digest_model.save_node_shape_counts_csv_report(node_shape_filepath) # Save csv containing all node-level information nodes_filepath = os.path.join( @@ -231,17 +289,17 @@ def save_reports(self): global_filepath = os.path.join( save_directory, "global_node_type_counts.csv" ) - global_node_type_counter = onnx_utils.NodeTypeCounts( + global_node_type_counter = NodeTypeCounts( self.global_node_type_counter.most_common() ) - onnx_utils.save_node_type_counts_csv_report( + save_node_type_counts_csv_report( global_node_type_counter, global_filepath ) global_filepath = os.path.join( save_directory, "global_node_shape_counts.csv" ) - onnx_utils.save_node_shape_counts_csv_report( + save_node_shape_counts_csv_report( self.global_node_shape_counter, global_filepath ) @@ -253,10 +311,18 @@ def save_reports(self): ) as csvfile: writer = csv.writer(csvfile) rows = [ - [model, data["opset"], data["parameters"], data["flops"]] - for model, data in self.global_model_data.items() + [ + data["model_name"], + data["model_type"], + data["opset"], + data["parameters"], + data["flops"], + ] + for _, data in self.global_model_data.items() ] - writer.writerow(["Model", "Opset", "Parameters", "FLOPs"]) + writer.writerow( + ["Model Name", "Model Type", "Opset", "Parameters", "FLOPs"] + ) writer.writerows(rows) if save_individual_reports or save_multi_reports: diff --git a/src/digest/multi_model_selection_page.py b/src/digest/multi_model_selection_page.py index d7b6a39..3290083 100644 --- a/src/digest/multi_model_selection_page.py +++ b/src/digest/multi_model_selection_page.py @@ -2,7 +2,7 @@ import os import glob -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Union from collections import defaultdict from google.protobuf.message import DecodeError import onnx @@ -22,6 +22,8 @@ from digest.ui.multimodelselection_page_ui import Ui_MultiModelSelection from digest.multi_model_analysis import MultiModelAnalysis from digest.qt_utils import apply_dark_style_sheet, prompt_user_ram_limit +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import DigestReportModel, compare_yaml_files from utils import onnx_utils @@ -33,7 +35,9 @@ class AnalysisThread(QThread): def __init__(self): super().__init__() - self.model_dict: Dict[str, Optional[onnx_utils.DigestOnnxModel]] = {} + self.model_dict: Dict[ + str, Optional[Union[DigestOnnxModel, DigestReportModel]] + ] = {} self.user_canceled = False def run(self): @@ -47,19 +51,21 @@ def run(self): self.step_progress.emit() if model: continue - model_name = os.path.splitext(os.path.basename(file))[0] - model_proto = onnx_utils.load_onnx(file, False) - self.model_dict[file] = onnx_utils.DigestOnnxModel( - model_proto, onnx_filepath=file, model_name=model_name, save_proto=False - ) + model_name, file_ext = os.path.splitext(os.path.basename(file)) + if file_ext == ".onnx": + model_proto = onnx_utils.load_onnx(file, False) + self.model_dict[file] = DigestOnnxModel( + model_proto, + onnx_file_path=file, + model_name=model_name, + save_proto=False, + ) + elif file_ext == ".yaml": + self.model_dict[file] = DigestReportModel(file) self.close_progress.emit() - model_list = [ - model - for model in self.model_dict.values() - if isinstance(model, onnx_utils.DigestOnnxModel) - ] + model_list = [model for model in self.model_dict.values()] self.completed.emit(model_list) @@ -82,8 +88,10 @@ def __init__( self.ui.warningLabel.hide() self.item_model = QStandardItemModel() self.item_model.itemChanged.connect(self.update_num_selected_label) - self.ui.selectAllBox.setCheckState(Qt.CheckState.Checked) - self.ui.selectAllBox.stateChanged.connect(self.update_list_view_items) + self.ui.radioAll.setChecked(True) + self.ui.radioAll.toggled.connect(self.update_list_view_items) + self.ui.radioONNX.toggled.connect(self.update_list_view_items) + self.ui.radioReports.toggled.connect(self.update_list_view_items) self.ui.selectFolderBtn.clicked.connect(self.openFolder) self.ui.duplicateLabel.hide() self.ui.modelListView.setModel(self.item_model) @@ -94,7 +102,9 @@ def __init__( self.ui.openAnalysisBtn.clicked.connect(self.start_analysis) - self.model_dict: Dict[str, Optional[onnx_utils.DigestOnnxModel]] = {} + self.model_dict: Dict[ + str, Optional[Union[DigestOnnxModel, DigestReportModel]] + ] = {} self.analysis_thread: Optional[AnalysisThread] = None self.progress: Optional[ProgressDialog] = None @@ -165,14 +175,24 @@ def update_num_selected_label(self): self.ui.openAnalysisBtn.setEnabled(False) def update_list_view_items(self): - state = self.ui.selectAllBox.checkState() + radio_all_state = self.ui.radioAll.isChecked() + radio_onnx_state = self.ui.radioONNX.isChecked() + radio_reports_state = self.ui.radioReports.isChecked() for row in range(self.item_model.rowCount()): item = self.item_model.item(row) - item.setCheckState(state) + value = item.data(Qt.ItemDataRole.DisplayRole) + if radio_all_state: + item.setCheckState(Qt.CheckState.Checked) + elif os.path.splitext(value)[-1] == ".onnx" and radio_onnx_state: + item.setCheckState(Qt.CheckState.Checked) + elif os.path.splitext(value)[-1] == ".yaml" and radio_reports_state: + item.setCheckState(Qt.CheckState.Checked) + else: + item.setCheckState(Qt.CheckState.Unchecked) def set_directory(self, directory: str): """ - Recursively searches a directory for onnx models. + Recursively searches a directory for onnx models and yaml report files. """ if not os.path.exists(directory): @@ -183,36 +203,57 @@ def set_directory(self, directory: str): else: return - progress = ProgressDialog("Searching Directory for ONNX Files", 0, self) + progress = ProgressDialog("Searching directory for model files", 0, self) + onnx_file_list = list( glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True) ) + onnx_file_list = [os.path.normpath(model_file) for model_file in onnx_file_list] + + yaml_file_list = list( + glob.glob(os.path.join(directory, "**/*.yaml"), recursive=True) + ) + yaml_file_list = [os.path.normpath(model_file) for model_file in yaml_file_list] + + # Filter out YAML files that are not valid reports + report_file_list = [] + for yaml_file in yaml_file_list: + digest_report = DigestReportModel(yaml_file) + if digest_report.is_valid: + report_file_list.append(yaml_file) + + total_num_models = len(onnx_file_list) + len(report_file_list) - onnx_file_list = [os.path.normpath(onnx_file) for onnx_file in onnx_file_list] serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list) progress.close() - progress = ProgressDialog("Loading ONNX Models", len(onnx_file_list), self) + progress = ProgressDialog("Loading models", total_num_models, self) memory_limit_percentage = 90 models_loaded = 0 - for filepath in onnx_file_list: + for filepath in onnx_file_list + report_file_list: progress.step() if progress.user_canceled: break try: models_loaded += 1 - model = onnx.load(filepath, load_external_data=False) - dialog_msg = f"""Warning: System RAM has exceeded the threshold of {memory_limit_percentage}%. - No further models will be loaded. - """ + extension = os.path.splitext(filepath)[-1] + if extension == ".onnx": + model = onnx.load(filepath, load_external_data=False) + serialized_models_paths[model.SerializeToString()].append(filepath) + elif extension == ".yaml": + pass + dialog_msg = ( + "Warning: System RAM has exceeded the threshold of " + f"{memory_limit_percentage}%. No further models will be loaded. " + ) if prompt_user_ram_limit( sys_ram_percent_limit=memory_limit_percentage, message=dialog_msg, parent=self, ): self.update_warning_label( - f"Loaded only {models_loaded - 1} out of {len(onnx_file_list)} models " + f"Loaded only {models_loaded - 1} out of {total_num_models} models " f"as memory consumption has reached {memory_limit_percentage}% of " "system memory. Preventing further loading of models." ) @@ -223,15 +264,13 @@ def set_directory(self, directory: str): break else: self.ui.warningLabel.hide() - serialized_models_paths[model.SerializeToString()].append(filepath) + except DecodeError as error: print(f"Error decoding model {filepath}: {error}") progress.close() - progress = ProgressDialog( - "Processing ONNX Models", len(serialized_models_paths), self - ) + progress = ProgressDialog("Processing Models", total_num_models, self) num_duplicates = 0 self.item_model.clear() @@ -245,15 +284,42 @@ def set_directory(self, directory: str): self.ui.duplicateListWidget.addItem(paths[0]) for dupe in paths[1:]: self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}") - item = QStandardItem(paths[0]) - item.setCheckable(True) - item.setCheckState(Qt.CheckState.Checked) - self.item_model.appendRow(item) - else: - item = QStandardItem(paths[0]) - item.setCheckable(True) - item.setCheckState(Qt.CheckState.Checked) - self.item_model.appendRow(item) + item = QStandardItem(paths[0]) + item.setCheckable(True) + item.setCheckState(Qt.CheckState.Checked) + self.item_model.appendRow(item) + + # Use a standard nested loop to detect duplicate reports + duplicate_reports: Dict[str, List[str]] = {} + processed_files = set() + for i in range(len(report_file_list)): + progress.step() + if progress.user_canceled: + break + path1 = report_file_list[i] + if path1 in processed_files: + continue # Skip already processed files + + # We will use path1 as the unique model and save a list of duplicates + duplicate_reports[path1] = [] + for j in range(i + 1, len(report_file_list)): + path2 = report_file_list[j] + if compare_yaml_files( + path1, path2, ["report_date", "model_files", "digest_version"] + ): + num_duplicates += 1 + duplicate_reports[path1].append(path2) + processed_files.add(path2) + + for path, dupes in duplicate_reports.items(): + if dupes: + self.ui.duplicateListWidget.addItem(path) + for dupe in dupes: + self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}") + item = QStandardItem(path) + item.setCheckable(True) + item.setCheckState(Qt.CheckState.Checked) + self.item_model.appendRow(item) progress.close() @@ -270,7 +336,7 @@ def set_directory(self, directory: str): self.update_num_selected_label() self.update_message_label( - f"Found a total of {len(onnx_file_list)} ONNX files. " + f"Found a total of {total_num_models} model files. " "Right click a model below " "to open it up in the model summary view." ) @@ -289,7 +355,9 @@ def start_analysis(self): self.analysis_thread.model_dict = self.model_dict self.analysis_thread.start() - def open_analysis(self, model_list: List[onnx_utils.DigestOnnxModel]): + def open_analysis( + self, model_list: List[Union[DigestOnnxModel, DigestReportModel]] + ): multi_model_analysis = MultiModelAnalysis(model_list) self.analysis_window.setCentralWidget(multi_model_analysis) self.analysis_window.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) diff --git a/src/digest/node_summary.py b/src/digest/node_summary.py index 99eb35f..01aaf09 100644 --- a/src/digest/node_summary.py +++ b/src/digest/node_summary.py @@ -6,6 +6,10 @@ from PySide6.QtWidgets import QWidget, QTableWidgetItem, QFileDialog from digest.ui.nodessummary_ui import Ui_nodesSummary from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_model import ( + save_node_shape_counts_csv_report, + save_nodes_csv_report, +) from utils import onnx_utils ROOT_FOLDER = os.path.dirname(__file__) @@ -111,8 +115,6 @@ def save_csv_file(self): self, "Save CSV", os.getcwd(), "CSV(*.csv)" ) if filepath and self.ui.allNodesBtn.isChecked(): - onnx_utils.save_nodes_csv_report(self.node_data, filepath) + save_nodes_csv_report(self.node_data, filepath) elif filepath and self.ui.shapeCountsBtn.isChecked(): - onnx_utils.save_node_shape_counts_csv_report( - self.node_shape_counts, filepath - ) + save_node_shape_counts_csv_report(self.node_shape_counts, filepath) diff --git a/src/digest/popup_window.py b/src/digest/popup_window.py index 09d1971..6e4e5ea 100644 --- a/src/digest/popup_window.py +++ b/src/digest/popup_window.py @@ -1,11 +1,14 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. # pylint: disable=no-name-in-module -from PySide6.QtWidgets import QApplication, QMainWindow, QWidget +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QApplication, QMainWindow, QWidget, QDialog, QVBoxLayout from PySide6.QtGui import QIcon class PopupWindow(QWidget): + """Opens new window that runs separate from the main digest window""" + def __init__(self, widget: QWidget, window_title: str = "", parent=None): super().__init__(parent) @@ -24,3 +27,36 @@ def open(self): def close(self): self.main_window.close() + + +class PopupDialog(QDialog): + """Opens a new window that takes focus and must be closed before returning + to the main digest window""" + + def __init__(self, widget: QWidget, window_title: str = "", parent=None): + super().__init__(parent) + + if hasattr(widget, "close_signal"): + widget.close_signal.connect(self.on_widget_closed) # type: ignore + + self.setWindowModality(Qt.WindowModality.WindowModal) + self.setWindowFlags(Qt.WindowType.Window) + + layout = QVBoxLayout() + layout.addWidget(widget) + self.setLayout(layout) + + self.setWindowIcon(QIcon(":/assets/images/digest_logo_500.jpg")) + self.setWindowTitle(window_title) + screen = QApplication.primaryScreen() + screen_geometry = screen.geometry() + self.resize( + int(screen_geometry.width() / 1.5), int(screen_geometry.height() * 0.80) + ) + + def open(self): + self.show() + self.exec() + + def on_widget_closed(self): + self.close() diff --git a/src/digest/pytorch_ingest.py b/src/digest/pytorch_ingest.py new file mode 100644 index 0000000..4f3d8cf --- /dev/null +++ b/src/digest/pytorch_ingest.py @@ -0,0 +1,289 @@ +# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. + +import os +from collections import OrderedDict +from typing import Optional, Callable, Union, List +from platformdirs import user_cache_dir +import torch + +# pylint: disable=no-name-in-module +from PySide6.QtWidgets import ( + QWidget, + QLabel, + QLineEdit, + QSizePolicy, + QFormLayout, + QFileDialog, + QHBoxLayout, + QComboBox, +) +from PySide6.QtGui import QFont +from PySide6.QtCore import Qt, Signal +from utils import onnx_utils +from digest.ui.pytorchingest_ui import Ui_pytorchIngest +from digest.qt_utils import apply_dark_style_sheet +from digest.model_class.digest_pytorch_model import ( + get_model_fwd_parameters, + DigestPyTorchModel, +) + +torch_tensor_types = { + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.uint8": torch.uint8, + "torch.uint16": torch.uint16, + "torch.uint32": torch.uint32, + "torch.uint64": torch.uint64, + "torch.int8": torch.int8, + "torch.int16": torch.int16, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.bool": torch.bool, +} + + +class UserModelInputsForm: + def __init__(self, form_layout: QFormLayout): + self.form_layout = form_layout + self.num_rows = 0 + + def add_row( + self, + label_text: str, + text_width: int, + edit_finished_fnc: Optional[Callable] = None, + ) -> int: + + # The label displays the tensor name + font = QFont("Inter", 10) + label = QLabel(f"{label_text}:") + label.setContentsMargins(0, 0, 0, 0) + label.setFont(font) + + # The combo box enables users to specify the tensor data type + dtype_combo_box = QComboBox() + for tensor_type in torch_tensor_types.keys(): + dtype_combo_box.addItem(tensor_type) + dtype_combo_box.setCurrentIndex(1) # float32 by default + dtype_combo_box.currentIndexChanged.connect(edit_finished_fnc) + + # Line edit is where the user specifies the tensor shape + line_edit = QLineEdit() + line_edit.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) + line_edit.setMinimumWidth(text_width) + line_edit.setMinimumHeight(20) + line_edit.setPlaceholderText("Set tensor shape here") + if edit_finished_fnc: + line_edit.editingFinished.connect(edit_finished_fnc) + + row_layout = QHBoxLayout() + row_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) + row_layout.setSpacing(5) + row_layout.setObjectName(f"row{self.num_rows}_layout") + row_layout.addWidget(label, alignment=Qt.AlignmentFlag.AlignHCenter) + row_layout.addWidget(dtype_combo_box, alignment=Qt.AlignmentFlag.AlignHCenter) + row_layout.addWidget(line_edit, alignment=Qt.AlignmentFlag.AlignHCenter) + + self.num_rows += 1 + self.form_layout.addRow(row_layout) + + return self.num_rows + + def get_row_tensor_name(self, row_idx: int) -> str: + form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) + row_layout = form_item.layout() + assert isinstance(row_layout, QHBoxLayout) + line_edit_item = row_layout.itemAt(0) + line_edit_widget = line_edit_item.widget() + assert isinstance(line_edit_widget, QLabel) + return line_edit_widget.text().split(":")[0] + + def get_row_tensor_dtype(self, row_idx: int) -> torch.dtype: + form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) + row_layout = form_item.layout() + combo_box = row_layout.itemAt(1) + assert combo_box, "The combo box was not found which is unexpected!" + combo_box_widget = combo_box.widget() + assert isinstance(combo_box_widget, QComboBox) + return torch_tensor_types[combo_box_widget.currentText()] + + def get_row_tensor_shape(self, row_idx: int) -> List[Union[str, int]]: + shape_widget = self.get_row_tensor_shape_widget(row_idx) + shape_str = shape_widget.text() + shape_list: List[Union[str, int]] = [] + if not shape_str: + return shape_list + shape_list_str = shape_str.split(",") + + for dim in shape_list_str: + dim = dim.strip() + # Integer based shape + if all(char.isdigit() for char in dim): + shape_list.append(int(dim)) + # Symbolic shape + else: + shape_list.append(dim) + return shape_list + + def get_row_tensor_shape_widget(self, row_idx: int) -> QLineEdit: + form_item = self.form_layout.itemAt(row_idx, QFormLayout.ItemRole.FieldRole) + row_layout = form_item.layout() + line_edit_item = row_layout.itemAt(2) + assert line_edit_item + line_edit_widget = line_edit_item.widget() + assert isinstance(line_edit_widget, QLineEdit) + return line_edit_widget + + +class PyTorchIngest(QWidget): + """PyTorchIngest is the pop up window that enables users to set static shapes and export + PyTorch models to ONNX models.""" + + # This enables the widget to close the parent window + close_signal = Signal() + + def __init__( + self, + model_file: str, + model_name: str, + parent=None, + ): + super().__init__(parent) + self.ui = Ui_pytorchIngest() + self.ui.setupUi(self) + apply_dark_style_sheet(self) + + self.ui.exportWarningLabel.hide() + + # We use a cache dir to save the exported ONNX model + # Users have the option to choose a different location + # if they wish to keep the exported model. + user_cache_directory = user_cache_dir("digest") + os.makedirs(user_cache_directory, exist_ok=True) + self.save_directory: str = user_cache_directory + + self.ui.selectDirBtn.clicked.connect(self.select_directory) + self.ui.exportOnnxBtn.clicked.connect(self.export_onnx) + + self.ui.modelName.setText(str(model_name)) + + self.ui.modelFilename.setText(str(model_file)) + + self.ui.foldingCheckBox.stateChanged.connect(self.on_checkbox_folding_changed) + self.ui.exportParamsCheckBox.stateChanged.connect( + self.on_checkbox_export_params_changed + ) + + self.digest_pytorch_model = DigestPyTorchModel(model_file, model_name) + self.digest_pytorch_model.do_constant_folding = ( + self.ui.foldingCheckBox.isChecked() + ) + self.digest_pytorch_model.export_params = ( + self.ui.exportParamsCheckBox.isChecked() + ) + + self.user_input_form = UserModelInputsForm(self.ui.inputsFormLayout) + + # Set up the opset form + self.lowest_supported_opset = 7 # this requirement came from pytorch + self.supported_opset_version = onnx_utils.get_supported_opset() + self.ui.opsetLineEdit.setText(str(self.digest_pytorch_model.opset)) + self.ui.opsetInfoLabel.setStyleSheet("color: grey;") + self.ui.opsetInfoLabel.setText( + f"(accepted range is {self.lowest_supported_opset} - {self.supported_opset_version}):" + ) + self.ui.opsetLineEdit.editingFinished.connect(self.update_opset_version) + + # Present each input in the forward function + self.fwd_parameters = OrderedDict(get_model_fwd_parameters(model_file)) + for val in self.fwd_parameters.values(): + self.user_input_form.add_row( + str(val), + 250, + self.update_tensor_info, + ) + + def set_widget_invalid(self, widget: QWidget): + widget.setStyleSheet("border: 1px solid red;") + + def set_widget_valid(self, widget: QWidget): + widget.setStyleSheet("") + + def on_checkbox_folding_changed(self): + self.digest_pytorch_model.do_constant_folding = ( + self.ui.foldingCheckBox.isChecked() + ) + + def on_checkbox_export_params_changed(self): + self.digest_pytorch_model.export_params = ( + self.ui.exportParamsCheckBox.isChecked() + ) + + def select_directory(self): + dir = QFileDialog(self).getExistingDirectory(self, "Select Directory") + if os.path.exists(dir): + self.save_directory = dir + info_message = f"The ONNX model will be exported to {self.save_directory}" + self.update_message_label(info_message=info_message) + + def update_message_label( + self, info_message: Optional[str] = None, warn_message: Optional[str] = None + ) -> None: + if info_message: + message = f"ℹ️ {info_message}" + elif warn_message: + message = f"⚠️ {warn_message}" + + self.ui.selectDirLabel.setText(message) + + def update_opset_version(self): + opset_text_item = self.ui.opsetLineEdit.text() + if all(char.isdigit() for char in opset_text_item): + opset_text_item = int(opset_text_item) + if ( + opset_text_item + and opset_text_item < self.lowest_supported_opset + or opset_text_item > self.supported_opset_version + ): + self.set_widget_invalid(self.ui.opsetLineEdit) + else: + self.digest_pytorch_model.opset = opset_text_item + self.set_widget_valid(self.ui.opsetLineEdit) + + def update_tensor_info(self): + """Because this is an external function to the UserInputFormWithInfo class + we go through each input everytime there is an update.""" + for row_idx in range(self.user_input_form.form_layout.rowCount()): + widget = self.user_input_form.get_row_tensor_shape_widget(row_idx) + tensor_name = self.user_input_form.get_row_tensor_name(row_idx) + tensor_dtype = self.user_input_form.get_row_tensor_dtype(row_idx) + try: + tensor_shape = self.user_input_form.get_row_tensor_shape(row_idx) + except ValueError as err: + print(f"Shape invalid: {err}") + self.set_widget_invalid(widget) + else: + if tensor_name and tensor_shape: + self.set_widget_valid(widget) + self.digest_pytorch_model.input_tensor_info[tensor_name] = ( + tensor_dtype, + tensor_shape, + ) + + def export_onnx(self): + onnx_file_path = os.path.join( + self.save_directory, f"{self.digest_pytorch_model.model_name}.onnx" + ) + try: + self.digest_pytorch_model.export_to_onnx(onnx_file_path) + except (ValueError, TypeError, RuntimeError) as err: + self.ui.exportWarningLabel.setText(f"Failed to export ONNX: {err}") + self.ui.exportWarningLabel.show() + else: + self.ui.exportWarningLabel.hide() + self.close_widget() + + def close_widget(self): + self.close_signal.emit() + self.close() diff --git a/src/digest/resource.qrc b/src/digest/resource.qrc index 5a70586..6d2e347 100644 --- a/src/digest/resource.qrc +++ b/src/digest/resource.qrc @@ -1,21 +1,22 @@ - - assets/icons/close-window-64.ico - assets/icons/info.png - assets/icons/open.png - assets/icons/digest_logo.ico - assets/images/digest_logo_500.jpg - assets/images/remove_background_500_zoom.png - assets/images/remove_background_200_zoom.png - assets/icons/huggingface.png - assets/icons/huggingface_64px.png - assets/gifs/load.gif - assets/icons/save.png - assets/icons/node_list.png - assets/icons/search.png - assets/icons/models.png - assets/icons/file.png - assets/icons/freeze.png - assets/icons/summary.png - + + assets/icons/64px-PyTorch_logo_icon.svg.png + assets/icons/close-window-64.ico + assets/icons/info.png + assets/icons/open.png + assets/icons/digest_logo.ico + assets/images/digest_logo_500.jpg + assets/images/remove_background_500_zoom.png + assets/images/remove_background_200_zoom.png + assets/icons/huggingface.png + assets/icons/huggingface_64px.png + assets/gifs/load.gif + assets/icons/save.png + assets/icons/node_list.png + assets/icons/search.png + assets/icons/models.png + assets/icons/file.png + assets/icons/freeze.png + assets/icons/summary.png + diff --git a/src/digest/resource_rc.py b/src/digest/resource_rc.py index cf29584..79c2adf 100644 --- a/src/digest/resource_rc.py +++ b/src/digest/resource_rc.py @@ -1,6 +1,6 @@ # Resource object code (Python 3) # Created by: object code -# Created by: The Resource Compiler for Qt version 6.8.0 +# Created by: The Resource Compiler for Qt version 6.8.1 # WARNING! All changes made in this file will be lost! from PySide6 import QtCore @@ -19134,6 +19134,125 @@ \x00\x9dOif\xf4\x11\xbb\xa4\xfbG\xfe\xfb\x7f\x8c \ \xf7\xde\xa1\x08\xbb~\x00\x00\x00\x00IEND\xaeB\ `\x82\ +\x00\x00\x07A\ +\x89\ +PNG\x0d\x0a\x1a\x0a\x00\x00\x00\x0dIHDR\x00\ +\x00\x00@\x00\x00\x00N\x08\x03\x00\x00\x00\xa7\xbd\xe0\x9c\ +\x00\x00\x00\x04gAMA\x00\x00\xb1\x8f\x0b\xfca\x05\ +\x00\x00\x00 cHRM\x00\x00z&\x00\x00\x80\x84\ +\x00\x00\xfa\x00\x00\x00\x80\xe8\x00\x00u0\x00\x00\xea`\ +\x00\x00:\x98\x00\x00\x17p\x9c\xbaQ<\x00\x00\x02(\ +PLTE\x00\x00\x00\xff\x00\x00\xeeK,\xe3U9\ +\xeeL-\xeeK-\xeeK,\xff@@\xf1L)\xec\ +L+\xefP0\xffUU\xefN*\xeeL,\xeeM\ +,\xeeL,\xdf@ \xf0M.\xeeL,\xefM,\ +\xeeM+\xf0K-\xeeK,\xefL,\xeeL,\xee\ +L+\xefM,\xeeL,\xecM-\xf3I1\xeeL\ +,\xedL,\xedO,\xedL,\xefL,\xeeL,\ +\xff\x80\x00\xeaJ+\xeeL+\xecK-\xedM+\xed\ +L-\xedL,\xeeM+\xeeM-\xedK,\xeeL\ +,\xedK,\xeeK+\xf1G+\xeeL,\xf0L.\ +\xefL-\xeeL-\xeeM,\xedL-\xffI$\xec\ +L/\xeeL,\xeeL,\xff33\xeeM,\xeeD\ +3\xf0M.\xeeL,\xeeK,\xeeL,\xeeL,\ +\xf0J,\xeeL,\xedM,\xedL-\xeeM,\xee\ +K+\xedL+\xe6M3\xeeL,\xeeL,\xebG\ +)\xedK,\xeeL,\xefP0\xedM+\xeeL,\ +\xffU+\xefL+\xeeM-\xefL,\xefL,\xee\ +K,\xeeL,\xebJ/\xedM.\xeeL,\xeeL\ +,\xedL,\xefL+\xeeL-\xeeL,\xeeK.\ +\xeeL,\xefL,\xedM,\xedL,\xeeL,\xee\ +L,\xefJ+\xeeL,\xeeL,\xf0I,\xeeM\ +,\xeeM+\xeeL,\xeeL,\xeeL,\xedM*\ +\xeeL,\xeeL,\xeeM-\xeeK,\xeeM,\xee\ +L+\xefM,\xeeM,\xefK,\xedK+\xefL\ +,\xedK-\xe9N,\xedJ+\xefK,\xeeL,\ +\xefK-\xf0K-\xeeL,\xeeL-\xefL+\xef\ +L,\xeeM-\xedL+\xefM-\xeeL+\xefM\ +-\xedL,\xecJ*\xf2Q(\xeeM,\xeeL,\ +\xedM-\xeaU+\xeeL,\xeeM+\xefK+\xed\ +L,\xefL,\xeeL,\xeeL,\xeeM,\xeeL\ +-\xeeL,\xefJ-\xeeL,\xefL-\xf1N+\ +\xeeL,\xeeM,\xf0M+\xedM+\xeeL,\xee\ +L,\xeeL,\xefL+\xeeM+\xeeM,\xeeM\ +,\xeeJ-\xeeL,\xeeL,\xeeM,\xedM,\ +\xf0K-\xebN'\xeeL,\xff\xff\xff\x8c_ $\ +\x00\x00\x00\xb6tRNS\x00\x01\x95\x09\x9a\xa6\xa9\x04\ +%6\x10\x031\xcf\xfa\x97\x08!\xef\xaa\x993\xd5o\ +\xd1k\x8c\xfd(\x15\xe2\x91\x1d\xae\xe8u\x02\x18\xa5D\ +\x82r\x80jgs\xfb:\xdf\x12\xa7C\xab\xd4\xd2\xbd\ +\x07\x1b\xe0\xd6\x05\xdc\x0f2\xf8i\xdb\xf94\xddV\x83\ +\xc1X\xbc\x0a\xfc\xc3\x19b\xfe q\xe4\x06/x\xbb\ +\x7f\xc4\xf4&8\xd0h\x90\x8d\x94\xb3,\xbf\xcaF\xda\ +\xd7\xf30\xe7\xe3#\xe6\x1e\xf7\xd8\xf1+\xeb\xcc\xb7\x5c\ +\x92Mn\x85\xba\x81@U\x17Hm\xa1}\x11\xf2J\ +|\xc9\xc8e`\x93~\xe97\x13\xbe\xde\x8f\x0c\xb5Y\ +p\x9e\x9d\xed\xb2L\x89\xa2>\xc6a$\xa8{BS\ +\xf6\xea\xee^\x88\xf0\x96-\x86\xec\xa3\xad\x22\x0d@!\ +\x93S\x00\x00\x00\x01bKGD\xb7\xdd\x00;g\x00\ +\x00\x00\x07tIME\x07\xe8\x07\x08\x10*\x1dQ\x1c\ +\xc41\x00\x00\x03TIDATX\xc3\xa5V\xf9C\ +LQ\x14>\xd3\x944K\x85\xb4\xa8\x946\xb4\x8e\xb2\ +f\x10)\xb4\x0e\x8d(\xfb\x12\x225-H\x8aP\x96\ +\xac\xa5\x12\xca.\xd9I\xee\xdfg\xe6\xbds\xdf{\xb3\ +\xf4\xde\xbbw\xeeO\xe7|\xf7|\xdf\xbd\xf7\xdd\xfb\xce\ +9\x00\xaa\xc3\x10b\x84`\x86!\x84\x18\x83\xe4\x07%\ +\xe0\xe1\x07# \xf0\x83\x10\x10\xf9\xfc\x02\xc8\xe7\x16\xa0\ +|^\x01\x89\xcf) \xf3\xf9\x04\x14|.\x01C\xa8\ +\xcc\x0f3\x04\xb7\xfe\x82p\x19_\x18ab_\xdf,\ +\xc1\x16+!\x91Q\xd1\xdc\xeb/Z, K\xb8\xf7\ +\x1f\x83\xd8RN>\xc4\x22\x18\xc7\xc9\x87xD\x138\ +\xf9\xb0LD\x13\x938\xf9\x90\xbc\x5c\x80Sx\xf9\x00\ +\xa9+\x22IZz\x067\xdf=\x8c\x99\x19\xaa\x1f \ +J\xe6g\x99\xb5n{\xa5?\xb4Jc}\xaf\xb1\x9a\ +d\xfbB9,\xeb\xa7\xbb\xa3r\xbd\xa1\x88H\x86\xf5\ +\xf3\xae\x98&\ +\xef\xde\xfb\xf4[\x1f\x06Z\x94\xf3\xd6\x00\xed\x98\xf3\xa3\ +2\x82\xb8z?M\xa7\x0a\xdbt&}\x1e\xb5z\xcd\ +\x91\xea\xc1@G\xccH!\xbec\xe6K\x99\xedk\x95\ +\x1f<9_\x1d\xfc\xd6Lt\x0c\xc7\xf7\xf9_|\x92\ +U\x9b\xdf:\xad\xfa\xd0\xe2~\xa8\xd3\xf3c\x865\xde\ +j\xffO\x97\xca\xee'\x7f\xe9\xf8\xdd\xfa\x7f\xff\x09L\ +/\x99\xd5\xdb\x85\xd8\xa7\xb2\xfcNRZQkb\xc9\ +\xdbf\xcb\xc4\xec\x10\xfeb3\xf1\x7f\xfb\xa6+\x81g\ +D\x0f\xcf\xcde\xfeS\x0d\xf9\x0f|\xbd\x92\x22\xc7k\ +h\x91\x00\x00\x00%tEXtdate:c\ +reate\x002024-07-08\ +T16:42:28+00:00\x91\ +\x1c5\xf8\x00\x00\x00%tEXtdate:\ +modify\x002024-07-0\ +8T16:42:28+00:00\ +\xe0A\x8dD\x00\x00\x00\x00IEND\xaeB`\x82\ +\ \x00\x00\x171\ \x89\ PNG\x0d\x0a\x1a\x0a\x00\x00\x00\x0dIHDR\x00\ @@ -19653,6 +19772,11 @@ \x04\xd2YG\ \x00i\ \x00n\x00f\x00o\x00.\x00p\x00n\x00g\ +\x00\x1e\ +\x01_\x1b\xa7\ +\x006\ +\x004\x00p\x00x\x00-\x00P\x00y\x00T\x00o\x00r\x00c\x00h\x00_\x00l\x00o\x00g\x00o\ +\x00_\x00i\x00c\x00o\x00n\x00.\x00s\x00v\x00g\x00.\x00p\x00n\x00g\ \x00\x0f\ \x0cYr'\ \x00h\ @@ -19669,46 +19793,48 @@ \x00\x00\x00\x00\x00\x00\x00\x00\ \x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x02\ \x00\x00\x00\x00\x00\x00\x00\x00\ -\x00\x00\x00\x22\x00\x02\x00\x00\x00\x01\x00\x00\x00\x15\ +\x00\x00\x00\x22\x00\x02\x00\x00\x00\x01\x00\x00\x00\x16\ \x00\x00\x00\x00\x00\x00\x00\x00\ -\x00\x00\x00\x12\x00\x02\x00\x00\x00\x0d\x00\x00\x00\x08\ +\x00\x00\x00\x12\x00\x02\x00\x00\x00\x0e\x00\x00\x00\x08\ \x00\x00\x00\x00\x00\x00\x00\x00\ \x00\x00\x000\x00\x02\x00\x00\x00\x03\x00\x00\x00\x05\ \x00\x00\x00\x00\x00\x00\x00\x00\ \x00\x00\x00B\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\ -\x00\x00\x01\x93Ju\xc2>\ +\x00\x00\x01\x93K\x85\xbd\xbb\ \x00\x00\x00\xb0\x00\x00\x00\x00\x00\x01\x00\x01\x86\x02\ -\x00\x00\x01\x93Ju\xc2A\ +\x00\x00\x01\x93K\x85\xbd\xbb\ \x00\x00\x00\x84\x00\x00\x00\x00\x00\x01\x00\x01 bool: + + loop = QEventLoop() + timer = QTimer() + timer.setSingleShot(True) + timer.timeout.connect(loop.quit) + + def check_threads(): + if all(thread.isFinished() for thread in threads): + loop.quit() + + check_timer = QTimer() + check_timer.timeout.connect(check_threads) + check_timer.start(100) # Check every 100ms + + timer.start(timeout) + loop.exec() + + check_timer.stop() + timer.stop() + + # Return True if all threads finished, False if timed out + return all(thread.isFinished() for thread in threads) + + class StatsThread(QThread): - completed = Signal(onnx_utils.DigestOnnxModel, str) + completed = Signal(DigestOnnxModel, str) def __init__( self, @@ -31,50 +59,104 @@ def run(self): if not self.unique_id: raise ValueError("You must specify a unique id.") - digest_model = onnx_utils.DigestOnnxModel(self.model, save_proto=False) + digest_model = DigestOnnxModel(self.model, save_proto=False) self.completed.emit(digest_model, self.unique_id) + def wait(self, timeout=10000): + wait_threads([self], timeout) + class SimilarityThread(QThread): - completed_successfully = Signal(bool, str, str, str) + completed_successfully = Signal(bool, str, str, str, pd.DataFrame) def __init__( self, - model_filepath: Optional[str] = None, - png_filepath: Optional[str] = None, + model_file_path: Optional[str] = None, + png_file_path: Optional[str] = None, model_id: Optional[str] = None, ): super().__init__() - self.model_filepath = model_filepath - self.png_filepath = png_filepath + self.model_file_path = model_file_path + self.png_file_path = png_file_path self.model_id = model_id def run(self): - if not self.model_filepath: - raise ValueError("You must set the model filepath") - if not self.png_filepath: - raise ValueError("You must set the png filepath") + if not self.model_file_path: + raise ValueError("You must set the model file_path") + if not self.png_file_path: + raise ValueError("You must set the png file_path") if not self.model_id: raise ValueError("You must set the model id") try: - most_similar, _ = find_match( - self.model_filepath, - self.png_filepath, + most_similar, _, df_sorted = find_match( + self.model_file_path, dequantize=False, replace=True, - dark_mode=True, ) most_similar = [os.path.basename(path) for path in most_similar] - most_similar = ",".join(most_similar[1:4]) + # We convert List[str] to str to send through the signal + most_similar = ",".join(most_similar) self.completed_successfully.emit( - True, self.model_id, most_similar, self.png_filepath + True, self.model_id, most_similar, self.png_file_path, df_sorted ) except Exception as e: # pylint: disable=broad-exception-caught most_similar = "" self.completed_successfully.emit( - False, self.model_id, most_similar, self.png_filepath + False, self.model_id, most_similar, self.png_file_path, df_sorted ) print(f"Issue creating similarity analysis: {e}") + + def wait(self, timeout=10000): + wait_threads([self], timeout) + + +def post_process( + model_name: str, + name_list: List[str], + df_sorted: pd.DataFrame, + png_file_path: str, + dark_mode: bool = True, +): + """Matplotlib is not thread safe so we must do post_processing on the main thread""" + if dark_mode: + plt.style.use("dark_background") + fig, ax = plt.subplots(figsize=(12, 10)) + im = ax.imshow(df_sorted, cmap="viridis") + + # Show all ticks and label them with the respective list entries + ax.set_xticks(np.arange(len(df_sorted.columns))) + ax.set_yticks(np.arange(len(name_list))) + ax.set_xticklabels([a[:5] for a in df_sorted.columns]) + ax.set_yticklabels(name_list) + + # Rotate the tick labels and set their alignment + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + ax.set_title(f"Model Similarity Heatmap - {model_name}") + + cb = plt.colorbar( + im, + ax=ax, + shrink=0.5, + format="%.2f", + label="Correlation Ratio", + orientation="vertical", + # pad=0.02, + ) + cb.set_ticks([0, 0.5, 1]) # Set colorbar ticks at 0, 0.5, and 1 + cb.set_ticklabels( + ["0.0 (Low)", "0.5 (Medium)", "1.0 (High)"] + ) # Set corresponding labels + cb.set_label("Correlation Ratio", labelpad=-100) + + fig.tight_layout() + + if png_file_path is None: + png_file_path = "heatmap.png" + + fig.savefig(png_file_path) + + plt.close(fig) diff --git a/src/digest/ui/freezeinputs_ui.py b/src/digest/ui/freezeinputs_ui.py index 3838e57..85e8211 100644 --- a/src/digest/ui/freezeinputs_ui.py +++ b/src/digest/ui/freezeinputs_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'freezeinputs.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ diff --git a/src/digest/ui/huggingface_page_ui.py b/src/digest/ui/huggingface_page_ui.py index a06a573..5cfcbfe 100644 --- a/src/digest/ui/huggingface_page_ui.py +++ b/src/digest/ui/huggingface_page_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'huggingface_page.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ diff --git a/src/digest/ui/mainwindow.ui b/src/digest/ui/mainwindow.ui index 8643efa..e7e28f3 100644 --- a/src/digest/ui/mainwindow.ui +++ b/src/digest/ui/mainwindow.ui @@ -179,7 +179,7 @@ Qt::FocusPolicy::NoFocus - <html><head/><body><p>Open a local model file (Ctrl-O)</p></body></html> + <html><head/><body><p>Open (Ctrl-O)</p></body></html> QPushButton { diff --git a/src/digest/ui/mainwindow_ui.py b/src/digest/ui/mainwindow_ui.py index 9904c77..61119a1 100644 --- a/src/digest/ui/mainwindow_ui.py +++ b/src/digest/ui/mainwindow_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'mainwindow.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -520,7 +520,7 @@ def setupUi(self, MainWindow): def retranslateUi(self, MainWindow): MainWindow.setWindowTitle(QCoreApplication.translate("MainWindow", u"DigestAI", None)) #if QT_CONFIG(tooltip) - self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"

Open a local model file (Ctrl-O)

", None)) + self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"

Open (Ctrl-O)

", None)) #endif // QT_CONFIG(tooltip) self.openFileBtn.setText("") #if QT_CONFIG(shortcut) diff --git a/src/digest/ui/modelsummary.ui b/src/digest/ui/modelsummary.ui index 180fed4..737cf33 100644 --- a/src/digest/ui/modelsummary.ui +++ b/src/digest/ui/modelsummary.ui @@ -6,8 +6,8 @@ 0 0 - 980 - 687 + 1138 + 837
@@ -153,11 +153,17 @@ border-top-right-radius: 10px; 0 - 0 + -776 991 - 1453 + 1443 + + + 0 + 0 + + background-color: black; @@ -244,7 +250,7 @@ QFrame:hover { 6 - + @@ -271,7 +277,7 @@ QFrame:hover { - + @@ -667,20 +673,32 @@ QFrame:hover { - + 0 0 + + + 300 + 500 + + - + - + 0 0 + + + 0 + 0 + + 16777215 @@ -690,6 +708,9 @@ QFrame:hover { Loading... + + false + Qt::AlignmentFlag::AlignCenter @@ -834,7 +855,7 @@ QFrame:hover { - + 0 0 @@ -853,6 +874,9 @@ QFrame:hover { + + 6 + 20 @@ -861,8 +885,14 @@ QFrame:hover { - + + + + 0 + 0 + + QLabel { font-size: 18px; @@ -875,11 +905,11 @@ QFrame:hover { - + - - 0 + + 1 0 @@ -975,7 +1005,7 @@ QScrollBar::handle:vertical { - + @@ -983,6 +1013,18 @@ QScrollBar::handle:vertical { 0 + + + 0 + 0 + + + + + 16777215 + 16777215 + + PointingHandCursor @@ -1067,7 +1109,7 @@ QPushButton:pressed { - + 0 0 @@ -1218,7 +1260,7 @@ QScrollBar::handle:vertical { - + diff --git a/src/digest/ui/modelsummary_ui.py b/src/digest/ui/modelsummary_ui.py index 1102e3a..e217372 100644 --- a/src/digest/ui/modelsummary_ui.py +++ b/src/digest/ui/modelsummary_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'modelsummary.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -29,7 +29,7 @@ class Ui_modelSummary(object): def setupUi(self, modelSummary): if not modelSummary.objectName(): modelSummary.setObjectName(u"modelSummary") - modelSummary.resize(980, 687) + modelSummary.resize(1138, 837) sizePolicy = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -115,17 +115,22 @@ def setupUi(self, modelSummary): self.scrollArea.setWidgetResizable(True) self.scrollAreaWidgetContents = QWidget() self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") - self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 991, 1453)) + self.scrollAreaWidgetContents.setGeometry(QRect(0, -776, 991, 1443)) + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.scrollAreaWidgetContents.sizePolicy().hasHeightForWidth()) + self.scrollAreaWidgetContents.setSizePolicy(sizePolicy2) self.scrollAreaWidgetContents.setStyleSheet(u"background-color: black;") self.verticalLayout_20 = QVBoxLayout(self.scrollAreaWidgetContents) self.verticalLayout_20.setObjectName(u"verticalLayout_20") self.cardFrame = QFrame(self.scrollAreaWidgetContents) self.cardFrame.setObjectName(u"cardFrame") - sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) - sizePolicy2.setHorizontalStretch(0) - sizePolicy2.setVerticalStretch(0) - sizePolicy2.setHeightForWidth(self.cardFrame.sizePolicy().hasHeightForWidth()) - self.cardFrame.setSizePolicy(sizePolicy2) + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.cardFrame.sizePolicy().hasHeightForWidth()) + self.cardFrame.setSizePolicy(sizePolicy3) self.cardFrame.setStyleSheet(u"background: transparent; /*rgb(40,40,40)*/") self.cardFrame.setFrameShape(QFrame.Shape.StyledPanel) self.cardFrame.setFrameShadow(QFrame.Shadow.Raised) @@ -134,19 +139,19 @@ def setupUi(self, modelSummary): self.horizontalLayout.setContentsMargins(-1, -1, -1, 1) self.cardWidget = QWidget(self.cardFrame) self.cardWidget.setObjectName(u"cardWidget") - sizePolicy2.setHeightForWidth(self.cardWidget.sizePolicy().hasHeightForWidth()) - self.cardWidget.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.cardWidget.sizePolicy().hasHeightForWidth()) + self.cardWidget.setSizePolicy(sizePolicy3) self.horizontalLayout_2 = QHBoxLayout(self.cardWidget) self.horizontalLayout_2.setSpacing(13) self.horizontalLayout_2.setObjectName(u"horizontalLayout_2") self.horizontalLayout_2.setContentsMargins(-1, 6, 25, 35) self.opsetFrame = QFrame(self.cardWidget) self.opsetFrame.setObjectName(u"opsetFrame") - sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed) - sizePolicy3.setHorizontalStretch(0) - sizePolicy3.setVerticalStretch(0) - sizePolicy3.setHeightForWidth(self.opsetFrame.sizePolicy().hasHeightForWidth()) - self.opsetFrame.setSizePolicy(sizePolicy3) + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed) + sizePolicy4.setHorizontalStretch(0) + sizePolicy4.setVerticalStretch(0) + sizePolicy4.setHeightForWidth(self.opsetFrame.sizePolicy().hasHeightForWidth()) + self.opsetFrame.setSizePolicy(sizePolicy4) self.opsetFrame.setMinimumSize(QSize(220, 70)) self.opsetFrame.setMaximumSize(QSize(16777215, 80)) self.opsetFrame.setStyleSheet(u"QFrame {\n" @@ -164,11 +169,11 @@ def setupUi(self, modelSummary): self.verticalLayout_5.setContentsMargins(-1, -1, 6, -1) self.opsetLabel = QLabel(self.opsetFrame) self.opsetLabel.setObjectName(u"opsetLabel") - sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) - sizePolicy4.setHorizontalStretch(0) - sizePolicy4.setVerticalStretch(0) - sizePolicy4.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth()) - self.opsetLabel.setSizePolicy(sizePolicy4) + sizePolicy5 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Fixed) + sizePolicy5.setHorizontalStretch(0) + sizePolicy5.setVerticalStretch(0) + sizePolicy5.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth()) + self.opsetLabel.setSizePolicy(sizePolicy5) self.opsetLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -178,12 +183,12 @@ def setupUi(self, modelSummary): self.opsetLabel.setAlignment(Qt.AlignmentFlag.AlignCenter) self.opsetLabel.setTextInteractionFlags(Qt.TextInteractionFlag.TextSelectableByMouse) - self.verticalLayout_5.addWidget(self.opsetLabel, 0, Qt.AlignmentFlag.AlignHCenter) + self.verticalLayout_5.addWidget(self.opsetLabel) self.opsetVersion = QLabel(self.opsetFrame) self.opsetVersion.setObjectName(u"opsetVersion") - sizePolicy4.setHeightForWidth(self.opsetVersion.sizePolicy().hasHeightForWidth()) - self.opsetVersion.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.opsetVersion.sizePolicy().hasHeightForWidth()) + self.opsetVersion.setSizePolicy(sizePolicy5) self.opsetVersion.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -192,15 +197,15 @@ def setupUi(self, modelSummary): self.opsetVersion.setAlignment(Qt.AlignmentFlag.AlignCenter) self.opsetVersion.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse) - self.verticalLayout_5.addWidget(self.opsetVersion, 0, Qt.AlignmentFlag.AlignHCenter) + self.verticalLayout_5.addWidget(self.opsetVersion) self.horizontalLayout_2.addWidget(self.opsetFrame) self.nodesFrame = QFrame(self.cardWidget) self.nodesFrame.setObjectName(u"nodesFrame") - sizePolicy3.setHeightForWidth(self.nodesFrame.sizePolicy().hasHeightForWidth()) - self.nodesFrame.setSizePolicy(sizePolicy3) + sizePolicy4.setHeightForWidth(self.nodesFrame.sizePolicy().hasHeightForWidth()) + self.nodesFrame.setSizePolicy(sizePolicy4) self.nodesFrame.setMinimumSize(QSize(220, 70)) self.nodesFrame.setMaximumSize(QSize(16777215, 80)) self.nodesFrame.setStyleSheet(u"QFrame {\n" @@ -218,8 +223,8 @@ def setupUi(self, modelSummary): self.verticalLayout_12.setContentsMargins(-1, 9, -1, -1) self.nodesLabel = QLabel(self.nodesFrame) self.nodesLabel.setObjectName(u"nodesLabel") - sizePolicy4.setHeightForWidth(self.nodesLabel.sizePolicy().hasHeightForWidth()) - self.nodesLabel.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.nodesLabel.sizePolicy().hasHeightForWidth()) + self.nodesLabel.setSizePolicy(sizePolicy5) self.nodesLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -233,8 +238,8 @@ def setupUi(self, modelSummary): self.nodes = QLabel(self.nodesFrame) self.nodes.setObjectName(u"nodes") - sizePolicy4.setHeightForWidth(self.nodes.sizePolicy().hasHeightForWidth()) - self.nodes.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.nodes.sizePolicy().hasHeightForWidth()) + self.nodes.setSizePolicy(sizePolicy5) self.nodes.setMinimumSize(QSize(150, 32)) self.nodes.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" @@ -254,8 +259,8 @@ def setupUi(self, modelSummary): self.paramFrame = QFrame(self.cardWidget) self.paramFrame.setObjectName(u"paramFrame") - sizePolicy3.setHeightForWidth(self.paramFrame.sizePolicy().hasHeightForWidth()) - self.paramFrame.setSizePolicy(sizePolicy3) + sizePolicy4.setHeightForWidth(self.paramFrame.sizePolicy().hasHeightForWidth()) + self.paramFrame.setSizePolicy(sizePolicy4) self.paramFrame.setMinimumSize(QSize(220, 70)) self.paramFrame.setMaximumSize(QSize(16777215, 80)) self.paramFrame.setStyleSheet(u"QFrame {\n" @@ -272,8 +277,8 @@ def setupUi(self, modelSummary): self.verticalLayout_9.setObjectName(u"verticalLayout_9") self.parametersLabel = QLabel(self.paramFrame) self.parametersLabel.setObjectName(u"parametersLabel") - sizePolicy4.setHeightForWidth(self.parametersLabel.sizePolicy().hasHeightForWidth()) - self.parametersLabel.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.parametersLabel.sizePolicy().hasHeightForWidth()) + self.parametersLabel.setSizePolicy(sizePolicy5) self.parametersLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -287,8 +292,8 @@ def setupUi(self, modelSummary): self.parameters = QLabel(self.paramFrame) self.parameters.setObjectName(u"parameters") - sizePolicy4.setHeightForWidth(self.parameters.sizePolicy().hasHeightForWidth()) - self.parameters.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.parameters.sizePolicy().hasHeightForWidth()) + self.parameters.setSizePolicy(sizePolicy5) self.parameters.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -304,8 +309,8 @@ def setupUi(self, modelSummary): self.flopsFrame = QFrame(self.cardWidget) self.flopsFrame.setObjectName(u"flopsFrame") - sizePolicy3.setHeightForWidth(self.flopsFrame.sizePolicy().hasHeightForWidth()) - self.flopsFrame.setSizePolicy(sizePolicy3) + sizePolicy4.setHeightForWidth(self.flopsFrame.sizePolicy().hasHeightForWidth()) + self.flopsFrame.setSizePolicy(sizePolicy4) self.flopsFrame.setMinimumSize(QSize(220, 70)) self.flopsFrame.setMaximumSize(QSize(16777215, 80)) self.flopsFrame.setCursor(QCursor(Qt.CursorShape.ArrowCursor)) @@ -323,8 +328,8 @@ def setupUi(self, modelSummary): self.verticalLayout_11.setObjectName(u"verticalLayout_11") self.flopsLabel = QLabel(self.flopsFrame) self.flopsLabel.setObjectName(u"flopsLabel") - sizePolicy4.setHeightForWidth(self.flopsLabel.sizePolicy().hasHeightForWidth()) - self.flopsLabel.setSizePolicy(sizePolicy4) + sizePolicy5.setHeightForWidth(self.flopsLabel.sizePolicy().hasHeightForWidth()) + self.flopsLabel.setSizePolicy(sizePolicy5) self.flopsLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" @@ -338,11 +343,11 @@ def setupUi(self, modelSummary): self.flops = QLabel(self.flopsFrame) self.flops.setObjectName(u"flops") - sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Fixed) - sizePolicy5.setHorizontalStretch(0) - sizePolicy5.setVerticalStretch(0) - sizePolicy5.setHeightForWidth(self.flops.sizePolicy().hasHeightForWidth()) - self.flops.setSizePolicy(sizePolicy5) + sizePolicy6 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Fixed) + sizePolicy6.setHorizontalStretch(0) + sizePolicy6.setVerticalStretch(0) + sizePolicy6.setHeightForWidth(self.flops.sizePolicy().hasHeightForWidth()) + self.flops.setSizePolicy(sizePolicy6) self.flops.setMinimumSize(QSize(200, 32)) self.flops.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" @@ -380,11 +385,11 @@ def setupUi(self, modelSummary): self.parametersPieChart = PieChartWidget(self.scrollAreaWidgetContents) self.parametersPieChart.setObjectName(u"parametersPieChart") - sizePolicy6 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) - sizePolicy6.setHorizontalStretch(0) - sizePolicy6.setVerticalStretch(0) - sizePolicy6.setHeightForWidth(self.parametersPieChart.sizePolicy().hasHeightForWidth()) - self.parametersPieChart.setSizePolicy(sizePolicy6) + sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + sizePolicy7.setHorizontalStretch(0) + sizePolicy7.setVerticalStretch(0) + sizePolicy7.setHeightForWidth(self.parametersPieChart.sizePolicy().hasHeightForWidth()) + self.parametersPieChart.setSizePolicy(sizePolicy7) self.parametersPieChart.setMinimumSize(QSize(300, 500)) self.firstRowChartsLayout.addWidget(self.parametersPieChart) @@ -397,23 +402,29 @@ def setupUi(self, modelSummary): self.secondRowChartsLayout.setContentsMargins(-1, 20, -1, -1) self.similarityWidget = QWidget(self.scrollAreaWidgetContents) self.similarityWidget.setObjectName(u"similarityWidget") - sizePolicy.setHeightForWidth(self.similarityWidget.sizePolicy().hasHeightForWidth()) - self.similarityWidget.setSizePolicy(sizePolicy) + sizePolicy7.setHeightForWidth(self.similarityWidget.sizePolicy().hasHeightForWidth()) + self.similarityWidget.setSizePolicy(sizePolicy7) + self.similarityWidget.setMinimumSize(QSize(300, 500)) self.placeholderWidget = QVBoxLayout(self.similarityWidget) self.placeholderWidget.setObjectName(u"placeholderWidget") self.similarityImg = ClickableLabel(self.similarityWidget) self.similarityImg.setObjectName(u"similarityImg") - sizePolicy.setHeightForWidth(self.similarityImg.sizePolicy().hasHeightForWidth()) - self.similarityImg.setSizePolicy(sizePolicy) + sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + sizePolicy8.setHorizontalStretch(0) + sizePolicy8.setVerticalStretch(0) + sizePolicy8.setHeightForWidth(self.similarityImg.sizePolicy().hasHeightForWidth()) + self.similarityImg.setSizePolicy(sizePolicy8) + self.similarityImg.setMinimumSize(QSize(0, 0)) self.similarityImg.setMaximumSize(QSize(16777215, 16777215)) + self.similarityImg.setScaledContents(False) self.similarityImg.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.placeholderWidget.addWidget(self.similarityImg, 0, Qt.AlignmentFlag.AlignHCenter) + self.placeholderWidget.addWidget(self.similarityImg) self.similarityCorrelationStatic = QLabel(self.similarityWidget) self.similarityCorrelationStatic.setObjectName(u"similarityCorrelationStatic") - sizePolicy2.setHeightForWidth(self.similarityCorrelationStatic.sizePolicy().hasHeightForWidth()) - self.similarityCorrelationStatic.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.similarityCorrelationStatic.sizePolicy().hasHeightForWidth()) + self.similarityCorrelationStatic.setSizePolicy(sizePolicy3) self.similarityCorrelationStatic.setFont(font) self.similarityCorrelationStatic.setAlignment(Qt.AlignmentFlag.AlignCenter) @@ -421,8 +432,8 @@ def setupUi(self, modelSummary): self.similarityCorrelation = QLabel(self.similarityWidget) self.similarityCorrelation.setObjectName(u"similarityCorrelation") - sizePolicy2.setHeightForWidth(self.similarityCorrelation.sizePolicy().hasHeightForWidth()) - self.similarityCorrelation.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.similarityCorrelation.sizePolicy().hasHeightForWidth()) + self.similarityCorrelation.setSizePolicy(sizePolicy3) palette = QPalette() brush = QBrush(QColor(0, 0, 0, 255)) brush.setStyle(Qt.SolidPattern) @@ -446,9 +457,6 @@ def setupUi(self, modelSummary): self.flopsPieChart = PieChartWidget(self.scrollAreaWidgetContents) self.flopsPieChart.setObjectName(u"flopsPieChart") - sizePolicy7 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Preferred) - sizePolicy7.setHorizontalStretch(0) - sizePolicy7.setVerticalStretch(0) sizePolicy7.setHeightForWidth(self.flopsPieChart.sizePolicy().hasHeightForWidth()) self.flopsPieChart.setSizePolicy(sizePolicy7) self.flopsPieChart.setMinimumSize(QSize(300, 500)) @@ -462,19 +470,25 @@ def setupUi(self, modelSummary): self.verticalLayout_20.addLayout(self.chartsLayout) self.thirdRowInputsLayout = QHBoxLayout() + self.thirdRowInputsLayout.setSpacing(6) self.thirdRowInputsLayout.setObjectName(u"thirdRowInputsLayout") self.thirdRowInputsLayout.setContentsMargins(20, 30, -1, -1) self.inputsLayout = QVBoxLayout() self.inputsLayout.setObjectName(u"inputsLayout") self.inputsLabel = QLabel(self.scrollAreaWidgetContents) self.inputsLabel.setObjectName(u"inputsLabel") + sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Maximum) + sizePolicy9.setHorizontalStretch(0) + sizePolicy9.setVerticalStretch(0) + sizePolicy9.setHeightForWidth(self.inputsLabel.sizePolicy().hasHeightForWidth()) + self.inputsLabel.setSizePolicy(sizePolicy9) self.inputsLabel.setStyleSheet(u"QLabel {\n" " font-size: 18px;\n" " font-weight: bold;\n" " background: transparent;\n" "}") - self.inputsLayout.addWidget(self.inputsLabel, 0, Qt.AlignmentFlag.AlignVCenter) + self.inputsLayout.addWidget(self.inputsLabel) self.inputsTable = QTableWidget(self.scrollAreaWidgetContents) if (self.inputsTable.columnCount() < 4): @@ -488,11 +502,11 @@ def setupUi(self, modelSummary): __qtablewidgetitem3 = QTableWidgetItem() self.inputsTable.setHorizontalHeaderItem(3, __qtablewidgetitem3) self.inputsTable.setObjectName(u"inputsTable") - sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Preferred) - sizePolicy8.setHorizontalStretch(0) - sizePolicy8.setVerticalStretch(0) - sizePolicy8.setHeightForWidth(self.inputsTable.sizePolicy().hasHeightForWidth()) - self.inputsTable.setSizePolicy(sizePolicy8) + sizePolicy10 = QSizePolicy(QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding) + sizePolicy10.setHorizontalStretch(1) + sizePolicy10.setVerticalStretch(0) + sizePolicy10.setHeightForWidth(self.inputsTable.sizePolicy().hasHeightForWidth()) + self.inputsTable.setSizePolicy(sizePolicy10) self.inputsTable.setStyleSheet(u"QTableWidget {\n" " gridline-color: #353535; /* Grid lines */\n" " selection-background-color: #3949AB; /* Blue selection */\n" @@ -543,18 +557,20 @@ def setupUi(self, modelSummary): self.inputsTable.verticalHeader().setVisible(False) self.inputsTable.verticalHeader().setHighlightSections(True) - self.inputsLayout.addWidget(self.inputsTable, 0, Qt.AlignmentFlag.AlignVCenter) + self.inputsLayout.addWidget(self.inputsTable) self.thirdRowInputsLayout.addLayout(self.inputsLayout) self.freezeButton = QPushButton(self.scrollAreaWidgetContents) self.freezeButton.setObjectName(u"freezeButton") - sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) - sizePolicy9.setHorizontalStretch(0) - sizePolicy9.setVerticalStretch(0) - sizePolicy9.setHeightForWidth(self.freezeButton.sizePolicy().hasHeightForWidth()) - self.freezeButton.setSizePolicy(sizePolicy9) + sizePolicy11 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + sizePolicy11.setHorizontalStretch(0) + sizePolicy11.setVerticalStretch(0) + sizePolicy11.setHeightForWidth(self.freezeButton.sizePolicy().hasHeightForWidth()) + self.freezeButton.setSizePolicy(sizePolicy11) + self.freezeButton.setMinimumSize(QSize(0, 0)) + self.freezeButton.setMaximumSize(QSize(16777215, 16777215)) self.freezeButton.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) self.freezeButton.setStyleSheet(u"QPushButton {\n" " color: white;\n" @@ -580,7 +596,7 @@ def setupUi(self, modelSummary): self.freezeButton.setIcon(icon) self.freezeButton.setIconSize(QSize(32, 32)) - self.thirdRowInputsLayout.addWidget(self.freezeButton, 0, Qt.AlignmentFlag.AlignTop) + self.thirdRowInputsLayout.addWidget(self.freezeButton) self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) @@ -616,8 +632,11 @@ def setupUi(self, modelSummary): __qtablewidgetitem7 = QTableWidgetItem() self.outputsTable.setHorizontalHeaderItem(3, __qtablewidgetitem7) self.outputsTable.setObjectName(u"outputsTable") - sizePolicy8.setHeightForWidth(self.outputsTable.sizePolicy().hasHeightForWidth()) - self.outputsTable.setSizePolicy(sizePolicy8) + sizePolicy12 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Expanding) + sizePolicy12.setHorizontalStretch(0) + sizePolicy12.setVerticalStretch(0) + sizePolicy12.setHeightForWidth(self.outputsTable.sizePolicy().hasHeightForWidth()) + self.outputsTable.setSizePolicy(sizePolicy12) self.outputsTable.setStyleSheet(u"QTableWidget {\n" " gridline-color: #353535; /* Grid lines */\n" " selection-background-color: #3949AB; /* Blue selection */\n" @@ -684,8 +703,8 @@ def setupUi(self, modelSummary): self.sidePaneFrame = QFrame(modelSummary) self.sidePaneFrame.setObjectName(u"sidePaneFrame") - sizePolicy2.setHeightForWidth(self.sidePaneFrame.sizePolicy().hasHeightForWidth()) - self.sidePaneFrame.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.sidePaneFrame.sizePolicy().hasHeightForWidth()) + self.sidePaneFrame.setSizePolicy(sizePolicy3) self.sidePaneFrame.setMinimumSize(QSize(0, 0)) self.sidePaneFrame.setStyleSheet(u"QFrame {\n" " /*background: rgb(30,30,30);*/\n" @@ -741,8 +760,8 @@ def setupUi(self, modelSummary): __qtablewidgetitem21 = QTableWidgetItem() self.modelProtoTable.setItem(3, 1, __qtablewidgetitem21) self.modelProtoTable.setObjectName(u"modelProtoTable") - sizePolicy2.setHeightForWidth(self.modelProtoTable.sizePolicy().hasHeightForWidth()) - self.modelProtoTable.setSizePolicy(sizePolicy2) + sizePolicy3.setHeightForWidth(self.modelProtoTable.sizePolicy().hasHeightForWidth()) + self.modelProtoTable.setSizePolicy(sizePolicy3) self.modelProtoTable.setMinimumSize(QSize(0, 0)) self.modelProtoTable.setMaximumSize(QSize(16777215, 100)) self.modelProtoTable.setStyleSheet(u"QTableWidget::item {\n" @@ -770,7 +789,7 @@ def setupUi(self, modelSummary): self.modelProtoTable.verticalHeader().setMinimumSectionSize(20) self.modelProtoTable.verticalHeader().setDefaultSectionSize(20) - self.verticalLayout_3.addWidget(self.modelProtoTable, 0, Qt.AlignmentFlag.AlignRight) + self.verticalLayout_3.addWidget(self.modelProtoTable) self.importsLabel = QLabel(self.sidePaneFrame) self.importsLabel.setObjectName(u"importsLabel") @@ -791,11 +810,8 @@ def setupUi(self, modelSummary): __qtablewidgetitem23 = QTableWidgetItem() self.importsTable.setHorizontalHeaderItem(1, __qtablewidgetitem23) self.importsTable.setObjectName(u"importsTable") - sizePolicy10 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding) - sizePolicy10.setHorizontalStretch(0) - sizePolicy10.setVerticalStretch(0) - sizePolicy10.setHeightForWidth(self.importsTable.sizePolicy().hasHeightForWidth()) - self.importsTable.setSizePolicy(sizePolicy10) + sizePolicy2.setHeightForWidth(self.importsTable.sizePolicy().hasHeightForWidth()) + self.importsTable.setSizePolicy(sizePolicy2) self.importsTable.setStyleSheet(u"QTableWidget::item {\n" " color: white;\n" " padding: 5px;\n" diff --git a/src/digest/ui/multimodelanalysis.ui b/src/digest/ui/multimodelanalysis.ui index cf044e3..16109d0 100644 --- a/src/digest/ui/multimodelanalysis.ui +++ b/src/digest/ui/multimodelanalysis.ui @@ -6,8 +6,8 @@ 0 0 - 908 - 647 + 1085 + 866 @@ -51,7 +51,7 @@ QFrame::Shadow::Raised - + @@ -176,7 +176,7 @@ - + 0 0 @@ -198,8 +198,8 @@ 0 0 - 888 - 464 + 1065 + 688 @@ -242,6 +242,12 @@ + + + 0 + 0 + + QFrame::Shape::StyledPanel @@ -258,7 +264,7 @@ QFrame::Shadow::Raised - + @@ -279,17 +285,19 @@ + + + 0 + 0 + + QFrame::Shape::StyledPanel QFrame::Shadow::Raised - - - - - + diff --git a/src/digest/ui/multimodelanalysis_ui.py b/src/digest/ui/multimodelanalysis_ui.py index 54aa6d6..9f4b359 100644 --- a/src/digest/ui/multimodelanalysis_ui.py +++ b/src/digest/ui/multimodelanalysis_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'multimodelanalysis.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -26,7 +26,7 @@ class Ui_multiModelAnalysis(object): def setupUi(self, multiModelAnalysis): if not multiModelAnalysis.objectName(): multiModelAnalysis.setObjectName(u"multiModelAnalysis") - multiModelAnalysis.resize(908, 647) + multiModelAnalysis.resize(1085, 866) sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) @@ -71,7 +71,7 @@ def setupUi(self, multiModelAnalysis): self.modelName.setIndent(5) self.modelName.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse) - self.verticalLayout_17.addWidget(self.modelName, 0, Qt.AlignmentFlag.AlignTop) + self.verticalLayout_17.addWidget(self.modelName) self.summaryTopBannerLayout.addWidget(self.modelNameFrame) @@ -127,7 +127,7 @@ def setupUi(self, multiModelAnalysis): self.scrollArea = QScrollArea(multiModelAnalysis) self.scrollArea.setObjectName(u"scrollArea") - sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.MinimumExpanding) + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy4.setHorizontalStretch(0) sizePolicy4.setVerticalStretch(0) sizePolicy4.setHeightForWidth(self.scrollArea.sizePolicy().hasHeightForWidth()) @@ -138,7 +138,7 @@ def setupUi(self, multiModelAnalysis): self.scrollArea.setWidgetResizable(True) self.scrollAreaWidgetContents = QWidget() self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") - self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 888, 464)) + self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 1065, 688)) sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) sizePolicy5.setHorizontalStretch(0) sizePolicy5.setVerticalStretch(100) @@ -165,6 +165,11 @@ def setupUi(self, multiModelAnalysis): self.frame_2 = QFrame(self.scrollAreaWidgetContents) self.frame_2.setObjectName(u"frame_2") + sizePolicy7 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.Preferred) + sizePolicy7.setHorizontalStretch(0) + sizePolicy7.setVerticalStretch(0) + sizePolicy7.setHeightForWidth(self.frame_2.sizePolicy().hasHeightForWidth()) + self.frame_2.setSizePolicy(sizePolicy7) self.frame_2.setFrameShape(QFrame.Shape.StyledPanel) self.frame_2.setFrameShadow(QFrame.Shadow.Raised) self.horizontalLayout_2 = QHBoxLayout(self.frame_2) @@ -177,29 +182,29 @@ def setupUi(self, multiModelAnalysis): self.verticalLayout_3.setObjectName(u"verticalLayout_3") self.opHistogramChart = HistogramChartWidget(self.combinedHistogramFrame) self.opHistogramChart.setObjectName(u"opHistogramChart") - sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Minimum) - sizePolicy7.setHorizontalStretch(0) - sizePolicy7.setVerticalStretch(0) - sizePolicy7.setHeightForWidth(self.opHistogramChart.sizePolicy().hasHeightForWidth()) - self.opHistogramChart.setSizePolicy(sizePolicy7) + sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Minimum) + sizePolicy8.setHorizontalStretch(0) + sizePolicy8.setVerticalStretch(0) + sizePolicy8.setHeightForWidth(self.opHistogramChart.sizePolicy().hasHeightForWidth()) + self.opHistogramChart.setSizePolicy(sizePolicy8) self.opHistogramChart.setMinimumSize(QSize(500, 300)) - self.verticalLayout_3.addWidget(self.opHistogramChart, 0, Qt.AlignmentFlag.AlignTop) + self.verticalLayout_3.addWidget(self.opHistogramChart) self.horizontalLayout_2.addWidget(self.combinedHistogramFrame) self.stackedHistogramFrame = QFrame(self.frame_2) self.stackedHistogramFrame.setObjectName(u"stackedHistogramFrame") + sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + sizePolicy9.setHorizontalStretch(0) + sizePolicy9.setVerticalStretch(0) + sizePolicy9.setHeightForWidth(self.stackedHistogramFrame.sizePolicy().hasHeightForWidth()) + self.stackedHistogramFrame.setSizePolicy(sizePolicy9) self.stackedHistogramFrame.setFrameShape(QFrame.Shape.StyledPanel) self.stackedHistogramFrame.setFrameShadow(QFrame.Shadow.Raised) self.verticalLayout_5 = QVBoxLayout(self.stackedHistogramFrame) self.verticalLayout_5.setObjectName(u"verticalLayout_5") - self.verticalLayout_4 = QVBoxLayout() - self.verticalLayout_4.setObjectName(u"verticalLayout_4") - - self.verticalLayout_5.addLayout(self.verticalLayout_4) - self.horizontalLayout_2.addWidget(self.stackedHistogramFrame) diff --git a/src/digest/ui/multimodelselection_page.ui b/src/digest/ui/multimodelselection_page.ui index c5d12f8..034ed88 100644 --- a/src/digest/ui/multimodelselection_page.ui +++ b/src/digest/ui/multimodelselection_page.ui @@ -52,7 +52,7 @@ - + @@ -68,7 +68,7 @@ - + false @@ -128,7 +128,7 @@ - Warning: The chosen folder contains more than MAX_ONNX_MODELS + Warning 2 @@ -141,7 +141,7 @@ - + 0 @@ -161,25 +161,77 @@ - Select All + All + + + true - + + + + 0 + 0 + + + + + 0 + 33 + + + + false + - 0 selected models - - - true + ONNX + + + + 0 + 0 + + + + + 0 + 33 + + + + false + + + + + + Reports + + + + + + + 0 + 0 + + + + + 550 + 0 + + @@ -191,8 +243,34 @@ + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + 0 selected models + + + true + + + diff --git a/src/digest/ui/multimodelselection_page_ui.py b/src/digest/ui/multimodelselection_page_ui.py index e6acb66..0e25178 100644 --- a/src/digest/ui/multimodelselection_page_ui.py +++ b/src/digest/ui/multimodelselection_page_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'multimodelselection_page.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ @@ -15,9 +15,9 @@ QFont, QFontDatabase, QGradient, QIcon, QImage, QKeySequence, QLinearGradient, QPainter, QPalette, QPixmap, QRadialGradient, QTransform) -from PySide6.QtWidgets import (QAbstractItemView, QApplication, QCheckBox, QHBoxLayout, - QLabel, QListView, QListWidget, QListWidgetItem, - QPushButton, QSizePolicy, QSpacerItem, QVBoxLayout, +from PySide6.QtWidgets import (QAbstractItemView, QApplication, QHBoxLayout, QLabel, + QListView, QListWidget, QListWidgetItem, QPushButton, + QRadioButton, QSizePolicy, QSpacerItem, QVBoxLayout, QWidget) class Ui_MultiModelSelection(object): @@ -59,7 +59,7 @@ def setupUi(self, MultiModelSelection): self.selectFolderBtn.setSizePolicy(sizePolicy) self.selectFolderBtn.setStyleSheet(u"") - self.horizontalLayout_2.addWidget(self.selectFolderBtn, 0, Qt.AlignmentFlag.AlignLeft|Qt.AlignmentFlag.AlignVCenter) + self.horizontalLayout_2.addWidget(self.selectFolderBtn) self.openAnalysisBtn = QPushButton(MultiModelSelection) self.openAnalysisBtn.setObjectName(u"openAnalysisBtn") @@ -68,7 +68,7 @@ def setupUi(self, MultiModelSelection): self.openAnalysisBtn.setSizePolicy(sizePolicy) self.openAnalysisBtn.setStyleSheet(u"") - self.horizontalLayout_2.addWidget(self.openAnalysisBtn, 0, Qt.AlignmentFlag.AlignLeft|Qt.AlignmentFlag.AlignVCenter) + self.horizontalLayout_2.addWidget(self.openAnalysisBtn) self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) @@ -104,33 +104,64 @@ def setupUi(self, MultiModelSelection): self.horizontalLayout_3 = QHBoxLayout() self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") - self.selectAllBox = QCheckBox(MultiModelSelection) - self.selectAllBox.setObjectName(u"selectAllBox") - sizePolicy.setHeightForWidth(self.selectAllBox.sizePolicy().hasHeightForWidth()) - self.selectAllBox.setSizePolicy(sizePolicy) - self.selectAllBox.setMinimumSize(QSize(0, 33)) - self.selectAllBox.setAutoFillBackground(False) - self.selectAllBox.setStyleSheet(u"") - - self.horizontalLayout_3.addWidget(self.selectAllBox) - - self.numSelectedLabel = QLabel(MultiModelSelection) - self.numSelectedLabel.setObjectName(u"numSelectedLabel") - self.numSelectedLabel.setStyleSheet(u"") - self.numSelectedLabel.setWordWrap(True) - - self.horizontalLayout_3.addWidget(self.numSelectedLabel) + self.radioAll = QRadioButton(MultiModelSelection) + self.radioAll.setObjectName(u"radioAll") + sizePolicy.setHeightForWidth(self.radioAll.sizePolicy().hasHeightForWidth()) + self.radioAll.setSizePolicy(sizePolicy) + self.radioAll.setMinimumSize(QSize(0, 33)) + self.radioAll.setAutoFillBackground(False) + self.radioAll.setStyleSheet(u"") + self.radioAll.setChecked(True) + + self.horizontalLayout_3.addWidget(self.radioAll) + + self.radioONNX = QRadioButton(MultiModelSelection) + self.radioONNX.setObjectName(u"radioONNX") + sizePolicy.setHeightForWidth(self.radioONNX.sizePolicy().hasHeightForWidth()) + self.radioONNX.setSizePolicy(sizePolicy) + self.radioONNX.setMinimumSize(QSize(0, 33)) + self.radioONNX.setAutoFillBackground(False) + self.radioONNX.setStyleSheet(u"") + + self.horizontalLayout_3.addWidget(self.radioONNX) + + self.radioReports = QRadioButton(MultiModelSelection) + self.radioReports.setObjectName(u"radioReports") + sizePolicy.setHeightForWidth(self.radioReports.sizePolicy().hasHeightForWidth()) + self.radioReports.setSizePolicy(sizePolicy) + self.radioReports.setMinimumSize(QSize(0, 33)) + self.radioReports.setAutoFillBackground(False) + self.radioReports.setStyleSheet(u"") + + self.horizontalLayout_3.addWidget(self.radioReports) self.duplicateLabel = QLabel(MultiModelSelection) self.duplicateLabel.setObjectName(u"duplicateLabel") + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.duplicateLabel.sizePolicy().hasHeightForWidth()) + self.duplicateLabel.setSizePolicy(sizePolicy2) + self.duplicateLabel.setMinimumSize(QSize(550, 0)) self.duplicateLabel.setStyleSheet(u"") self.duplicateLabel.setWordWrap(True) self.horizontalLayout_3.addWidget(self.duplicateLabel) + self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.horizontalLayout_3.addItem(self.horizontalSpacer_2) + self.verticalLayout.addLayout(self.horizontalLayout_3) + self.numSelectedLabel = QLabel(MultiModelSelection) + self.numSelectedLabel.setObjectName(u"numSelectedLabel") + self.numSelectedLabel.setStyleSheet(u"") + self.numSelectedLabel.setWordWrap(True) + + self.verticalLayout.addWidget(self.numSelectedLabel) + self.columnsLayout = QHBoxLayout() self.columnsLayout.setObjectName(u"columnsLayout") self.leftColumnLayout = QVBoxLayout() @@ -151,11 +182,11 @@ def setupUi(self, MultiModelSelection): self.rightColumnLayout.setObjectName(u"rightColumnLayout") self.duplicateListWidget = QListWidget(MultiModelSelection) self.duplicateListWidget.setObjectName(u"duplicateListWidget") - sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - sizePolicy2.setHorizontalStretch(0) - sizePolicy2.setVerticalStretch(0) - sizePolicy2.setHeightForWidth(self.duplicateListWidget.sizePolicy().hasHeightForWidth()) - self.duplicateListWidget.setSizePolicy(sizePolicy2) + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.duplicateListWidget.sizePolicy().hasHeightForWidth()) + self.duplicateListWidget.setSizePolicy(sizePolicy3) self.duplicateListWidget.setStyleSheet(u"") self.duplicateListWidget.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) self.duplicateListWidget.setSelectionMode(QAbstractItemView.SelectionMode.MultiSelection) @@ -184,9 +215,11 @@ def retranslateUi(self, MultiModelSelection): self.selectFolderBtn.setText(QCoreApplication.translate("MultiModelSelection", u"Select Folder", None)) self.openAnalysisBtn.setText(QCoreApplication.translate("MultiModelSelection", u"Open Analysis", None)) self.infoLabel.setText("") - self.warningLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Warning: The chosen folder contains more than MAX_ONNX_MODELS", None)) - self.selectAllBox.setText(QCoreApplication.translate("MultiModelSelection", u"Select All", None)) - self.numSelectedLabel.setText(QCoreApplication.translate("MultiModelSelection", u"0 selected models", None)) + self.warningLabel.setText(QCoreApplication.translate("MultiModelSelection", u"Warning", None)) + self.radioAll.setText(QCoreApplication.translate("MultiModelSelection", u"All", None)) + self.radioONNX.setText(QCoreApplication.translate("MultiModelSelection", u"ONNX", None)) + self.radioReports.setText(QCoreApplication.translate("MultiModelSelection", u"Reports", None)) self.duplicateLabel.setText(QCoreApplication.translate("MultiModelSelection", u"The following models were found to be duplicates and have been deselected from the list on the left.", None)) + self.numSelectedLabel.setText(QCoreApplication.translate("MultiModelSelection", u"0 selected models", None)) # retranslateUi diff --git a/src/digest/ui/nodessummary_ui.py b/src/digest/ui/nodessummary_ui.py index 7efc69d..e0e400c 100644 --- a/src/digest/ui/nodessummary_ui.py +++ b/src/digest/ui/nodessummary_ui.py @@ -3,7 +3,7 @@ ################################################################################ ## Form generated from reading UI file 'nodessummary.ui' ## -## Created by: Qt User Interface Compiler version 6.8.0 +## Created by: Qt User Interface Compiler version 6.8.1 ## ## WARNING! All changes made in this file will be lost when recompiling UI file! ################################################################################ diff --git a/src/digest/ui/pytorchingest.ui b/src/digest/ui/pytorchingest.ui new file mode 100644 index 0000000..6be230b --- /dev/null +++ b/src/digest/ui/pytorchingest.ui @@ -0,0 +1,565 @@ + + + pytorchIngest + + + + 0 + 0 + 1060 + 748 + + + + + 0 + 0 + + + + Form + + + + :/assets/images/digest_logo_500.jpg:/assets/images/digest_logo_500.jpg + + + + + + + + + + 0 + 0 + + + + + + + + 0 + 0 + + + + + 16777215 + 16777215 + + + + + + + :/assets/icons/64px-PyTorch_logo_icon.svg.png + + + true + + + 5 + + + + + + + + 0 + 0 + + + + false + + + + + + QFrame::Shape::NoFrame + + + QFrame::Shadow::Raised + + + + + + + 0 + 0 + + + + + true + + + + + + + PyTorch Ingest + + + true + + + 1 + + + 5 + + + Qt::TextInteractionFlag::LinksAccessibleByMouse|Qt::TextInteractionFlag::TextSelectableByKeyboard|Qt::TextInteractionFlag::TextSelectableByMouse + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + + + + + 0 + 0 + + + + Qt::ScrollBarPolicy::ScrollBarAsNeeded + + + Qt::ScrollBarPolicy::ScrollBarAsNeeded + + + QAbstractScrollArea::SizeAdjustPolicy::AdjustToContents + + + true + + + + + 0 + 0 + 1040 + 616 + + + + + 0 + 100 + + + + + + + + 10 + + + + + + 0 + 0 + + + + QLabel { + font-size: 28px; + font-weight: bold; + margin-bottom: -5px; +} + + + model name + + + + + + + path to the model file + + + 5 + + + + + + + 20 + + + 10 + + + + + + 0 + 0 + + + + PointingHandCursor + + + + + + Select Directory + + + false + + + + + + + + + + Select a directory if you would like to save the ONNX model file + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + 0 + 0 + + + + + 13 + + + + + + + Export Options + + + + 15 + + + 35 + + + 9 + + + + + 0 + + + 0 + + + + + + 0 + 0 + + + + + 10 + + + + Do constant folding + + + true + + + + + + + + + 10 + + + + + + 0 + 0 + + + + + 10 + + + + Export params + + + true + + + + + + + + + + + + 0 + 0 + + + + + 12 + false + + + + Opset + + + + + + + + 0 + 0 + + + + + 10 + false + + + + (accepted range is 7 - 21): + + + 0 + + + + + + + + 0 + 0 + + + + + 35 + 16777215 + + + + + 10 + + + + 17 + + + + + + + Qt::Orientation::Horizontal + + + + 40 + 20 + + + + + + + + + + + + + + 14 + + + + Inputs + + + + 15 + + + 25 + + + + + + 12 + + + + color: lightgrey; + + + The following inputs were taken from the PyTorch model's forward function. Please set the type and dimensions for each required input. Shape dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244. + + + true + + + 5 + + + + + + + 10 + + + 10 + + + 20 + + + + + + + + + + + 0 + 0 + + + + QLabel { + font-size: 10px; + background-color: #FFCC00; + border: 1px solid #996600; + color: #333333; + font-weight: bold; + border-radius: 0px; +} + + + <html><head/><body><p>This is a warning message that we can use for now to prompt the user.</p></body></html> + + + 5 + + + + + + + + 0 + 0 + + + + PointingHandCursor + + + + + + Export ONNX + + + false + + + + + + + Qt::Orientation::Vertical + + + + 20 + 40 + + + + + + + + + + + + + + + diff --git a/src/digest/ui/pytorchingest_ui.py b/src/digest/ui/pytorchingest_ui.py new file mode 100644 index 0000000..f658051 --- /dev/null +++ b/src/digest/ui/pytorchingest_ui.py @@ -0,0 +1,360 @@ +# -*- coding: utf-8 -*- + +################################################################################ +## Form generated from reading UI file 'pytorchingest.ui' +## +## Created by: Qt User Interface Compiler version 6.8.1 +## +## WARNING! All changes made in this file will be lost when recompiling UI file! +################################################################################ + +from PySide6.QtCore import (QCoreApplication, QDate, QDateTime, QLocale, + QMetaObject, QObject, QPoint, QRect, + QSize, QTime, QUrl, Qt) +from PySide6.QtGui import (QBrush, QColor, QConicalGradient, QCursor, + QFont, QFontDatabase, QGradient, QIcon, + QImage, QKeySequence, QLinearGradient, QPainter, + QPalette, QPixmap, QRadialGradient, QTransform) +from PySide6.QtWidgets import (QAbstractScrollArea, QApplication, QCheckBox, QFormLayout, + QFrame, QGroupBox, QHBoxLayout, QLabel, + QLineEdit, QPushButton, QScrollArea, QSizePolicy, + QSpacerItem, QVBoxLayout, QWidget) +import resource_rc + +class Ui_pytorchIngest(object): + def setupUi(self, pytorchIngest): + if not pytorchIngest.objectName(): + pytorchIngest.setObjectName(u"pytorchIngest") + pytorchIngest.resize(1060, 748) + sizePolicy = QSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + sizePolicy.setHorizontalStretch(0) + sizePolicy.setVerticalStretch(0) + sizePolicy.setHeightForWidth(pytorchIngest.sizePolicy().hasHeightForWidth()) + pytorchIngest.setSizePolicy(sizePolicy) + icon = QIcon() + icon.addFile(u":/assets/images/digest_logo_500.jpg", QSize(), QIcon.Mode.Normal, QIcon.State.Off) + pytorchIngest.setWindowIcon(icon) + pytorchIngest.setStyleSheet(u"") + self.verticalLayout = QVBoxLayout(pytorchIngest) + self.verticalLayout.setObjectName(u"verticalLayout") + self.summaryTopBanner = QWidget(pytorchIngest) + self.summaryTopBanner.setObjectName(u"summaryTopBanner") + sizePolicy1 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum) + sizePolicy1.setHorizontalStretch(0) + sizePolicy1.setVerticalStretch(0) + sizePolicy1.setHeightForWidth(self.summaryTopBanner.sizePolicy().hasHeightForWidth()) + self.summaryTopBanner.setSizePolicy(sizePolicy1) + self.summaryTopBannerLayout = QHBoxLayout(self.summaryTopBanner) + self.summaryTopBannerLayout.setObjectName(u"summaryTopBannerLayout") + self.pytorchLogo = QLabel(self.summaryTopBanner) + self.pytorchLogo.setObjectName(u"pytorchLogo") + sizePolicy2 = QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + sizePolicy2.setHorizontalStretch(0) + sizePolicy2.setVerticalStretch(0) + sizePolicy2.setHeightForWidth(self.pytorchLogo.sizePolicy().hasHeightForWidth()) + self.pytorchLogo.setSizePolicy(sizePolicy2) + self.pytorchLogo.setMaximumSize(QSize(16777215, 16777215)) + self.pytorchLogo.setPixmap(QPixmap(u":/assets/icons/64px-PyTorch_logo_icon.svg.png")) + self.pytorchLogo.setScaledContents(True) + self.pytorchLogo.setMargin(5) + + self.summaryTopBannerLayout.addWidget(self.pytorchLogo) + + self.headerFrame = QFrame(self.summaryTopBanner) + self.headerFrame.setObjectName(u"headerFrame") + sizePolicy3 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Expanding) + sizePolicy3.setHorizontalStretch(0) + sizePolicy3.setVerticalStretch(0) + sizePolicy3.setHeightForWidth(self.headerFrame.sizePolicy().hasHeightForWidth()) + self.headerFrame.setSizePolicy(sizePolicy3) + self.headerFrame.setAutoFillBackground(False) + self.headerFrame.setStyleSheet(u"") + self.headerFrame.setFrameShape(QFrame.Shape.NoFrame) + self.headerFrame.setFrameShadow(QFrame.Shadow.Raised) + self.horizontalLayout = QHBoxLayout(self.headerFrame) + self.horizontalLayout.setObjectName(u"horizontalLayout") + self.titleLabel = QLabel(self.headerFrame) + self.titleLabel.setObjectName(u"titleLabel") + sizePolicy4 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + sizePolicy4.setHorizontalStretch(0) + sizePolicy4.setVerticalStretch(0) + sizePolicy4.setHeightForWidth(self.titleLabel.sizePolicy().hasHeightForWidth()) + self.titleLabel.setSizePolicy(sizePolicy4) + font = QFont() + font.setBold(True) + self.titleLabel.setFont(font) + self.titleLabel.setStyleSheet(u"") + self.titleLabel.setWordWrap(True) + self.titleLabel.setMargin(1) + self.titleLabel.setIndent(5) + self.titleLabel.setTextInteractionFlags(Qt.TextInteractionFlag.LinksAccessibleByMouse|Qt.TextInteractionFlag.TextSelectableByKeyboard|Qt.TextInteractionFlag.TextSelectableByMouse) + + self.horizontalLayout.addWidget(self.titleLabel) + + self.horizontalSpacer = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.horizontalLayout.addItem(self.horizontalSpacer) + + + self.summaryTopBannerLayout.addWidget(self.headerFrame) + + + self.verticalLayout.addWidget(self.summaryTopBanner) + + self.scrollArea = QScrollArea(pytorchIngest) + self.scrollArea.setObjectName(u"scrollArea") + sizePolicy5 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) + sizePolicy5.setHorizontalStretch(0) + sizePolicy5.setVerticalStretch(0) + sizePolicy5.setHeightForWidth(self.scrollArea.sizePolicy().hasHeightForWidth()) + self.scrollArea.setSizePolicy(sizePolicy5) + self.scrollArea.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.scrollArea.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.scrollArea.setSizeAdjustPolicy(QAbstractScrollArea.SizeAdjustPolicy.AdjustToContents) + self.scrollArea.setWidgetResizable(True) + self.scrollAreaWidgetContents = QWidget() + self.scrollAreaWidgetContents.setObjectName(u"scrollAreaWidgetContents") + self.scrollAreaWidgetContents.setGeometry(QRect(0, 0, 1040, 616)) + sizePolicy6 = QSizePolicy(QSizePolicy.Policy.MinimumExpanding, QSizePolicy.Policy.MinimumExpanding) + sizePolicy6.setHorizontalStretch(0) + sizePolicy6.setVerticalStretch(100) + sizePolicy6.setHeightForWidth(self.scrollAreaWidgetContents.sizePolicy().hasHeightForWidth()) + self.scrollAreaWidgetContents.setSizePolicy(sizePolicy6) + self.scrollAreaWidgetContents.setStyleSheet(u"") + self.verticalLayout_20 = QVBoxLayout(self.scrollAreaWidgetContents) + self.verticalLayout_20.setSpacing(10) + self.verticalLayout_20.setObjectName(u"verticalLayout_20") + self.modelName = QLabel(self.scrollAreaWidgetContents) + self.modelName.setObjectName(u"modelName") + sizePolicy7 = QSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum) + sizePolicy7.setHorizontalStretch(0) + sizePolicy7.setVerticalStretch(0) + sizePolicy7.setHeightForWidth(self.modelName.sizePolicy().hasHeightForWidth()) + self.modelName.setSizePolicy(sizePolicy7) + self.modelName.setStyleSheet(u"QLabel {\n" +" font-size: 28px;\n" +" font-weight: bold;\n" +" margin-bottom: -5px;\n" +"}") + + self.verticalLayout_20.addWidget(self.modelName) + + self.modelFilename = QLabel(self.scrollAreaWidgetContents) + self.modelFilename.setObjectName(u"modelFilename") + self.modelFilename.setMargin(5) + + self.verticalLayout_20.addWidget(self.modelFilename) + + self.selectDirLayout = QHBoxLayout() + self.selectDirLayout.setSpacing(20) + self.selectDirLayout.setObjectName(u"selectDirLayout") + self.selectDirLayout.setContentsMargins(-1, -1, -1, 10) + self.selectDirBtn = QPushButton(self.scrollAreaWidgetContents) + self.selectDirBtn.setObjectName(u"selectDirBtn") + sizePolicy8 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Fixed) + sizePolicy8.setHorizontalStretch(0) + sizePolicy8.setVerticalStretch(0) + sizePolicy8.setHeightForWidth(self.selectDirBtn.sizePolicy().hasHeightForWidth()) + self.selectDirBtn.setSizePolicy(sizePolicy8) + self.selectDirBtn.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) + self.selectDirBtn.setStyleSheet(u"") + self.selectDirBtn.setAutoExclusive(False) + + self.selectDirLayout.addWidget(self.selectDirBtn) + + self.selectDirLabel = QLabel(self.scrollAreaWidgetContents) + self.selectDirLabel.setObjectName(u"selectDirLabel") + self.selectDirLabel.setStyleSheet(u"") + + self.selectDirLayout.addWidget(self.selectDirLabel) + + self.horizontalSpacer_2 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.selectDirLayout.addItem(self.horizontalSpacer_2) + + + self.verticalLayout_20.addLayout(self.selectDirLayout) + + self.exportOptionsGroupBox = QGroupBox(self.scrollAreaWidgetContents) + self.exportOptionsGroupBox.setObjectName(u"exportOptionsGroupBox") + sizePolicy4.setHeightForWidth(self.exportOptionsGroupBox.sizePolicy().hasHeightForWidth()) + self.exportOptionsGroupBox.setSizePolicy(sizePolicy4) + font1 = QFont() + font1.setPointSize(13) + self.exportOptionsGroupBox.setFont(font1) + self.exportOptionsGroupBox.setStyleSheet(u"") + self.verticalLayout_2 = QVBoxLayout(self.exportOptionsGroupBox) + self.verticalLayout_2.setSpacing(15) + self.verticalLayout_2.setObjectName(u"verticalLayout_2") + self.verticalLayout_2.setContentsMargins(-1, 35, -1, 9) + self.horizontalLayout_3 = QHBoxLayout() + self.horizontalLayout_3.setSpacing(0) + self.horizontalLayout_3.setObjectName(u"horizontalLayout_3") + self.horizontalLayout_3.setContentsMargins(-1, 0, -1, -1) + self.foldingCheckBox = QCheckBox(self.exportOptionsGroupBox) + self.foldingCheckBox.setObjectName(u"foldingCheckBox") + sizePolicy4.setHeightForWidth(self.foldingCheckBox.sizePolicy().hasHeightForWidth()) + self.foldingCheckBox.setSizePolicy(sizePolicy4) + font2 = QFont() + font2.setPointSize(10) + self.foldingCheckBox.setFont(font2) + self.foldingCheckBox.setChecked(True) + + self.horizontalLayout_3.addWidget(self.foldingCheckBox) + + + self.verticalLayout_2.addLayout(self.horizontalLayout_3) + + self.horizontalLayout_4 = QHBoxLayout() + self.horizontalLayout_4.setSpacing(10) + self.horizontalLayout_4.setObjectName(u"horizontalLayout_4") + self.exportParamsCheckBox = QCheckBox(self.exportOptionsGroupBox) + self.exportParamsCheckBox.setObjectName(u"exportParamsCheckBox") + sizePolicy4.setHeightForWidth(self.exportParamsCheckBox.sizePolicy().hasHeightForWidth()) + self.exportParamsCheckBox.setSizePolicy(sizePolicy4) + self.exportParamsCheckBox.setFont(font2) + self.exportParamsCheckBox.setChecked(True) + + self.horizontalLayout_4.addWidget(self.exportParamsCheckBox) + + + self.verticalLayout_2.addLayout(self.horizontalLayout_4) + + self.opsetLayout = QHBoxLayout() + self.opsetLayout.setObjectName(u"opsetLayout") + self.opsetLabel = QLabel(self.exportOptionsGroupBox) + self.opsetLabel.setObjectName(u"opsetLabel") + sizePolicy2.setHeightForWidth(self.opsetLabel.sizePolicy().hasHeightForWidth()) + self.opsetLabel.setSizePolicy(sizePolicy2) + font3 = QFont() + font3.setPointSize(12) + font3.setBold(False) + self.opsetLabel.setFont(font3) + + self.opsetLayout.addWidget(self.opsetLabel) + + self.opsetInfoLabel = QLabel(self.exportOptionsGroupBox) + self.opsetInfoLabel.setObjectName(u"opsetInfoLabel") + sizePolicy2.setHeightForWidth(self.opsetInfoLabel.sizePolicy().hasHeightForWidth()) + self.opsetInfoLabel.setSizePolicy(sizePolicy2) + font4 = QFont() + font4.setPointSize(10) + font4.setItalic(False) + self.opsetInfoLabel.setFont(font4) + self.opsetInfoLabel.setMargin(0) + + self.opsetLayout.addWidget(self.opsetInfoLabel) + + self.opsetLineEdit = QLineEdit(self.exportOptionsGroupBox) + self.opsetLineEdit.setObjectName(u"opsetLineEdit") + sizePolicy2.setHeightForWidth(self.opsetLineEdit.sizePolicy().hasHeightForWidth()) + self.opsetLineEdit.setSizePolicy(sizePolicy2) + self.opsetLineEdit.setMaximumSize(QSize(35, 16777215)) + self.opsetLineEdit.setFont(font2) + + self.opsetLayout.addWidget(self.opsetLineEdit) + + self.horizontalSpacer_4 = QSpacerItem(40, 20, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum) + + self.opsetLayout.addItem(self.horizontalSpacer_4) + + + self.verticalLayout_2.addLayout(self.opsetLayout) + + + self.verticalLayout_20.addWidget(self.exportOptionsGroupBox) + + self.inputsGroupBox = QGroupBox(self.scrollAreaWidgetContents) + self.inputsGroupBox.setObjectName(u"inputsGroupBox") + font5 = QFont() + font5.setPointSize(14) + self.inputsGroupBox.setFont(font5) + self.verticalLayout_3 = QVBoxLayout(self.inputsGroupBox) + self.verticalLayout_3.setSpacing(15) + self.verticalLayout_3.setObjectName(u"verticalLayout_3") + self.verticalLayout_3.setContentsMargins(-1, 25, -1, -1) + self.label = QLabel(self.inputsGroupBox) + self.label.setObjectName(u"label") + font6 = QFont() + font6.setPointSize(12) + self.label.setFont(font6) + self.label.setStyleSheet(u"color: lightgrey;") + self.label.setWordWrap(True) + self.label.setMargin(5) + + self.verticalLayout_3.addWidget(self.label) + + self.inputsFormLayout = QFormLayout() + self.inputsFormLayout.setObjectName(u"inputsFormLayout") + self.inputsFormLayout.setHorizontalSpacing(10) + self.inputsFormLayout.setVerticalSpacing(10) + self.inputsFormLayout.setContentsMargins(20, -1, -1, -1) + + self.verticalLayout_3.addLayout(self.inputsFormLayout) + + + self.verticalLayout_20.addWidget(self.inputsGroupBox) + + self.exportWarningLabel = QLabel(self.scrollAreaWidgetContents) + self.exportWarningLabel.setObjectName(u"exportWarningLabel") + sizePolicy9 = QSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Preferred) + sizePolicy9.setHorizontalStretch(0) + sizePolicy9.setVerticalStretch(0) + sizePolicy9.setHeightForWidth(self.exportWarningLabel.sizePolicy().hasHeightForWidth()) + self.exportWarningLabel.setSizePolicy(sizePolicy9) + self.exportWarningLabel.setStyleSheet(u"QLabel {\n" +" font-size: 10px;\n" +" background-color: #FFCC00; \n" +" border: 1px solid #996600; \n" +" color: #333333;\n" +" font-weight: bold;\n" +" border-radius: 0px;\n" +"}") + self.exportWarningLabel.setMargin(5) + + self.verticalLayout_20.addWidget(self.exportWarningLabel) + + self.exportOnnxBtn = QPushButton(self.scrollAreaWidgetContents) + self.exportOnnxBtn.setObjectName(u"exportOnnxBtn") + sizePolicy8.setHeightForWidth(self.exportOnnxBtn.sizePolicy().hasHeightForWidth()) + self.exportOnnxBtn.setSizePolicy(sizePolicy8) + self.exportOnnxBtn.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) + self.exportOnnxBtn.setStyleSheet(u"") + self.exportOnnxBtn.setAutoExclusive(False) + + self.verticalLayout_20.addWidget(self.exportOnnxBtn) + + self.verticalSpacer = QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding) + + self.verticalLayout_20.addItem(self.verticalSpacer) + + self.scrollArea.setWidget(self.scrollAreaWidgetContents) + + self.verticalLayout.addWidget(self.scrollArea) + + + self.retranslateUi(pytorchIngest) + + QMetaObject.connectSlotsByName(pytorchIngest) + # setupUi + + def retranslateUi(self, pytorchIngest): + pytorchIngest.setWindowTitle(QCoreApplication.translate("pytorchIngest", u"Form", None)) + self.pytorchLogo.setText("") + self.titleLabel.setText(QCoreApplication.translate("pytorchIngest", u"PyTorch Ingest", None)) + self.modelName.setText(QCoreApplication.translate("pytorchIngest", u"model name", None)) + self.modelFilename.setText(QCoreApplication.translate("pytorchIngest", u"path to the model file", None)) + self.selectDirBtn.setText(QCoreApplication.translate("pytorchIngest", u"Select Directory", None)) + self.selectDirLabel.setText(QCoreApplication.translate("pytorchIngest", u"Select a directory if you would like to save the ONNX model file", None)) + self.exportOptionsGroupBox.setTitle(QCoreApplication.translate("pytorchIngest", u"Export Options", None)) + self.foldingCheckBox.setText(QCoreApplication.translate("pytorchIngest", u"Do constant folding", None)) + self.exportParamsCheckBox.setText(QCoreApplication.translate("pytorchIngest", u"Export params", None)) + self.opsetLabel.setText(QCoreApplication.translate("pytorchIngest", u"Opset", None)) + self.opsetInfoLabel.setText(QCoreApplication.translate("pytorchIngest", u"(accepted range is 7 - 21):", None)) + self.opsetLineEdit.setText(QCoreApplication.translate("pytorchIngest", u"17", None)) + self.inputsGroupBox.setTitle(QCoreApplication.translate("pytorchIngest", u"Inputs", None)) + self.label.setText(QCoreApplication.translate("pytorchIngest", u"The following inputs were taken from the PyTorch model's forward function. Please set the type and dimensions for each required input. Shape dimensions can be set by specifying a combination of symbolic and integer values separated by a comma, for example: batch_size, 3, 224, 244.", None)) + self.exportWarningLabel.setText(QCoreApplication.translate("pytorchIngest", u"

This is a warning message that we can use for now to prompt the user.

", None)) + self.exportOnnxBtn.setText(QCoreApplication.translate("pytorchIngest", u"Export ONNX", None)) + # retranslateUi + diff --git a/src/utils/onnx_utils.py b/src/utils/onnx_utils.py index d8a6894..9b92be1 100644 --- a/src/utils/onnx_utils.py +++ b/src/utils/onnx_utils.py @@ -1,95 +1,19 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. import os -import csv import tempfile -from uuid import uuid4 -from collections import Counter, OrderedDict, defaultdict -from typing import List, Dict, Optional, Any, Tuple, Union, cast -from datetime import datetime +from collections import Counter +from typing import List, Optional, Tuple, Union import numpy as np import onnx import onnxruntime as ort -from prettytable import PrettyTable - - -class NodeParsingException(Exception): - pass - - -# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias -class NodeShapeCounts(defaultdict[str, Counter]): - def __init__(self): - super().__init__(Counter) # Initialize with the Counter factory - - -class NodeTypeCounts(Dict[str, int]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class TensorInfo: - "Used to store node input and output tensor information" - - def __init__(self) -> None: - self.dtype: Optional[str] = None - self.dtype_bytes: Optional[int] = None - self.size_kbytes: Optional[float] = None - self.shape: List[Union[int, str]] = [] - - -class TensorData(OrderedDict[str, TensorInfo]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class NodeInfo: - def __init__(self) -> None: - self.flops: Optional[int] = None - self.parameters: int = 0 - self.node_type: Optional[str] = None - self.attributes: OrderedDict[str, Any] = OrderedDict() - # We use an ordered dictionary because the order in which - # the inputs and outputs are listed in the node matter. - self.inputs = TensorData() - self.outputs = TensorData() - - def get_input(self, index: int) -> TensorInfo: - return list(self.inputs.values())[index] - - def get_output(self, index: int) -> TensorInfo: - return list(self.outputs.values())[index] - - def __str__(self): - """Provides a human-readable string representation of NodeInfo.""" - output = [ - f"Node Type: {self.node_type}", - f"FLOPs: {self.flops if self.flops is not None else 'N/A'}", - f"Parameters: {self.parameters}", - ] - - if self.attributes: - output.append("Attributes:") - for key, value in self.attributes.items(): - output.append(f" - {key}: {value}") - - if self.inputs: - output.append("Inputs:") - for name, tensor in self.inputs.items(): - output.append(f" - {name}: {tensor}") - - if self.outputs: - output.append("Outputs:") - for name, tensor in self.outputs.items(): - output.append(f" - {name}: {tensor}") - - return "\n".join(output) - - -# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias -class NodeData(OrderedDict[str, NodeInfo]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +from digest.model_class.digest_model import ( + NodeTypeCounts, + NodeData, + NodeShapeCounts, + TensorData, + TensorInfo, +) # Convert tensor type to human-readable string and size in bytes @@ -117,706 +41,6 @@ def tensor_type_to_str_and_size(elem_type) -> Tuple[str, int]: return type_mapping.get(elem_type, ("unknown", 0)) -class DigestOnnxModel: - def __init__( - self, - onnx_model: onnx.ModelProto, - onnx_filepath: Optional[str] = None, - model_name: Optional[str] = None, - save_proto: bool = True, - ) -> None: - # Public members exposed to the API - self.unique_id: str = str(uuid4()) - self.filepath: Optional[str] = onnx_filepath - self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None - self.model_name: Optional[str] = model_name - self.model_version: Optional[int] = None - self.graph_name: Optional[str] = None - self.producer_name: Optional[str] = None - self.producer_version: Optional[str] = None - self.ir_version: Optional[int] = None - self.opset: Optional[int] = None - self.imports: Dict[str, int] = {} - self.node_type_counts: NodeTypeCounts = NodeTypeCounts() - self.model_flops: Optional[int] = None - self.model_parameters: int = 0 - self.node_type_flops: Dict[str, int] = {} - self.node_type_parameters: Dict[str, int] = {} - self.per_node_info = NodeData() - self.model_inputs = TensorData() - self.model_outputs = TensorData() - - # Private members not intended to be exposed - self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.init_tensors_: Dict[str, onnx.TensorProto] = {} - - self.update_state(onnx_model) - - def update_state(self, model_proto: onnx.ModelProto) -> None: - self.model_version = model_proto.model_version - self.graph_name = model_proto.graph.name - self.producer_name = model_proto.producer_name - self.producer_version = model_proto.producer_version - self.ir_version = model_proto.ir_version - self.opset = get_opset(model_proto) - self.imports = { - import_.domain: import_.version for import_ in model_proto.opset_import - } - - self.model_inputs = get_model_input_shapes_types(model_proto) - self.model_outputs = get_model_output_shapes_types(model_proto) - - self.node_type_counts = get_node_type_counts(model_proto) - self.parse_model_nodes(model_proto) - - def get_node_tensor_info_( - self, onnx_node: onnx.NodeProto - ) -> Tuple[TensorData, TensorData]: - """ - This function is set to private because it is not intended to be used - outside of the DigestOnnxModel class. - """ - - input_tensor_info = TensorData() - for node_input in onnx_node.input: - input_tensor_info[node_input] = TensorInfo() - if ( - node_input in self.input_tensors_ - or node_input in self.value_tensors_ - or node_input in self.output_tensors_ - ): - tensor = ( - self.input_tensors_.get(node_input) - or self.value_tensors_.get(node_input) - or self.output_tensors_.get(node_input) - ) - if tensor: - for dim in tensor.type.tensor_type.shape.dim: - if dim.HasField("dim_value"): - input_tensor_info[node_input].shape.append(dim.dim_value) - elif dim.HasField("dim_param"): - input_tensor_info[node_input].shape.append(dim.dim_param) - - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - tensor.type.tensor_type.elem_type - ) - elif node_input in self.init_tensors_: - input_tensor_info[node_input].shape.extend( - [dim for dim in self.init_tensors_[node_input].dims] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - self.init_tensors_[node_input].data_type - ) - else: - dtype_str = None - dtype_bytes = None - - input_tensor_info[node_input].dtype = dtype_str - input_tensor_info[node_input].dtype_bytes = dtype_bytes - - if ( - all(isinstance(s, int) for s in input_tensor_info[node_input].shape) - and dtype_bytes - ): - tensor_size = float( - np.prod(np.array(input_tensor_info[node_input].shape)) - ) - input_tensor_info[node_input].size_kbytes = ( - tensor_size * float(dtype_bytes) / 1024.0 - ) - - output_tensor_info = TensorData() - for node_output in onnx_node.output: - output_tensor_info[node_output] = TensorInfo() - if ( - node_output in self.input_tensors_ - or node_output in self.value_tensors_ - or node_output in self.output_tensors_ - ): - tensor = ( - self.input_tensors_.get(node_output) - or self.value_tensors_.get(node_output) - or self.output_tensors_.get(node_output) - ) - if tensor: - output_tensor_info[node_output].shape.extend( - [ - int(dim.dim_value) - for dim in tensor.type.tensor_type.shape.dim - ] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - tensor.type.tensor_type.elem_type - ) - elif node_output in self.init_tensors_: - output_tensor_info[node_output].shape.extend( - [dim for dim in self.init_tensors_[node_output].dims] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - self.init_tensors_[node_output].data_type - ) - - else: - dtype_str = None - dtype_bytes = None - - output_tensor_info[node_output].dtype = dtype_str - output_tensor_info[node_output].dtype_bytes = dtype_bytes - - if ( - all(isinstance(s, int) for s in output_tensor_info[node_output].shape) - and dtype_bytes - ): - tensor_size = float( - np.prod(np.array(output_tensor_info[node_output].shape)) - ) - output_tensor_info[node_output].size_kbytes = ( - tensor_size * float(dtype_bytes) / 1024.0 - ) - - return input_tensor_info, output_tensor_info - - def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None: - """ - Calculate total number of FLOPs found in the onnx model. - FLOP is defined as one floating-point operation. This distinguishes - from multiply-accumulates (MACs) where FLOPs == 2 * MACs. - """ - - # Initialze to zero so we can accumulate. Set to None during the - # model FLOPs calculation if it errors out. - self.model_flops = 0 - - # Check to see if the model inputs have any dynamic shapes - if get_dynamic_input_dims(onnx_model): - self.model_flops = None - - try: - onnx_model, _ = optimize_onnx_model(onnx_model) - - onnx_model = onnx.shape_inference.infer_shapes( - onnx_model, strict_mode=True, data_prop=True - ) - except Exception as e: # pylint: disable=broad-except - print(f"ONNX utils: {str(e)}") - self.model_flops = None - - # If the ONNX model contains one of the following unsupported ops, then this - # function will return None since the FLOP total is expected to be incorrect - unsupported_ops = [ - "Einsum", - "RNN", - "GRU", - "DeformConv", - ] - - if not self.input_tensors_: - self.input_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.input - } - - if not self.output_tensors_: - self.output_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.output - } - - if not self.value_tensors_: - self.value_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.value_info - } - - if not self.init_tensors_: - self.init_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.initializer - } - - for node in onnx_model.graph.node: # pylint: disable=E1101 - - node_info = NodeInfo() - - # TODO: I have encountered models containing nodes with no name. It would be a good idea - # to have this type of model info fed back to the user through a warnings section. - if not node.name: - node.name = f"{node.op_type}_{len(self.per_node_info)}" - - node_info.node_type = node.op_type - input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node) - node_info.inputs = input_tensor_info - node_info.outputs = output_tensor_info - - # Check if this node has parameters through the init tensors - for input_name, input_tensor in node_info.inputs.items(): - if input_name in self.init_tensors_: - if all(isinstance(dim, int) for dim in input_tensor.shape): - input_parameters = int(np.prod(np.array(input_tensor.shape))) - node_info.parameters += input_parameters - self.model_parameters += input_parameters - self.node_type_parameters[node.op_type] = ( - self.node_type_parameters.get(node.op_type, 0) - + input_parameters - ) - else: - print(f"Tensor with params has unknown shape: {input_name}") - - for attribute in node.attribute: - node_info.attributes.update(attribute_to_dict(attribute)) - - # if node.name in self.per_node_info: - # print(f"Node name {node.name} is a duplicate.") - - self.per_node_info[node.name] = node_info - - if node.op_type in unsupported_ops: - self.model_flops = None - node_info.flops = None - - try: - - if ( - node.op_type == "MatMul" - or node.op_type == "MatMulInteger" - or node.op_type == "QLinearMatMul" - ): - - input_a = node_info.get_input(0).shape - if node.op_type == "QLinearMatMul": - input_b = node_info.get_input(3).shape - else: - input_b = node_info.get_input(1).shape - - if not all( - isinstance(dim, int) for dim in input_a - ) or not isinstance(input_b[-1], int): - node_info.flops = None - self.model_flops = None - continue - - node_info.flops = int( - 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1] - ) - - elif ( - node.op_type == "Mul" - or node.op_type == "Div" - or node.op_type == "Add" - ): - input_a = node_info.get_input(0).shape - input_b = node_info.get_input(1).shape - - if not all(isinstance(dim, int) for dim in input_a) or not all( - isinstance(dim, int) for dim in input_b - ): - node_info.flops = None - self.model_flops = None - continue - - node_info.flops = int( - np.prod(np.array(input_a), dtype=np.int64) - ) + int(np.prod(np.array(input_b), dtype=np.int64)) - - elif node.op_type == "Gemm" or node.op_type == "QGemm": - x_shape = node_info.get_input(0).shape - if node.op_type == "Gemm": - w_shape = node_info.get_input(1).shape - else: - w_shape = node_info.get_input(3).shape - - if not all(isinstance(dim, int) for dim in x_shape) or not all( - isinstance(dim, int) for dim in w_shape - ): - node_info.flops = None - self.model_flops = None - continue - - mm_dims = [ - ( - x_shape[0] - if not node_info.attributes.get("transA", 0) - else x_shape[1] - ), - ( - x_shape[1] - if not node_info.attributes.get("transA", 0) - else x_shape[0] - ), - ( - w_shape[1] - if not node_info.attributes.get("transB", 0) - else w_shape[0] - ), - ] - - node_info.flops = int( - 2 * np.prod(np.array(mm_dims), dtype=np.int64) - ) - - if len(mm_dims) == 3: # if there is a bias input - bias_shape = node_info.get_input(2).shape - node_info.flops += int(np.prod(np.array(bias_shape))) - - elif ( - node.op_type == "Conv" - or node.op_type == "ConvInteger" - or node.op_type == "QLinearConv" - or node.op_type == "ConvTranspose" - ): - # N, C, d1, ..., dn - x_shape = node_info.get_input(0).shape - - # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose - if node.op_type == "QLinearConv": - w_shape = node_info.get_input(3).shape - else: - w_shape = node_info.get_input(1).shape - - if not all(isinstance(dim, int) for dim in x_shape): - node_info.flops = None - self.model_flops = None - continue - - x_shape_ints = cast(List[int], x_shape) - w_shape_ints = cast(List[int], w_shape) - - has_bias = False # Note, ConvInteger has no bias - if node.op_type == "Conv" and len(node_info.inputs) == 3: - has_bias = True - elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9: - has_bias = True - - num_dims = len(x_shape_ints) - 2 - strides = node_info.attributes.get( - "strides", [1] * num_dims - ) # type: List[int] - dilation = node_info.attributes.get( - "dilations", [1] * num_dims - ) # type: List[int] - kernel_shape = w_shape_ints[2:] - batch_size = x_shape_ints[0] - out_channels = w_shape_ints[0] - out_dims = [batch_size, out_channels] - output_shape = node_info.attributes.get( - "output_shape", [] - ) # type: List[int] - - # If output_shape is given then we do not need to compute it ourselves - # The output_shape attribute does not include batch_size or channels and - # is only valid for ConvTranspose - if output_shape: - out_dims.extend(output_shape) - else: - auto_pad = node_info.attributes.get( - "auto_pad", "NOTSET".encode() - ).decode() - # SAME expects padding so that the output_shape = CEIL(input_shape / stride) - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_dims.extend( - [x * s for x, s in zip(x_shape_ints[2:], strides)] - ) - else: - # NOTSET means just use pads attribute - if auto_pad == "NOTSET": - pads = node_info.attributes.get( - "pads", [0] * num_dims * 2 - ) - # VALID essentially means no padding - elif auto_pad == "VALID": - pads = [0] * num_dims * 2 - - for i in range(num_dims): - dim_in = x_shape_ints[i + 2] # type: int - - if node.op_type == "ConvTranspose": - out_dim = ( - strides[i] * (dim_in - 1) - + ((kernel_shape[i] - 1) * dilation[i] + 1) - - pads[i] - - pads[i + num_dims] - ) - else: - out_dim = ( - dim_in - + pads[i] - + pads[i + num_dims] - - dilation[i] * (kernel_shape[i] - 1) - - 1 - ) // strides[i] + 1 - - out_dims.append(out_dim) - - kernel_flops = int( - np.prod(np.array(kernel_shape)) * w_shape_ints[1] - ) - output_points = int(np.prod(np.array(out_dims))) - bias_ops = output_points if has_bias else int(0) - node_info.flops = 2 * kernel_flops * output_points + bias_ops - - elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM": - - x_shape = node_info.get_input( - 0 - ).shape # seq_length, batch_size, input_dim - - if not all(isinstance(dim, int) for dim in x_shape): - node_info.flops = None - self.model_flops = None - continue - - x_shape_ints = cast(List[int], x_shape) - hidden_size = node_info.attributes["hidden_size"] - direction = ( - 2 - if node_info.attributes.get("direction") - == "bidirectional".encode() - else 1 - ) - - has_bias = True if len(node_info.inputs) >= 4 else False - if has_bias: - bias_shape = node_info.get_input(3).shape - if isinstance(bias_shape[1], int): - bias_ops = bias_shape[1] - else: - bias_ops = 0 - else: - bias_ops = 0 - # seq_length, batch_size, input_dim = x_shape - if not isinstance(bias_ops, int): - bias_ops = int(0) - num_gates = int(4) - gate_input_flops = int(2 * x_shape_ints[2] * hidden_size) - gate_hid_flops = int(2 * hidden_size * hidden_size) - unit_flops = ( - num_gates * (gate_input_flops + gate_hid_flops) + bias_ops - ) - node_info.flops = ( - x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops - ) - # In this case we just hit an op that doesn't have FLOPs - else: - node_info.flops = None - - except IndexError as err: - print(f"Error parsing node {node.name}: {err}") - node_info.flops = None - self.model_flops = None - continue - - # Update the model level flops count - if node_info.flops is not None and self.model_flops is not None: - self.model_flops += node_info.flops - - # Update the node type flops count - self.node_type_flops[node.op_type] = ( - self.node_type_flops.get(node.op_type, 0) + node_info.flops - ) - - def save_txt_report(self, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - report_date = datetime.now().strftime("%B %d, %Y") - - with open(filepath, "w", encoding="utf-8") as f_p: - f_p.write(f"Report created on {report_date}\n") - if self.filepath: - f_p.write(f"ONNX file: {self.filepath}\n") - f_p.write(f"Name of the model: {self.model_name}\n") - f_p.write(f"Model version: {self.model_version}\n") - f_p.write(f"Name of the graph: {self.graph_name}\n") - f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n") - f_p.write(f"Ir version: {self.ir_version}\n") - f_p.write(f"Opset: {self.opset}\n\n") - f_p.write("Import list\n") - for name, version in self.imports.items(): - f_p.write(f"\t{name}: {version}\n") - - f_p.write("\n") - f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n") - f_p.write(f"Number of parameters: {self.model_parameters}\n") - if self.model_flops: - f_p.write(f"Number of FLOPs: {self.model_flops}\n") - f_p.write("\n") - - table_op_intensity = PrettyTable() - table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"] - for op_type, count in self.node_type_flops.items(): - if count > 0: - table_op_intensity.add_row( - [ - op_type, - count, - 100.0 * float(count) / float(self.model_flops), - ] - ) - - f_p.write("Op intensity:\n") - f_p.write(table_op_intensity.get_string()) - f_p.write("\n\n") - - node_counts_table = PrettyTable() - node_counts_table.field_names = ["Node", "Occurrences"] - for op, count in self.node_type_counts.items(): - node_counts_table.add_row([op, count]) - f_p.write("Nodes and their occurrences:\n") - f_p.write(node_counts_table.get_string()) - f_p.write("\n\n") - - input_table = PrettyTable() - input_table.field_names = [ - "Input Name", - "Shape", - "Type", - "Tensor Size (KB)", - ] - for input_name, input_details in self.model_inputs.items(): - if input_details.size_kbytes: - kbytes = f"{input_details.size_kbytes:.2f}" - else: - kbytes = "" - - input_table.add_row( - [ - input_name, - input_details.shape, - input_details.dtype, - kbytes, - ] - ) - f_p.write("Input Tensor(s) Information:\n") - f_p.write(input_table.get_string()) - f_p.write("\n\n") - - output_table = PrettyTable() - output_table.field_names = [ - "Output Name", - "Shape", - "Type", - "Tensor Size (KB)", - ] - for output_name, output_details in self.model_outputs.items(): - if output_details.size_kbytes: - kbytes = f"{output_details.size_kbytes:.2f}" - else: - kbytes = "" - - output_table.add_row( - [ - output_name, - output_details.shape, - output_details.dtype, - kbytes, - ] - ) - f_p.write("Output Tensor(s) Information:\n") - f_p.write(output_table.get_string()) - f_p.write("\n\n") - - def save_nodes_csv_report(self, filepath: str) -> None: - save_nodes_csv_report(self.per_node_info, filepath) - - def get_node_type_counts(self) -> Union[NodeTypeCounts, None]: - if not self.node_type_counts and self.model_proto: - self.node_type_counts = get_node_type_counts(self.model_proto) - return self.node_type_counts if self.node_type_counts else None - - def get_node_shape_counts(self) -> NodeShapeCounts: - tensor_shape_counter = NodeShapeCounts() - for _, info in self.per_node_info.items(): - shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()]) - if info.node_type: - tensor_shape_counter[info.node_type][shape_hash] += 1 - return tensor_shape_counter - - -def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - flattened_data = [] - fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"] - input_fieldnames = [] - output_fieldnames = [] - for name, node_info in node_data.items(): - row = OrderedDict() - row["Node Name"] = name - row["Node Type"] = str(node_info.node_type) - row["Parameters"] = str(node_info.parameters) - row["FLOPs"] = str(node_info.flops) - if node_info.attributes: - row["Attributes"] = str({k: v for k, v in node_info.attributes.items()}) - else: - row["Attributes"] = "" - - for i, (input_name, input_info) in enumerate(node_info.inputs.items()): - column_name = f"Input{i+1} (Shape, Dtype, Size (kB))" - row[column_name] = ( - f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})" - ) - - # Dynamically add input column names to fieldnames if not already present - if column_name not in input_fieldnames: - input_fieldnames.append(column_name) - - for i, (output_name, output_info) in enumerate(node_info.outputs.items()): - column_name = f"Output{i+1} (Shape, Dtype, Size (kB))" - row[column_name] = ( - f"{output_name} ({output_info.shape}, " - f"{output_info.dtype}, {output_info.size_kbytes})" - ) - - # Dynamically add input column names to fieldnames if not already present - if column_name not in output_fieldnames: - output_fieldnames.append(column_name) - - flattened_data.append(row) - - fieldnames = fieldnames + input_fieldnames + output_fieldnames - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") - writer.writeheader() - writer.writerows(flattened_data) - - -def save_node_type_counts_csv_report(node_data: NodeTypeCounts, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - header = ["Node Type", "Count"] - - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.writer(csvfile, lineterminator="\n") - writer.writerow(header) - for node_type, node_count in node_data.items(): - writer.writerow([node_type, node_count]) - - -def save_node_shape_counts_csv_report( - node_data: NodeShapeCounts, filepath: str -) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - header = ["Node Type", "Input Tensors Shapes", "Count"] - - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.writer(csvfile, dialect="excel", lineterminator="\n") - writer.writerow(header) - for node_type, node_info in node_data.items(): - info_iter = iter(node_info.items()) - for shape, count in info_iter: - writer.writerow([node_type, shape, count]) - - def load_onnx(onnx_path: str, load_external_data: bool = True) -> onnx.ModelProto: if os.path.exists(onnx_path): return onnx.load(onnx_path, load_external_data=load_external_data) @@ -987,3 +211,9 @@ def optimize_onnx_model( except onnx.checker.ValidationError: print("Model did not pass checker!") return model_proto, False + + +def get_supported_opset() -> int: + """This function will return the opset version associated + with the currently installed ONNX library""" + return onnx.defs.onnx_opset_version() diff --git a/test/resnet18_reports/resnet18_heatmap.png b/test/resnet18_reports/resnet18_heatmap.png new file mode 100644 index 0000000..1fb614e Binary files /dev/null and b/test/resnet18_reports/resnet18_heatmap.png differ diff --git a/test/resnet18_reports/resnet18_histogram.png b/test/resnet18_reports/resnet18_histogram.png new file mode 100644 index 0000000..eb13e01 Binary files /dev/null and b/test/resnet18_reports/resnet18_histogram.png differ diff --git a/test/resnet18_reports/resnet18_node_type_counts.csv b/test/resnet18_reports/resnet18_node_type_counts.csv new file mode 100644 index 0000000..29504ba --- /dev/null +++ b/test/resnet18_reports/resnet18_node_type_counts.csv @@ -0,0 +1,8 @@ +Node Type,Count +Conv,20 +Relu,17 +Add,8 +MaxPool,1 +GlobalAveragePool,1 +Flatten,1 +Gemm,1 diff --git a/test/resnet18_test_nodes.csv b/test/resnet18_reports/resnet18_nodes.csv similarity index 100% rename from test/resnet18_test_nodes.csv rename to test/resnet18_reports/resnet18_nodes.csv diff --git a/test/resnet18_test_summary.txt b/test/resnet18_reports/resnet18_report.txt similarity index 86% rename from test/resnet18_test_summary.txt rename to test/resnet18_reports/resnet18_report.txt index a5b4cfb..fdda0bf 100644 --- a/test/resnet18_test_summary.txt +++ b/test/resnet18_reports/resnet18_report.txt @@ -1,5 +1,5 @@ -Report created on June 02, 2024 -ONNX file: resnet18.onnx +Report created on December 06, 2024 +ONNX file: C:\Users\pcolange\Projects\digestai\test\resnet18.onnx Name of the model: resnet18 Model version: 0 Name of the graph: main_graph @@ -9,6 +9,13 @@ Opset: 17 Import list : 17 + ai.onnx.ml: 5 + ai.onnx.preview.training: 1 + ai.onnx.training: 1 + com.microsoft: 1 + com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 + org.pytorch.aten: 1 Total graph nodes: 49 Number of parameters: 11684712 diff --git a/test/resnet18_reports/resnet18_report.yaml b/test/resnet18_reports/resnet18_report.yaml new file mode 100644 index 0000000..9df22be --- /dev/null +++ b/test/resnet18_reports/resnet18_report.yaml @@ -0,0 +1,56 @@ +report_date: December 06, 2024 +model_file: C:\Users\pcolange\Projects\digestai\test\resnet18.onnx +model_type: onnx +model_name: resnet18 +model_version: 0 +graph_name: main_graph +producer_name: pytorch +producer_version: 2.1.0 +ir_version: 8 +opset: 17 +import_list: + ? '' + : 17 + ai.onnx.ml: 5 + ai.onnx.preview.training: 1 + ai.onnx.training: 1 + com.microsoft: 1 + com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 + org.pytorch.aten: 1 +graph_nodes: 49 +parameters: 11684712 +flops: 3632136680 +node_type_counts: + Conv: 20 + Relu: 17 + Add: 8 + MaxPool: 1 + GlobalAveragePool: 1 + Flatten: 1 + Gemm: 1 +node_type_flops: + Conv: 3629606400 + Add: 1505280 + Gemm: 1025000 +node_type_parameters: + Conv: 11171712 + Gemm: 513000 +input_tensors: + input.1: + dtype: float32 + dtype_bytes: 4 + size_kbytes: 588.0 + shape: + - 1 + - 3 + - 224 + - 224 +output_tensors: + '191': + dtype: float32 + dtype_bytes: 4 + size_kbytes: 3.90625 + shape: + - 1 + - 1000 diff --git a/test/test_gui.py b/test/test_gui.py index 0e1d351..9a06f3e 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -5,63 +5,166 @@ import tempfile import unittest from unittest.mock import patch +import timm +import torch # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt, QDeadlineTimer +from PySide6.QtCore import Qt from PySide6.QtWidgets import QApplication import digest.main from digest.node_summary import NodeSummary +from digest.model_class.digest_pytorch_model import DigestPyTorchModel +from digest.pytorch_ingest import PyTorchIngest -ONNX_BASENAME = "resnet18" -TEST_DIR = os.path.abspath(os.path.dirname(__file__)) -ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{ONNX_BASENAME}.onnx")) + +def save_resnet18_pt(directory: str) -> str: + """Simply saves a PyTorch resnet18 model and returns its file path""" + model = timm.models.create_model("resnet18", pretrained=True) # type: ignore + model.eval() + file_path = os.path.join(directory, "resnet18.pt") + # Save the model + try: + torch.save(model, file_path) + return file_path + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Error saving model: {e}") + return "" class DigestGuiTest(unittest.TestCase): + RESNET18_BASENAME = "resnet18" + + TEST_DIR = os.path.abspath(os.path.dirname(__file__)) + ONNX_FILE_PATH = os.path.normpath( + os.path.join(TEST_DIR, f"{RESNET18_BASENAME}.onnx") + ) + YAML_FILE_PATH = os.path.normpath( + os.path.join( + TEST_DIR, f"{RESNET18_BASENAME}_reports", f"{RESNET18_BASENAME}_report.yaml" + ) + ) @classmethod def setUpClass(cls): cls.app = QApplication(sys.argv) + return super().setUpClass() + + @classmethod + def tearDownClass(cls): + if isinstance(cls.app, QApplication): + cls.app.closeAllWindows() + cls.app = None def setUp(self): self.digest_app = digest.main.DigestApp() self.digest_app.show() def tearDown(self): - self.wait_all_threads() self.digest_app.close() - def wait_all_threads(self): + def wait_all_threads(self, timeout=10000) -> bool: + all_threads = list(self.digest_app.model_nodes_stats_thread.values()) + list( + self.digest_app.model_similarity_thread.values() + ) - for thread in self.digest_app.model_nodes_stats_thread.values(): - thread.wait(deadline=QDeadlineTimer.Forever) + for thread in all_threads: + thread.wait(timeout) - for thread in self.digest_app.model_similarity_thread.values(): - thread.wait(deadline=QDeadlineTimer.Forever) + # Return True if all threads finished, False if timed out + return all(thread.isFinished() for thread in all_threads) def test_open_valid_onnx(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - ONNX_FILEPATH, + self.ONNX_FILE_PATH, + "", + ) + + num_tabs_prior = self.digest_app.ui.tabWidget.count() + + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) + + self.assertTrue(self.wait_all_threads()) + + self.assertTrue( + self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 + ) # Check if a tab was added + + self.digest_app.closeTab(num_tabs_prior) + + def test_open_valid_yaml(self): + with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: + mock_dialog.return_value = ( + self.YAML_FILE_PATH, "", ) + num_tabs_prior = self.digest_app.ui.tabWidget.count() + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) self.assertTrue( - self.digest_app.ui.tabWidget.count() > 0 + self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 ) # Check if a tab was added + self.digest_app.closeTab(num_tabs_prior) + + def test_open_valid_pytorch(self): + """We test the PyTorch path slightly different than the others + since Digest opens a modal window that blocks the main thread. This makes it difficult + to interact with the Window in this test.""" + + with tempfile.TemporaryDirectory() as tmpdir: + pt_file_path = save_resnet18_pt(tmpdir) + self.assertTrue(os.path.exists(tmpdir)) + basename = os.path.splitext(os.path.basename(pt_file_path)) + model_name = basename[0] + digest_model = DigestPyTorchModel(pt_file_path, model_name) + self.assertTrue(isinstance(digest_model.file_path, str)) + pytorch_ingest = PyTorchIngest(pt_file_path, digest_model.model_name) + pytorch_ingest.show() + + input_shape_edit = ( + pytorch_ingest.user_input_form.get_row_tensor_shape_widget(0) + ) + + assert input_shape_edit + input_shape_edit.setText("batch_size, 3, 224, 224") + pytorch_ingest.update_tensor_info() + + with patch( + "PySide6.QtWidgets.QFileDialog.getExistingDirectory" + ) as mock_save_dialog: + print("TMPDIR", tmpdir) + mock_save_dialog.return_value = tmpdir + pytorch_ingest.select_directory() + + pytorch_ingest.ui.exportOnnxBtn.click() + + timeout_ms = 10000 + interval_ms = 100 + for _ in range(timeout_ms // interval_ms): + QTest.qWait(interval_ms) + onnx_file_path = pytorch_ingest.digest_pytorch_model.onnx_file_path + if onnx_file_path and os.path.exists(onnx_file_path): + break # File found! + + assert isinstance(pytorch_ingest.digest_pytorch_model.onnx_file_path, str) + self.assertTrue( + os.path.exists(pytorch_ingest.digest_pytorch_model.onnx_file_path) + ) + def test_open_invalid_file(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ("invalid_file.txt", "") + num_tabs_prior = self.digest_app.ui.tabWidget.count() QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.wait_all_threads() - self.assertEqual(self.digest_app.ui.tabWidget.count(), 0) + self.assertTrue(self.wait_all_threads()) + self.assertEqual(self.digest_app.ui.tabWidget.count(), num_tabs_prior) def test_save_reports(self): with patch( @@ -70,7 +173,7 @@ def test_save_reports(self): "PySide6.QtWidgets.QFileDialog.getExistingDirectory" ) as mock_save_dialog: - mock_open_dialog.return_value = (ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname @@ -79,44 +182,56 @@ def test_save_reports(self): Qt.MouseButton.LeftButton, ) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) - # This is a slight hack but the issue is that model similarity takes - # a bit longer to complete and we must have it done before the save - # button is enabled guaranteeing all the artifacts are saved. - # wait_all_threads() above doesn't seem to work. The only thing that - # does is just waiting 5 seconds. - QTest.qWait(5000) + self.assertTrue( + self.digest_app.ui.saveBtn.isEnabled(), "Save button is disabled!" + ) QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton) mock_save_dialog.assert_called_once() - result_basepath = os.path.join(tmpdirname, f"{ONNX_BASENAME}_reports") + result_basepath = os.path.join( + tmpdirname, f"{self.RESNET18_BASENAME}_reports" + ) # Text report test - txt_report_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_report.txt" + text_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.txt" + ) + self.assertTrue( + os.path.isfile(text_report_FILE_PATH), + f"{text_report_FILE_PATH} not found!", ) - self.assertTrue(os.path.isfile(txt_report_filepath)) + + # YAML report test + yaml_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.yaml" + ) + self.assertTrue(os.path.isfile(yaml_report_FILE_PATH)) # Nodes test - nodes_csv_report_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_nodes.csv" + nodes_csv_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_nodes.csv" ) - self.assertTrue(os.path.isfile(nodes_csv_report_filepath)) + self.assertTrue(os.path.isfile(nodes_csv_report_FILE_PATH)) # Histogram test - histogram_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_histogram.png" + histogram_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_histogram.png" ) - self.assertTrue(os.path.isfile(histogram_filepath)) + self.assertTrue(os.path.isfile(histogram_FILE_PATH)) # Heatmap test - heatmap_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_heatmap.png" + heatmap_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_heatmap.png" ) - self.assertTrue(os.path.isfile(heatmap_filepath)) + self.assertTrue(os.path.isfile(heatmap_FILE_PATH)) + + num_tabs = self.digest_app.ui.tabWidget.count() + self.assertTrue(num_tabs == 1) + self.digest_app.closeTab(0) def test_save_tables(self): with patch( @@ -125,10 +240,10 @@ def test_save_tables(self): "PySide6.QtWidgets.QFileDialog.getSaveFileName" ) as mock_save_dialog: - mock_open_dialog.return_value = (ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = ( - os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv"), + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv"), "", ) @@ -136,17 +251,19 @@ def test_save_tables(self): self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton ) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) QTest.mouseClick( self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton ) - # We assume there is only model loaded + # We assume there is only one model loaded _, node_window = self.digest_app.nodes_window.popitem() node_summary = node_window.main_window.centralWidget() self.assertIsInstance(node_summary, NodeSummary) + + # This line of code seems redundant but we do this to clean pylance if isinstance(node_summary, NodeSummary): QTest.mouseClick( node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton @@ -156,11 +273,15 @@ def test_save_tables(self): self.assertTrue( os.path.exists( - os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv") + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv") ), "Nodes csv file not found.", ) + num_tabs = self.digest_app.ui.tabWidget.count() + self.assertTrue(num_tabs == 1) + self.digest_app.closeTab(0) + if __name__ == "__main__": unittest.main() diff --git a/test/test_reports.py b/test/test_reports.py index a16c4d8..e4d327e 100644 --- a/test/test_reports.py +++ b/test/test_reports.py @@ -1,17 +1,22 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. -"""Unit tests for Vitis ONNX Model Analyzer """ - import os import unittest import tempfile import csv -from utils.onnx_utils import DigestOnnxModel, load_onnx +import utils.onnx_utils as onnx_utils +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import compare_yaml_files TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_ONNX = os.path.join(TEST_DIR, "resnet18.onnx") -TEST_SUMMARY_TXT_REPORT = os.path.join(TEST_DIR, "resnet18_test_summary.txt") -TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_test_nodes.csv") +TEST_SUMMARY_TEXT_REPORT = os.path.join( + TEST_DIR, "resnet18_reports/resnet18_report.txt" +) +TEST_SUMMARY_YAML_REPORT = os.path.join( + TEST_DIR, "resnet18_reports/resnet18_report.yaml" +) +TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_reports/resnet18_nodes.csv") class TestDigestReports(unittest.TestCase): @@ -46,27 +51,35 @@ def compare_csv_files(self, file1, file2, skip_lines=0): self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}") def test_against_example_reports(self): - model_proto = load_onnx(TEST_ONNX) + model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False) model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0] + opt_model, _ = onnx_utils.optimize_onnx_model(model_proto) digest_model = DigestOnnxModel( - model_proto, onnx_filepath=TEST_ONNX, model_name=model_name, save_proto=False, + opt_model, + onnx_file_path=TEST_ONNX, + model_name=model_name, + save_proto=False, ) with tempfile.TemporaryDirectory() as tmpdir: - # Model summary text report - summary_filepath = os.path.join(tmpdir, f"{model_name}_summary.txt") - digest_model.save_txt_report(summary_filepath) - - with self.subTest("Testing summary text file"): - self.compare_files_line_by_line( - TEST_SUMMARY_TXT_REPORT, - summary_filepath, - skip_lines=2, + # Model yaml report + yaml_report_filepath = os.path.join(tmpdir, f"{model_name}_report.yaml") + digest_model.save_yaml_report(yaml_report_filepath) + with self.subTest("Testing report yaml file"): + self.assertTrue( + compare_yaml_files( + TEST_SUMMARY_YAML_REPORT, + yaml_report_filepath, + skip_keys=["report_date", "model_file", "digest_version"], + ) ) # Save CSV containing node-level information nodes_filepath = os.path.join(tmpdir, f"{model_name}_nodes.csv") digest_model.save_nodes_csv_report(nodes_filepath) - with self.subTest("Testing nodes csv file"): self.compare_csv_files(TEST_NODES_CSV_REPORT, nodes_filepath) + + +if __name__ == "__main__": + unittest.main()