Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support to complex64 dtype #387

Closed
wants to merge 3 commits into from
Closed

Conversation

Mon-ius
Copy link

@Mon-ius Mon-ius commented Nov 18, 2023

Hi Huggingface guys

Complex numbers are an important part of real-world applications, especially in the field of medical imaging, such as MRI k-space data. However, I discovered that it is not possible to store complex data from pytorch when collaborate with SafeTensors even pytorch itself already support it since in v1.9.0.

Since v1.6 (28 July 2020), pytorch now supports complex vectors and complex gradients as BETA

Therefore, I attempted to submit a PR to support this feature 🤗

The sample code is as follows:

##Pytorch example

import torch
from safetensors.torch import save_file

x = torch.randn(2,2, dtype=torch.complex64)

tensors = {
    "torch_complex64": x
}
save_file(tensors, "model.safetensors")

@Mon-ius
Copy link
Author

Mon-ius commented Nov 18, 2023

And a simple unit test for reference 🤗

import torch, unittest
from safetensors.torch import save_file, load_file

class TestCase(unittest.TestCase):
    def assertTensorEqual(self, tensors1, tensors2, equality_fn):
        self.assertEqual(tensors1.keys(), tensors2.keys(), "tensor keys don't match")

        for k, v1 in tensors1.items():
            v2 = tensors2[k]

            self.assertTrue(equality_fn(v1, v2), f"{k} tensors are different")

    def test_torch_example(self):
        tensors = {
            "a": torch.zeros((2, 2)),
            "b": torch.zeros((2, 3), dtype=torch.uint8),
            "c": torch.randn(2,2, dtype=torch.complex64),
        }

        tensors2 = tensors.copy()

        save_file(tensors, "./out.safetensors")

        loaded = load_file("./out.safetensors")
        self.assertTensorEqual(tensors2, loaded, torch.allclose)

hug = TestCase()
hug.test_torch_example()

@Mon-ius
Copy link
Author

Mon-ius commented Nov 19, 2023

Could u plz trigger the CI or have any question regards this PR? 🤗

@Mon-ius
Copy link
Author

Mon-ius commented Nov 20, 2023

Also, I tested the GPU load mode, which shows everything works as expected 🤗

from safetensors.torch import save_file, load_file, safe_open
import torch
import os

os.environ["SAFETENSORS_FAST_GPU"] = "1"
x = torch.randn(2, 2, dtype=torch.complex64)

tensor = {"torch_complex64": x}
save_file(tensor, "model.safetensors")

loaded = load_file("model.safetensors")

tensors = {}
with safe_open("model.safetensors", framework="pt", device="cuda") as f:
    for key in f.keys():
        tensors[key] = f.get_tensor(key)

print(loaded)
print(tensors)
❯ python load.py                                                                                                                         (conda)
{'torch_complex64': tensor([[-0.6954-0.4037j, -0.3998+0.1616j],
        [-0.0983+0.5514j,  0.2918-0.1942j]])}
{'torch_complex64': tensor([[-0.6954-0.4037j, -0.3998+0.1616j],
        [-0.0983+0.5514j,  0.2918-0.1942j]], device='cuda:0')}

@Narsil
Copy link
Collaborator

Narsil commented Nov 20, 2023

Thanks a lot for this contribution.

Related issues #77 and #256

In general, I think safetensors need a general rule about which dtype should make it into the format.
I'll try to explain here: #389

I'll try to gather a bit of feedback on it before merging.

@Mon-ius
Copy link
Author

Mon-ius commented Nov 21, 2023

@Narsil Thanks for your feedback.

According to the CI result,

For S390X, there is a network error caused failure, not related to the functionality issues.

For the format issue, I submit a new commit to fix, it may need you to re-trigger the CI, thx.

According to #77 (comment), Pytorch officially support the all complex datatype, not HW specific, for example, complex64 as torch.complex64/torch.cfloat, details in PyTorch Source

.. _tensor-doc:

torch.Tensor
===================================

A :class:`torch.Tensor` is a multi-dimensional matrix containing elements of
a single data type.


Data types
----------

Torch defines 10 tensor types with CPU and GPU variants which are as follows:

======================================= =========================================== ============================= ================================
Data type                               dtype                                       CPU tensor                    GPU tensor
======================================= =========================================== ============================= ================================
32-bit floating point                   ``torch.float32`` or ``torch.float``        :class:`torch.FloatTensor`    :class:`torch.cuda.FloatTensor`
64-bit floating point                   ``torch.float64`` or ``torch.double``       :class:`torch.DoubleTensor`   :class:`torch.cuda.DoubleTensor`
16-bit floating point [1]_              ``torch.float16`` or ``torch.half``         :class:`torch.HalfTensor`     :class:`torch.cuda.HalfTensor`
16-bit floating point [2]_              ``torch.bfloat16``                          :class:`torch.BFloat16Tensor` :class:`torch.cuda.BFloat16Tensor`
32-bit complex                          ``torch.complex32`` or ``torch.chalf``
64-bit complex                          ``torch.complex64`` or ``torch.cfloat``
128-bit complex                         ``torch.complex128`` or ``torch.cdouble``
8-bit integer (unsigned)                ``torch.uint8``                             :class:`torch.ByteTensor`     :class:`torch.cuda.ByteTensor`
8-bit integer (signed)                  ``torch.int8``                              :class:`torch.CharTensor`     :class:`torch.cuda.CharTensor`
16-bit integer (signed)                 ``torch.int16`` or ``torch.short``          :class:`torch.ShortTensor`    :class:`torch.cuda.ShortTensor`
32-bit integer (signed)                 ``torch.int32`` or ``torch.int``            :class:`torch.IntTensor`      :class:`torch.cuda.IntTensor`
64-bit integer (signed)                 ``torch.int64`` or ``torch.long``           :class:`torch.LongTensor`     :class:`torch.cuda.LongTensor`
Boolean                                 ``torch.bool``                              :class:`torch.BoolTensor`     :class:`torch.cuda.BoolTensor`
quantized 8-bit integer (unsigned)      ``torch.quint8``                            :class:`torch.ByteTensor`     /
quantized 8-bit integer (signed)        ``torch.qint8``                             :class:`torch.CharTensor`     /
quantized 32-bit integer (signed)       ``torch.qint32``                            :class:`torch.IntTensor`      /
quantized 4-bit integer (unsigned) [3]_ ``torch.quint4x2``                          :class:`torch.ByteTensor`     /
======================================= =========================================== ============================= ================================

.. [1]
  Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10
  significand bits. Useful when precision is important at the expense of range.
.. [2]
  Sometimes referred to as Brain Floating Point: uses 1 sign, 8 exponent, and 7
  significand bits. Useful when range is important, since it has the same
  number of exponent bits as ``float32``
.. [3]
  quantized 4-bit integer is stored as a 8-bit signed integer. Currently it's only supported in EmbeddingBag operator.

@github-actions github-actions bot added the Stale label Dec 21, 2023
@github-actions github-actions bot closed this Dec 27, 2023
@Mon-ius
Copy link
Author

Mon-ius commented Dec 27, 2023

@Narsil Why this bot close this PR, can you explain?

@Narsil
Copy link
Collaborator

Narsil commented Jan 8, 2024

Because it's stale.

I don't think the situation has evolved really. Do you have any models that use complex64 and have significant usage, that would definitely help the usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants