diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index c63c872c..5ff77820 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -579,31 +579,32 @@ impl Open { .call((), Some(&view_kwargs))?; if byteorder == "big" { - let version: String = - torch.getattr(intern!(py, "__version__"))?.extract()?; - let version = - Version::from_string(&version).map_err(SafetensorError::new_err)?; - if version >= Version::new(2, 1, 0) { + let inplace_kwargs = + [(intern!(py, "inplace"), true.into_py(py))].into_py_dict_bound(py); + + if info.dtype == Dtype::BF16 { + // Reinterpret to f16 for numpy compatibility. + let dtype: PyObject = get_pydtype(torch, Dtype::F16, false)?; + let view_kwargs = + [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py); + tensor = tensor + .getattr(intern!(py, "view"))? + .call((), Some(&view_kwargs))?; + } + let numpy = tensor + .getattr(intern!(py, "numpy"))? + .call0()? + .getattr("byteswap")? + .call((), Some(&inplace_kwargs))?; + tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?; + if info.dtype == Dtype::BF16 { + // Reinterpret to f16 for numpy compatibility. let dtype: PyObject = get_pydtype(torch, info.dtype, false)?; - tensor = tensor.getattr(intern!(py, "clone"))?.call0()?; - tensor - .getattr(intern!(py, "untyped_storage"))? - .call0()? - .getattr(intern!(py, "byteswap"))? - .call1((dtype,))?; - } else if info.dtype == Dtype::BF16 { - return Err(SafetensorError::new_err( - "PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.", - )); - } else { - let inplace_kwargs = [(intern!(py, "inplace"), false.into_py(py))] - .into_py_dict_bound(py); - let numpy = tensor - .getattr(intern!(py, "numpy"))? - .call0()? - .getattr("byteswap")? - .call((), Some(&inplace_kwargs))?; - tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?; + let view_kwargs = + [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py); + tensor = tensor + .getattr(intern!(py, "view"))? + .call((), Some(&view_kwargs))?; } } @@ -929,32 +930,30 @@ impl PySafeSlice { .getattr(intern!(py, "view"))? .call((), Some(&view_kwargs))?; if byteorder == "big" { - let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?; - let version = - Version::from_string(&version).map_err(SafetensorError::new_err)?; - if version >= Version::new(2, 1, 0) { + let inplace_kwargs = + [(intern!(py, "inplace"), true.into_py(py))].into_py_dict_bound(py); + + if self.info.dtype == Dtype::BF16 { + // Reinterpret to f16 for numpy compatibility. + let dtype: PyObject = get_pydtype(torch, Dtype::F16, false)?; + let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py); + tensor = tensor + .getattr(intern!(py, "view"))? + .call((), Some(&view_kwargs))?; + } + let numpy = tensor + .getattr(intern!(py, "numpy"))? + .call0()? + .getattr("byteswap")? + .call((), Some(&inplace_kwargs))?; + tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?; + if self.info.dtype == Dtype::BF16 { + // Reinterpret to f16 for numpy compatibility. let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?; - // Clone is required otherwise storage is shared with previous slices, - // making n amount of byteswaps. - tensor = tensor.getattr(intern!(py, "clone"))?.call0()?; - tensor - .getattr(intern!(py, "untyped_storage"))? - .call0()? - .getattr(intern!(py, "byteswap"))? - .call1((dtype,))?; - } else if self.info.dtype == Dtype::BF16 { - return Err(SafetensorError::new_err( - "PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.", - )); - } else { - let inplace_kwargs = - [(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py); - let numpy = tensor - .getattr(intern!(py, "numpy"))? - .call0()? - .getattr("byteswap")? - .call((), Some(&inplace_kwargs))?; - tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?; + let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict_bound(py); + tensor = tensor + .getattr(intern!(py, "view"))? + .call((), Some(&view_kwargs))?; } } tensor = tensor @@ -1030,7 +1029,17 @@ fn create_tensor<'a>( (intern!(py, "dtype"), dtype), ] .into_py_dict_bound(py); - module.call_method("frombuffer", (), Some(&kwargs))? + let tensor = module.call_method("frombuffer", (), Some(&kwargs))?; + let sys = PyModule::import_bound(py, intern!(py, "sys"))?; + let byteorder: String = sys.getattr(intern!(py, "byteorder"))?.extract()?; + if byteorder == "big" { + let inplace_kwargs = + [(intern!(py, "inplace"), true.into_py(py))].into_py_dict_bound(py); + tensor + .getattr("byteswap")? + .call((), Some(&inplace_kwargs))?; + } + tensor }; let mut tensor: PyBound<'_, PyAny> = tensor.call_method1("reshape", (shape,))?; let tensor = match framework {