Skip to content

Commit

Permalink
Merge pull request #781 from mlcommons/dev
Browse files Browse the repository at this point in the history
Dev -> main
  • Loading branch information
priyakasimbeg authored Aug 29, 2024
2 parents 2d1ac6f + 3b832f4 commit bdece3b
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ algorithmic_efficiency/workloads/librispeech_conformer/work_dir
*.vocab
wandb/
*.txt
scoring/plots/

!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv
112 changes: 112 additions & 0 deletions scoring/compute_speedups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""File to compute speedups (i.e. geometric means between runtimes)."""

import pickle

from absl import app
from absl import flags
import numpy as np
import pandas as pd
from performance_profile import BASE_WORKLOADS
from performance_profile import get_workloads_time_to_target
from scipy import stats

flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.')
flags.DEFINE_string(
'base',
'prize_qualification_baseline',
'Base submission to compare to. Defaults to the `prize_qualification_baseline`.'
)
flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.')
flags.DEFINE_boolean('self_tuning_ruleset',
False,
'Whether the self-tuning ruleset is being scored.')
flags.DEFINE_boolean('save_results',
False,
'Whether to save the results to disk.')
FLAGS = flags.FLAGS

MAX_BUDGETS = {
'criteo1tb': 7703,
'fastmri': 8859,
'imagenet_resnet': 63_008,
'imagenet_vit': 77_520,
'librispeech_conformer': 61_068,
'librispeech_deepspeech': 55_506,
'ogbg': 18_477,
'wmt': 48_151,
}


def replace_inf(row):
"""Replace ifs with maximum runtime budget (+1 second).
Args:
row (pd.Series): The original row.
Returns:
pd.Series: The row with infs replaced.
"""
workload_name = row.name
# Factor of 3 for self-tuning ruleset
factor = 3 if FLAGS.self_tuning_ruleset else 1
max_runtime_workload = factor * MAX_BUDGETS[workload_name]
row.replace(np.inf, max_runtime_workload + 1, inplace=True)
return row


def compute_speedup():
"""Compute speedup between two algorithms."""
# Load results from disk
with open(FLAGS.results_txt, 'rb') as f:
results = pickle.load(f)

# Compute median over runtimes for both training algorithms
base_results = get_workloads_time_to_target(
results[FLAGS.base],
FLAGS.base,
time_col="score",
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)
comparison_results = get_workloads_time_to_target(
results[FLAGS.comparison],
FLAGS.comparison,
time_col="score",
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)

# Merge results
merged_results = pd.concat([base_results, comparison_results]).transpose()

# Ignore workload variants (only consider base workloads) for speedup
merged_results = merged_results.loc[merged_results.index.isin(BASE_WORKLOADS)]

# Replace infs with maximum runtime budget (+1 second)
merged_results = merged_results.apply(replace_inf, axis=1)

# Compute speedup
merged_results['speedup'] = merged_results[
f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
speedups = merged_results['speedup'].to_numpy()
mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups

print(merged_results, end='\n\n')
print(
f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}"
)

if FLAGS.save_results:
# Optionally save results to disk
print("Saving results to disk...")
filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv'
merged_results.to_csv(filename)


def main(_):
"""Main function to compute speedup between two algorithms."""
compute_speedup()


if __name__ == '__main__':
flags.mark_flag_as_required('results_txt')
flags.mark_flag_as_required('comparison')
app.run(main)
25 changes: 23 additions & 2 deletions scoring/performance_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
the dictionary of submissions.
"""
import itertools
import json
import operator
import os
import re
Expand All @@ -45,6 +46,10 @@
BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
# Open json file to read heldout workloads
# TODO: This probably shouldn't be hardcoded but passed as an argument.
with open("held_out_workloads_algoperf_v05.json", "r") as f:
HELDOUT_WORKLOADS = json.load(f)
# These global variables have to be set according to the current set of
# workloads and rules for the scoring to be correct.
# We do not use the workload registry since it contains test and development
Expand Down Expand Up @@ -248,6 +253,9 @@ def filter(x):
try:
if x[variant_workload] == np.inf:
return np.inf
# Also check for nan values (e.g. OOMs)
elif np.isnan(x[variant_workload]):
return np.inf
else:
return x[base_workload]
except KeyError as e:
Expand Down Expand Up @@ -306,19 +314,32 @@ def compute_performance_profiles(submissions,
self_tuning_ruleset,
strict))
df = pd.concat(dfs)
# Restrict to base and sampled held-out workloads
# (ignore the additional workload variants of the baseline
# as they cause issues when checking for nans in workload variants).
df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS]
# Sort workloads alphabetically (for better display)
df = df.reindex(sorted(df.columns), axis=1)

# For each held-out workload set to inf if the base workload is inf or nan
for workload in df.keys():
if workload not in BASE_WORKLOADS:
# If base do not have finite score set variant score to inf
base_workload = get_base_workload_name(workload)
df[workload] = df.apply(
variant_criteria_filter(workload, base_workload), axis=1)

# Set score to inf if not within 4x of fastest submission
best_scores = df.min(axis=0)
df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf

# For each held-out workload if variant target was not hit set submission to inf
# For each base workload if variant target was not hit set submission to inf
for workload in df.keys():
if workload not in BASE_WORKLOADS:
# If variants do not have finite score set base_workload score to inf
base_workload = get_base_workload_name(workload)
df[base_workload] = df.apply(
variant_criteria_filter(base_workload, workload), axis=1)

df = df[BASE_WORKLOADS]

if verbosity > 0:
Expand Down
2 changes: 1 addition & 1 deletion scoring/score_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def main(_):
results,
time_col='score',
min_tau=1.0,
max_tau=None,
max_tau=4.0,
reference_submission_tag=None,
num_points=100,
scale='linear',
Expand Down

0 comments on commit bdece3b

Please sign in to comment.