Skip to content

Commit

Permalink
Update torch.py (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhu authored Sep 8, 2023
1 parent cc5d941 commit 0282296
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _remove_duplicate_names(

keep_name = sorted(list(complete_names))[0]

# Mecanism to preferentially select keys to keep
# Mechanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
Expand Down Expand Up @@ -173,7 +173,7 @@ def save_model(
raise ValueError(msg)


def load_model(model: torch.nn.Module, filename: str, strict=True) -> Tuple[List[str], List[str]]:
def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict=True) -> Tuple[List[str], List[str]]:
"""
Loads a given filename onto a torch model.
This method exists specifically to avoid tensor sharing issues which are
Expand All @@ -182,7 +182,7 @@ def load_model(model: torch.nn.Module, filename: str, strict=True) -> Tuple[List
Args:
model (`torch.nn.Module`):
The model to load onto.
filename (`str`):
filename (`str`, or `os.PathLike`):
The filename location to load the file from.
strict (`bool`, *optional*, defaults to True):
Wether to fail if you're missing keys or having unexpected ones
Expand Down Expand Up @@ -286,7 +286,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torc
Loads a safetensors file into torch format.
Args:
filename (`str`, or `os.PathLike`)):
filename (`str`, or `os.PathLike`):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
Expand Down

0 comments on commit 0282296

Please sign in to comment.