diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index ba0a1a7a..cb418f28 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -58,6 +58,12 @@ jobs: pip install .[jax] shell: bash + - name: Install (mlx) + if: matrix.os == 'macos-latest' && matrix.version.python == "3.10" + run: | + pip install .[mlx] + shell: bash + - name: Check style run: | pip install .[quality] diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 5274ae9c..78220ae2 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -9,7 +9,7 @@ name = "safetensors_rust" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.19.2", features = ["extension-module"] } +pyo3 = { version = "0.20.2", features = ["extension-module"] } memmap2 = "0.5" serde_json = "1.0" diff --git a/bindings/python/py_src/safetensors/flax.py b/bindings/python/py_src/safetensors/flax.py index 208264ab..aa906273 100644 --- a/bindings/python/py_src/safetensors/flax.py +++ b/bindings/python/py_src/safetensors/flax.py @@ -106,9 +106,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]: Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): - The device where the tensors need to be located after load. - available options are all regular flax device locations Returns: `Dict[str, Array]`: dictionary that contains name as key, value as `Array` diff --git a/bindings/python/py_src/safetensors/numpy.py b/bindings/python/py_src/safetensors/numpy.py index 71814afc..a852ff35 100644 --- a/bindings/python/py_src/safetensors/numpy.py +++ b/bindings/python/py_src/safetensors/numpy.py @@ -111,9 +111,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]: Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): - The device where the tensors need to be located after load. - available options are all regular numpy device locations Returns: `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` diff --git a/bindings/python/py_src/safetensors/tensorflow.py b/bindings/python/py_src/safetensors/tensorflow.py index 449d5157..65b8aeb4 100644 --- a/bindings/python/py_src/safetensors/tensorflow.py +++ b/bindings/python/py_src/safetensors/tensorflow.py @@ -105,9 +105,6 @@ def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]: Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): - The device where the tensors need to be located after load. - available options are all regular tensorflow device locations Returns: `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index e6156c38..306da132 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -50,6 +50,9 @@ jax = [ "jax>=0.3.25", "jaxlib>=0.3.25", ] +mlx = [ + "mlx>=0.0.9", +] paddlepaddle = [ "safetensors[numpy]", "paddlepaddle>=2.4.1",