Skip to content

Commit

Permalink
Merge branch 'main' into local-min-deps
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Nov 11, 2024
2 parents ba0bd21 + d0adc25 commit 3abf9dc
Show file tree
Hide file tree
Showing 37 changed files with 492 additions and 217 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.0
rev: v0.7.2
hooks:
- id: ruff
types_or: [python, pyi, jupyter]
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
import scanpy as sc

if TYPE_CHECKING:
from collections.abc import Callable, Sequence, Set
from collections.abc import Callable, Sequence
from collections.abc import Set as AbstractSet
from typing import Literal, Protocol, TypeVar

from anndata import AnnData

C = TypeVar("C", bound=Callable)

class ParamSkipper(Protocol):
def __call__(self, **skipped: Set) -> Callable[[C], C]: ...
def __call__(self, **skipped: AbstractSet) -> Callable[[C], C]: ...

Dataset = Literal["pbmc68k_reduced", "pbmc3k", "bmmc", "lung93k"]
KeyX = Literal[None, "off-axis"]
Expand Down Expand Up @@ -195,7 +196,7 @@ def param_skipper(
b 5
"""

def skip(**skipped: Set) -> Callable[[C], C]:
def skip(**skipped: AbstractSet) -> Callable[[C], C]:
skipped_combs = [
tuple(record.values())
for record in (
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/3307.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support {class}`dask.array.Array` to {func}`scanpy.pp.calculate_qc_metrics` {smaller}`I Gold`
1 change: 1 addition & 0 deletions docs/release-notes/3335.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Run numba functions single-threaded when called from inside of a ThreadPool {smaller}`P Angerer`
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ select = [
"TID251", # Banned imports
"ICN", # Follow import conventions
"PTH", # Pathlib instead of os.path
"PYI", # Typing
"PLR0917", # Ban APIs with too many positional parameters
"FBT", # No positional boolean parameters
"PT", # Pytest style
Expand All @@ -246,6 +247,8 @@ ignore = [
"E262",
# allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation
"E741",
# `Literal["..."] | str` is useful for autocompletion
"PYI051",
]
[tool.ruff.lint.per-file-ignores]
# Do not assign a lambda expression, use a def
Expand All @@ -259,6 +262,8 @@ required-imports = ["from __future__ import annotations"]
"pandas.value_counts".msg = "Use pd.Series(a).value_counts() instead"
"legacy_api_wrap.legacy_api".msg = "Use scanpy._compat.old_positionals instead"
"numpy.bool".msg = "Use `np.bool_` instead for numpy>=1.24<2 compatibility"
"numba.jit".msg = "Use `scanpy._compat.njit` instead"
"numba.njit".msg = "Use `scanpy._compat.njit` instead"
[tool.ruff.lint.flake8-type-checking]
exempt-modules = []
strict = true
Expand Down
108 changes: 106 additions & 2 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import os
import sys
import warnings
from dataclasses import dataclass, field
from functools import cache, partial
from functools import cache, partial, wraps
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload

from packaging.version import Version

if TYPE_CHECKING:
from collections.abc import Callable
from importlib.metadata import PackageMetadata

P = ParamSpec("P")
R = TypeVar("R")


if TYPE_CHECKING:
# type checkers are confused and can only see …core.Array
Expand Down Expand Up @@ -90,3 +96,101 @@ def pkg_version(package: str) -> Version:
# but this code makes it possible to run scanpy without it.
def old_positionals(*old_positionals: str):
return lambda func: func


@overload
def njit(fn: Callable[P, R], /) -> Callable[P, R]: ...
@overload
def njit() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def njit(
fn: Callable[P, R] | None = None, /
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""\
Jit-compile a function using numba.
On call, this function dispatches to a parallel or sequential numba function,
depending on if it has been called from a thread pool.
See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809>
"""

def decorator(f: Callable[P, R], /) -> Callable[P, R]:
import numba

fns: dict[bool, Callable[P, R]] = {
parallel: numba.njit(f, cache=True, parallel=parallel) # noqa: TID251
for parallel in (True, False)
}

@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
parallel = not _is_in_unsafe_thread_pool()
if not parallel:
msg = (
"Detected unsupported threading environment. "
f"Trying to run {f.__name__} in serial mode. "
"In case of problems, install `tbb`."
)
warnings.warn(msg, stacklevel=2)
return fns[parallel](*args, **kwargs)

return wrapper

return decorator if fn is None else decorator(fn)


LayerType = Literal["default", "safe", "threadsafe", "forksafe"]
Layer = Literal["tbb", "omp", "workqueue"]


LAYERS: dict[LayerType, set[Layer]] = {
"default": {"tbb", "omp", "workqueue"},
"safe": {"tbb"},
"threadsafe": {"tbb", "omp"},
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
}


def _is_in_unsafe_thread_pool() -> bool:
import threading

current_thread = threading.current_thread()
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
return (
current_thread.name.startswith("ThreadPoolExecutor")
and _numba_threading_layer() not in LAYERS["threadsafe"]
)


@cache
def _numba_threading_layer() -> Layer:
"""\
Get numba’s threading layer.
This function implements the algorithm as described in
<https://numba.readthedocs.io/en/stable/user/threading-layer.html>
"""
import importlib

import numba

if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None:
# given by direct name
return numba.config.THREADING_LAYER

# given by layer type (safe, …)
for layer in cast(list[Layer], numba.config.THREADING_LAYER_PRIORITY):
if layer not in available:
continue
if layer != "workqueue":
try: # `importlib.util.find_spec` doesn’t work here
importlib.import_module(f"numba.np.ufunc.{layer}pool")
except ImportError:
continue
# the layer has been found
return layer
msg = (
f"No loadable threading layer: {numba.config.THREADING_LAYER=} "
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
)
raise ValueError(msg)
4 changes: 2 additions & 2 deletions src/scanpy/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Collected from the print_* functions in matplotlib.backends
_Format = (
Literal["png", "jpg", "tif", "tiff"]
Literal["png", "jpg", "tif", "tiff"] # noqa: PYI030
| Literal["pdf", "ps", "eps", "svg", "svgz", "pgf"]
| Literal["raw", "rgba"]
)
Expand Down Expand Up @@ -340,7 +340,7 @@ def max_memory(self) -> int | float:
return self._max_memory

@max_memory.setter
def max_memory(self, max_memory: int | float):
def max_memory(self, max_memory: float):
_type_check(max_memory, "max_memory", (int, float))
self._max_memory = max_memory

Expand Down
78 changes: 62 additions & 16 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
import re
import sys
import warnings
from collections import namedtuple
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from functools import partial, reduce, singledispatch, wraps
from operator import mul, or_, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, overload
from types import MethodType, ModuleType, UnionType
from typing import (
TYPE_CHECKING,
Literal,
NamedTuple,
Union,
get_args,
get_origin,
overload,
)
from weakref import WeakSet

import h5py
Expand All @@ -42,19 +49,20 @@
from anndata._core.sparse_dataset import SparseDataset

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Mapping
from collections.abc import Callable, Iterable, KeysView, Mapping
from pathlib import Path
from typing import Any, Literal, TypeVar
from typing import Any, TypeVar

from anndata import AnnData
from numpy.typing import DTypeLike, NDArray
from numpy.typing import ArrayLike, DTypeLike, NDArray

from ..neighbors import NeighborsParams, RPForestDict


# e.g. https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
# maybe in the future random.Generator
AnyRandom = int | np.random.RandomState | None
LegacyUnionType = type(Union[int, str]) # noqa: UP007


class Empty(Enum):
Expand Down Expand Up @@ -296,6 +304,11 @@ def get_igraph_from_adjacency(adjacency, directed=None):
# --------------------------------------------------------------------------------


class AssoResult(NamedTuple):
asso_names: list[str]
asso_matrix: NDArray[np.floating]


def compute_association_matrix_of_groups(
adata: AnnData,
prediction: str,
Expand All @@ -304,7 +317,7 @@ def compute_association_matrix_of_groups(
normalization: Literal["prediction", "reference"] = "prediction",
threshold: float = 0.01,
max_n_names: int | None = 2,
):
) -> AssoResult:
"""Compute overlaps between groups.
See ``identify_groups`` for identifying the groups.
Expand Down Expand Up @@ -346,8 +359,8 @@ def compute_association_matrix_of_groups(
f"Ignoring category {cat!r} "
"as it’s in `settings.categories_to_ignore`."
)
asso_names = []
asso_matrix = []
asso_names: list[str] = []
asso_matrix: list[list[float]] = []
for ipred_group, pred_group in enumerate(adata.obs[prediction].cat.categories):
if "?" in pred_group:
pred_group = str(ipred_group)
Expand Down Expand Up @@ -380,13 +393,12 @@ def compute_association_matrix_of_groups(
if asso_matrix[-1][i] > threshold
]
asso_names += ["\n".join(name_list_pred[:max_n_names])]
Result = namedtuple(
"compute_association_matrix_of_groups", ["asso_names", "asso_matrix"]
)
return Result(asso_names=asso_names, asso_matrix=np.array(asso_matrix))
return AssoResult(asso_names=asso_names, asso_matrix=np.array(asso_matrix))


def get_associated_colors_of_groups(reference_colors, asso_matrix):
def get_associated_colors_of_groups(
reference_colors: Mapping[int, str], asso_matrix: NDArray[np.floating]
) -> list[dict[str, float]]:
return [
{
reference_colors[i_ref]: asso_matrix[i_pred, i_ref]
Expand Down Expand Up @@ -532,6 +544,19 @@ def update_params(
return updated_params


# `get_args` returns `tuple[Any]` so I don’t think it’s possible to get the correct type here
def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]:
"""Get all literal values from a Literal or Union of … of Literal type."""
if isinstance(typ, UnionType | LegacyUnionType):
return reduce(
or_, (dict.fromkeys(get_literal_vals(t)) for t in get_args(typ))
).keys()
if get_origin(typ) is Literal:
return dict.fromkeys(get_args(typ)).keys()
msg = f"{typ} is not a valid Literal"
raise TypeError(msg)


# --------------------------------------------------------------------------------
# Others
# --------------------------------------------------------------------------------
Expand Down Expand Up @@ -713,6 +738,27 @@ def _(
)


@singledispatch
def axis_nnz(X: ArrayLike, axis: Literal[0, 1]) -> np.ndarray:
return np.count_nonzero(X, axis=axis)


@axis_nnz.register(sparse.spmatrix)
def _(X: sparse.spmatrix, axis: Literal[0, 1]) -> np.ndarray:
return X.getnnz(axis=axis)


@axis_nnz.register(DaskArray)
def _(X: DaskArray, axis: Literal[0, 1]) -> DaskArray:
return X.map_blocks(
partial(axis_nnz, axis=axis),
dtype=np.int64,
meta=np.array([], dtype=np.int64),
drop_axis=0,
chunks=len(X.to_delayed()) * (X.chunksize[int(not axis)],),
)


@overload
def axis_sum(
X: sparse.spmatrix,
Expand Down
Loading

0 comments on commit 3abf9dc

Please sign in to comment.