Skip to content

Commit

Permalink
Merge branch 'mlcommons:main' into scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg authored Oct 15, 2024
2 parents a0e4502 + ef77fc4 commit 790c282
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 58 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ algorithmic_efficiency/workloads/librispeech_conformer/work_dir
*.vocab
wandb/
*.txt
scoring/plots/

!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv
2 changes: 1 addition & 1 deletion CALL_FOR_SUBMISSIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Submissions can compete under two hyperparameter tuning rulesets (with separate
- **Registration deadline to express non-binding intent to submit: February 28th, 2024**.\
Please fill out the (mandatory but non-binding) [**registration form**](https://forms.gle/K7ty8MaYdi2AxJ4N8).
- **Submission deadline: April 04th, 2024** *(moved by a week from the initial March 28th, 2024)*
- [tentative] Announcement of all results: July 15th, 2024
- [Announcement of all results](https://mlcommons.org/2024/08/mlc-algoperf-benchmark-competition/): August 1st, 2024

For a detailed and up-to-date timeline see the [Competition Rules](/COMPETITION_RULES.md).

Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
---

> [!IMPORTANT]
> Submitters are no longer required to self-report results.
> We are currently in the process of evaluating and scoring received submissions.
> We are aiming to release results by July 15th 2024.
> For other key dates please see [Call for Submissions](CALL_FOR_SUBMISSIONS.md).
> The results of the inaugural AlgoPerf: Training Algorithms benchmark competition have been announced. See the [MLCommons blog post](https://mlcommons.org/2024/08/mlc-algoperf-benchmark-competition/) for an overview and the [results page](https://mlcommons.org/benchmarks/algorithms/) for more details on the results. We are currently preparing an in-depth analysis of the results in the form of a paper and plan the next iteration of the benchmark competition.
## Table of Contents <!-- omit from toc -->

Expand Down
112 changes: 112 additions & 0 deletions scoring/compute_speedups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""File to compute speedups (i.e. geometric means between runtimes)."""

import pickle

from absl import app
from absl import flags
import numpy as np
import pandas as pd
from performance_profile import BASE_WORKLOADS
from performance_profile import get_workloads_time_to_target
from scipy import stats

flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.')
flags.DEFINE_string(
'base',
'prize_qualification_baseline',
'Base submission to compare to. Defaults to the `prize_qualification_baseline`.'
)
flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.')
flags.DEFINE_boolean('self_tuning_ruleset',
False,
'Whether the self-tuning ruleset is being scored.')
flags.DEFINE_boolean('save_results',
False,
'Whether to save the results to disk.')
FLAGS = flags.FLAGS

MAX_BUDGETS = {
'criteo1tb': 7703,
'fastmri': 8859,
'imagenet_resnet': 63_008,
'imagenet_vit': 77_520,
'librispeech_conformer': 61_068,
'librispeech_deepspeech': 55_506,
'ogbg': 18_477,
'wmt': 48_151,
}


def replace_inf(row):
"""Replace ifs with maximum runtime budget (+1 second).
Args:
row (pd.Series): The original row.
Returns:
pd.Series: The row with infs replaced.
"""
workload_name = row.name
# Factor of 3 for self-tuning ruleset
factor = 3 if FLAGS.self_tuning_ruleset else 1
max_runtime_workload = factor * MAX_BUDGETS[workload_name]
row.replace(np.inf, max_runtime_workload + 1, inplace=True)
return row


def compute_speedup():
"""Compute speedup between two algorithms."""
# Load results from disk
with open(FLAGS.results_txt, 'rb') as f:
results = pickle.load(f)

# Compute median over runtimes for both training algorithms
base_results = get_workloads_time_to_target(
results[FLAGS.base],
FLAGS.base,
time_col="score",
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)
comparison_results = get_workloads_time_to_target(
results[FLAGS.comparison],
FLAGS.comparison,
time_col="score",
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)

# Merge results
merged_results = pd.concat([base_results, comparison_results]).transpose()

# Ignore workload variants (only consider base workloads) for speedup
merged_results = merged_results.loc[merged_results.index.isin(BASE_WORKLOADS)]

# Replace infs with maximum runtime budget (+1 second)
merged_results = merged_results.apply(replace_inf, axis=1)

# Compute speedup
merged_results['speedup'] = merged_results[
f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
speedups = merged_results['speedup'].to_numpy()
mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups

print(merged_results, end='\n\n')
print(
f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}"
)

if FLAGS.save_results:
# Optionally save results to disk
print("Saving results to disk...")
filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv'
merged_results.to_csv(filename)


def main(_):
"""Main function to compute speedup between two algorithms."""
compute_speedup()


if __name__ == '__main__':
flags.mark_flag_as_required('results_txt')
flags.mark_flag_as_required('comparison')
app.run(main)
115 changes: 83 additions & 32 deletions scoring/performance_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
the dictionary of submissions.
"""
import itertools
import json
import operator
import os
import re

from absl import logging
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tabulate import tabulate

from algorithmic_efficiency.workloads.workloads import get_base_workload_name
import algorithmic_efficiency.workloads.workloads as workloads_registry
Expand All @@ -43,6 +46,10 @@
BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
# Open json file to read heldout workloads
# TODO: This probably shouldn't be hardcoded but passed as an argument.
with open("held_out_workloads_algoperf_v05.json", "r") as f:
HELDOUT_WORKLOADS = json.load(f)
# These global variables have to be set according to the current set of
# workloads and rules for the scoring to be correct.
# We do not use the workload registry since it contains test and development
Expand All @@ -63,6 +70,37 @@

MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']

#MPL params
mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = [
'Times New Roman'
] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice
mpl.rcParams['font.size'] = 22
mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures

# Plot Elements
mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed
mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(
color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
"#9467bd"]) # Example color cycle (consider ColorBrewer or viridis)
mpl.rcParams['axes.labelsize'] = 22 # Axis label font size
mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size
mpl.rcParams['ytick.labelsize'] = 20

# Legends and Gridlines
mpl.rcParams['legend.fontsize'] = 20 # Legend font size
mpl.rcParams[
'legend.loc'] = 'best' # Let matplotlib decide the best legend location
mpl.rcParams['axes.grid'] = True # Enable grid
mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency


def print_dataframe(df):
tabulated_df = tabulate(df.T, headers='keys', tablefmt='psql')
logging.info(tabulated_df)


def generate_eval_cols(metrics):
splits = ['train', 'validation']
Expand Down Expand Up @@ -150,10 +188,10 @@ def get_workloads_time_to_target(submission,
if strict:
raise ValueError(
f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
f'but found {num_workloads} workloads.')
f'but found {num_workloads} workloads for {submission_name}.')
logging.warning(
f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
f'but found {num_workloads} workloads.')
f'but found {num_workloads} workloads for {submission_name}.')

# For each workload get submission time get the submission times to target.
for workload, group in submission.groupby('workload'):
Expand All @@ -164,11 +202,13 @@ def get_workloads_time_to_target(submission,
num_studies = len(group.groupby('study'))
if num_studies != NUM_STUDIES:
if strict:
raise ValueError(f'Expecting {NUM_STUDIES} trials for workload '
f'{workload} but found {num_studies} trials.')
raise ValueError(f'Expecting {NUM_STUDIES} studies for workload '
f'{workload} but found {num_studies} studies '
f'for {submission_name}.')
else:
logging.warning(f'Expecting {NUM_STUDIES} trials for workload '
f'{workload} but found {num_studies} trials.')
logging.warning(f'Expecting {NUM_STUDIES} studies for workload '
f'{workload} but found {num_studies} studies '
f'for {submission_name}.')

# For each study check trials
for study, group in group.groupby('study'):
Expand All @@ -177,11 +217,15 @@ def get_workloads_time_to_target(submission,
num_trials = len(group)
if num_trials != NUM_TRIALS and not self_tuning_ruleset:
if strict:
raise ValueError(f'Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')
raise ValueError(
f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials '
f'for {submission_name}.')
else:
logging.warning(f'Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')
logging.warning(
f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials '
f'for {submission_name}.')

# Get trial and time index that reaches target
trial_idx, time_idx = get_best_trial_index(
Expand All @@ -194,13 +238,12 @@ def get_workloads_time_to_target(submission,

workloads.append({
'submission': submission_name,
'workload': workload,
'workload': re.sub(r'_(jax|pytorch)$', '', workload),
time_col: np.median(time_vals_per_study),
})

df = pd.DataFrame.from_records(workloads)
df = df.pivot(index='submission', columns='workload', values=time_col)

return df


Expand All @@ -210,6 +253,9 @@ def filter(x):
try:
if x[variant_workload] == np.inf:
return np.inf
# Also check for nan values (e.g. OOMs)
elif np.isnan(x[variant_workload]):
return np.inf
else:
return x[base_workload]
except KeyError as e:
Expand Down Expand Up @@ -268,27 +314,33 @@ def compute_performance_profiles(submissions,
self_tuning_ruleset,
strict))
df = pd.concat(dfs)
# Restrict to base and sampled held-out workloads
# (ignore the additional workload variants of the baseline
# as they cause issues when checking for nans in workload variants).
df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS]
# Sort workloads alphabetically (for better display)
df = df.reindex(sorted(df.columns), axis=1)

# For each held-out workload set to inf if the base workload is inf or nan
for workload in df.keys():
if workload not in BASE_WORKLOADS:
# If base do not have finite score set variant score to inf
base_workload = get_base_workload_name(workload)
df[workload] = df.apply(
variant_criteria_filter(workload, base_workload), axis=1)

# Set score to inf if not within 4x of fastest submission
best_scores = df.min(axis=0)
df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf

# For each held-out workload if variant target was not hit set submission to inf
framework = None
# For each base workload if variant target was not hit set submission to inf
for workload in df.keys():
# Check if this is a variant
framework = workload.split('_')[-1]
workload_ = workload.split(f'_{framework}')[0]
if workload_ not in BASE_WORKLOADS:
if workload not in BASE_WORKLOADS:
# If variants do not have finite score set base_workload score to inf
base_workload = get_base_workload_name(workload_)
base_workload = get_base_workload_name(workload)
df[base_workload] = df.apply(
variant_criteria_filter(base_workload + f'_{framework}', workload),
axis=1)

base_workloads = [w + f'_{framework}' for w in BASE_WORKLOADS]
df = df[base_workloads]
print(df)
variant_criteria_filter(base_workload, workload), axis=1)
df = df[BASE_WORKLOADS]

if verbosity > 0:
logging.info('\n`{time_col}` to reach target:')
Expand Down Expand Up @@ -375,8 +427,7 @@ def plot_performance_profiles(perf_df,
df_col,
scale='linear',
save_dir=None,
figsize=(30, 10),
font_size=18):
figsize=(30, 10)):
"""Plot performance profiles.
Args:
Expand All @@ -396,12 +447,12 @@ def plot_performance_profiles(perf_df,
Returns:
None. If a valid save_dir is provided, save both the plot and perf_df.
"""
fig = perf_df.T.plot(figsize=figsize)
fig = perf_df.T.plot(figsize=figsize, alpha=0.7)
df_col_display = f'log10({df_col})' if scale == 'log' else df_col
fig.set_xlabel(
f'Ratio of `{df_col_display}` to best submission', size=font_size)
fig.set_ylabel('Proportion of workloads', size=font_size)
fig.legend(prop={'size': font_size}, bbox_to_anchor=(1.0, 1.0))
fig.set_xlabel(f'Ratio of `{df_col_display}` to best submission')
fig.set_ylabel('Proportion of workloads')
fig.legend(bbox_to_anchor=(1.0, 1.0))
plt.tight_layout()
maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}')
maybe_save_df_to_csv(save_dir,
perf_df,
Expand Down
Loading

0 comments on commit 790c282

Please sign in to comment.