Skip to content

Commit

Permalink
fix: making sure elements in pytorch files are tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Nov 8, 2024
1 parent 83961ea commit c40b642
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/core/handlers/pytorch/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import argparse
import numpy as np


def main():
Expand Down Expand Up @@ -51,6 +52,13 @@ def main():
model = model["model"]

for tensor_name, tensor in model.items():
# make sure it's a tensor
if not isinstance(tensor, torch.Tensor):
try:
tensor = torch.tensor(tensor)
except:
continue

inspection["data_size"] += tensor.shape.numel() * tensor.element_size()

shape = list(tensor.shape)
Expand Down

0 comments on commit c40b642

Please sign in to comment.