Skip to content

Commit

Permalink
Allow for specifying a tensor device in AnnLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
austinv11 committed Nov 21, 2023
1 parent b953702 commit 91f66fb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 2 additions & 2 deletions anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def __len__(self):

def default_converter(arr, device, pin_memory):
if isinstance(arr, torch.Tensor):
if device != "cpu":
if device is not None:
arr = arr.to(device)
elif pin_memory:
arr = arr.pin_memory()
elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number):
if issparse(arr):
arr = arr.toarray()
if device != "cpu":
if device is not None:
arr = torch.tensor(arr, device=device)
else:
arr = torch.tensor(arr)
Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes/0.10.4.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### 0.10.4 {small}`the future`

```{rubric} New features
```
* `AnnLoader` now accepts a `device` argument to specify the device to load the data to {pr}`1240` {user}`austinv11`

```{rubric} Bugfix
```
* Only try to use `Categorical.map(na_action=…)` in actually supported Pandas ≥2.1 {pr}`1226` {user}`flying-sheep`
Expand All @@ -10,3 +14,7 @@

```{rubric} Performance
```

```{rubric} Deprecations
```
* `AnnLoader(use_cuda=…)` is deprecated in favour of `AnnLoader(device=…)` {pr}`1240` {user}`austinv11

0 comments on commit 91f66fb

Please sign in to comment.