-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support to complex64 dtype #387
Conversation
And a simple unit test for reference 🤗 import torch, unittest
from safetensors.torch import save_file, load_file
class TestCase(unittest.TestCase):
def assertTensorEqual(self, tensors1, tensors2, equality_fn):
self.assertEqual(tensors1.keys(), tensors2.keys(), "tensor keys don't match")
for k, v1 in tensors1.items():
v2 = tensors2[k]
self.assertTrue(equality_fn(v1, v2), f"{k} tensors are different")
def test_torch_example(self):
tensors = {
"a": torch.zeros((2, 2)),
"b": torch.zeros((2, 3), dtype=torch.uint8),
"c": torch.randn(2,2, dtype=torch.complex64),
}
tensors2 = tensors.copy()
save_file(tensors, "./out.safetensors")
loaded = load_file("./out.safetensors")
self.assertTensorEqual(tensors2, loaded, torch.allclose)
hug = TestCase()
hug.test_torch_example() |
Could u plz trigger the CI or have any question regards this PR? 🤗 |
Also, I tested the GPU load mode, which shows everything works as expected 🤗 from safetensors.torch import save_file, load_file, safe_open
import torch
import os
os.environ["SAFETENSORS_FAST_GPU"] = "1"
x = torch.randn(2, 2, dtype=torch.complex64)
tensor = {"torch_complex64": x}
save_file(tensor, "model.safetensors")
loaded = load_file("model.safetensors")
tensors = {}
with safe_open("model.safetensors", framework="pt", device="cuda") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
print(loaded)
print(tensors) ❯ python load.py (conda)
{'torch_complex64': tensor([[-0.6954-0.4037j, -0.3998+0.1616j],
[-0.0983+0.5514j, 0.2918-0.1942j]])}
{'torch_complex64': tensor([[-0.6954-0.4037j, -0.3998+0.1616j],
[-0.0983+0.5514j, 0.2918-0.1942j]], device='cuda:0')} |
@Narsil Thanks for your feedback. According to the CI result, For S390X, there is a network error caused failure, not related to the functionality issues. For the format issue, I submit a new commit to fix, it may need you to re-trigger the CI, thx. According to #77 (comment), .. _tensor-doc:
torch.Tensor
===================================
A :class:`torch.Tensor` is a multi-dimensional matrix containing elements of
a single data type.
Data types
----------
Torch defines 10 tensor types with CPU and GPU variants which are as follows:
======================================= =========================================== ============================= ================================
Data type dtype CPU tensor GPU tensor
======================================= =========================================== ============================= ================================
32-bit floating point ``torch.float32`` or ``torch.float`` :class:`torch.FloatTensor` :class:`torch.cuda.FloatTensor`
64-bit floating point ``torch.float64`` or ``torch.double`` :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor`
16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` :class:`torch.HalfTensor` :class:`torch.cuda.HalfTensor`
16-bit floating point [2]_ ``torch.bfloat16`` :class:`torch.BFloat16Tensor` :class:`torch.cuda.BFloat16Tensor`
32-bit complex ``torch.complex32`` or ``torch.chalf``
64-bit complex ``torch.complex64`` or ``torch.cfloat``
128-bit complex ``torch.complex128`` or ``torch.cdouble``
8-bit integer (unsigned) ``torch.uint8`` :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor`
8-bit integer (signed) ``torch.int8`` :class:`torch.CharTensor` :class:`torch.cuda.CharTensor`
16-bit integer (signed) ``torch.int16`` or ``torch.short`` :class:`torch.ShortTensor` :class:`torch.cuda.ShortTensor`
32-bit integer (signed) ``torch.int32`` or ``torch.int`` :class:`torch.IntTensor` :class:`torch.cuda.IntTensor`
64-bit integer (signed) ``torch.int64`` or ``torch.long`` :class:`torch.LongTensor` :class:`torch.cuda.LongTensor`
Boolean ``torch.bool`` :class:`torch.BoolTensor` :class:`torch.cuda.BoolTensor`
quantized 8-bit integer (unsigned) ``torch.quint8`` :class:`torch.ByteTensor` /
quantized 8-bit integer (signed) ``torch.qint8`` :class:`torch.CharTensor` /
quantized 32-bit integer (signed) ``torch.qint32`` :class:`torch.IntTensor` /
quantized 4-bit integer (unsigned) [3]_ ``torch.quint4x2`` :class:`torch.ByteTensor` /
======================================= =========================================== ============================= ================================
.. [1]
Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10
significand bits. Useful when precision is important at the expense of range.
.. [2]
Sometimes referred to as Brain Floating Point: uses 1 sign, 8 exponent, and 7
significand bits. Useful when range is important, since it has the same
number of exponent bits as ``float32``
.. [3]
quantized 4-bit integer is stored as a 8-bit signed integer. Currently it's only supported in EmbeddingBag operator. |
@Narsil Why this bot close this PR, can you explain? |
Because it's stale. I don't think the situation has evolved really. Do you have any models that use complex64 and have significant usage, that would definitely help the usage. |
Hi Huggingface guys
Complex numbers are an important part of real-world applications, especially in the field of medical imaging, such as MRI k-space data. However, I discovered that it is not possible to store complex data from pytorch when collaborate with
SafeTensors
even pytorch itself already support it since inv1.9.0
.Therefore, I attempted to submit a PR to support this feature 🤗
The sample code is as follows: