From 0ad9a662713793864daebea43a9e27cb11cd46a5 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 5 Mar 2024 11:39:23 +0100 Subject: [PATCH 1/4] fix typo --- bindings/python/py_src/safetensors/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 22915c98..0b18d8db 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -185,7 +185,7 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict filename (`str`, or `os.PathLike`): The filename location to load the file from. strict (`bool`, *optional*, defaults to True): - Wether to fail if you're missing keys or having unexpected ones + Whether to fail if you're missing keys or having unexpected ones. When false, the function simply returns missing and unexpected names. Returns: From bab680dbd265b1a8ecffbda5d4d5023932146573 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 5 Mar 2024 11:47:30 +0100 Subject: [PATCH 2/4] allow loading torch model to device --- bindings/python/py_src/safetensors/torch.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 0b18d8db..2c4ce20c 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -173,7 +173,7 @@ def save_model( raise ValueError(msg) -def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict=True) -> Tuple[List[str], List[str]]: +def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device="cpu") -> Tuple[List[str], List[str]]: """ Loads a given filename onto a torch model. This method exists specifically to avoid tensor sharing issues which are @@ -187,6 +187,9 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict strict (`bool`, *optional*, defaults to True): Whether to fail if you're missing keys or having unexpected ones. When false, the function simply returns missing and unexpected names. + device (`Dict[str, any]`, *optional*, defaults to `cpu`): + The device where the tensors need to be located after load. + available options are all regular torch device locations. Returns: `(missing, unexpected): (List[str], List[str])` @@ -194,7 +197,7 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict `unexpected` are names that are on the file, but weren't used during the load. """ - state_dict = load_file(filename) + state_dict = load_file(filename, device=device) model_state_dict = model.state_dict() to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys()) missing, unexpected = model.load_state_dict(state_dict, strict=False) @@ -290,7 +293,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torc The name of the file which contains the tensors device (`Dict[str, any]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. - available options are all regular torch device locations + available options are all regular torch device locations. Returns: `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor` From d723ff7be806a6a696763b80554579d99791524b Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 5 Mar 2024 16:35:00 +0100 Subject: [PATCH 3/4] fix device type --- bindings/python/py_src/safetensors/paddle.py | 2 +- bindings/python/py_src/safetensors/torch.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bindings/python/py_src/safetensors/paddle.py b/bindings/python/py_src/safetensors/paddle.py index b242237a..cec36866 100644 --- a/bindings/python/py_src/safetensors/paddle.py +++ b/bindings/python/py_src/safetensors/paddle.py @@ -105,7 +105,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, padd Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): + device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. available options are all regular paddle device locations diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 2c4ce20c..5463e7d5 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -187,7 +187,7 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict strict (`bool`, *optional*, defaults to True): Whether to fail if you're missing keys or having unexpected ones. When false, the function simply returns missing and unexpected names. - device (`Dict[str, any]`, *optional*, defaults to `cpu`): + device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. available options are all regular torch device locations. @@ -291,7 +291,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torc Args: filename (`str`, or `os.PathLike`): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): + device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. available options are all regular torch device locations. From 3bad1e2618c16a5db70a4e02b4fcf1ebb8641a95 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Thu, 7 Mar 2024 14:04:34 +0100 Subject: [PATCH 4/4] update device type --- bindings/python/py_src/safetensors/torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 5463e7d5..5d98bac4 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -173,7 +173,7 @@ def save_model( raise ValueError(msg) -def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device="cpu") -> Tuple[List[str], List[str]]: +def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]: """ Loads a given filename onto a torch model. This method exists specifically to avoid tensor sharing issues which are @@ -187,7 +187,7 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict strict (`bool`, *optional*, defaults to True): Whether to fail if you're missing keys or having unexpected ones. When false, the function simply returns missing and unexpected names. - device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): + device (`Union[str, int]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. available options are all regular torch device locations. @@ -284,14 +284,14 @@ def save_file( serialize_file(_flatten(tensors), filename, metadata=metadata) -def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torch.Tensor]: +def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]: """ Loads a safetensors file into torch format. Args: filename (`str`, or `os.PathLike`): The name of the file which contains the tensors - device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): + device (`Union[str, int]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. available options are all regular torch device locations.