diff --git a/benchmarks/results/process.py b/benchmarks/results/process.py index 80e2e991..762c14c5 100644 --- a/benchmarks/results/process.py +++ b/benchmarks/results/process.py @@ -27,9 +27,18 @@ def process(): filter_params = ["tol", "seed", "max_epochs", "lr", "batch_size"] data = {} + + # Load existing data + try: + with open("results/data.json", "r") as f: + data = json.load(f) + except FileNotFoundError: + pass + benchmarks = sorted(df["benchmark"].unique()) for bm in benchmarks: - data[bm] = [] + if bm not in data: + data[bm] = [] df_benchmark = df[df["benchmark"] == bm] operators = sorted(df_benchmark["operator"].unique()) @@ -40,6 +49,22 @@ def process(): # Get the best operator run for this benchmark best = operator_data.sort_values("loss/test").iloc[0].to_dict() + # Check if the operator already exists with a better loss + existing_operators = [v["operator"] for v in data[bm]] + if operator in existing_operators: + old_best_value = float( + [v["loss/test"] for v in data[bm] if v["operator"] == operator][0] + ) + if best["loss/test"] >= old_best_value: + print( + f"Skipping {bm} {operator} because it has a better loss already: " + f"old={old_best_value:.4e} new={best['loss/test']:.4e}" + ) + continue + else: + # Remove the old operator + data[bm] = [v for v in data[bm] if v["operator"] != operator] + # Filter out parameters best["params"] = { k: v for k, v in best["params"].items() if k not in filter_params