Skip to content

Commit

Permalink
Fix typing for optional imports (#1692)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Sep 24, 2024
1 parent b930f27 commit 3795125
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 79 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING

import zarr as _ # noqa: F401 # Makes {read,write}_zarr show up in docs.
from docutils import nodes

if TYPE_CHECKING:
Expand Down
10 changes: 4 additions & 6 deletions src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
import scipy.sparse as ss
from scipy.sparse import _sparsetools

try:
# Not really important, just for IDEs to be more helpful
from scipy.sparse._compressed import _cs_matrix
except ImportError:
from scipy.sparse import spmatrix as _cs_matrix

from .. import abc
from .._settings import settings
from ..compat import H5Group, SpArray, ZarrArray, ZarrGroup, _read_attr
Expand All @@ -41,8 +35,12 @@
from collections.abc import Sequence
from typing import Literal

from scipy.sparse._compressed import _cs_matrix

from .._types import GroupStorageType
from .index import Index
else:
from scipy.sparse import spmatrix as _cs_matrix


class BackedFormat(NamedTuple):
Expand Down
49 changes: 22 additions & 27 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from functools import singledispatch, wraps
from importlib.util import find_spec
from inspect import Parameter, signature
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, Union
Expand Down Expand Up @@ -90,10 +91,10 @@ def pairwise(iterable):
# Optional deps
#############################

try:
if find_spec("zarr") or TYPE_CHECKING:
from zarr.core import Array as ZarrArray
from zarr.hierarchy import Group as ZarrGroup
except ImportError:
else:

class ZarrArray:
@staticmethod
Expand All @@ -106,60 +107,54 @@ def __repr__():
return "mock zarr.core.Group"


try:
import awkward

AwkArray = awkward.Array

except ImportError:
if find_spec("awkward") or TYPE_CHECKING:
import awkward # noqa: F401
from awkward import Array as AwkArray
else:

class AwkArray:
@staticmethod
def __repr__():
return "mock awkward.highlevel.Array"


try:
if find_spec("zappy") or TYPE_CHECKING:
from zappy.base import ZappyArray
except ImportError:
else:

class ZappyArray:
@staticmethod
def __repr__():
return "mock zappy.base.ZappyArray"


try:
if TYPE_CHECKING:
# type checkers are confused and can only see …core.Array
from dask.array.core import Array as DaskArray
elif find_spec("dask"):
from dask.array import Array as DaskArray
except ImportError:
else:

class DaskArray:
@staticmethod
def __repr__():
return "mock dask.array.core.Array"


try:
if find_spec("cupy") or TYPE_CHECKING:
from cupy import ndarray as CupyArray
from cupyx.scipy.sparse import (
csc_matrix as CupyCSCMatrix,
)
from cupyx.scipy.sparse import (
csr_matrix as CupyCSRMatrix,
)
from cupyx.scipy.sparse import (
spmatrix as CupySparseMatrix,
)
from cupyx.scipy.sparse import csc_matrix as CupyCSCMatrix
from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix
from cupyx.scipy.sparse import spmatrix as CupySparseMatrix

try:
import dask.array as da

da.register_chunk_type(CupyCSRMatrix)
da.register_chunk_type(CupyCSCMatrix)
except ImportError:
pass

except ImportError:
else:
da.register_chunk_type(CupyCSRMatrix)
da.register_chunk_type(CupyCSCMatrix)
else:

class CupySparseMatrix:
@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions src/anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from copy import copy
from functools import partial
from importlib.util import find_spec
from math import ceil
from typing import TYPE_CHECKING

Expand All @@ -14,10 +15,10 @@
if TYPE_CHECKING:
from collections.abc import Sequence

try:
if find_spec("torch") or TYPE_CHECKING:
import torch
from torch.utils.data import BatchSampler, DataLoader, Sampler
except ImportError:
else:
Sampler, BatchSampler, DataLoader = object, object, object


Expand Down
19 changes: 7 additions & 12 deletions src/anndata/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import sys
from importlib.util import find_spec
from typing import TYPE_CHECKING

from ._core.sparse_dataset import sparse_dataset
Expand All @@ -17,20 +17,15 @@
from ._io.specs import read_elem, write_elem
from ._io.write import write_csvs, write_loom

if "zarr" in sys.modules or TYPE_CHECKING:
if find_spec("zarr") or TYPE_CHECKING:
from ._io.zarr import read_zarr, write_zarr
else:
# In case zarr is not imported (and maybe not installed),
# wrap these functions into shims.
def read_zarr(*args, **kw): # pragma: no cover
from ._io.zarr import read_zarr
else: # pragma: no cover

return read_zarr(*args, **kw)
def read_zarr(*args, **kw):
raise ImportError("zarr is not installed")

def write_zarr(*args, **kw): # pragma: no cover
from ._io.zarr import write_zarr

return write_zarr(*args, **kw)
def write_zarr(*args, **kw):
raise ImportError("zarr is not installed")


__all__ = [
Expand Down
67 changes: 36 additions & 31 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import random
import re
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from functools import partial, singledispatch, wraps
from importlib.util import find_spec
from string import ascii_letters
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -35,7 +37,7 @@
from anndata.utils import asarray

if TYPE_CHECKING:
from collections.abc import Collection
from collections.abc import Collection, Iterable
from typing import Callable, Literal, TypeGuard, TypeVar

DT = TypeVar("DT")
Expand Down Expand Up @@ -1037,46 +1039,49 @@ def shares_memory_sparse(x, y):
),
]

try:
import zarr
if find_spec("zarr") or TYPE_CHECKING:
from zarr import DirectoryStore
else:

class AccessTrackingStore(zarr.DirectoryStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._access_count = {}
self._accessed_keys = {}
class DirectoryStore:
def __init__(self, *_args, **_kwargs) -> None:
cls_name = type(self).__name__
msg = f"zarr must be imported to create a {cls_name} instance."
raise ImportError(msg)

def __getitem__(self, key):
for tracked in self._access_count:
if tracked in key:
self._access_count[tracked] += 1
self._accessed_keys[tracked] += [key]
return super().__getitem__(key)

def get_access_count(self, key):
return self._access_count[key]
class AccessTrackingStore(DirectoryStore):
_access_count: Counter[str]
_accessed_keys: dict[str, list[str]]

def get_accessed_keys(self, key):
return self._accessed_keys[key]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._access_count = Counter()
self._accessed_keys = {}

def initialize_key_trackers(self, keys_to_track):
for k in keys_to_track:
self._access_count[k] = 0
self._accessed_keys[k] = []
def __getitem__(self, key: str) -> object:
for tracked in self._access_count:
if tracked in key:
self._access_count[tracked] += 1
self._accessed_keys[tracked] += [key]
return super().__getitem__(key)

def reset_key_trackers(self):
self.initialize_key_trackers(self._access_count.keys())
def get_access_count(self, key: str) -> int:
return self._access_count[key]

except ImportError:
def get_accessed_keys(self, key: str) -> list[str]:
return self._accessed_keys[key]

class AccessTrackingStore:
def __init__(self, *_args, **_kwargs) -> None:
raise ImportError(
"zarr must be imported to create an `AccessTrackingStore` instance."
)
def initialize_key_trackers(self, keys_to_track: Iterable[str]) -> None:
for k in keys_to_track:
self._access_count[k] = 0
self._accessed_keys[k] = []

def reset_key_trackers(self) -> None:
self.initialize_key_trackers(self._access_count.keys())


def get_multiindex_columns_df(shape):
def get_multiindex_columns_df(shape: tuple[int, int]) -> pd.DataFrame:
return pd.DataFrame(
np.random.rand(shape[0], shape[1]),
columns=pd.MultiIndex.from_tuples(
Expand Down

0 comments on commit 3795125

Please sign in to comment.