Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for device in safetensors.torch.load_model #449

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/python/py_src/safetensors/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, padd
Args:
filename (`str`, or `os.PathLike`)):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular paddle device locations

Expand Down
15 changes: 9 additions & 6 deletions bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def save_model(
raise ValueError(msg)


def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict=True) -> Tuple[List[str], List[str]]:
def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]:
"""
Loads a given filename onto a torch model.
This method exists specifically to avoid tensor sharing issues which are
Expand All @@ -185,16 +185,19 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict
filename (`str`, or `os.PathLike`):
The filename location to load the file from.
strict (`bool`, *optional*, defaults to True):
Wether to fail if you're missing keys or having unexpected ones
Whether to fail if you're missing keys or having unexpected ones.
When false, the function simply returns missing and unexpected names.
device (`Union[str, int]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular torch device locations.

Returns:
`(missing, unexpected): (List[str], List[str])`
`missing` are names in the model which were not modified during loading
`unexpected` are names that are on the file, but weren't used during
the load.
"""
state_dict = load_file(filename)
state_dict = load_file(filename, device=device)
model_state_dict = model.state_dict()
to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys())
missing, unexpected = model.load_state_dict(state_dict, strict=False)
Expand Down Expand Up @@ -281,16 +284,16 @@ def save_file(
serialize_file(_flatten(tensors), filename, metadata=metadata)


def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torch.Tensor]:
def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]:
"""
Loads a safetensors file into torch format.

Args:
filename (`str`, or `os.PathLike`):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
device (`Union[str, int]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular torch device locations
available options are all regular torch device locations.

Returns:
`Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor`
Expand Down
Loading