Skip to content

Commit

Permalink
Improve KNN label_transfer in PerturbationSpace (#658)
Browse files Browse the repository at this point in the history
* Add uncertainty score in KNN label_transfer in PerturbationSpace
Certainty is quantified as the fraction of nearest neighbors belonging to the classified (i.e. the most abundant) label compared to the total number of nearest neighbors.

* Update pre-commit-config.yaml
Replaces yanked dependency of mypy "types-pkg-resources" with "types-setuptools" as recommended: https://pypi.org/project/types-pkg-resources/

* Improve label imputation in PerturbationSpace class
Key changes:
- Now uses KNN graph in adata: saves cost and increases consistency
- Vectorized operations instead of expensive for loop
- Distance weighting for KNN imputation
- Quantifies uncertainty as local KNN label entropy
  • Loading branch information
stefanpeidli authored Sep 30, 2024
1 parent ae048dd commit 98e2bdb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
additional_dependencies:
["types-pkg-resources", "types-requests", "types-attrs"]
["types-setuptools", "types-requests", "types-attrs"]
- repo: local
hooks:
- id: forbid-to-commit
Expand Down
69 changes: 39 additions & 30 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from anndata import AnnData
from lamin_utils import logger
from rich import print
from scipy.stats import entropy

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -364,50 +365,58 @@ def label_transfer(
self,
adata: AnnData,
column: str = "perturbation",
column_uncertainty_score_key: str = "perturbation_transfer_uncertainty",
target_val: str = "unknown",
n_neighbors: int = 5,
use_rep: str = "X_umap",
neighbors_key: str = "neighbors",
) -> None:
"""Impute missing values in the specified column using KNN imputation in the space defined by `use_rep`.
Uncertainty is calculated as the entropy of the label distribution in the neighborhood of the target cell.
In other words, a cell where all neighbors have the same set of labels will have an uncertainty of 0, whereas a cell
where all neighbors have many different labels will have high uncertainty.
Args:
adata: The AnnData object containing single-cell data.
column: The column name in AnnData object to perform imputation on.
column: The column name in adata.obs to perform imputation on.
column_uncertainty_score_key: The column name in adata.obs to store the uncertainty score of the label transfer.
target_val: The target value to impute.
n_neighbors: Number of neighbors to use for imputation.
use_rep: The key in `adata.obsm` where the embedding (UMAP, PCA, etc.) is stored.
neighbors_key: The key in adata.uns where the neighbors are stored.
Examples:
>>> import pertpy as pt
>>> import scanpy as sc
>>> import numpy as np
>>> adata = sc.datasets.pbmc68k_reduced()
>>> rng = np.random.default_rng()
>>> adata.obs["perturbation"] = rng.choice(
... ["A", "B", "C", "unknown"], size=adata.n_obs, p=[0.33, 0.33, 0.33, 0.01]
... )
>>> # randomly dropout 10% of the data annotations
>>> adata.obs["perturbation"] = adata.obs["louvain"].astype(str).copy()
>>> random_cells = np.random.choice(adata.obs.index, int(adata.obs.shape[0] * 0.1), replace=False)
>>> adata.obs.loc[random_cells, "perturbation"] = "unknown"
>>> sc.pp.neighbors(adata)
>>> sc.tl.umap(adata)
>>> ps = pt.tl.PseudobulkSpace()
>>> ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
>>> ps.label_transfer(adata)
"""
if use_rep not in adata.obsm:
raise ValueError(f"Representation {use_rep} not found in the AnnData object.")

embedding = adata.obsm[use_rep]

from pynndescent import NNDescent

nnd = NNDescent(embedding, n_neighbors=n_neighbors)
indices, _ = nnd.query(embedding, k=n_neighbors)

perturbations = np.array(adata.obs[column])
missing_mask = perturbations == target_val

for idx in np.where(missing_mask)[0]:
neighbor_indices = indices[idx]
neighbor_categories = perturbations[neighbor_indices]
most_common = pd.Series(neighbor_categories).mode()[0]
perturbations[idx] = most_common

adata.obs[column] = perturbations
if neighbors_key not in adata.uns:
raise ValueError(f"Key {neighbors_key} not found in adata.uns. Please run `sc.pp.neighbors` first.")

labels = adata.obs[column].astype(str)
target_cells = labels == target_val

connectivities = adata.obsp[adata.uns[neighbors_key]["connectivities_key"]]
# convert labels to an incidence matrix
one_hot_encoded_labels = adata.obs[column].astype(str).str.get_dummies()
# convert to distance-weighted neighborhood incidence matrix
weighted_label_occurence = pd.DataFrame(
(one_hot_encoded_labels.values.T * connectivities).T,
index=adata.obs_names,
columns=one_hot_encoded_labels.columns,
)
# choose best label for each target cell
best_labels = weighted_label_occurence.drop(target_val, axis=1)[target_cells].idxmax(axis=1)
adata.obs[column] = labels
adata.obs.loc[target_cells, column] = best_labels

# calculate uncertainty
uncertainty = np.zeros(adata.n_obs)
uncertainty[target_cells] = entropy(weighted_label_occurence.drop(target_val, axis=1)[target_cells], axis=1)
adata.obs[column_uncertainty_score_key] = uncertainty
12 changes: 10 additions & 2 deletions tests/tools/_perturbation_space/test_simple_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,17 @@ def test_label_transfer():
adata = AnnData(X)
perturbations = np.array(["A", "B", "C"] * 22 + ["unknown"] * 3)
adata.obs["perturbation"] = perturbations

with pytest.raises(ValueError):
ps = pt.tl.PseudobulkSpace()
ps.label_transfer(adata)

sc.pp.neighbors(adata, use_rep="X")
sc.tl.umap(adata)

ps = pt.tl.PseudobulkSpace()
ps.label_transfer(adata, n_neighbors=5, use_rep="X_umap")
ps.label_transfer(adata)
assert "unknown" not in adata.obs["perturbation"]
assert all(adata.obs["perturbation_transfer_uncertainty"] >= 0)
assert not all(adata.obs["perturbation_transfer_uncertainty"] == 0)
is_known = perturbations != "unknown"
assert all(adata.obs.loc[is_known, "perturbation_transfer_uncertainty"] == 0)

0 comments on commit 98e2bdb

Please sign in to comment.