Skip to content

Commit

Permalink
Merge pull request #313 from mims-harvard/scdti_benchmark
Browse files Browse the repository at this point in the history
modify scdti group to use pinnacle single-cell network dataset for ib…
  • Loading branch information
amva13 authored Sep 13, 2024
2 parents 858eefb + f9453c1 commit 80acefc
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 61 deletions.
148 changes: 108 additions & 40 deletions tdc/benchmark_group/scdti_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

from .base_group import BenchmarkGroup
from ..resource.pinnacle import PINNACLE


class SCDTIGroup(BenchmarkGroup):
Expand All @@ -15,52 +16,119 @@ class SCDTIGroup(BenchmarkGroup):

def __init__(self, path="./data", file_format="csv"):
"""Create an SCDTI benchmark group class."""
# super().__init__(name="SCDTI_Group", path=path)
self.name = "SCDTI_Group"
self.path = os.path.join(path, self.name)
# self.datasets = ["opentargets_dti"]
self.dataset_names = ["opentargets_dti"]
self.file_format = file_format
self.split = None
self.p = PINNACLE()

def get_train_valid_split(self):
def precision_recall_at_k(self, y, preds, k: int = 5):
"""
Calculate recall@k and precision@k for binary classification.
"""
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, average_precision_score
assert preds.shape[0] == y.shape[0]
assert k > 0
if k > preds.shape[0]:
return -1, -1, -1, -1

# Sort the scores and the labels by the scores
sorted_indices = np.argsort(preds.flatten())[::-1]
sorted_preds = preds[sorted_indices]
sorted_y = y[sorted_indices]

# Get the scores of the k highest predictions
topk_preds = sorted_preds[:k]
topk_y = sorted_y[:k]

# Calculate the recall@k and precision@k
recall_k = np.sum(topk_y) / np.sum(y)
precision_k = np.sum(topk_y) / k

# Calculate the accuracy@k
accuracy_k = accuracy_score(topk_y, topk_preds > 0.5)

# Calculate the AP@k
ap_k = average_precision_score(topk_y, topk_preds)

return recall_k, precision_k, accuracy_k, ap_k

def get_train_valid_split(self, seed=1):
"""parameters included for compatibility. this benchmark has a fixed train/test split."""
from ..resource.dataloader import DataLoader
if self.split is None:
dl = DataLoader(name="opentargets_dti")
self.split = dl.get_split()
return self.split["train"], self.split["dev"]

def get_test(self):
from ..resource.dataloader import DataLoader
if self.split is None:
dl = DataLoader(name="opentargets_dti")
self.split = dl.get_split()
return self.split["test"]

def evaluate(self, y_pred):
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
y_true = self.get_test()["Y"]
# Calculate metrics
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
return [precision, recall, accuracy, f1]

def evaluate_many(self, preds):
train = self.p.get_exp_data(seed=seed, split="train")
val = self.p.get_exp_data(seed=seed, split="val")
return {"train": train, "val": val}

def get_test(self, seed=1):
return {"test": self.p.get_exp_data(seed=seed, split="test")}

def evaluate(self, y_pred, k=5, top_k=20):
from numpy import mean
from sklearn.metrics import roc_auc_score
y_true = self.get_test()["test"]
assert "preds" in y_pred.columns, "require 'preds' prediction label in input df"
assert "cell_type_label" in y_pred.columns, "require cell_type_label in input df"
assert "disease" in y_pred.columns, "require 'disease' in input df"
cells = y_true["cell_type_label"].unique()
diseases = y_true["disease"].unique()
assert len(cells) == len(
y_pred["cell_type_label"].unique()
), "number of cell types in input df and test df do not match. expected {}".format(
len(cells))
assert len(diseases) == len(
y_pred["disease"].unique()
), "number of diseases in input df do not match test df. expected {}".format(
len(diseases))
results = {d: [] for d in diseases}
for disease in diseases:
for cell in cells:
preds = y_pred[(y_pred["disease"] == disease) &
(y_pred["cell_type_label"] == cell)]
yt = y_true[(y_true["disease"] == disease) &
(y_true["cell_type_label"] == cell)]
assert len(preds) == len(
yt
), "mismatch in length of predictions and results for a specific disease {} and cell type {}".format(
disease, cell)
if len(yt) == 0:
continue
auc = roc_auc_score(yt["y"], preds["preds"])
recall_k, precision_k, accuracy_k, ap_k = self.precision_recall_at_k(
yt["y"].values, preds["preds"].values, k=k)
results[disease].append({
"auc": auc,
"recall": recall_k,
"precision": precision_k,
"accuracy": accuracy_k,
"ap": ap_k
})
# for now, we benchmark with only ap@k with top 20 cells
for d, scores in results.items():
assert type(
scores
) == list, "scores should be a list. got {} with value {}".format(
scores, type(scores))
assert type(scores[0]
) == dict, "scores should contain dictionary of metrics"
assert "ap" in scores[0], "scores should include 'ap'"
topk_cells = [
x["ap"] for x in sorted(scores, key=lambda s: s["ap"])[-top_k:]
]
results[d] = mean(topk_cells)
return results

def evaluate_many(self, preds: list):
from numpy import mean, std
assert type(
preds
) == list, "expected preds to be a list containing prediction dataframes for multiple seeds"
if len(preds) < 5:
raise Exception(
"Run your model on at least 5 seeds to compare results and provide your outputs in preds."
)
out = dict()
preds = [self.evaluate(p) for p in preds]
out["precision"] = (mean([x[0] for x in preds]),
std([x[0] for x in preds]))
out["recall"] = (mean([x[1] for x in preds]), std([x[1] for x in preds
]))
out["accuracy"] = (mean([x[2] for x in preds]),
std([x[2] for x in preds]))
out["f1"] = (mean([x[3] for x in preds]), std([x[3] for x in preds]))
return out
evals = [self.evaluate(x) for x in preds]
diseases = preds[0]["disease"].unique()
return {
d: [mean([x[d] for x in evals]),
std([x[d] for x in evals])] for d in diseases
}
10 changes: 8 additions & 2 deletions tdc/resource/pinnacle.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def get_exp_data(self, seed=1, split="train"):
# clean data directory
file_list = os.listdir("./data")
for file in file_list:
os.remove(os.path.join("./data", file))
try:
os.remove(os.path.join("./data", file))
except:
continue
print("downloading pinancle zip data...")
zip_data_download_wrapper(
filename, "./data",
Expand All @@ -94,7 +97,10 @@ def get_exp_data(self, seed=1, split="train"):
f for f in os.listdir("./data") if not f.endswith(".csv")
]
for x in non_csv_files:
os.remove("./data/{}".format(x))
try:
os.remove("./data/{}".format(x))
except:
continue
# Get a list of all CSV files in the unzipped folder
csv_files = [f for f in os.listdir("./data") if f.endswith(".csv")]
if not csv_files:
Expand Down
45 changes: 26 additions & 19 deletions tdc/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,34 @@ def test_ADME_evaluate_many(self):
self.assertTrue(my_group["name"] in results)

def test_SCDTI_benchmark(self):
from tdc.resource.dataloader import DataLoader

data = DataLoader(name="opentargets_dti")
group = scdti_group.SCDTIGroup()
train, val = group.get_train_valid_split()
assert len(val) == 0 # this benchmark has no validation set
# test simple preds
y_true = group.get_test()["Y"]
results = group.evaluate(y_true)
assert results[-1] == 1.0 # should be perfect F1 score
# assert it matches the opentargets official test scores
tst = data.get_split()["test"]["Y"]
train_val = group.get_train_valid_split()
assert "train" in train_val, "no training set"
assert "val" in train_val, "no validation set"
assert len(train_val["train"]) > 0, "no entries in training set"
tst = group.get_test()["test"]
tst["preds"] = tst["y"] # switch predictions to ground truth
results = group.evaluate(tst)
assert results[-1] == 1.0
zero_pred = [0] * len(y_true)
results = group.evaluate(zero_pred)
assert results[-1] != 1.0 # should not be perfect F1 score
many_results = group.evaluate_many([y_true] * 5)
assert "f1" in many_results
assert len(many_results["f1"]
) == 2 # should include mean and standard deviation
assert "IBD" in results, "missing ibd from diseases. got {}".format(
results.keys())
assert "RA" in results, "missing ra from diseases. got {}".format(
results.keys())
assert results["IBD"] == results[
"RA"], "both should be perfect scores but got IBD {} vs RA {}".format(
results["IBD"], results["RA"]) # both should be perfect scores
assert results["IBD"] - 1.0 < 0.000001 # should be a perfect score
many_results = group.evaluate_many([tst] * 5)
assert "IBD" in many_results, "missing ibd from diseases in evaluate many. got {}".format(
many_results.keys())
assert "RA" in many_results, "missing ra from diseases in evaluate many. got {}".format(
many_results.keys())
assert len(many_results["IBD"]) == len(
many_results["RA"]
), "both diseases should include mean and standard deviation"
assert len(many_results["IBD"]
) == 2, "results should include mean and standard deviation"
assert many_results["IBD"][
0] - 1.0 < 0.000001, "should get perfect score"

@unittest.skip(
"counterfactual test is taking up too much memory"
Expand Down

0 comments on commit 80acefc

Please sign in to comment.