From 91f66fbef558d48931f9f9e4e8eadb2349c42bd4 Mon Sep 17 00:00:00 2001 From: austinv11 Date: Tue, 21 Nov 2023 13:29:13 -0500 Subject: [PATCH] Allow for specifying a tensor device in AnnLoader --- anndata/experimental/pytorch/_annloader.py | 4 ++-- docs/release-notes/0.10.4.md | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/anndata/experimental/pytorch/_annloader.py b/anndata/experimental/pytorch/_annloader.py index de95264b9..3563f7191 100644 --- a/anndata/experimental/pytorch/_annloader.py +++ b/anndata/experimental/pytorch/_annloader.py @@ -57,14 +57,14 @@ def __len__(self): def default_converter(arr, device, pin_memory): if isinstance(arr, torch.Tensor): - if device != "cpu": + if device is not None: arr = arr.to(device) elif pin_memory: arr = arr.pin_memory() elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number): if issparse(arr): arr = arr.toarray() - if device != "cpu": + if device is not None: arr = torch.tensor(arr, device=device) else: arr = torch.tensor(arr) diff --git a/docs/release-notes/0.10.4.md b/docs/release-notes/0.10.4.md index 2a9760e6e..960016a2a 100644 --- a/docs/release-notes/0.10.4.md +++ b/docs/release-notes/0.10.4.md @@ -1,5 +1,9 @@ ### 0.10.4 {small}`the future` +```{rubric} New features +``` +* `AnnLoader` now accepts a `device` argument to specify the device to load the data to {pr}`1240` {user}`austinv11` + ```{rubric} Bugfix ``` * Only try to use `Categorical.map(na_action=…)` in actually supported Pandas ≥2.1 {pr}`1226` {user}`flying-sheep` @@ -10,3 +14,7 @@ ```{rubric} Performance ``` + +```{rubric} Deprecations +``` +* `AnnLoader(use_cuda=…)` is deprecated in favour of `AnnLoader(device=…)` {pr}`1240` {user}`austinv11