From c40b6425616232940de4b83454327937fabf282c Mon Sep 17 00:00:00 2001 From: evilsocket Date: Fri, 8 Nov 2024 11:43:08 +0100 Subject: [PATCH] fix: making sure elements in pytorch files are tensors --- src/core/handlers/pytorch/inspect.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/core/handlers/pytorch/inspect.py b/src/core/handlers/pytorch/inspect.py index c0d8c04..808c2f8 100644 --- a/src/core/handlers/pytorch/inspect.py +++ b/src/core/handlers/pytorch/inspect.py @@ -2,6 +2,7 @@ import json import os import argparse +import numpy as np def main(): @@ -51,6 +52,13 @@ def main(): model = model["model"] for tensor_name, tensor in model.items(): + # make sure it's a tensor + if not isinstance(tensor, torch.Tensor): + try: + tensor = torch.tensor(tensor) + except: + continue + inspection["data_size"] += tensor.shape.numel() * tensor.element_size() shape = list(tensor.shape)