Open a local model file (Ctrl-O)
", None)) + self.openFileBtn.setToolTip(QCoreApplication.translate("MainWindow", u"Open (Ctrl-O)
", None)) #endif // QT_CONFIG(tooltip) self.openFileBtn.setText("") #if QT_CONFIG(shortcut) diff --git a/src/digest/ui/modelsummary.ui b/src/digest/ui/modelsummary.ui index 180fed4..737cf33 100644 --- a/src/digest/ui/modelsummary.ui +++ b/src/digest/ui/modelsummary.ui @@ -6,8 +6,8 @@This is a warning message that we can use for now to prompt the user.
", None)) + self.exportOnnxBtn.setText(QCoreApplication.translate("pytorchIngest", u"Export ONNX", None)) + # retranslateUi + diff --git a/src/utils/onnx_utils.py b/src/utils/onnx_utils.py index d8a6894..9b92be1 100644 --- a/src/utils/onnx_utils.py +++ b/src/utils/onnx_utils.py @@ -1,95 +1,19 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. import os -import csv import tempfile -from uuid import uuid4 -from collections import Counter, OrderedDict, defaultdict -from typing import List, Dict, Optional, Any, Tuple, Union, cast -from datetime import datetime +from collections import Counter +from typing import List, Optional, Tuple, Union import numpy as np import onnx import onnxruntime as ort -from prettytable import PrettyTable - - -class NodeParsingException(Exception): - pass - - -# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias -class NodeShapeCounts(defaultdict[str, Counter]): - def __init__(self): - super().__init__(Counter) # Initialize with the Counter factory - - -class NodeTypeCounts(Dict[str, int]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class TensorInfo: - "Used to store node input and output tensor information" - - def __init__(self) -> None: - self.dtype: Optional[str] = None - self.dtype_bytes: Optional[int] = None - self.size_kbytes: Optional[float] = None - self.shape: List[Union[int, str]] = [] - - -class TensorData(OrderedDict[str, TensorInfo]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class NodeInfo: - def __init__(self) -> None: - self.flops: Optional[int] = None - self.parameters: int = 0 - self.node_type: Optional[str] = None - self.attributes: OrderedDict[str, Any] = OrderedDict() - # We use an ordered dictionary because the order in which - # the inputs and outputs are listed in the node matter. - self.inputs = TensorData() - self.outputs = TensorData() - - def get_input(self, index: int) -> TensorInfo: - return list(self.inputs.values())[index] - - def get_output(self, index: int) -> TensorInfo: - return list(self.outputs.values())[index] - - def __str__(self): - """Provides a human-readable string representation of NodeInfo.""" - output = [ - f"Node Type: {self.node_type}", - f"FLOPs: {self.flops if self.flops is not None else 'N/A'}", - f"Parameters: {self.parameters}", - ] - - if self.attributes: - output.append("Attributes:") - for key, value in self.attributes.items(): - output.append(f" - {key}: {value}") - - if self.inputs: - output.append("Inputs:") - for name, tensor in self.inputs.items(): - output.append(f" - {name}: {tensor}") - - if self.outputs: - output.append("Outputs:") - for name, tensor in self.outputs.items(): - output.append(f" - {name}: {tensor}") - - return "\n".join(output) - - -# The classes are for type aliasing. Once python 3.10 is the minimum we can switch to TypeAlias -class NodeData(OrderedDict[str, NodeInfo]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +from digest.model_class.digest_model import ( + NodeTypeCounts, + NodeData, + NodeShapeCounts, + TensorData, + TensorInfo, +) # Convert tensor type to human-readable string and size in bytes @@ -117,706 +41,6 @@ def tensor_type_to_str_and_size(elem_type) -> Tuple[str, int]: return type_mapping.get(elem_type, ("unknown", 0)) -class DigestOnnxModel: - def __init__( - self, - onnx_model: onnx.ModelProto, - onnx_filepath: Optional[str] = None, - model_name: Optional[str] = None, - save_proto: bool = True, - ) -> None: - # Public members exposed to the API - self.unique_id: str = str(uuid4()) - self.filepath: Optional[str] = onnx_filepath - self.model_proto: Optional[onnx.ModelProto] = onnx_model if save_proto else None - self.model_name: Optional[str] = model_name - self.model_version: Optional[int] = None - self.graph_name: Optional[str] = None - self.producer_name: Optional[str] = None - self.producer_version: Optional[str] = None - self.ir_version: Optional[int] = None - self.opset: Optional[int] = None - self.imports: Dict[str, int] = {} - self.node_type_counts: NodeTypeCounts = NodeTypeCounts() - self.model_flops: Optional[int] = None - self.model_parameters: int = 0 - self.node_type_flops: Dict[str, int] = {} - self.node_type_parameters: Dict[str, int] = {} - self.per_node_info = NodeData() - self.model_inputs = TensorData() - self.model_outputs = TensorData() - - # Private members not intended to be exposed - self.input_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.output_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.value_tensors_: Dict[str, onnx.ValueInfoProto] = {} - self.init_tensors_: Dict[str, onnx.TensorProto] = {} - - self.update_state(onnx_model) - - def update_state(self, model_proto: onnx.ModelProto) -> None: - self.model_version = model_proto.model_version - self.graph_name = model_proto.graph.name - self.producer_name = model_proto.producer_name - self.producer_version = model_proto.producer_version - self.ir_version = model_proto.ir_version - self.opset = get_opset(model_proto) - self.imports = { - import_.domain: import_.version for import_ in model_proto.opset_import - } - - self.model_inputs = get_model_input_shapes_types(model_proto) - self.model_outputs = get_model_output_shapes_types(model_proto) - - self.node_type_counts = get_node_type_counts(model_proto) - self.parse_model_nodes(model_proto) - - def get_node_tensor_info_( - self, onnx_node: onnx.NodeProto - ) -> Tuple[TensorData, TensorData]: - """ - This function is set to private because it is not intended to be used - outside of the DigestOnnxModel class. - """ - - input_tensor_info = TensorData() - for node_input in onnx_node.input: - input_tensor_info[node_input] = TensorInfo() - if ( - node_input in self.input_tensors_ - or node_input in self.value_tensors_ - or node_input in self.output_tensors_ - ): - tensor = ( - self.input_tensors_.get(node_input) - or self.value_tensors_.get(node_input) - or self.output_tensors_.get(node_input) - ) - if tensor: - for dim in tensor.type.tensor_type.shape.dim: - if dim.HasField("dim_value"): - input_tensor_info[node_input].shape.append(dim.dim_value) - elif dim.HasField("dim_param"): - input_tensor_info[node_input].shape.append(dim.dim_param) - - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - tensor.type.tensor_type.elem_type - ) - elif node_input in self.init_tensors_: - input_tensor_info[node_input].shape.extend( - [dim for dim in self.init_tensors_[node_input].dims] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - self.init_tensors_[node_input].data_type - ) - else: - dtype_str = None - dtype_bytes = None - - input_tensor_info[node_input].dtype = dtype_str - input_tensor_info[node_input].dtype_bytes = dtype_bytes - - if ( - all(isinstance(s, int) for s in input_tensor_info[node_input].shape) - and dtype_bytes - ): - tensor_size = float( - np.prod(np.array(input_tensor_info[node_input].shape)) - ) - input_tensor_info[node_input].size_kbytes = ( - tensor_size * float(dtype_bytes) / 1024.0 - ) - - output_tensor_info = TensorData() - for node_output in onnx_node.output: - output_tensor_info[node_output] = TensorInfo() - if ( - node_output in self.input_tensors_ - or node_output in self.value_tensors_ - or node_output in self.output_tensors_ - ): - tensor = ( - self.input_tensors_.get(node_output) - or self.value_tensors_.get(node_output) - or self.output_tensors_.get(node_output) - ) - if tensor: - output_tensor_info[node_output].shape.extend( - [ - int(dim.dim_value) - for dim in tensor.type.tensor_type.shape.dim - ] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - tensor.type.tensor_type.elem_type - ) - elif node_output in self.init_tensors_: - output_tensor_info[node_output].shape.extend( - [dim for dim in self.init_tensors_[node_output].dims] - ) - dtype_str, dtype_bytes = tensor_type_to_str_and_size( - self.init_tensors_[node_output].data_type - ) - - else: - dtype_str = None - dtype_bytes = None - - output_tensor_info[node_output].dtype = dtype_str - output_tensor_info[node_output].dtype_bytes = dtype_bytes - - if ( - all(isinstance(s, int) for s in output_tensor_info[node_output].shape) - and dtype_bytes - ): - tensor_size = float( - np.prod(np.array(output_tensor_info[node_output].shape)) - ) - output_tensor_info[node_output].size_kbytes = ( - tensor_size * float(dtype_bytes) / 1024.0 - ) - - return input_tensor_info, output_tensor_info - - def parse_model_nodes(self, onnx_model: onnx.ModelProto) -> None: - """ - Calculate total number of FLOPs found in the onnx model. - FLOP is defined as one floating-point operation. This distinguishes - from multiply-accumulates (MACs) where FLOPs == 2 * MACs. - """ - - # Initialze to zero so we can accumulate. Set to None during the - # model FLOPs calculation if it errors out. - self.model_flops = 0 - - # Check to see if the model inputs have any dynamic shapes - if get_dynamic_input_dims(onnx_model): - self.model_flops = None - - try: - onnx_model, _ = optimize_onnx_model(onnx_model) - - onnx_model = onnx.shape_inference.infer_shapes( - onnx_model, strict_mode=True, data_prop=True - ) - except Exception as e: # pylint: disable=broad-except - print(f"ONNX utils: {str(e)}") - self.model_flops = None - - # If the ONNX model contains one of the following unsupported ops, then this - # function will return None since the FLOP total is expected to be incorrect - unsupported_ops = [ - "Einsum", - "RNN", - "GRU", - "DeformConv", - ] - - if not self.input_tensors_: - self.input_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.input - } - - if not self.output_tensors_: - self.output_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.output - } - - if not self.value_tensors_: - self.value_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.value_info - } - - if not self.init_tensors_: - self.init_tensors_ = { - tensor.name: tensor for tensor in onnx_model.graph.initializer - } - - for node in onnx_model.graph.node: # pylint: disable=E1101 - - node_info = NodeInfo() - - # TODO: I have encountered models containing nodes with no name. It would be a good idea - # to have this type of model info fed back to the user through a warnings section. - if not node.name: - node.name = f"{node.op_type}_{len(self.per_node_info)}" - - node_info.node_type = node.op_type - input_tensor_info, output_tensor_info = self.get_node_tensor_info_(node) - node_info.inputs = input_tensor_info - node_info.outputs = output_tensor_info - - # Check if this node has parameters through the init tensors - for input_name, input_tensor in node_info.inputs.items(): - if input_name in self.init_tensors_: - if all(isinstance(dim, int) for dim in input_tensor.shape): - input_parameters = int(np.prod(np.array(input_tensor.shape))) - node_info.parameters += input_parameters - self.model_parameters += input_parameters - self.node_type_parameters[node.op_type] = ( - self.node_type_parameters.get(node.op_type, 0) - + input_parameters - ) - else: - print(f"Tensor with params has unknown shape: {input_name}") - - for attribute in node.attribute: - node_info.attributes.update(attribute_to_dict(attribute)) - - # if node.name in self.per_node_info: - # print(f"Node name {node.name} is a duplicate.") - - self.per_node_info[node.name] = node_info - - if node.op_type in unsupported_ops: - self.model_flops = None - node_info.flops = None - - try: - - if ( - node.op_type == "MatMul" - or node.op_type == "MatMulInteger" - or node.op_type == "QLinearMatMul" - ): - - input_a = node_info.get_input(0).shape - if node.op_type == "QLinearMatMul": - input_b = node_info.get_input(3).shape - else: - input_b = node_info.get_input(1).shape - - if not all( - isinstance(dim, int) for dim in input_a - ) or not isinstance(input_b[-1], int): - node_info.flops = None - self.model_flops = None - continue - - node_info.flops = int( - 2 * np.prod(np.array(input_a), dtype=np.int64) * input_b[-1] - ) - - elif ( - node.op_type == "Mul" - or node.op_type == "Div" - or node.op_type == "Add" - ): - input_a = node_info.get_input(0).shape - input_b = node_info.get_input(1).shape - - if not all(isinstance(dim, int) for dim in input_a) or not all( - isinstance(dim, int) for dim in input_b - ): - node_info.flops = None - self.model_flops = None - continue - - node_info.flops = int( - np.prod(np.array(input_a), dtype=np.int64) - ) + int(np.prod(np.array(input_b), dtype=np.int64)) - - elif node.op_type == "Gemm" or node.op_type == "QGemm": - x_shape = node_info.get_input(0).shape - if node.op_type == "Gemm": - w_shape = node_info.get_input(1).shape - else: - w_shape = node_info.get_input(3).shape - - if not all(isinstance(dim, int) for dim in x_shape) or not all( - isinstance(dim, int) for dim in w_shape - ): - node_info.flops = None - self.model_flops = None - continue - - mm_dims = [ - ( - x_shape[0] - if not node_info.attributes.get("transA", 0) - else x_shape[1] - ), - ( - x_shape[1] - if not node_info.attributes.get("transA", 0) - else x_shape[0] - ), - ( - w_shape[1] - if not node_info.attributes.get("transB", 0) - else w_shape[0] - ), - ] - - node_info.flops = int( - 2 * np.prod(np.array(mm_dims), dtype=np.int64) - ) - - if len(mm_dims) == 3: # if there is a bias input - bias_shape = node_info.get_input(2).shape - node_info.flops += int(np.prod(np.array(bias_shape))) - - elif ( - node.op_type == "Conv" - or node.op_type == "ConvInteger" - or node.op_type == "QLinearConv" - or node.op_type == "ConvTranspose" - ): - # N, C, d1, ..., dn - x_shape = node_info.get_input(0).shape - - # M, C/group, k1, ..., kn. Note C and M are swapped for ConvTranspose - if node.op_type == "QLinearConv": - w_shape = node_info.get_input(3).shape - else: - w_shape = node_info.get_input(1).shape - - if not all(isinstance(dim, int) for dim in x_shape): - node_info.flops = None - self.model_flops = None - continue - - x_shape_ints = cast(List[int], x_shape) - w_shape_ints = cast(List[int], w_shape) - - has_bias = False # Note, ConvInteger has no bias - if node.op_type == "Conv" and len(node_info.inputs) == 3: - has_bias = True - elif node.op_type == "QLinearConv" and len(node_info.inputs) == 9: - has_bias = True - - num_dims = len(x_shape_ints) - 2 - strides = node_info.attributes.get( - "strides", [1] * num_dims - ) # type: List[int] - dilation = node_info.attributes.get( - "dilations", [1] * num_dims - ) # type: List[int] - kernel_shape = w_shape_ints[2:] - batch_size = x_shape_ints[0] - out_channels = w_shape_ints[0] - out_dims = [batch_size, out_channels] - output_shape = node_info.attributes.get( - "output_shape", [] - ) # type: List[int] - - # If output_shape is given then we do not need to compute it ourselves - # The output_shape attribute does not include batch_size or channels and - # is only valid for ConvTranspose - if output_shape: - out_dims.extend(output_shape) - else: - auto_pad = node_info.attributes.get( - "auto_pad", "NOTSET".encode() - ).decode() - # SAME expects padding so that the output_shape = CEIL(input_shape / stride) - if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - out_dims.extend( - [x * s for x, s in zip(x_shape_ints[2:], strides)] - ) - else: - # NOTSET means just use pads attribute - if auto_pad == "NOTSET": - pads = node_info.attributes.get( - "pads", [0] * num_dims * 2 - ) - # VALID essentially means no padding - elif auto_pad == "VALID": - pads = [0] * num_dims * 2 - - for i in range(num_dims): - dim_in = x_shape_ints[i + 2] # type: int - - if node.op_type == "ConvTranspose": - out_dim = ( - strides[i] * (dim_in - 1) - + ((kernel_shape[i] - 1) * dilation[i] + 1) - - pads[i] - - pads[i + num_dims] - ) - else: - out_dim = ( - dim_in - + pads[i] - + pads[i + num_dims] - - dilation[i] * (kernel_shape[i] - 1) - - 1 - ) // strides[i] + 1 - - out_dims.append(out_dim) - - kernel_flops = int( - np.prod(np.array(kernel_shape)) * w_shape_ints[1] - ) - output_points = int(np.prod(np.array(out_dims))) - bias_ops = output_points if has_bias else int(0) - node_info.flops = 2 * kernel_flops * output_points + bias_ops - - elif node.op_type == "LSTM" or node.op_type == "DynamicQuantizeLSTM": - - x_shape = node_info.get_input( - 0 - ).shape # seq_length, batch_size, input_dim - - if not all(isinstance(dim, int) for dim in x_shape): - node_info.flops = None - self.model_flops = None - continue - - x_shape_ints = cast(List[int], x_shape) - hidden_size = node_info.attributes["hidden_size"] - direction = ( - 2 - if node_info.attributes.get("direction") - == "bidirectional".encode() - else 1 - ) - - has_bias = True if len(node_info.inputs) >= 4 else False - if has_bias: - bias_shape = node_info.get_input(3).shape - if isinstance(bias_shape[1], int): - bias_ops = bias_shape[1] - else: - bias_ops = 0 - else: - bias_ops = 0 - # seq_length, batch_size, input_dim = x_shape - if not isinstance(bias_ops, int): - bias_ops = int(0) - num_gates = int(4) - gate_input_flops = int(2 * x_shape_ints[2] * hidden_size) - gate_hid_flops = int(2 * hidden_size * hidden_size) - unit_flops = ( - num_gates * (gate_input_flops + gate_hid_flops) + bias_ops - ) - node_info.flops = ( - x_shape_ints[1] * x_shape_ints[0] * direction * unit_flops - ) - # In this case we just hit an op that doesn't have FLOPs - else: - node_info.flops = None - - except IndexError as err: - print(f"Error parsing node {node.name}: {err}") - node_info.flops = None - self.model_flops = None - continue - - # Update the model level flops count - if node_info.flops is not None and self.model_flops is not None: - self.model_flops += node_info.flops - - # Update the node type flops count - self.node_type_flops[node.op_type] = ( - self.node_type_flops.get(node.op_type, 0) + node_info.flops - ) - - def save_txt_report(self, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - report_date = datetime.now().strftime("%B %d, %Y") - - with open(filepath, "w", encoding="utf-8") as f_p: - f_p.write(f"Report created on {report_date}\n") - if self.filepath: - f_p.write(f"ONNX file: {self.filepath}\n") - f_p.write(f"Name of the model: {self.model_name}\n") - f_p.write(f"Model version: {self.model_version}\n") - f_p.write(f"Name of the graph: {self.graph_name}\n") - f_p.write(f"Producer: {self.producer_name} {self.producer_version}\n") - f_p.write(f"Ir version: {self.ir_version}\n") - f_p.write(f"Opset: {self.opset}\n\n") - f_p.write("Import list\n") - for name, version in self.imports.items(): - f_p.write(f"\t{name}: {version}\n") - - f_p.write("\n") - f_p.write(f"Total graph nodes: {sum(self.node_type_counts.values())}\n") - f_p.write(f"Number of parameters: {self.model_parameters}\n") - if self.model_flops: - f_p.write(f"Number of FLOPs: {self.model_flops}\n") - f_p.write("\n") - - table_op_intensity = PrettyTable() - table_op_intensity.field_names = ["Operation", "FLOPs", "Intensity (%)"] - for op_type, count in self.node_type_flops.items(): - if count > 0: - table_op_intensity.add_row( - [ - op_type, - count, - 100.0 * float(count) / float(self.model_flops), - ] - ) - - f_p.write("Op intensity:\n") - f_p.write(table_op_intensity.get_string()) - f_p.write("\n\n") - - node_counts_table = PrettyTable() - node_counts_table.field_names = ["Node", "Occurrences"] - for op, count in self.node_type_counts.items(): - node_counts_table.add_row([op, count]) - f_p.write("Nodes and their occurrences:\n") - f_p.write(node_counts_table.get_string()) - f_p.write("\n\n") - - input_table = PrettyTable() - input_table.field_names = [ - "Input Name", - "Shape", - "Type", - "Tensor Size (KB)", - ] - for input_name, input_details in self.model_inputs.items(): - if input_details.size_kbytes: - kbytes = f"{input_details.size_kbytes:.2f}" - else: - kbytes = "" - - input_table.add_row( - [ - input_name, - input_details.shape, - input_details.dtype, - kbytes, - ] - ) - f_p.write("Input Tensor(s) Information:\n") - f_p.write(input_table.get_string()) - f_p.write("\n\n") - - output_table = PrettyTable() - output_table.field_names = [ - "Output Name", - "Shape", - "Type", - "Tensor Size (KB)", - ] - for output_name, output_details in self.model_outputs.items(): - if output_details.size_kbytes: - kbytes = f"{output_details.size_kbytes:.2f}" - else: - kbytes = "" - - output_table.add_row( - [ - output_name, - output_details.shape, - output_details.dtype, - kbytes, - ] - ) - f_p.write("Output Tensor(s) Information:\n") - f_p.write(output_table.get_string()) - f_p.write("\n\n") - - def save_nodes_csv_report(self, filepath: str) -> None: - save_nodes_csv_report(self.per_node_info, filepath) - - def get_node_type_counts(self) -> Union[NodeTypeCounts, None]: - if not self.node_type_counts and self.model_proto: - self.node_type_counts = get_node_type_counts(self.model_proto) - return self.node_type_counts if self.node_type_counts else None - - def get_node_shape_counts(self) -> NodeShapeCounts: - tensor_shape_counter = NodeShapeCounts() - for _, info in self.per_node_info.items(): - shape_hash = tuple([tuple(v.shape) for _, v in info.inputs.items()]) - if info.node_type: - tensor_shape_counter[info.node_type][shape_hash] += 1 - return tensor_shape_counter - - -def save_nodes_csv_report(node_data: NodeData, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - flattened_data = [] - fieldnames = ["Node Name", "Node Type", "Parameters", "FLOPs", "Attributes"] - input_fieldnames = [] - output_fieldnames = [] - for name, node_info in node_data.items(): - row = OrderedDict() - row["Node Name"] = name - row["Node Type"] = str(node_info.node_type) - row["Parameters"] = str(node_info.parameters) - row["FLOPs"] = str(node_info.flops) - if node_info.attributes: - row["Attributes"] = str({k: v for k, v in node_info.attributes.items()}) - else: - row["Attributes"] = "" - - for i, (input_name, input_info) in enumerate(node_info.inputs.items()): - column_name = f"Input{i+1} (Shape, Dtype, Size (kB))" - row[column_name] = ( - f"{input_name} ({input_info.shape}, {input_info.dtype}, {input_info.size_kbytes})" - ) - - # Dynamically add input column names to fieldnames if not already present - if column_name not in input_fieldnames: - input_fieldnames.append(column_name) - - for i, (output_name, output_info) in enumerate(node_info.outputs.items()): - column_name = f"Output{i+1} (Shape, Dtype, Size (kB))" - row[column_name] = ( - f"{output_name} ({output_info.shape}, " - f"{output_info.dtype}, {output_info.size_kbytes})" - ) - - # Dynamically add input column names to fieldnames if not already present - if column_name not in output_fieldnames: - output_fieldnames.append(column_name) - - flattened_data.append(row) - - fieldnames = fieldnames + input_fieldnames + output_fieldnames - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n") - writer.writeheader() - writer.writerows(flattened_data) - - -def save_node_type_counts_csv_report(node_data: NodeTypeCounts, filepath: str) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - header = ["Node Type", "Count"] - - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.writer(csvfile, lineterminator="\n") - writer.writerow(header) - for node_type, node_count in node_data.items(): - writer.writerow([node_type, node_count]) - - -def save_node_shape_counts_csv_report( - node_data: NodeShapeCounts, filepath: str -) -> None: - - parent_dir = os.path.dirname(os.path.abspath(filepath)) - if not os.path.exists(parent_dir): - raise FileNotFoundError(f"Directory {parent_dir} does not exist.") - - header = ["Node Type", "Input Tensors Shapes", "Count"] - - with open(filepath, "w", encoding="utf-8", newline="") as csvfile: - writer = csv.writer(csvfile, dialect="excel", lineterminator="\n") - writer.writerow(header) - for node_type, node_info in node_data.items(): - info_iter = iter(node_info.items()) - for shape, count in info_iter: - writer.writerow([node_type, shape, count]) - - def load_onnx(onnx_path: str, load_external_data: bool = True) -> onnx.ModelProto: if os.path.exists(onnx_path): return onnx.load(onnx_path, load_external_data=load_external_data) @@ -987,3 +211,9 @@ def optimize_onnx_model( except onnx.checker.ValidationError: print("Model did not pass checker!") return model_proto, False + + +def get_supported_opset() -> int: + """This function will return the opset version associated + with the currently installed ONNX library""" + return onnx.defs.onnx_opset_version() diff --git a/test/resnet18_reports/resnet18_heatmap.png b/test/resnet18_reports/resnet18_heatmap.png new file mode 100644 index 0000000..1fb614e Binary files /dev/null and b/test/resnet18_reports/resnet18_heatmap.png differ diff --git a/test/resnet18_reports/resnet18_histogram.png b/test/resnet18_reports/resnet18_histogram.png new file mode 100644 index 0000000..eb13e01 Binary files /dev/null and b/test/resnet18_reports/resnet18_histogram.png differ diff --git a/test/resnet18_reports/resnet18_node_type_counts.csv b/test/resnet18_reports/resnet18_node_type_counts.csv new file mode 100644 index 0000000..29504ba --- /dev/null +++ b/test/resnet18_reports/resnet18_node_type_counts.csv @@ -0,0 +1,8 @@ +Node Type,Count +Conv,20 +Relu,17 +Add,8 +MaxPool,1 +GlobalAveragePool,1 +Flatten,1 +Gemm,1 diff --git a/test/resnet18_test_nodes.csv b/test/resnet18_reports/resnet18_nodes.csv similarity index 100% rename from test/resnet18_test_nodes.csv rename to test/resnet18_reports/resnet18_nodes.csv diff --git a/test/resnet18_test_summary.txt b/test/resnet18_reports/resnet18_report.txt similarity index 86% rename from test/resnet18_test_summary.txt rename to test/resnet18_reports/resnet18_report.txt index a5b4cfb..fdda0bf 100644 --- a/test/resnet18_test_summary.txt +++ b/test/resnet18_reports/resnet18_report.txt @@ -1,5 +1,5 @@ -Report created on June 02, 2024 -ONNX file: resnet18.onnx +Report created on December 06, 2024 +ONNX file: C:\Users\pcolange\Projects\digestai\test\resnet18.onnx Name of the model: resnet18 Model version: 0 Name of the graph: main_graph @@ -9,6 +9,13 @@ Opset: 17 Import list : 17 + ai.onnx.ml: 5 + ai.onnx.preview.training: 1 + ai.onnx.training: 1 + com.microsoft: 1 + com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 + org.pytorch.aten: 1 Total graph nodes: 49 Number of parameters: 11684712 diff --git a/test/resnet18_reports/resnet18_report.yaml b/test/resnet18_reports/resnet18_report.yaml new file mode 100644 index 0000000..9df22be --- /dev/null +++ b/test/resnet18_reports/resnet18_report.yaml @@ -0,0 +1,56 @@ +report_date: December 06, 2024 +model_file: C:\Users\pcolange\Projects\digestai\test\resnet18.onnx +model_type: onnx +model_name: resnet18 +model_version: 0 +graph_name: main_graph +producer_name: pytorch +producer_version: 2.1.0 +ir_version: 8 +opset: 17 +import_list: + ? '' + : 17 + ai.onnx.ml: 5 + ai.onnx.preview.training: 1 + ai.onnx.training: 1 + com.microsoft: 1 + com.microsoft.experimental: 1 + com.microsoft.nchwc: 1 + org.pytorch.aten: 1 +graph_nodes: 49 +parameters: 11684712 +flops: 3632136680 +node_type_counts: + Conv: 20 + Relu: 17 + Add: 8 + MaxPool: 1 + GlobalAveragePool: 1 + Flatten: 1 + Gemm: 1 +node_type_flops: + Conv: 3629606400 + Add: 1505280 + Gemm: 1025000 +node_type_parameters: + Conv: 11171712 + Gemm: 513000 +input_tensors: + input.1: + dtype: float32 + dtype_bytes: 4 + size_kbytes: 588.0 + shape: + - 1 + - 3 + - 224 + - 224 +output_tensors: + '191': + dtype: float32 + dtype_bytes: 4 + size_kbytes: 3.90625 + shape: + - 1 + - 1000 diff --git a/test/test_gui.py b/test/test_gui.py index 0e1d351..9a06f3e 100644 --- a/test/test_gui.py +++ b/test/test_gui.py @@ -5,63 +5,166 @@ import tempfile import unittest from unittest.mock import patch +import timm +import torch # pylint: disable=no-name-in-module from PySide6.QtTest import QTest -from PySide6.QtCore import Qt, QDeadlineTimer +from PySide6.QtCore import Qt from PySide6.QtWidgets import QApplication import digest.main from digest.node_summary import NodeSummary +from digest.model_class.digest_pytorch_model import DigestPyTorchModel +from digest.pytorch_ingest import PyTorchIngest -ONNX_BASENAME = "resnet18" -TEST_DIR = os.path.abspath(os.path.dirname(__file__)) -ONNX_FILEPATH = os.path.normpath(os.path.join(TEST_DIR, f"{ONNX_BASENAME}.onnx")) + +def save_resnet18_pt(directory: str) -> str: + """Simply saves a PyTorch resnet18 model and returns its file path""" + model = timm.models.create_model("resnet18", pretrained=True) # type: ignore + model.eval() + file_path = os.path.join(directory, "resnet18.pt") + # Save the model + try: + torch.save(model, file_path) + return file_path + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Error saving model: {e}") + return "" class DigestGuiTest(unittest.TestCase): + RESNET18_BASENAME = "resnet18" + + TEST_DIR = os.path.abspath(os.path.dirname(__file__)) + ONNX_FILE_PATH = os.path.normpath( + os.path.join(TEST_DIR, f"{RESNET18_BASENAME}.onnx") + ) + YAML_FILE_PATH = os.path.normpath( + os.path.join( + TEST_DIR, f"{RESNET18_BASENAME}_reports", f"{RESNET18_BASENAME}_report.yaml" + ) + ) @classmethod def setUpClass(cls): cls.app = QApplication(sys.argv) + return super().setUpClass() + + @classmethod + def tearDownClass(cls): + if isinstance(cls.app, QApplication): + cls.app.closeAllWindows() + cls.app = None def setUp(self): self.digest_app = digest.main.DigestApp() self.digest_app.show() def tearDown(self): - self.wait_all_threads() self.digest_app.close() - def wait_all_threads(self): + def wait_all_threads(self, timeout=10000) -> bool: + all_threads = list(self.digest_app.model_nodes_stats_thread.values()) + list( + self.digest_app.model_similarity_thread.values() + ) - for thread in self.digest_app.model_nodes_stats_thread.values(): - thread.wait(deadline=QDeadlineTimer.Forever) + for thread in all_threads: + thread.wait(timeout) - for thread in self.digest_app.model_similarity_thread.values(): - thread.wait(deadline=QDeadlineTimer.Forever) + # Return True if all threads finished, False if timed out + return all(thread.isFinished() for thread in all_threads) def test_open_valid_onnx(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ( - ONNX_FILEPATH, + self.ONNX_FILE_PATH, + "", + ) + + num_tabs_prior = self.digest_app.ui.tabWidget.count() + + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) + + self.assertTrue(self.wait_all_threads()) + + self.assertTrue( + self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 + ) # Check if a tab was added + + self.digest_app.closeTab(num_tabs_prior) + + def test_open_valid_yaml(self): + with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: + mock_dialog.return_value = ( + self.YAML_FILE_PATH, "", ) + num_tabs_prior = self.digest_app.ui.tabWidget.count() + QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) self.assertTrue( - self.digest_app.ui.tabWidget.count() > 0 + self.digest_app.ui.tabWidget.count() == num_tabs_prior + 1 ) # Check if a tab was added + self.digest_app.closeTab(num_tabs_prior) + + def test_open_valid_pytorch(self): + """We test the PyTorch path slightly different than the others + since Digest opens a modal window that blocks the main thread. This makes it difficult + to interact with the Window in this test.""" + + with tempfile.TemporaryDirectory() as tmpdir: + pt_file_path = save_resnet18_pt(tmpdir) + self.assertTrue(os.path.exists(tmpdir)) + basename = os.path.splitext(os.path.basename(pt_file_path)) + model_name = basename[0] + digest_model = DigestPyTorchModel(pt_file_path, model_name) + self.assertTrue(isinstance(digest_model.file_path, str)) + pytorch_ingest = PyTorchIngest(pt_file_path, digest_model.model_name) + pytorch_ingest.show() + + input_shape_edit = ( + pytorch_ingest.user_input_form.get_row_tensor_shape_widget(0) + ) + + assert input_shape_edit + input_shape_edit.setText("batch_size, 3, 224, 224") + pytorch_ingest.update_tensor_info() + + with patch( + "PySide6.QtWidgets.QFileDialog.getExistingDirectory" + ) as mock_save_dialog: + print("TMPDIR", tmpdir) + mock_save_dialog.return_value = tmpdir + pytorch_ingest.select_directory() + + pytorch_ingest.ui.exportOnnxBtn.click() + + timeout_ms = 10000 + interval_ms = 100 + for _ in range(timeout_ms // interval_ms): + QTest.qWait(interval_ms) + onnx_file_path = pytorch_ingest.digest_pytorch_model.onnx_file_path + if onnx_file_path and os.path.exists(onnx_file_path): + break # File found! + + assert isinstance(pytorch_ingest.digest_pytorch_model.onnx_file_path, str) + self.assertTrue( + os.path.exists(pytorch_ingest.digest_pytorch_model.onnx_file_path) + ) + def test_open_invalid_file(self): with patch("PySide6.QtWidgets.QFileDialog.getOpenFileName") as mock_dialog: mock_dialog.return_value = ("invalid_file.txt", "") + num_tabs_prior = self.digest_app.ui.tabWidget.count() QTest.mouseClick(self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton) - self.wait_all_threads() - self.assertEqual(self.digest_app.ui.tabWidget.count(), 0) + self.assertTrue(self.wait_all_threads()) + self.assertEqual(self.digest_app.ui.tabWidget.count(), num_tabs_prior) def test_save_reports(self): with patch( @@ -70,7 +173,7 @@ def test_save_reports(self): "PySide6.QtWidgets.QFileDialog.getExistingDirectory" ) as mock_save_dialog: - mock_open_dialog.return_value = (ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = tmpdirname @@ -79,44 +182,56 @@ def test_save_reports(self): Qt.MouseButton.LeftButton, ) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) - # This is a slight hack but the issue is that model similarity takes - # a bit longer to complete and we must have it done before the save - # button is enabled guaranteeing all the artifacts are saved. - # wait_all_threads() above doesn't seem to work. The only thing that - # does is just waiting 5 seconds. - QTest.qWait(5000) + self.assertTrue( + self.digest_app.ui.saveBtn.isEnabled(), "Save button is disabled!" + ) QTest.mouseClick(self.digest_app.ui.saveBtn, Qt.MouseButton.LeftButton) mock_save_dialog.assert_called_once() - result_basepath = os.path.join(tmpdirname, f"{ONNX_BASENAME}_reports") + result_basepath = os.path.join( + tmpdirname, f"{self.RESNET18_BASENAME}_reports" + ) # Text report test - txt_report_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_report.txt" + text_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.txt" + ) + self.assertTrue( + os.path.isfile(text_report_FILE_PATH), + f"{text_report_FILE_PATH} not found!", ) - self.assertTrue(os.path.isfile(txt_report_filepath)) + + # YAML report test + yaml_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_report.yaml" + ) + self.assertTrue(os.path.isfile(yaml_report_FILE_PATH)) # Nodes test - nodes_csv_report_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_nodes.csv" + nodes_csv_report_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_nodes.csv" ) - self.assertTrue(os.path.isfile(nodes_csv_report_filepath)) + self.assertTrue(os.path.isfile(nodes_csv_report_FILE_PATH)) # Histogram test - histogram_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_histogram.png" + histogram_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_histogram.png" ) - self.assertTrue(os.path.isfile(histogram_filepath)) + self.assertTrue(os.path.isfile(histogram_FILE_PATH)) # Heatmap test - heatmap_filepath = os.path.join( - result_basepath, f"{ONNX_BASENAME}_heatmap.png" + heatmap_FILE_PATH = os.path.join( + result_basepath, f"{self.RESNET18_BASENAME}_heatmap.png" ) - self.assertTrue(os.path.isfile(heatmap_filepath)) + self.assertTrue(os.path.isfile(heatmap_FILE_PATH)) + + num_tabs = self.digest_app.ui.tabWidget.count() + self.assertTrue(num_tabs == 1) + self.digest_app.closeTab(0) def test_save_tables(self): with patch( @@ -125,10 +240,10 @@ def test_save_tables(self): "PySide6.QtWidgets.QFileDialog.getSaveFileName" ) as mock_save_dialog: - mock_open_dialog.return_value = (ONNX_FILEPATH, "") + mock_open_dialog.return_value = (self.ONNX_FILE_PATH, "") with tempfile.TemporaryDirectory() as tmpdirname: mock_save_dialog.return_value = ( - os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv"), + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv"), "", ) @@ -136,17 +251,19 @@ def test_save_tables(self): self.digest_app.ui.openFileBtn, Qt.MouseButton.LeftButton ) - self.wait_all_threads() + self.assertTrue(self.wait_all_threads()) QTest.mouseClick( self.digest_app.ui.nodesListBtn, Qt.MouseButton.LeftButton ) - # We assume there is only model loaded + # We assume there is only one model loaded _, node_window = self.digest_app.nodes_window.popitem() node_summary = node_window.main_window.centralWidget() self.assertIsInstance(node_summary, NodeSummary) + + # This line of code seems redundant but we do this to clean pylance if isinstance(node_summary, NodeSummary): QTest.mouseClick( node_summary.ui.saveCsvBtn, Qt.MouseButton.LeftButton @@ -156,11 +273,15 @@ def test_save_tables(self): self.assertTrue( os.path.exists( - os.path.join(tmpdirname, f"{ONNX_BASENAME}_nodes.csv") + os.path.join(tmpdirname, f"{self.RESNET18_BASENAME}_nodes.csv") ), "Nodes csv file not found.", ) + num_tabs = self.digest_app.ui.tabWidget.count() + self.assertTrue(num_tabs == 1) + self.digest_app.closeTab(0) + if __name__ == "__main__": unittest.main() diff --git a/test/test_reports.py b/test/test_reports.py index a16c4d8..e4d327e 100644 --- a/test/test_reports.py +++ b/test/test_reports.py @@ -1,17 +1,22 @@ # Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved. -"""Unit tests for Vitis ONNX Model Analyzer """ - import os import unittest import tempfile import csv -from utils.onnx_utils import DigestOnnxModel, load_onnx +import utils.onnx_utils as onnx_utils +from digest.model_class.digest_onnx_model import DigestOnnxModel +from digest.model_class.digest_report_model import compare_yaml_files TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_ONNX = os.path.join(TEST_DIR, "resnet18.onnx") -TEST_SUMMARY_TXT_REPORT = os.path.join(TEST_DIR, "resnet18_test_summary.txt") -TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_test_nodes.csv") +TEST_SUMMARY_TEXT_REPORT = os.path.join( + TEST_DIR, "resnet18_reports/resnet18_report.txt" +) +TEST_SUMMARY_YAML_REPORT = os.path.join( + TEST_DIR, "resnet18_reports/resnet18_report.yaml" +) +TEST_NODES_CSV_REPORT = os.path.join(TEST_DIR, "resnet18_reports/resnet18_nodes.csv") class TestDigestReports(unittest.TestCase): @@ -46,27 +51,35 @@ def compare_csv_files(self, file1, file2, skip_lines=0): self.assertEqual(row1, row2, msg=f"Difference in row: {row1} vs {row2}") def test_against_example_reports(self): - model_proto = load_onnx(TEST_ONNX) + model_proto = onnx_utils.load_onnx(TEST_ONNX, load_external_data=False) model_name = os.path.splitext(os.path.basename(TEST_ONNX))[0] + opt_model, _ = onnx_utils.optimize_onnx_model(model_proto) digest_model = DigestOnnxModel( - model_proto, onnx_filepath=TEST_ONNX, model_name=model_name, save_proto=False, + opt_model, + onnx_file_path=TEST_ONNX, + model_name=model_name, + save_proto=False, ) with tempfile.TemporaryDirectory() as tmpdir: - # Model summary text report - summary_filepath = os.path.join(tmpdir, f"{model_name}_summary.txt") - digest_model.save_txt_report(summary_filepath) - - with self.subTest("Testing summary text file"): - self.compare_files_line_by_line( - TEST_SUMMARY_TXT_REPORT, - summary_filepath, - skip_lines=2, + # Model yaml report + yaml_report_filepath = os.path.join(tmpdir, f"{model_name}_report.yaml") + digest_model.save_yaml_report(yaml_report_filepath) + with self.subTest("Testing report yaml file"): + self.assertTrue( + compare_yaml_files( + TEST_SUMMARY_YAML_REPORT, + yaml_report_filepath, + skip_keys=["report_date", "model_file", "digest_version"], + ) ) # Save CSV containing node-level information nodes_filepath = os.path.join(tmpdir, f"{model_name}_nodes.csv") digest_model.save_nodes_csv_report(nodes_filepath) - with self.subTest("Testing nodes csv file"): self.compare_csv_files(TEST_NODES_CSV_REPORT, nodes_filepath) + + +if __name__ == "__main__": + unittest.main()