Skip to content

Commit

Permalink
Load results in process.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed Apr 22, 2024
1 parent 3945f54 commit 1903951
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion benchmarks/results/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,18 @@ def process():
filter_params = ["tol", "seed", "max_epochs", "lr", "batch_size"]

data = {}

# Load existing data
try:
with open("results/data.json", "r") as f:
data = json.load(f)
except FileNotFoundError:
pass

benchmarks = sorted(df["benchmark"].unique())
for bm in benchmarks:
data[bm] = []
if bm not in data:
data[bm] = []

df_benchmark = df[df["benchmark"] == bm]
operators = sorted(df_benchmark["operator"].unique())
Expand All @@ -40,6 +49,22 @@ def process():
# Get the best operator run for this benchmark
best = operator_data.sort_values("loss/test").iloc[0].to_dict()

# Check if the operator already exists with a better loss
existing_operators = [v["operator"] for v in data[bm]]
if operator in existing_operators:
old_best_value = float(
[v["loss/test"] for v in data[bm] if v["operator"] == operator][0]
)
if best["loss/test"] >= old_best_value:
print(
f"Skipping {bm} {operator} because it has a better loss already: "
f"old={old_best_value:.4e} new={best['loss/test']:.4e}"
)
continue
else:
# Remove the old operator
data[bm] = [v for v in data[bm] if v["operator"] != operator]

# Filter out parameters
best["params"] = {
k: v for k, v in best["params"].items() if k not in filter_params
Expand Down

0 comments on commit 1903951

Please sign in to comment.