-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for device
in safetensors.torch.load_model
#449
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
device
in safetensors.torch.load_model
When false, the function simply returns missing and unexpected names. | ||
device (`Dict[str, any]`, *optional*, defaults to `cpu`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device (`Dict[str, any]`, *optional*, defaults to `cpu`): | |
device (`str`, *optional*, defaults to `cpu`): |
is it dict or str ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took the exact same definition as in load_file docstring:
device (`Dict[str, any]`, *optional*, defaults to `cpu`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defaults to
cpu
I see. I guess it is either str
or Union[Dict[str, any], str]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the change in d723ff7 :) (and used the Union[Dict[str, any], str]
type).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go down to the rust code here:
#[derive(Debug, Clone, PartialEq, Eq)]
enum Device {
Cpu,
Cuda(usize),
Mps,
Npu(usize),
Xpu(usize),
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it Union[Dict[str, any], str]
rather than str
?
Doesn't the enum mean that value is str rather than dict?
enum Device {
Cpu,
Cuda(usize),
Mps,
Npu(usize),
Xpu(usize),
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes sorry so I went back to the source code and it looks like the expected type should be Union[str, int]
. Here is the code:
- expecting string
- expecting int (which has been discussed/added in Some PyO3 specifics. #35 (comment)) => defaults to "Cuda:<int_value>".
I have updated the types accordingly in 3bad1e2. Let's wait for @Narsil's return just to be sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
safetensors.torch.load_file
has a "device" parameter to load the tensors directly to the correct device. This PR adds support for this parameter insafetensors.torch.load_model
too.(also fix
device
type -see #449 (comment))