Skip to content

Commit

Permalink
Much better fix that doesn't depend on torch version.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 30, 2024
1 parent 10ebfca commit 4e5a9b4
Showing 1 changed file with 59 additions and 50 deletions.
109 changes: 59 additions & 50 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,31 +579,32 @@ impl Open {
.call((), Some(&view_kwargs))?;

if byteorder == "big" {
let version: String =
torch.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(&version).map_err(SafetensorError::new_err)?;
if version >= Version::new(2, 1, 0) {
let inplace_kwargs =
[(intern!(py, "inplace"), true.into_py(py))].into_py_dict_bound(py);

if info.dtype == Dtype::BF16 {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, Dtype::F16, false)?;
let view_kwargs =
[(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
if info.dtype == Dtype::BF16 {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, info.dtype, false)?;
tensor = tensor.getattr(intern!(py, "clone"))?.call0()?;
tensor
.getattr(intern!(py, "untyped_storage"))?
.call0()?
.getattr(intern!(py, "byteswap"))?
.call1((dtype,))?;
} else if info.dtype == Dtype::BF16 {
return Err(SafetensorError::new_err(
"PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.",
));
} else {
let inplace_kwargs = [(intern!(py, "inplace"), false.into_py(py))]
.into_py_dict_bound(py);
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
let view_kwargs =
[(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
}

Expand Down Expand Up @@ -929,32 +930,30 @@ impl PySafeSlice {
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
if byteorder == "big" {
let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(&version).map_err(SafetensorError::new_err)?;
if version >= Version::new(2, 1, 0) {
let inplace_kwargs =
[(intern!(py, "inplace"), true.into_py(py))].into_py_dict_bound(py);

if self.info.dtype == Dtype::BF16 {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, Dtype::F16, false)?;
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
if self.info.dtype == Dtype::BF16 {
// Reinterpret to f16 for numpy compatibility.
let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?;
// Clone is required otherwise storage is shared with previous slices,
// making n amount of byteswaps.
tensor = tensor.getattr(intern!(py, "clone"))?.call0()?;
tensor
.getattr(intern!(py, "untyped_storage"))?
.call0()?
.getattr(intern!(py, "byteswap"))?
.call1((dtype,))?;
} else if self.info.dtype == Dtype::BF16 {
return Err(SafetensorError::new_err(
"PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.",
));
} else {
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py);
tensor = tensor
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
}
}
tensor = tensor
Expand Down Expand Up @@ -1030,7 +1029,17 @@ fn create_tensor<'a>(
(intern!(py, "dtype"), dtype),
]
.into_py_dict_bound(py);
module.call_method("frombuffer", (), Some(&kwargs))?
let tensor = module.call_method("frombuffer", (), Some(&kwargs))?;
let sys = PyModule::import_bound(py, intern!(py, "sys"))?;
let byteorder: String = sys.getattr(intern!(py, "byteorder"))?.extract()?;
if byteorder == "big" {
let inplace_kwargs =
[(intern!(py, "inplace"), true.into_py(py))].into_py_dict_bound(py);
tensor
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
}
tensor
};
let mut tensor: PyBound<'_, PyAny> = tensor.call_method1("reshape", (shape,))?;
let tensor = match framework {
Expand Down

0 comments on commit 4e5a9b4

Please sign in to comment.