diff --git a/src/fairchem/core/scripts/convert_hydra_to_release.py b/src/fairchem/core/scripts/convert_hydra_to_release.py new file mode 100644 index 000000000..c54874bb4 --- /dev/null +++ b/src/fairchem/core/scripts/convert_hydra_to_release.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import argparse +import logging + +import torch +import yaml + + +def convert_fine_tune_checkpoint( + fine_tune_checkpoint_fn, + output_checkpoint_fn, + fine_tune_yaml_fn=None, + output_yaml_fn=None, +): + fine_tune_checkpoint = torch.load(fine_tune_checkpoint_fn, map_location="cpu") + + if "config" not in fine_tune_checkpoint: + raise KeyError("Finetune checkpoint does not have a valid 'config' field") + + try: + starting_checkpoint_fn = fine_tune_checkpoint["config"]["model"][ + "finetune_config" + ]["starting_checkpoint"] + except KeyError as e: + logging.error( + f"Finetune config missing entry config/model/finetune_config/starting_checkpoint {fine_tune_checkpoint['config']}" + ) + raise e + + starting_checkpoint = torch.load(starting_checkpoint_fn, map_location="cpu") + start_checkpoint_model_config = starting_checkpoint["config"]["model"] + + fine_tune_checkpoint["config"]["model"]["backbone"] = start_checkpoint_model_config[ + "backbone" + ] + # if we are data only, then copy over the heads config too + ft_data_only = "heads" not in fine_tune_checkpoint["config"]["model"] + if ft_data_only: + fine_tune_checkpoint["config"]["model"]["heads"] = ( + start_checkpoint_model_config["heads"] + ) + + fine_tune_checkpoint["config"]["model"].pop("finetune_config") + + torch.save(fine_tune_checkpoint, output_checkpoint_fn) + + if fine_tune_yaml_fn is not None: + with open(fine_tune_yaml_fn) as yaml_f: + fine_tune_yaml = yaml.safe_load(yaml_f) + fine_tune_yaml["model"].pop("finetune_config") + fine_tune_yaml["model"]["backbone"] = start_checkpoint_model_config["backbone"] + if ft_data_only: + fine_tune_yaml["model"]["heads"] = start_checkpoint_model_config["heads"] + with open(output_yaml_fn, "w") as yaml_file: + yaml.dump(fine_tune_yaml, yaml_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--fine-tune-checkpoint", + help="path to fine tuned checkpoint", + type=str, + required=True, + ) + parser.add_argument( + "--output-release-checkpoint", + help="path to output checkpoint", + type=str, + required=True, + ) + parser.add_argument( + "--fine-tune-yaml", + help="path to fine tune yaml config", + type=str, + required=False, + default=None, + ) + parser.add_argument( + "--output-release-yaml", + help="path to output yaml config", + type=str, + required=False, + default=None, + ) + args = parser.parse_args() + + convert_fine_tune_checkpoint( + fine_tune_yaml_fn=args.fine_tune_yaml, + fine_tune_checkpoint_fn=args.fine_tune_checkpoint, + output_checkpoint_fn=args.output_release_checkpoint, + output_yaml_fn=args.output_release_yaml, + ) diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py index 256d3b087..9ae6abd1d 100644 --- a/tests/core/e2e/test_e2e_commons.py +++ b/tests/core/e2e/test_e2e_commons.py @@ -106,7 +106,7 @@ def _run_main( save_predictions_to=None, world_size=1, ): - config_yaml = Path(rundir) / "train_and_val_on_val.yml" + config_yaml = Path(rundir) / "test_run.yml" update_yaml_with_dict(input_yaml, config_yaml, update_dict_with) run_args = { "run_dir": rundir, @@ -119,7 +119,17 @@ def _run_main( # run parser = flags.get_parser() args, override_args = parser.parse_known_args( - ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu", "--num-gpus", str(world_size)] + [ + "--mode", + "train", + "--seed", + "100", + "--config-yml", + "config.yml", + "--cpu", + "--num-gpus", + str(world_size), + ] ) for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index cec60e394..9a36e09ef 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest +from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint import torch from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths @@ -43,12 +44,14 @@ def make_checkpoint(tempdir: str, data_source: Path, seed: int) -> str: return ck_path -def run_main_with_ft_hydra(tempdir: str, - yaml: str, - data_src: str, - run_args: dict, - model_config: str, - output_checkpoint: str): +def run_main_with_ft_hydra( + tempdir: str, + yaml: str, + data_src: str, + run_args: dict, + model_config: str, + output_checkpoint: str, +): _run_main( tempdir, yaml, @@ -58,7 +61,7 @@ def run_main_with_ft_hydra(tempdir: str, "eval_every": 8, "batch_size": 1, "num_workers": 0, - "lr_initial": 0.0 # don't learn anything + "lr_initial": 0.0, # don't learn anything }, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(data_src), @@ -74,6 +77,39 @@ def run_main_with_ft_hydra(tempdir: str, ) +def verify_release_checkpoint(release_yaml_fn, release_checkpoint_fn, ft_state_dict): + with tempfile.TemporaryDirectory() as temp_dir: + # now lets run the new release checkpoint for a few iterations at lr0.0 + ck_release_ft_afterload_path = os.path.join( + temp_dir, "checkpoint_ft_release.pt" + ) + release_ft_temp_dir = os.path.join(temp_dir, "release_ft") + os.makedirs(release_ft_temp_dir) + + _run_main( + release_ft_temp_dir, + release_yaml_fn, + update_run_args_with={"seed": 1337, "checkpoint": release_checkpoint_fn}, + save_checkpoint_to=ck_release_ft_afterload_path, + world_size=1, + update_dict_with={ + "optim": { + "max_epochs": 2, + } + }, + ) + + # make sure the checkpoint after running with lr0.0 is identical + # to the previous checkpoint + assert os.path.isfile(ck_release_ft_afterload_path) + ft_after_state_dict = torch.load(ck_release_ft_afterload_path)["state_dict"] + for key in ft_after_state_dict: + if key.startswith("module.backbone"): + assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key]) + elif key.startswith("module.output_heads") and key.endswith("weight"): + assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key]) + + def test_finetune_hydra_retain_backbone(tutorial_val_src): with tempfile.TemporaryDirectory() as orig_ckpt_dir: starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) @@ -83,36 +119,57 @@ def test_finetune_hydra_retain_backbone(tutorial_val_src): ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") model_config = { - "name" : "hydra", - "finetune_config": {'starting_checkpoint': starting_ckpt}, + "name": "hydra", + "finetune_config": {"starting_checkpoint": starting_ckpt}, "heads": { - "energy": { - "module": "equiformer_v2_energy_head" - }, - "forces": { - "module": "equiformer_v2_force_head" - } - } + "energy": {"module": "equiformer_v2_energy_head"}, + "forces": {"module": "equiformer_v2_force_head"}, + }, } - run_main_with_ft_hydra(tempdir = ft_temp_dir, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - model_config = model_config, - output_checkpoint = ck_ft_path) + run_main_with_ft_hydra( + tempdir=ft_temp_dir, + yaml=ft_yml, + data_src=tutorial_val_src, + run_args={"seed": 1000}, + model_config=model_config, + output_checkpoint=ck_ft_path, + ) assert os.path.isfile(ck_ft_path) ft_ckpt = torch.load(ck_ft_path) assert "config" in ft_ckpt assert ft_ckpt["config"]["model"]["name"] == "hydra" # check that the backbone weights are the same, and other weights are not the same - new_state_dict = ft_ckpt["state_dict"] - for key in new_state_dict: - if key.startswith("backbone"): + ft_state_dict = ft_ckpt["state_dict"] + for key in ft_state_dict: + if key.startswith("module.backbone"): # backbone should be identical - assert torch.allclose(new_state_dict[key], old_state_dict[key]) - elif key.startswith("output_heads") and key.endswith("weight"): + assert torch.allclose(ft_state_dict[key], old_state_dict[key]) + elif key.startswith("module.output_heads") and key.endswith("weight"): # heads weight should be different because the seeds are different - assert not torch.allclose(new_state_dict[key], old_state_dict[key]) + assert not torch.allclose(ft_state_dict[key], old_state_dict[key]) + + # Add a test to convert the FT hydra checkpoint to a release checkpoint + # This could be a separate test but we would need to generate the FT checkpoint + # all over again + + # Convert FT hydra checkpoint to release checkpoint + ck_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.pt") + yml_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.yml") + # the actual on disk yaml used in the previous run, after argument updates + fine_tune_yaml_fn = os.path.join(ft_temp_dir, "test_run.yml") + convert_fine_tune_checkpoint( + fine_tune_checkpoint_fn=ck_ft_path, + fine_tune_yaml_fn=fine_tune_yaml_fn, + output_checkpoint_fn=ck_release_ft_path, + output_yaml_fn=yml_release_ft_path, + ) + + # remove starting checkpoint, so that we cant accidentally load it + os.remove(ck_ft_path) + + verify_release_checkpoint( + yml_release_ft_path, ck_release_ft_path, ft_state_dict + ) def test_finetune_hydra_data_only(tutorial_val_src): @@ -124,25 +181,50 @@ def test_finetune_hydra_data_only(tutorial_val_src): ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") model_config = { - "name" : "hydra", - "finetune_config": {'starting_checkpoint': starting_ckpt}, + "name": "hydra", + "finetune_config": {"starting_checkpoint": starting_ckpt}, } - run_main_with_ft_hydra(tempdir = ft_temp_dir, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - model_config = model_config, - output_checkpoint = ck_ft_path) + run_main_with_ft_hydra( + tempdir=ft_temp_dir, + yaml=ft_yml, + data_src=tutorial_val_src, + run_args={"seed": 1000}, + model_config=model_config, + output_checkpoint=ck_ft_path, + ) assert os.path.isfile(ck_ft_path) ft_ckpt = torch.load(ck_ft_path) assert "config" in ft_ckpt config_model = ft_ckpt["config"]["model"] assert config_model["name"] == "hydra" # check that the entire model weights are the same - new_state_dict = ft_ckpt["state_dict"] - assert len(new_state_dict) == len(old_state_dict) - for key in new_state_dict: - assert torch.allclose(new_state_dict[key], old_state_dict[key]) + ft_state_dict = ft_ckpt["state_dict"] + assert len(ft_state_dict) == len(old_state_dict) + for key in ft_state_dict: + assert torch.allclose(ft_state_dict[key], old_state_dict[key]) + + # Add a test to convert the FT hydra checkpoint to a release checkpoint + # This could be a separate test but we would need to generate the FT checkpoint + # all over again + + # Convert FT hydra checkpoint to release checkpoint + ck_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.pt") + yml_release_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft_release.yml") + # the actual on disk yaml used in the previous run, after argument updates + fine_tune_yaml_fn = os.path.join(ft_temp_dir, "test_run.yml") + convert_fine_tune_checkpoint( + fine_tune_checkpoint_fn=ck_ft_path, + fine_tune_yaml_fn=fine_tune_yaml_fn, + output_checkpoint_fn=ck_release_ft_path, + output_yaml_fn=yml_release_ft_path, + ) + + # remove starting checkpoint, so that we cant accidentally load it + os.remove(ck_ft_path) + + verify_release_checkpoint( + yml_release_ft_path, ck_release_ft_path, ft_state_dict + ) def test_finetune_from_finetunehydra(tutorial_val_src): @@ -153,15 +235,17 @@ def test_finetune_from_finetunehydra(tutorial_val_src): ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt") model_config_1 = { - "name" : "hydra", - "finetune_config": {'starting_checkpoint': starting_ckpt}, + "name": "hydra", + "finetune_config": {"starting_checkpoint": starting_ckpt}, } - run_main_with_ft_hydra(tempdir = finetune_run1_dir, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - model_config = model_config_1, - output_checkpoint = ck_ft_path) + run_main_with_ft_hydra( + tempdir=finetune_run1_dir, + yaml=ft_yml, + data_src=tutorial_val_src, + run_args={"seed": 1000}, + model_config=model_config_1, + output_checkpoint=ck_ft_path, + ) assert os.path.isfile(ck_ft_path) # now that we have a second checkpoint, try finetuning again from this checkpoint @@ -169,15 +253,17 @@ def test_finetune_from_finetunehydra(tutorial_val_src): with tempfile.TemporaryDirectory() as finetune_run2_dir: ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt") model_config_2 = { - "name" : "hydra", - "finetune_config": {'starting_checkpoint': ck_ft_path}, + "name": "hydra", + "finetune_config": {"starting_checkpoint": ck_ft_path}, } - run_main_with_ft_hydra(tempdir = finetune_run2_dir, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - model_config = model_config_2, - output_checkpoint = ck_ft2_path) + run_main_with_ft_hydra( + tempdir=finetune_run2_dir, + yaml=ft_yml, + data_src=tutorial_val_src, + run_args={"seed": 1000}, + model_config=model_config_2, + output_checkpoint=ck_ft2_path, + ) ft_ckpt2 = torch.load(ck_ft2_path) assert "config" in ft_ckpt2 config_model = ft_ckpt2["config"]["model"] diff --git a/tests/core/models/test_configs/test_finetune_hydra.yml b/tests/core/models/test_configs/test_finetune_hydra.yml index e1be6d20b..d3456a69b 100644 --- a/tests/core/models/test_configs/test_finetune_hydra.yml +++ b/tests/core/models/test_configs/test_finetune_hydra.yml @@ -36,7 +36,7 @@ logger: name: tensorboard model: - name: finetune_hydra + name: hydra finetune_config: {}