Skip to content

Commit

Permalink
Add script to make release from fine tuned hydra (#875)
Browse files Browse the repository at this point in the history
* add script to convert hydra checkpoint to release

* add tests
  • Loading branch information
misko authored Oct 20, 2024
1 parent 8a9adbb commit 20d2798
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 59 deletions.
94 changes: 94 additions & 0 deletions src/fairchem/core/scripts/convert_hydra_to_release.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

import argparse
import logging

import torch
import yaml


def convert_fine_tune_checkpoint(
fine_tune_checkpoint_fn,
output_checkpoint_fn,
fine_tune_yaml_fn=None,
output_yaml_fn=None,
):
fine_tune_checkpoint = torch.load(fine_tune_checkpoint_fn, map_location="cpu")

if "config" not in fine_tune_checkpoint:
raise KeyError("Finetune checkpoint does not have a valid 'config' field")

try:
starting_checkpoint_fn = fine_tune_checkpoint["config"]["model"][
"finetune_config"
]["starting_checkpoint"]
except KeyError as e:
logging.error(
f"Finetune config missing entry config/model/finetune_config/starting_checkpoint {fine_tune_checkpoint['config']}"
)
raise e

starting_checkpoint = torch.load(starting_checkpoint_fn, map_location="cpu")
start_checkpoint_model_config = starting_checkpoint["config"]["model"]

fine_tune_checkpoint["config"]["model"]["backbone"] = start_checkpoint_model_config[
"backbone"
]
# if we are data only, then copy over the heads config too
ft_data_only = "heads" not in fine_tune_checkpoint["config"]["model"]
if ft_data_only:
fine_tune_checkpoint["config"]["model"]["heads"] = (
start_checkpoint_model_config["heads"]
)

fine_tune_checkpoint["config"]["model"].pop("finetune_config")

torch.save(fine_tune_checkpoint, output_checkpoint_fn)

if fine_tune_yaml_fn is not None:
with open(fine_tune_yaml_fn) as yaml_f:
fine_tune_yaml = yaml.safe_load(yaml_f)
fine_tune_yaml["model"].pop("finetune_config")
fine_tune_yaml["model"]["backbone"] = start_checkpoint_model_config["backbone"]
if ft_data_only:
fine_tune_yaml["model"]["heads"] = start_checkpoint_model_config["heads"]
with open(output_yaml_fn, "w") as yaml_file:
yaml.dump(fine_tune_yaml, yaml_file)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--fine-tune-checkpoint",
help="path to fine tuned checkpoint",
type=str,
required=True,
)
parser.add_argument(
"--output-release-checkpoint",
help="path to output checkpoint",
type=str,
required=True,
)
parser.add_argument(
"--fine-tune-yaml",
help="path to fine tune yaml config",
type=str,
required=False,
default=None,
)
parser.add_argument(
"--output-release-yaml",
help="path to output yaml config",
type=str,
required=False,
default=None,
)
args = parser.parse_args()

convert_fine_tune_checkpoint(
fine_tune_yaml_fn=args.fine_tune_yaml,
fine_tune_checkpoint_fn=args.fine_tune_checkpoint,
output_checkpoint_fn=args.output_release_checkpoint,
output_yaml_fn=args.output_release_yaml,
)
14 changes: 12 additions & 2 deletions tests/core/e2e/test_e2e_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _run_main(
save_predictions_to=None,
world_size=1,
):
config_yaml = Path(rundir) / "train_and_val_on_val.yml"
config_yaml = Path(rundir) / "test_run.yml"
update_yaml_with_dict(input_yaml, config_yaml, update_dict_with)
run_args = {
"run_dir": rundir,
Expand All @@ -119,7 +119,17 @@ def _run_main(
# run
parser = flags.get_parser()
args, override_args = parser.parse_known_args(
["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu", "--num-gpus", str(world_size)]
[
"--mode",
"train",
"--seed",
"100",
"--config-yml",
"config.yml",
"--cpu",
"--num-gpus",
str(world_size),
]
)
for arg_name, arg_value in run_args.items():
setattr(args, arg_name, arg_value)
Expand Down
198 changes: 142 additions & 56 deletions tests/core/e2e/test_e2e_finetune_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import pytest
from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint
import torch
from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths

Expand Down Expand Up @@ -43,12 +44,14 @@ def make_checkpoint(tempdir: str, data_source: Path, seed: int) -> str:
return ck_path


def run_main_with_ft_hydra(tempdir: str,
yaml: str,
data_src: str,
run_args: dict,
model_config: str,
output_checkpoint: str):
def run_main_with_ft_hydra(
tempdir: str,
yaml: str,
data_src: str,
run_args: dict,
model_config: str,
output_checkpoint: str,
):
_run_main(
tempdir,
yaml,
Expand All @@ -58,7 +61,7 @@ def run_main_with_ft_hydra(tempdir: str,
"eval_every": 8,
"batch_size": 1,
"num_workers": 0,
"lr_initial": 0.0 # don't learn anything
"lr_initial": 0.0, # don't learn anything
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(data_src),
Expand All @@ -74,6 +77,39 @@ def run_main_with_ft_hydra(tempdir: str,
)


def verify_release_checkpoint(release_yaml_fn, release_checkpoint_fn, ft_state_dict):
with tempfile.TemporaryDirectory() as temp_dir:
# now lets run the new release checkpoint for a few iterations at lr0.0
ck_release_ft_afterload_path = os.path.join(
temp_dir, "checkpoint_ft_release.pt"
)
release_ft_temp_dir = os.path.join(temp_dir, "release_ft")
os.makedirs(release_ft_temp_dir)

_run_main(
release_ft_temp_dir,
release_yaml_fn,
update_run_args_with={"seed": 1337, "checkpoint": release_checkpoint_fn},
save_checkpoint_to=ck_release_ft_afterload_path,
world_size=1,
update_dict_with={
"optim": {
"max_epochs": 2,
}
},
)

# make sure the checkpoint after running with lr0.0 is identical
# to the previous checkpoint
assert os.path.isfile(ck_release_ft_afterload_path)
ft_after_state_dict = torch.load(ck_release_ft_afterload_path)["state_dict"]
for key in ft_after_state_dict:
if key.startswith("module.backbone"):
assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key])
elif key.startswith("module.output_heads") and key.endswith("weight"):
assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key])


def test_finetune_hydra_retain_backbone(tutorial_val_src):
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
Expand All @@ -83,36 +119,57 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src):
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
model_config = {
"name" : "hydra",
"finetune_config": {'starting_checkpoint': starting_ckpt},
"name": "hydra",
"finetune_config": {"starting_checkpoint": starting_ckpt},
"heads": {
"energy": {
"module": "equiformer_v2_energy_head"
},
"forces": {
"module": "equiformer_v2_force_head"
}
}
"energy": {"module": "equiformer_v2_energy_head"},
"forces": {"module": "equiformer_v2_force_head"},
},
}
run_main_with_ft_hydra(tempdir = ft_temp_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
model_config = model_config,
output_checkpoint = ck_ft_path)
run_main_with_ft_hydra(
tempdir=ft_temp_dir,
yaml=ft_yml,
data_src=tutorial_val_src,
run_args={"seed": 1000},
model_config=model_config,
output_checkpoint=ck_ft_path,
)
assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
assert ft_ckpt["config"]["model"]["name"] == "hydra"
# check that the backbone weights are the same, and other weights are not the same
new_state_dict = ft_ckpt["state_dict"]
for key in new_state_dict:
if key.startswith("backbone"):
ft_state_dict = ft_ckpt["state_dict"]
for key in ft_state_dict:
if key.startswith("module.backbone"):
# backbone should be identical
assert torch.allclose(new_state_dict[key], old_state_dict[key])
elif key.startswith("output_heads") and key.endswith("weight"):
assert torch.allclose(ft_state_dict[key], old_state_dict[key])
elif key.startswith("module.output_heads") and key.endswith("weight"):
# heads weight should be different because the seeds are different
assert not torch.allclose(new_state_dict[key], old_state_dict[key])
assert not torch.allclose(ft_state_dict[key], old_state_dict[key])

# Add a test to convert the FT hydra checkpoint to a release checkpoint
# This could be a separate test but we would need to generate the FT checkpoint
# all over again

# Convert FT hydra checkpoint to release checkpoint
ck_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.pt")
yml_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.yml")
# the actual on disk yaml used in the previous run, after argument updates
fine_tune_yaml_fn = os.path.join(ft_temp_dir, "test_run.yml")
convert_fine_tune_checkpoint(
fine_tune_checkpoint_fn=ck_ft_path,
fine_tune_yaml_fn=fine_tune_yaml_fn,
output_checkpoint_fn=ck_release_ft_path,
output_yaml_fn=yml_release_ft_path,
)

# remove starting checkpoint, so that we cant accidentally load it
os.remove(ck_ft_path)

verify_release_checkpoint(
yml_release_ft_path, ck_release_ft_path, ft_state_dict
)


def test_finetune_hydra_data_only(tutorial_val_src):
Expand All @@ -124,25 +181,50 @@ def test_finetune_hydra_data_only(tutorial_val_src):
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
model_config = {
"name" : "hydra",
"finetune_config": {'starting_checkpoint': starting_ckpt},
"name": "hydra",
"finetune_config": {"starting_checkpoint": starting_ckpt},
}
run_main_with_ft_hydra(tempdir = ft_temp_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
model_config = model_config,
output_checkpoint = ck_ft_path)
run_main_with_ft_hydra(
tempdir=ft_temp_dir,
yaml=ft_yml,
data_src=tutorial_val_src,
run_args={"seed": 1000},
model_config=model_config,
output_checkpoint=ck_ft_path,
)
assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
config_model = ft_ckpt["config"]["model"]
assert config_model["name"] == "hydra"
# check that the entire model weights are the same
new_state_dict = ft_ckpt["state_dict"]
assert len(new_state_dict) == len(old_state_dict)
for key in new_state_dict:
assert torch.allclose(new_state_dict[key], old_state_dict[key])
ft_state_dict = ft_ckpt["state_dict"]
assert len(ft_state_dict) == len(old_state_dict)
for key in ft_state_dict:
assert torch.allclose(ft_state_dict[key], old_state_dict[key])

# Add a test to convert the FT hydra checkpoint to a release checkpoint
# This could be a separate test but we would need to generate the FT checkpoint
# all over again

# Convert FT hydra checkpoint to release checkpoint
ck_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.pt")
yml_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.yml")
# the actual on disk yaml used in the previous run, after argument updates
fine_tune_yaml_fn = os.path.join(ft_temp_dir, "test_run.yml")
convert_fine_tune_checkpoint(
fine_tune_checkpoint_fn=ck_ft_path,
fine_tune_yaml_fn=fine_tune_yaml_fn,
output_checkpoint_fn=ck_release_ft_path,
output_yaml_fn=yml_release_ft_path,
)

# remove starting checkpoint, so that we cant accidentally load it
os.remove(ck_ft_path)

verify_release_checkpoint(
yml_release_ft_path, ck_release_ft_path, ft_state_dict
)


def test_finetune_from_finetunehydra(tutorial_val_src):
Expand All @@ -153,31 +235,35 @@ def test_finetune_from_finetunehydra(tutorial_val_src):
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt")
model_config_1 = {
"name" : "hydra",
"finetune_config": {'starting_checkpoint': starting_ckpt},
"name": "hydra",
"finetune_config": {"starting_checkpoint": starting_ckpt},
}
run_main_with_ft_hydra(tempdir = finetune_run1_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
model_config = model_config_1,
output_checkpoint = ck_ft_path)
run_main_with_ft_hydra(
tempdir=finetune_run1_dir,
yaml=ft_yml,
data_src=tutorial_val_src,
run_args={"seed": 1000},
model_config=model_config_1,
output_checkpoint=ck_ft_path,
)
assert os.path.isfile(ck_ft_path)

# now that we have a second checkpoint, try finetuning again from this checkpoint
########################################################################################
with tempfile.TemporaryDirectory() as finetune_run2_dir:
ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt")
model_config_2 = {
"name" : "hydra",
"finetune_config": {'starting_checkpoint': ck_ft_path},
"name": "hydra",
"finetune_config": {"starting_checkpoint": ck_ft_path},
}
run_main_with_ft_hydra(tempdir = finetune_run2_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
model_config = model_config_2,
output_checkpoint = ck_ft2_path)
run_main_with_ft_hydra(
tempdir=finetune_run2_dir,
yaml=ft_yml,
data_src=tutorial_val_src,
run_args={"seed": 1000},
model_config=model_config_2,
output_checkpoint=ck_ft2_path,
)
ft_ckpt2 = torch.load(ck_ft2_path)
assert "config" in ft_ckpt2
config_model = ft_ckpt2["config"]["model"]
Expand Down
Loading

0 comments on commit 20d2798

Please sign in to comment.