Skip to content

Commit

Permalink
add support to complex64 dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Mon-ius committed Nov 18, 2023
1 parent 829bfa8 commit 0d940f8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,11 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
torch.int8: 1,
torch.bool: 1,
torch.float64: 8,
torch.complex64: 8,
}

_TYPES = {
"C64": torch.complex64,
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
Expand Down Expand Up @@ -432,6 +434,7 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
torch.int8: np.int8,
torch.bool: bool,
torch.float64: np.float64,
torch.complex64: np.complex64,
}
npdtype = NPDTYPES[tensor.dtype]
# Not in place as that would potentially modify a live running model
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ fn prepare(tensor_dict: HashMap<String, &PyDict>) -> PyResult<HashMap<String, Te
"float16" => Some(Dtype::F16),
"float32" => Some(Dtype::F32),
"float64" => Some(Dtype::F64),
"complex64" => Some(Dtype::C64),
"bfloat16" => Some(Dtype::BF16),
dtype_str => {
return Err(SafetensorError::new_err(format!(
Expand Down Expand Up @@ -960,6 +961,7 @@ fn create_tensor(
fn get_pydtype(module: &PyModule, dtype: Dtype, is_numpy: bool) -> PyResult<PyObject> {
Python::with_gil(|py| {
let dtype: PyObject = match dtype {
Dtype::C64 => module.getattr(intern!(py, "complex64"))?.into(),
Dtype::F64 => module.getattr(intern!(py, "float64"))?.into(),
Dtype::F32 => module.getattr(intern!(py, "float32"))?.into(),
Dtype::BF16 => {
Expand Down
4 changes: 4 additions & 0 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,8 @@ pub enum Dtype {
F32,
/// Floating point (64-bit)
F64,
/// Complex Floating point (64-bit)
C64,
/// Signed integer (64-bit)
I64,
/// Unsigned integer (64-bit)
Expand All @@ -688,6 +690,7 @@ impl Dtype {
Dtype::BF16 => 2,
Dtype::F32 => 4,
Dtype::F64 => 8,
Dtype::C64 => 8,
}
}
}
Expand Down Expand Up @@ -717,6 +720,7 @@ mod tests {
Just(Dtype::BF16),
Just(Dtype::F32),
Just(Dtype::F64),
Just(Dtype::C64),
]
}

Expand Down

0 comments on commit 0d940f8

Please sign in to comment.