-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_dataset_stats.py
67 lines (54 loc) · 2.08 KB
/
compute_dataset_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import itertools
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from hydra import initialize, compose
from omegaconf import DictConfig
from tqdm import tqdm
from skimage.morphology import label
from hoechstgan.data import create_dataset
def compute_mask_stats(mask):
mask = mask != 0
return {
"cells": label(mask != 0, connectivity=1).max(),
"present": mask.max() > 0,
"area": mask.sum() / mask.size,
}
def compute_stats(cfg: DictConfig) -> None:
dataset = create_dataset(cfg)
dataset_size = len(dataset)
def yield_from_dict(d):
iters = {k: iter(v) for (k, v) in d.items()}
try:
while True:
yield {k: next(v) for (k, v) in iters.items()}
except StopIteration:
pass
def compute():
data_generator = itertools.chain.from_iterable(
map(yield_from_dict, dataset))
# data_generator = itertools.islice(data_generator, 500)
for data in tqdm(data_generator, total=dataset_size, desc=f"processing {cfg.phase} dataset"):
yield dict(
(f"{channel} {stat}", value)
for channel in ("Hoechst", "CD3", "CD8")
for (stat, value) in compute_mask_stats(data[f"{channel}_mask"].numpy().squeeze(axis=0)).items()
)
df = pd.DataFrame(list(compute()))
return df
def aggregate(df: pd.DataFrame, phase: str):
out_folder = Path("results")
out_folder.mkdir(exist_ok=True)
df = df.agg(["count", "sum", "mean", "std", "min", "max"])
df.to_csv(out_folder / f"dataset_stats_{phase}.csv")
if __name__ == "__main__":
initialize(config_path="conf", version_base="1.2")
dfs = []
for phase in "train", "test":
cfg = compose("config.yaml", overrides=["+experiment=compute_stats",
f"phase={phase}",
f"is_train={str(phase == 'train').lower()}"])
df = compute_stats(cfg)
dfs.append(df)
aggregate(df, phase)
aggregate(pd.concat(dfs), "all")