Skip to content

Commit

Permalink
Merge pull request #44 from helicalAI/fix-integration-evaluation-bug
Browse files Browse the repository at this point in the history
Fix integration evaluation bug
  • Loading branch information
bputzeys authored Jul 5, 2024
2 parents a2e17fb + f690050 commit a0ca392
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ci/tests/test_benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_evaluate_integration():
data.obs["str_labels"] = pd.Categorical(['type1', 'type2'] * (data.shape[0] // 2))
data.obsm["X_scgpt"] = np.zeros((data.shape[0], 10))

evaluations = evaluate_integration([("scgpt", data, "X_scgpt"), ("different_model", data, "X_scgpt")], config)
evaluations = evaluate_integration([("scgpt", "X_scgpt"), ("different_model", "X_scgpt")], data, config)

# scgpt
assert_near_exact(evaluations["scgpt"]["BATCH"]["ASW_batch"], 1.0)
Expand Down
8 changes: 4 additions & 4 deletions examples/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def run_integration_example(adata: ad.AnnData, cfg: DictConfig) -> dict[str, dic

return evaluate_integration(
[
("scgpt", adata, "X_scgpt"),
("uce", adata, "X_uce"),
("scanorama", adata, "X_scanorama")
], cfg
("scgpt", "X_scgpt"),
("uce", "X_uce"),
("scanorama", "X_scanorama")
], adata, cfg
)


Expand Down
15 changes: 10 additions & 5 deletions helical/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
from scib.metrics import metrics
from omegaconf import DictConfig
from copy import deepcopy

LOGGER = logging.getLogger(__name__)

def evaluate_integration(data_list: list[tuple[str, AnnData, str]], cfg: DictConfig) -> dict[str, dict[str, float]]:
def evaluate_integration(data_list: list[tuple[str, str]], adata: AnnData, cfg: DictConfig) -> dict[str, dict[str, float]]:
"""
Evaluate the data integration of the anndata object using the scib metrics.
Expand All @@ -18,8 +19,9 @@ def evaluate_integration(data_list: list[tuple[str, AnnData, str]], cfg: DictCon
data_list : list[tuple[str, AnnData, str]]
A list of tuples containing:
The name of the model that was used to generate the embeddings.
The AnnData object that contains the embeddings.
The name of the obsm attribute that contains the embeddings.
adata : AnnData
The AnnData object that contains the embeddings.
cfg : DictConfig
The configuration of the data and the integration.
Ie. the config must enable access to cfg["data"] and cfg["integration"].
Expand All @@ -30,10 +32,13 @@ def evaluate_integration(data_list: list[tuple[str, AnnData, str]], cfg: DictCon
"""
evaluations = {}
for model, adata, embed_obsm_name in data_list:
for model, embed_obsm_name in data_list:
LOGGER.info(f"Processing integration evaluation using...")
evaluation = _get_integration_evaluations(adata,
adata,

# because scib library modifies the adata object, we need to deepcopy it for each model
# otherwise, some evaluations will be identical and thus incorrect
evaluation = _get_integration_evaluations(deepcopy(adata),
deepcopy(adata),
cfg["data"]["batch_key"],
cfg["data"]["label_key"],
embed_obsm_name,
Expand Down

0 comments on commit a0ca392

Please sign in to comment.