diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 087dfde6..a7e19ab5 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -578,8 +578,6 @@ impl Open { .getattr(intern!(py, "view"))? .call((), Some(&view_kwargs))?; - println!("Byte order {byteorder}"); - if byteorder == "big" { let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?; @@ -587,7 +585,6 @@ impl Open { Version::from_string(&version).map_err(SafetensorError::new_err)?; if version >= Version::new(2, 1, 0) { let dtype: PyObject = get_pydtype(torch, info.dtype, false)?; - println!("Using torch byteswap"); tensor .getattr(intern!(py, "untyped_storage"))? .call0()? @@ -600,7 +597,6 @@ impl Open { } else { let inplace_kwargs = [(intern!(py, "inplace"), false.into_py(py))] .into_py_dict_bound(py); - println!("Using numpy byteswap"); let numpy = tensor .getattr(intern!(py, "numpy"))? .call0()? @@ -931,7 +927,9 @@ impl PySafeSlice { .call((storage_slice,), Some(&kwargs))? .getattr(intern!(py, "view"))? .call((), Some(&view_kwargs))?; + println!("Byte order {byteorder}"); if byteorder == "big" { + println!("Using torch byteswap"); let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?; let version = Version::from_string(&version).map_err(SafetensorError::new_err)?; @@ -947,6 +945,7 @@ impl PySafeSlice { "PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.", )); } else { + println!("Using numpy byteswap"); let inplace_kwargs = [(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py); let numpy = tensor diff --git a/bindings/python/tests/test_simple.py b/bindings/python/tests/test_simple.py index e8543c72..38ca7236 100644 --- a/bindings/python/tests/test_simple.py +++ b/bindings/python/tests/test_simple.py @@ -248,7 +248,7 @@ def test_torch_slice(self): tensor = slice_[:2] self.assertEqual(list(tensor.shape), [2, 5]) - torch.testing.assert_close(tensor, A[:2]) + torch.testing.assert_close(tensor, A[:2], f"{tensor} != {A[:2]}") tensor = slice_[:, :2] self.assertEqual(list(tensor.shape), [10, 2])