Skip to content

Commit

Permalink
Different logs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 29, 2024
1 parent 3a84ffb commit 684cebd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
7 changes: 3 additions & 4 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,16 +578,13 @@ impl Open {
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;

println!("Byte order {byteorder}");

if byteorder == "big" {
let version: String =
torch.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(&version).map_err(SafetensorError::new_err)?;
if version >= Version::new(2, 1, 0) {
let dtype: PyObject = get_pydtype(torch, info.dtype, false)?;
println!("Using torch byteswap");
tensor
.getattr(intern!(py, "untyped_storage"))?
.call0()?
Expand All @@ -600,7 +597,6 @@ impl Open {
} else {
let inplace_kwargs = [(intern!(py, "inplace"), false.into_py(py))]
.into_py_dict_bound(py);
println!("Using numpy byteswap");
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
Expand Down Expand Up @@ -931,7 +927,9 @@ impl PySafeSlice {
.call((storage_slice,), Some(&kwargs))?
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
println!("Byte order {byteorder}");
if byteorder == "big" {
println!("Using torch byteswap");
let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(&version).map_err(SafetensorError::new_err)?;
Expand All @@ -947,6 +945,7 @@ impl PySafeSlice {
"PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.",
));
} else {
println!("Using numpy byteswap");
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);
let numpy = tensor
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_torch_slice(self):

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
torch.testing.assert_close(tensor, A[:2])
torch.testing.assert_close(tensor, A[:2], f"{tensor} != {A[:2]}")

tensor = slice_[:, :2]
self.assertEqual(list(tensor.shape), [10, 2])
Expand Down

0 comments on commit 684cebd

Please sign in to comment.