diff --git a/src/templar/__init__.py b/src/templar/__init__.py index 54c9c1a..bc08be2 100644 --- a/src/templar/__init__.py +++ b/src/templar/__init__.py @@ -20,7 +20,7 @@ # mypy: ignore-errors # type: ignore -__version__ = "0.1.18" +__version__ = "0.1.27" version_key = 3000 # Import package. diff --git a/src/templar/comms.py b/src/templar/comms.py index 2deecc8..b6d4765 100644 --- a/src/templar/comms.py +++ b/src/templar/comms.py @@ -537,7 +537,7 @@ async def validate_slice_data(slice_file: str, save_location: str) -> bool: """ try: # Load the slice data - slice_data = torch.load(slice_file) + slice_data = torch.load(slice_file, weights_only=True) # Basic validation checks if not isinstance(slice_data, dict): @@ -1070,7 +1070,7 @@ async def load_checkpoint( additional_state (dict): Dictionary of additional state variables. """ try: - checkpoint = torch.load(filename, map_location=device) + checkpoint = torch.load(filename, map_location=device, weights_only=True) model.load_state_dict(checkpoint["model_state_dict"]) if optimizer and "optimizer_state_dict" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer_state_dict"])