From 0282296401be5a4e9e7f8a454f7c05106876e499 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 8 Sep 2023 00:25:12 -0700 Subject: [PATCH] Update torch.py (#356) --- bindings/python/py_src/safetensors/torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 12d1f893..7fa59675 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -109,7 +109,7 @@ def _remove_duplicate_names( keep_name = sorted(list(complete_names))[0] - # Mecanism to preferentially select keys to keep + # Mechanism to preferentially select keys to keep # coming from the on-disk file to allow # loading models saved with a different choice # of keep_name @@ -173,7 +173,7 @@ def save_model( raise ValueError(msg) -def load_model(model: torch.nn.Module, filename: str, strict=True) -> Tuple[List[str], List[str]]: +def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict=True) -> Tuple[List[str], List[str]]: """ Loads a given filename onto a torch model. This method exists specifically to avoid tensor sharing issues which are @@ -182,7 +182,7 @@ def load_model(model: torch.nn.Module, filename: str, strict=True) -> Tuple[List Args: model (`torch.nn.Module`): The model to load onto. - filename (`str`): + filename (`str`, or `os.PathLike`): The filename location to load the file from. strict (`bool`, *optional*, defaults to True): Wether to fail if you're missing keys or having unexpected ones @@ -286,7 +286,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torc Loads a safetensors file into torch format. Args: - filename (`str`, or `os.PathLike`)): + filename (`str`, or `os.PathLike`): The name of the file which contains the tensors device (`Dict[str, any]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load.