Skip to content

Commit

Permalink
Removing old doc.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 17, 2024
1 parent 9c81742 commit 78caafd
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
3 changes: 0 additions & 3 deletions bindings/python/py_src/safetensors/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
Args:
filename (`str`, or `os.PathLike`)):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular flax device locations
Returns:
`Dict[str, Array]`: dictionary that contains name as key, value as `Array`
Expand Down
3 changes: 0 additions & 3 deletions bindings/python/py_src/safetensors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]:
Args:
filename (`str`, or `os.PathLike`)):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular numpy device locations
Returns:
`Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray`
Expand Down
3 changes: 0 additions & 3 deletions bindings/python/py_src/safetensors/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]:
Args:
filename (`str`, or `os.PathLike`)):
The name of the file which contains the tensors
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
The device where the tensors need to be located after load.
available options are all regular tensorflow device locations
Returns:
`Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor`
Expand Down
11 changes: 11 additions & 0 deletions bindings/python/tests/test_flax_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def test_zero_sized(self):
# instead
self.assertEqual(data["test"].shape, reloaded["test"].shape)

def test_gpu(self):
data = {
"test": jnp.zeros((2, 0), dtype=jnp.float32),
}
local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
save_file(data.copy(), local)
reloaded = load_file(local, device="cuda")
# Empty tensor != empty tensor on numpy, so comparing shapes
# instead
self.assertEqual(data["test"].shape, reloaded["test"].shape)

def test_deserialization_safe(self):
weights = load_file(self.sf_filename)

Expand Down

0 comments on commit 78caafd

Please sign in to comment.