Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for wrapper tensor subclass for safetensor #516

Closed
wants to merge 2 commits into from

Conversation

jerryzh168
Copy link

Summary:
Similar to huggingface/huggingface_hub#2440 we want to allow safetensor to handle wrapper tensor subclasses, we mainly added:

  1. tensor storage size: this is done through flattening the wrapper tensor subclass and add up the storage size of all sub tensors recursively
  2. storage_ptr: this is done by constructing a tuple given the "storage_ptr" for flattened tensors, this could be a nested tuple of tuple of int, e.g. ((1, 2), 3, (4, (5, 6),),)

Test Plan:
Added a test in test_pt_model.py, will also test manually

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
Similar to huggingface/huggingface_hub#2440 we want to allow
safetensor to handle wrapper tensor subclasses, we mainly added:

1. tensor storage size: this is done through flattening the wrapper tensor subclass and add up the storage
 size of all sub tensors recursively
2. storage_ptr: this is done by constructing a tuple given the "storage_ptr" for flattened tensors, this could
 be a nested tuple of tuple of int, e.g. ((1, 2), 3, (4, (5, 6),),)

Test Plan:
Added a test in test_pt_model.py, will also test manually

Reviewers:

Subscribers:

Tasks:

Tags:
@Narsil
Copy link
Collaborator

Narsil commented Sep 3, 2024

Re-iterating on what I've mentionned in private.

Private functions/modules will never get used here. They are private in torch, we cannot rely on them (that's why they are marked private).

Reading more about what those wrapper classes it feels like something users should handle. safetensors doesn't know about anything else than dense, native dtypes tensors. Anything more complex is currently handled in user space.

This seems like the same here.

@jerryzh168
Copy link
Author

@Narsil I updated the PR to do try/catch so it won't affect old code. if it's a strong requirement that no private function can be used here, then I think we'll wait for the larger refactor for the sharding logic and will have to use non-safe serialization for now.

@Narsil
Copy link
Collaborator

Narsil commented Sep 4, 2024

it's a strong requirement that no private function can be used here

Yes it's a strong requirement. private interface can change and break arbitrarily, some breaking doesn't mean an exception necessarily.
If all tensors can be flattened, you could simply save those flattened tensors, no ?

@jerryzh168
Copy link
Author

jerryzh168 commented Sep 4, 2024

it's a strong requirement that no private function can be used here

Yes it's a strong requirement. private interface can change and break arbitrarily, some breaking doesn't mean an exception necessarily. If all tensors can be flattened, you could simply save those flattened tensors, no ?

I think we can remove the private function by defining the function in safetensor directly, it's a simple function

If all tensors can be flattened, you could simply save those flattened tensors, no ?

we can, but it will complicate the UX with safetensor a lot I think, since we also need to recover the wrapper tensor subclass from flattened tensor, also I'm not sure if it's possible to do:

right now we have a util to flatten wrapper tensor subclasses to plain tensors: unwrap_tensor_subclass

# save
float_model = ...
from torchao.quantization import quantize_
quantize_(float_model, some_config)
# float_model now because a model with wrapper tensor subclass as weights

from torchao.utils import unwrap_tensor_subclass
unwrap_tensor_subclass(float_model)
float_model.save_pretrained(safe_serialization=True)


# load
# need to quantize the float model with the same config
float_model = ...
# here the config has to match the config of the quantized weights
quantize(float_model, same_config)

# also have to unwrap it because the weights are flattened
from torchao.utils import unwrap_tensor_subclass
unwrap_tensor_subclass(float_model)

float_model.load_pretrained(...)

# also we have to be working with flattened weights now, this may not always make sense

but if we supported safe tensor serialization directly, the UX will be:

# save
from torchao.quantization import quantize_
quantize_(model_with_wrapper_tensor_subclass, some_config)
model_with_wrapper_tensor_subclass.save_pretrained(safe_serialization=True)


# load
float_model = ...
float_model.load_pretrained(...)

it will be a much simpler UX I think.

wrapper tensor subclass is really just a generalized torch.Tensor, and it's also the official extension point of pytorch, so I feel supporting this directly would make sense

@Narsil please let me know what you think about this.

@Narsil
Copy link
Collaborator

Narsil commented Sep 5, 2024

I think we can remove the private function by defining the function in safetensor directly, it's a simple function

This is perfectly reasonable to own part of the complexity (We already own logic for determining sharing of tensors in torch).

since we also need to recover the wrapper tensor subclass from flattened tensor, also I'm not sure if it's possible to do:

There is here a fundamental conflict. In order to know that some tensor was wrapped, we would need to store that information in the file, where there's no fundamental way to communicate that at the moment right ?
safetensors not being torch-only needs to take the burden on non-torch land readers into account when implementing things.
Currently mostly hardware defined types have made it. Quantized types seem to be ever evolving and there's no clear way to store/read them and they don't seem to share any invariant among them (q2k, bnb, awq, col grouped, row grouped, non linear groupings, exl2 etc..). Most quantized safetensors formats I'm aware of store their own information elsewhere than the safetensors file, and simply reinterpret some raw byte tensor in a meaningful manner later.

My understanding is that the wrappers can always be casted to dense tensors, which make them eligible to the same solution.

We also have save_model/load_model functions in here (and derivatives in transformers) where the place where the extra information is stored directly in the model so the model itself can reinterpret that flattened tensor as it sees fit. Maybe the complexity ownership needs to be done directly in from_pretrained/save_pretrained (transformers is not using save_model/load_model because it owns more complexity than those functions handle stuff like the meta device) ?

wrapper tensor subclass is really just a generalized torch.Tensor, and it's also the official extension point of pytorch, so I feel supporting this directly would make sense

Sparse and ragged tensors are also official members of torch have low usage and where not implemented in safetensors either and so far not a lot of issues have been raised.

@jerryzh168
Copy link
Author

@Narsil thanks for the detailed reply! I think I understand better now. I only thought about the get around with the sharding logic, but haven't thought about actually serializing the wrapper tensor subclasses in the safetensor format before. Thinking about actually storing a wrapper tensor subclass in safetensor format, it seems like we have to define something like https://github.com/neuralmagic/compressed-tensors as you mentioned. This might be a bit too complicated without clear benefit I think, since we will potentially support a lot of variations of quantized tensor in the future. In this case it seems most reasonable thing will just be not supporting serialization for wrapper tensor subclass in safetensor format at this point?

@jerryzh168 jerryzh168 closed this Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants