From e978b3665d68416b782a34acb3dfca4f1f14e8b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Thu, 14 Dec 2023 17:39:41 +0100 Subject: [PATCH] convert_dinov2: ignore pyright errors And save converted weights into safetensors instead of pickle --- scripts/conversion/convert_dinov2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py index 98dafb5b0..5e8af3574 100644 --- a/scripts/conversion/convert_dinov2.py +++ b/scripts/conversion/convert_dinov2.py @@ -2,6 +2,8 @@ import torch +from refiners.fluxion.utils import save_to_safetensors + def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: """Convert a DINOv2 weights from facebook to refiners.""" @@ -126,9 +128,9 @@ def main() -> None: parser.add_argument("--output_path", type=str, required=True) args = parser.parse_args() - weights = torch.load(args.weights_path) + weights = torch.load(args.weights_path) # type: ignore convert_dinov2_facebook(weights) - torch.save(weights, args.output_path) + save_to_safetensors(path=args.output_path, tensors=weights) if __name__ == "__main__":