diff --git a/bindings/python/py_src/safetensors/flax.py b/bindings/python/py_src/safetensors/flax.py index 208264ab..aa906273 100644 --- a/bindings/python/py_src/safetensors/flax.py +++ b/bindings/python/py_src/safetensors/flax.py @@ -106,9 +106,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]: Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): - The device where the tensors need to be located after load. - available options are all regular flax device locations Returns: `Dict[str, Array]`: dictionary that contains name as key, value as `Array` diff --git a/bindings/python/py_src/safetensors/numpy.py b/bindings/python/py_src/safetensors/numpy.py index 71814afc..a852ff35 100644 --- a/bindings/python/py_src/safetensors/numpy.py +++ b/bindings/python/py_src/safetensors/numpy.py @@ -111,9 +111,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]: Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): - The device where the tensors need to be located after load. - available options are all regular numpy device locations Returns: `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` diff --git a/bindings/python/py_src/safetensors/tensorflow.py b/bindings/python/py_src/safetensors/tensorflow.py index 449d5157..65b8aeb4 100644 --- a/bindings/python/py_src/safetensors/tensorflow.py +++ b/bindings/python/py_src/safetensors/tensorflow.py @@ -105,9 +105,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]: Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): - The device where the tensors need to be located after load. - available options are all regular tensorflow device locations Returns: `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` diff --git a/bindings/python/tests/test_flax_comparison.py b/bindings/python/tests/test_flax_comparison.py index 23d6c4bb..d537f391 100644 --- a/bindings/python/tests/test_flax_comparison.py +++ b/bindings/python/tests/test_flax_comparison.py @@ -41,6 +41,17 @@ def test_zero_sized(self): # instead self.assertEqual(data["test"].shape, reloaded["test"].shape) + def test_gpu(self): + data = { + "test": jnp.zeros((2, 0), dtype=jnp.float32), + } + local = "./tests/data/out_safe_flat_mmap_small2.safetensors" + save_file(data.copy(), local) + reloaded = load_file(local, device="cuda") + # Empty tensor != empty tensor on numpy, so comparing shapes + # instead + self.assertEqual(data["test"].shape, reloaded["test"].shape) + def test_deserialization_safe(self): weights = load_file(self.sf_filename)