forked from mlcommons/GaNDLF
-
Notifications
You must be signed in to change notification settings - Fork 1
/
gandlf_collectStats
133 lines (106 loc) · 4.62 KB
/
gandlf_collectStats
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import argparse
import subprocess
import pathlib
from pathlib import Path
from datetime import date
import numpy as np
import pandas as pd
from io import StringIO
import seaborn as sns
import matplotlib.pyplot as plt
def main():
copyrightMessage = (
"Contact: [email protected]\n\n"
+ "This program is NOT FDA/CE approved and NOT intended for clinical use.\nCopyright (c) "
+ str(date.today().year)
+ " University of Pennsylvania. All rights reserved."
)
parser = argparse.ArgumentParser(
prog="GANDLF_CollectCSV",
formatter_class=argparse.RawTextHelpFormatter,
description="Collect statistics from different testing/validation combinations from output directory.\n\n"
+ copyrightMessage,
)
parser.add_argument(
"-inputDir",
type=str,
help="Input directory which contains testing and validation models",
required=True,
)
parser.add_argument(
"-outputDir",
type=str,
help="Output directory to save stats and plot",
required=True,
)
args = parser.parse_args()
inputDir = os.path.normpath(args.inputDir)
outputDir = os.path.normpath(args.outputDir)
Path(outputDir).mkdir(parents=True, exist_ok=True)
outputFile = os.path.join(outputDir, "data.csv") # data file name
outputPlot = os.path.join(outputDir, "plot.png") # plot file
final_stats = "Epoch,Train_Loss,Train_Dice,Val_Loss,Val_Dice,Testing_Loss,Testing_Dice\n" # the columns that need to be present in final output; epoch is always removed
# loop through output directory
for dirs in os.listdir(inputDir):
currentTestingDir = os.path.join(inputDir, dirs)
if os.path.isdir(currentTestingDir): # go in only if it is a directory
if "testing_" in dirs: # ensure it is part of the testing structure
for val in os.listdir(
currentTestingDir
): # loop through all validation directories
currentValidationDir = os.path.join(currentTestingDir, val)
if os.path.isdir(currentValidationDir):
filesInDir = os.listdir(
currentValidationDir
) # get all files in each directory
for i in range(len(filesInDir)):
if (
"trainingScores_log" in filesInDir[i]
): # when the log has been found, collect the final numbers
log_file = os.path.join(
currentValidationDir, filesInDir[i]
)
with open(log_file) as f:
for line in f:
pass
final_stats = final_stats + line
data_string = StringIO(final_stats)
data_full = pd.read_csv(data_string, sep=",")
del data_full["Epoch"] # no need for epoch
data_full.to_csv(outputFile, index=False) # save updated data
# perform deep copy
data_loss = data_full.copy()
data_dice = data_full.copy()
cols = ["Train", "Val", "Testing"] # set the datasets that need to be plotted
for i in cols:
del data_dice[i + "_Loss"] # keep only dice
del data_loss[i + "_Dice"] # keep only loss
data_loss.rename(columns={i + "_Loss": i}, inplace=True) # rename the columns
data_dice.rename(columns={i + "_Dice": i}, inplace=True) # rename the columns
fig, axes = plt.subplots(
nrows=1, ncols=2, constrained_layout=True
) # set plot properties
bplot = sns.boxplot(
data=data_dice, width=0.5, palette="colorblind", ax=axes[0]
) # plot the data
bplot.set(ylim=(0, 1)) # set limits for y-axis for proper visualization
bplot.set(xlabel="Dataset", ylabel="Dice", title="Dice plot") # set labels
bplot.set_xticklabels(
bplot.get_xticklabels(), rotation=15, ha="right"
) # rotate so that everything is visible
bplot = sns.boxplot(
data=data_loss, width=0.5, palette="colorblind", ax=axes[1]
) # plot the data
bplot.set(ylim=(0, 1)) # set limits for y-axis for proper visualization
bplot.set(xlabel="Dataset", ylabel="Loss", title="Loss plot") # set labels
bplot.set_xticklabels(
bplot.get_xticklabels(), rotation=15, ha="right"
) # rotate so that everything is visible
plt.savefig(outputPlot, dpi=600)
# main function
if __name__ == "__main__":
main()