-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for wrapper tensor subclass for safetensor #516
Conversation
Summary: Similar to huggingface/huggingface_hub#2440 we want to allow safetensor to handle wrapper tensor subclasses, we mainly added: 1. tensor storage size: this is done through flattening the wrapper tensor subclass and add up the storage size of all sub tensors recursively 2. storage_ptr: this is done by constructing a tuple given the "storage_ptr" for flattened tensors, this could be a nested tuple of tuple of int, e.g. ((1, 2), 3, (4, (5, 6),),) Test Plan: Added a test in test_pt_model.py, will also test manually Reviewers: Subscribers: Tasks: Tags:
Re-iterating on what I've mentionned in private. Private functions/modules will never get used here. They are private in torch, we cannot rely on them (that's why they are marked private). Reading more about what those wrapper classes it feels like something users should handle. safetensors doesn't know about anything else than dense, native dtypes tensors. Anything more complex is currently handled in user space. This seems like the same here. |
@Narsil I updated the PR to do try/catch so it won't affect old code. if it's a strong requirement that no private function can be used here, then I think we'll wait for the larger refactor for the sharding logic and will have to use non-safe serialization for now. |
Yes it's a strong requirement. private interface can change and break arbitrarily, some breaking doesn't mean an exception necessarily. |
I think we can remove the private function by defining the function in safetensor directly, it's a simple function
we can, but it will complicate the UX with safetensor a lot I think, since we also need to recover the wrapper tensor subclass from flattened tensor, also I'm not sure if it's possible to do: right now we have a util to flatten wrapper tensor subclasses to plain tensors:
but if we supported safe tensor serialization directly, the UX will be:
it will be a much simpler UX I think. wrapper tensor subclass is really just a generalized torch.Tensor, and it's also the official extension point of pytorch, so I feel supporting this directly would make sense @Narsil please let me know what you think about this. |
This is perfectly reasonable to own part of the complexity (We already own logic for determining sharing of tensors in torch).
There is here a fundamental conflict. In order to know that some tensor was wrapped, we would need to store that information in the file, where there's no fundamental way to communicate that at the moment right ? My understanding is that the wrappers can always be casted to dense tensors, which make them eligible to the same solution. We also have
Sparse and ragged tensors are also official members of torch have low usage and where not implemented in safetensors either and so far not a lot of issues have been raised. |
@Narsil thanks for the detailed reply! I think I understand better now. I only thought about the get around with the sharding logic, but haven't thought about actually serializing the wrapper tensor subclasses in the safetensor format before. Thinking about actually storing a wrapper tensor subclass in safetensor format, it seems like we have to define something like https://github.com/neuralmagic/compressed-tensors as you mentioned. This might be a bit too complicated without clear benefit I think, since we will potentially support a lot of variations of quantized tensor in the future. In this case it seems most reasonable thing will just be not supporting serialization for wrapper tensor subclass in safetensor format at this point? |
Summary:
Similar to huggingface/huggingface_hub#2440 we want to allow safetensor to handle wrapper tensor subclasses, we mainly added:
Test Plan:
Added a test in test_pt_model.py, will also test manually
Reviewers:
Subscribers:
Tasks:
Tags: