Skip to content

Commit

Permalink
Allow for specifying a tensor device in AnnLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
austinv11 committed Nov 21, 2023
1 parent 3e340e1 commit b953702
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import partial
from math import ceil
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
from scipy.sparse import issparse
Expand Down Expand Up @@ -54,18 +55,17 @@ def __len__(self):
return length


# maybe replace use_cuda with explicit device option
def default_converter(arr, use_cuda, pin_memory):
def default_converter(arr, device, pin_memory):
if isinstance(arr, torch.Tensor):
if use_cuda:
arr = arr.cuda()
if device != "cpu":
arr = arr.to(device)
elif pin_memory:
arr = arr.pin_memory()
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
if issparse(arr):
arr = arr.toarray()
if use_cuda:
arr = torch.tensor(arr, device="cuda")
if device != "cpu":
arr = torch.tensor(arr, device=device)
else:
arr = torch.tensor(arr)
arr = arr.pin_memory() if pin_memory else arr
Expand Down Expand Up @@ -114,12 +114,15 @@ class AnnLoader(DataLoader):
Set to `True` to have the data reshuffled at every epoch.
use_default_converter
Use the default converter to convert arrays to pytorch tensors, transfer to
the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`).
the specified device (if `device!=None`), do memory pinning (if `pin_memory=True`).
If you pass an AnnCollection object with prespecified converters, the default converter
won't overwrite these converters but will be applied on top of them.
use_cuda
Transfer pytorch tensors to the default cuda device after conversion.
Only works if `use_default_converter=True`
Only works if `use_default_converter=True`. DEPRECATED in favour of `device`.
device
The device to which to transfer pytorch tensors after conversion (example: "cuda").
Only works if `use_default_converter=True`.
**kwargs
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
arguments for `AnnCollection` initialization.
Expand All @@ -132,8 +135,16 @@ def __init__(
shuffle: bool = False,
use_default_converter: bool = True,
use_cuda: bool = False,
device: str | None = None,
**kwargs,
):
if use_cuda:
warn(
"Argument use_cuda has been deprecated in favour of `device`. ",
FutureWarning,
)
device = "cuda"

if isinstance(adatas, AnnData):
adatas = [adatas]

Expand Down Expand Up @@ -171,7 +182,7 @@ def __init__(
if use_default_converter:
pin_memory = kwargs.pop("pin_memory", False)
_converter = partial(
default_converter, use_cuda=use_cuda, pin_memory=pin_memory
default_converter, device=device, pin_memory=pin_memory
)
dataset.convert = _convert_on_top(
dataset.convert, _converter, dict(dataset.attrs_keys, X=[])
Expand Down

0 comments on commit b953702

Please sign in to comment.