diff --git a/anndata/experimental/pytorch/_annloader.py b/anndata/experimental/pytorch/_annloader.py index de95264b9..3563f7191 100644 --- a/anndata/experimental/pytorch/_annloader.py +++ b/anndata/experimental/pytorch/_annloader.py @@ -57,14 +57,14 @@ def __len__(self): def default_converter(arr, device, pin_memory): if isinstance(arr, torch.Tensor): - if device != "cpu": + if device is not None: arr = arr.to(device) elif pin_memory: arr = arr.pin_memory() elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number): if issparse(arr): arr = arr.toarray() - if device != "cpu": + if device is not None: arr = torch.tensor(arr, device=device) else: arr = torch.tensor(arr) diff --git a/docs/release-notes/0.10.4.md b/docs/release-notes/0.10.4.md index 2a9760e6e..960016a2a 100644 --- a/docs/release-notes/0.10.4.md +++ b/docs/release-notes/0.10.4.md @@ -1,5 +1,9 @@ ### 0.10.4 {small}`the future` +```{rubric} New features +``` +* `AnnLoader` now accepts a `device` argument to specify the device to load the data to {pr}`1240` {user}`austinv11` + ```{rubric} Bugfix ``` * Only try to use `Categorical.map(na_action=…)` in actually supported Pandas ≥2.1 {pr}`1226` {user}`flying-sheep` @@ -10,3 +14,7 @@ ```{rubric} Performance ``` + +```{rubric} Deprecations +``` +* `AnnLoader(use_cuda=…)` is deprecated in favour of `AnnLoader(device=…)` {pr}`1240` {user}`austinv11