Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for specifying a tensor device in AnnLoader #1240

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial
from math import ceil
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
from scipy.sparse import issparse
Expand Down Expand Up @@ -54,18 +55,17 @@ def __len__(self):
return length


# maybe replace use_cuda with explicit device option
def default_converter(arr, use_cuda, pin_memory):
def default_converter(arr, device, pin_memory):
if isinstance(arr, torch.Tensor):
if use_cuda:
arr = arr.cuda()
if device is not None:
arr = arr.to(device)
elif pin_memory:
arr = arr.pin_memory()
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
if issparse(arr):
arr = arr.toarray()
if use_cuda:
arr = torch.tensor(arr, device="cuda")
if device is not None:
arr = torch.tensor(arr, device=device)
else:
arr = torch.tensor(arr)
arr = arr.pin_memory() if pin_memory else arr
Expand Down Expand Up @@ -114,12 +114,15 @@ class AnnLoader(DataLoader):
Set to `True` to have the data reshuffled at every epoch.
use_default_converter
Use the default converter to convert arrays to pytorch tensors, transfer to
the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`).
the specified device (if `device!=None`), do memory pinning (if `pin_memory=True`).
If you pass an AnnCollection object with prespecified converters, the default converter
won't overwrite these converters but will be applied on top of them.
use_cuda
Transfer pytorch tensors to the default cuda device after conversion.
Only works if `use_default_converter=True`
Only works if `use_default_converter=True`. DEPRECATED in favour of `device`.
device
The device to which to transfer pytorch tensors after conversion (example: "cuda").
Only works if `use_default_converter=True`.
**kwargs
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
arguments for `AnnCollection` initialization.
Expand All @@ -132,8 +135,16 @@ def __init__(
shuffle: bool = False,
use_default_converter: bool = True,
use_cuda: bool = False,
device: str | None = None,
**kwargs,
):
if use_cuda:
warn(
"Argument use_cuda has been deprecated in favour of `device`. ",
FutureWarning,
)
device = "cuda"

if isinstance(adatas, AnnData):
adatas = [adatas]

Expand Down Expand Up @@ -171,7 +182,7 @@ def __init__(
if use_default_converter:
pin_memory = kwargs.pop("pin_memory", False)
_converter = partial(
default_converter, use_cuda=use_cuda, pin_memory=pin_memory
default_converter, device=device, pin_memory=pin_memory
)
dataset.convert = _convert_on_top(
dataset.convert, _converter, dict(dataset.attrs_keys, X=[])
Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes/0.10.4.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### 0.10.4 {small}`the future`

```{rubric} New features
```
* `AnnLoader` now accepts a `device` argument to specify the device to load the data to {pr}`1240` {user}`austinv11`

```{rubric} Bugfix
```
* Only try to use `Categorical.map(na_action=…)` in actually supported Pandas ≥2.1 {pr}`1226` {user}`flying-sheep`
Expand All @@ -10,3 +14,7 @@

```{rubric} Performance
```

```{rubric} Deprecations
```
* `AnnLoader(use_cuda=…)` is deprecated in favour of `AnnLoader(device=…)` {pr}`1240` {user}`austinv11
Loading