Skip to content

Commit

Permalink
Merge pull request #537 from chandramouli-sastry/dev
Browse files Browse the repository at this point in the history
Update scoring-utils to collect results from workdirs run with different timestamps
  • Loading branch information
priyakasimbeg authored Nov 30, 2023
2 parents 8e21b4e + 157fead commit 98146be
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 48 deletions.
77 changes: 41 additions & 36 deletions scoring/scoring_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import glob
import json
import os
import re
Expand All @@ -12,6 +13,7 @@
METRICS_LINE_REGEX = '(.*) Metrics: ({.*})'
TRIAL_DIR_REGEX = 'trial_(\d+)'
MEASUREMENTS_FILENAME = 'eval_measurements.csv'
TIMESTAMP = r"-\d{4}(-\d{2}){5}"

WORKLOADS = workloads_registry.WORKLOADS
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
Expand Down Expand Up @@ -143,7 +145,9 @@ def get_experiment_df(experiment_dir):
performance_profile.compute_performance_profiles.
Args:
experiment_dir: path to experiment directory containing
results for workloads.
results for workloads. Measurements from experiments
sharing the same prefix but different timestamps are
collected together.
The directory structure is assumed to be:
+ experiment_dir
+ <workload>
Expand All @@ -154,43 +158,44 @@ def get_experiment_df(experiment_dir):
df: DataFrame where indices are trials, columns are
metric names and values are lists.
e.g
+----+-----------+---------+--------------------+--------------------+
| | workload | trial | validation/accuracy| score |
|----+-----------+---------+--------------------+--------------------|
| 0 | mnist_jax | trial_1 | [0.0911, 0.0949] | [10.6396, 10.6464] |
+----+-----------+---------+--------------------+--------------------+
+----+-----------+-----------------------------+--------------------+--------------------+
| | workload | trial | validation/accuracy| score |
|----+-----------+-----------------------------+--------------------+--------------------|
| 0 | mnist_jax | (trial_1, <experiment_dir>) | [0.0911, 0.0949] | [10.6396, 10.6464] |
+----+-----------+-----------------------------+--------------------+--------------------+
"""
df = pd.DataFrame()
workload_dirs = os.listdir(experiment_dir)
num_workloads = len(workload_dirs)
for workload in workload_dirs:
data = {
'workload': workload,
}
trial_dirs = [
t for t in os.listdir(os.path.join(experiment_dir, workload))
if re.match(TRIAL_DIR_REGEX, t)
]
workload_df = pd.DataFrame()
for trial in trial_dirs:
eval_measurements_filepath = os.path.join(
experiment_dir,
workload,
trial,
MEASUREMENTS_FILENAME,
)
try:
trial_df = pd.read_csv(eval_measurements_filepath)
except FileNotFoundError:
logging.info(f'Could not read {eval_measurements_filepath}')
continue
data['trial'] = trial
for column in trial_df.columns:
values = trial_df[column].to_numpy()
data[column] = values
trial_df = pd.DataFrame([data])
workload_df = pd.concat([workload_df, trial_df], ignore_index=True)
df = pd.concat([df, workload_df], ignore_index=True)
paths = filter(
lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir,
glob.glob(f"{experiment_dir}*"))
for experiment_dir in list(paths):
workload_dirs = os.listdir(experiment_dir)
for workload in workload_dirs:
data = {
'workload': workload,
}
trial_dirs = [
t for t in os.listdir(os.path.join(experiment_dir, workload))
if re.match(TRIAL_DIR_REGEX, t)
]
for trial in trial_dirs:
eval_measurements_filepath = os.path.join(
experiment_dir,
workload,
trial,
MEASUREMENTS_FILENAME,
)
try:
trial_df = pd.read_csv(eval_measurements_filepath)
except FileNotFoundError as e:
logging.info(f'Could not read {eval_measurements_filepath}')
continue
data['trial'] = (trial, experiment_dir)
for column in trial_df.columns:
values = trial_df[column].to_numpy()
data[column] = values
trial_df = pd.DataFrame([data])
df = pd.concat([df, trial_df], ignore_index=True)
return df


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples
3.3074607849121094,0.0,10.884385585784912,1,0,10.884385585784912,0.1094000041484832,2.584994316101074,10000,14.191919088363647,0.1090000048279762,2.5650947093963623,0.1155000030994415,2.579528570175171,10000
3.337209701538086,0.0272295475006103,20.800366163253784,2260,0,20.800366163253784,0.8156000375747681,0.8071792125701904,10000,24.222018003463745,0.8114000558853149,0.8132649064064026,0.8164000511169434,0.8204169273376465,10000
3.3682632446289062,0.0540821552276611,30.71659564971924,4595,0,30.71659564971924,0.8079000115394592,0.728191077709198,10000,34.255709409713745,0.8051000237464905,0.7325593829154968,0.8007000088691711,0.7405639290809631,10000
3.392564058303833,0.079397439956665,40.63307213783264,6971,0,40.63307213783264,0.8212000131607056,0.6618639826774597,10000,44.28224039077759,0.8146000504493713,0.6732342839241028,0.8176000118255615,0.6771405339241028,10000
3.4162850379943848,0.1048784255981445,50.54938673973084,9310,0,50.54938673973084,0.8224000334739685,0.6520460844039917,10000,54.30677127838135,0.8114000558853149,0.6680686473846436,0.8162000179290771,0.6716861724853516,10000
3.44294548034668,0.1316289901733398,60.46591234207153,11640,0,60.46591234207153,0.8291000127792358,0.612684965133667,10000,64.33635830879211,0.8278000354766846,0.6238877773284912,0.829200029373169,0.6163886785507202,10000
3.4705374240875244,0.1571955680847168,70.38250923156738,13962,0,70.38250923156738,0.8330000638961792,0.6149598956108093,10000,74.36598539352417,0.8288000226020813,0.6226670742034912,0.829800009727478,0.6122417449951172,10000
3.496747732162476,0.1836600303649902,80.29951977729797,16249,0,80.29951977729797,0.8243000507354736,0.6327693462371826,10000,84.3950743675232,0.8252000212669373,0.644512414932251,0.8238000273704529,0.6446439623832703,10000
3.522087574005127,0.20969820022583,90.21488404273988,18545,0,90.21488404273988,0.8294000625610352,0.6092277765274048,10000,94.4209051132202,0.8259000182151794,0.6247073411941528,0.8263000249862671,0.6222184896469116,10000
3.549894332885742,0.2363710403442382,100.13060092926024,20791,0,100.13060092926024,0.8614000678062439,0.52787846326828,10000,104.44931364059448,0.8500000238418579,0.553955078125,0.858500063419342,0.5388756990432739,10000
3.5757312774658203,0.2636079788208008,110.0454170703888,23088,0,110.0454170703888,0.8564000129699707,0.5537337064743042,10000,114.4781494140625,0.8472000360488892,0.5771669745445251,0.8512000441551208,0.5673128366470337,10000
3.600522756576538,0.2896809577941894,119.95860981941225,25491,0,119.95860981941225,0.8591000437736511,0.5227596163749695,10000,124.5031967163086,0.8520000576972961,0.5452638268470764,0.8614000678062439,0.5311681628227234,10000
3.625296115875244,0.3164572715759277,129.8751039505005,27834,0,129.8751039505005,0.8648000359535217,0.5044941306114197,10000,134.5310184955597,0.8707000613212585,0.4973550140857696,0.872700035572052,0.5061097741127014,10000
3.652939796447754,0.3452167510986328,139.7876615524292,30160,0,139.7876615524292,0.8766000270843506,0.4868762791156769,10000,144.55956268310547,0.8693000674247742,0.5040879845619202,0.876300036907196,0.4923528432846069,10000
3.67879319190979,0.3744730949401855,149.7024416923523,32448,0,149.7024416923523,0.8764000535011292,0.4738911986351013,10000,154.5886163711548,0.8797000646591187,0.4765508472919464,0.8817000389099121,0.4693307876586914,10000
3.70514726638794,0.401404857635498,159.61914157867432,34683,0,159.61914157867432,0.8802000284194946,0.4768852889537811,10000,164.61644506454468,0.8692000508308411,0.4997503161430359,0.8842000365257263,0.4773700535297394,10000
3.73129153251648,0.4311785697937011,169.52961444854736,36914,0,169.52961444854736,0.8880000710487366,0.4485557973384857,10000,174.64349031448364,0.8759000301361084,0.4759842455387115,0.8856000304222107,0.4616841077804565,10000
3.756492137908936,0.4605388641357422,179.44147443771362,39160,0,179.44147443771362,0.8827000260353088,0.4693212509155273,10000,184.6703307628632,0.8759000301361084,0.4877949357032776,0.8823000192642212,0.47818323969841,10000
3.782692193984986,0.4901380538940429,189.3587987422943,41392,0,189.3587987422943,0.8880000710487366,0.4533746242523193,10000,194.70058631896973,0.8836000561714172,0.4600184261798858,0.8873000144958496,0.4541208744049072,10000
3.809891700744629,0.5197367668151855,199.274918794632,43666,0,199.274918794632,0.8891000151634216,0.4392279982566833,10000,204.731507062912,0.8875000476837158,0.4534733593463897,0.8879000544548035,0.4517350792884826,10000
3.8350577354431152,0.5492231845855713,209.1874165534973,45945,0,209.1874165534973,0.8857000470161438,0.4512746930122375,10000,214.7579679489136,0.8837000131607056,0.4590539634227752,0.8818000555038452,0.4652692973613739,10000
3.85998797416687,0.5787074565887451,219.0983612537384,48237,0,219.0983612537384,0.89000004529953,0.4437272548675537,10000,224.7834641933441,0.8831000328063965,0.4595882892608642,0.8875000476837158,0.4498450458049774,10000
3.885103702545166,0.6087353229522705,229.0138251781464,50487,0,229.0138251781464,0.894800066947937,0.4131699800491333,10000,234.81255435943604,0.8968000411987305,0.4132199287414551,0.8938000202178955,0.4232735633850097,10000
3.911269426345825,0.6387021541595459,238.9281618595124,52699,0,238.9281618595124,0.9029000401496888,0.4060218632221222,10000,244.8411045074463,0.8994000554084778,0.4132817685604095,0.8960000276565552,0.4201119840145111,10000
3.936365365982056,0.668968677520752,248.83987426757807,54968,0,248.83987426757807,0.9014000296592712,0.3983155488967895,10000,254.8668367862701,0.8975000381469727,0.4028871953487396,0.9005000591278076,0.4035448133945465,10000
3.963345289230346,0.6995627880096436,258.7516326904297,57238,0,258.7516326904297,0.9070000648498536,0.3919171690940857,10000,264.8952040672302,0.905500054359436,0.3880215883255005,0.9032000303268432,0.3942824900150299,10000
3.9895851612091056,0.7304470539093018,268.66518568992615,59479,0,268.66518568992615,0.910700023174286,0.3861576020717621,10000,274.9236674308777,0.903700053691864,0.4019041359424591,0.9068000316619872,0.3913732171058655,10000
4.01618766784668,0.7612087726593018,278.5771634578705,61780,0,278.5771634578705,0.9115000367164612,0.3799691796302795,10000,284.9529445171356,0.9104000329971312,0.3839159905910492,0.9086000323295592,0.3893324434757232,10000
4.04093861579895,0.7921364307403564,288.49143171310425,64050,0,288.49143171310425,0.9142000675201416,0.3721667230129242,10000,294.9814648628235,0.9113000631332396,0.374419093132019,0.9112000465393066,0.3832390904426574,10000
4.066902160644531,0.8264949321746826,298.4053153991699,66301,0,298.4053153991699,0.9122000336647034,0.372380793094635,10000,305.0141706466675,0.9129000306129456,0.3773751556873321,0.9122000336647034,0.3790977895259857,10000
4.0928122997283936,0.8593041896820068,308.3142282962799,68549,0,308.3142282962799,0.9160000681877136,0.3669502735137939,10000,315.04016852378845,0.9104000329971312,0.3748669624328613,0.9105000495910645,0.3784786462783813,10000
4.1184492111206055,0.8911542892456055,318.22746777534485,70800,0,318.22746777534485,0.9153000712394714,0.3636593520641327,10000,325.0690577030182,0.9185000658035278,0.3687613904476166,0.910800039768219,0.3743433952331543,10000
4.146052122116089,0.922600030899048,328.1395530700684,72999,0,328.1395530700684,0.9163000583648682,0.3645511567592621,10000,335.09958815574646,0.916100025177002,0.364233136177063,0.9119000434875488,0.3745094835758209,10000
4.172115802764893,0.9543182849884032,338.0537300109863,75213,0,338.0537300109863,0.9157000184059144,0.3641790449619293,10000,345.1285364627838,0.9118000268936156,0.3679656088352203,0.9119000434875488,0.3735902607440948,10000
4.197761535644531,0.9881937503814696,347.9646620750427,77485,0,347.9646620750427,0.9155000448226928,0.3638691902160644,10000,355.1568307876587,0.9173000454902648,0.3572725653648376,0.9121000170707704,0.3734744191169739,10000
4.22634482383728,1.0207302570343018,357.88095235824585,79736,0,357.88095235824585,0.915600061416626,0.3638694882392883,10000,365.1885812282562,0.9158000349998474,0.3646068274974823,0.9122000336647034,0.3734698891639709,10000
4.25168514251709,1.0548341274261477,367.7923603057861,81999,0,367.7923603057861,0.915600061416626,0.3638694882392883,10000,375.2162811756134,0.9165000319480896,0.3642504215240478,0.9122000336647034,0.3734698891639709,10000
4.2778027057647705,1.0879690647125244,377.7043735980988,84267,0,377.7043735980988,0.915600061416626,0.3638694882392883,10000,385.2464287281037,0.914500057697296,0.3697480857372284,0.9122000336647034,0.3734698891639709,10000
4.303594589233398,1.1212027072906494,387.61431884765625,86553,0,387.61431884765625,0.915600061416626,0.3638694882392883,10000,395.2737421989441,0.917400062084198,0.3631039559841156,0.9122000336647034,0.3734698891639709,10000
4.330204248428345,1.1557085514068604,397.5266599655152,88769,0,397.5266599655152,0.915600061416626,0.3638694882392883,10000,405.3015666007996,0.907300055027008,0.3805611133575439,0.9122000336647034,0.3734698891639709,10000
4.3575522899627686,1.1906015872955322,407.4391739368439,90941,0,407.4391739368439,0.915600061416626,0.3638694882392883,10000,415.3328382968903,0.9160000681877136,0.3628231287002563,0.9122000336647034,0.3734698891639709,10000
4.3869407176971436,1.2255029678344727,417.34814286231995,93156,0,417.34814286231995,0.915600061416626,0.36386948823928833,10000,425.36263489723206,0.9160000681877136,0.3684694468975067,0.9122000336647034,0.37346988916397095,10000
26 changes: 14 additions & 12 deletions scoring/test_scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ def test_get_trials_df(self):
self.assertEqual(len(df.at['1', column]), NUM_EVALS)

def test_get_experiment_df(self):
_ = scoring_utils.get_experiment_df(TEST_DIR)
self.assertWarnsRegex(
Warning,
f'There should be {NUM_WORKLOADS} workloads but there are 1.',
scoring_utils.get_experiment_df,
TEST_DIR)
self.assertWarnsRegex(
Warning,
f'There should be {NUM_TRIALS} trials for workload mnist_jax but there '
'are only 1.',
scoring_utils.get_experiment_df,
TEST_DIR)
df = scoring_utils.get_experiment_df(TEST_DIR)
assert len(df) == 2

def test_scores(self):
df = scoring_utils.get_experiment_df(TEST_DIR)
performance_profile_df = scoring.compute_performance_profiles(
{'my.submission': df},
time_col='score',
min_tau=1.0,
max_tau=None,
reference_submission_tag=None,
num_points=100,
scale='linear',
verbosity=0)


if __name__ == '__main__':
Expand Down

0 comments on commit 98146be

Please sign in to comment.