diff --git a/anndata/experimental/pytorch/_annloader.py b/anndata/experimental/pytorch/_annloader.py index 8cc883921..de95264b9 100644 --- a/anndata/experimental/pytorch/_annloader.py +++ b/anndata/experimental/pytorch/_annloader.py @@ -4,6 +4,7 @@ from functools import partial from math import ceil from typing import TYPE_CHECKING +from warnings import warn import numpy as np from scipy.sparse import issparse @@ -54,18 +55,17 @@ def __len__(self): return length -# maybe replace use_cuda with explicit device option -def default_converter(arr, use_cuda, pin_memory): +def default_converter(arr, device, pin_memory): if isinstance(arr, torch.Tensor): - if use_cuda: - arr = arr.cuda() + if device != "cpu": + arr = arr.to(device) elif pin_memory: arr = arr.pin_memory() elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number): if issparse(arr): arr = arr.toarray() - if use_cuda: - arr = torch.tensor(arr, device="cuda") + if device != "cpu": + arr = torch.tensor(arr, device=device) else: arr = torch.tensor(arr) arr = arr.pin_memory() if pin_memory else arr @@ -114,12 +114,15 @@ class AnnLoader(DataLoader): Set to `True` to have the data reshuffled at every epoch. use_default_converter Use the default converter to convert arrays to pytorch tensors, transfer to - the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`). + the specified device (if `device!=None`), do memory pinning (if `pin_memory=True`). If you pass an AnnCollection object with prespecified converters, the default converter won't overwrite these converters but will be applied on top of them. use_cuda Transfer pytorch tensors to the default cuda device after conversion. - Only works if `use_default_converter=True` + Only works if `use_default_converter=True`. DEPRECATED in favour of `device`. + device + The device to which to transfer pytorch tensors after conversion (example: "cuda"). + Only works if `use_default_converter=True`. **kwargs Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also arguments for `AnnCollection` initialization. @@ -132,8 +135,16 @@ def __init__( shuffle: bool = False, use_default_converter: bool = True, use_cuda: bool = False, + device: str | None = None, **kwargs, ): + if use_cuda: + warn( + "Argument use_cuda has been deprecated in favour of `device`. ", + FutureWarning, + ) + device = "cuda" + if isinstance(adatas, AnnData): adatas = [adatas] @@ -171,7 +182,7 @@ def __init__( if use_default_converter: pin_memory = kwargs.pop("pin_memory", False) _converter = partial( - default_converter, use_cuda=use_cuda, pin_memory=pin_memory + default_converter, device=device, pin_memory=pin_memory ) dataset.convert = _convert_on_top( dataset.convert, _converter, dict(dataset.attrs_keys, X=[])