diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 63c36e8f50..4bfc4bb62a 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -37,18 +37,9 @@ def _train_batch(trainer: BaseTrainer, batch) -> None: del out, loss -def main(*, num_batches: int = 16) -> None: - # region args/config setup - setup_logging() - - parser = flags.get_parser() - args, override_args = parser.parse_known_args() - _config = build_config(args, override_args) - _config["logger"] = "wandb" - # endregion +def compute_scaling_factors(config, num_batches: int = 16) -> None: - assert not args.distributed, "This doesn't work with DDP" - with new_trainer_context(args=args, config=_config) as ctx: + with new_trainer_context(config=config) as ctx: config = ctx.config trainer = ctx.trainer @@ -61,8 +52,8 @@ def main(*, num_batches: int = 16) -> None: logging.info(f"Input checkpoint path: {ckpt_file}, {ckpt_file.exists()=}") model: nn.Module = trainer.model - val_loader = trainer.val_loader - assert val_loader is not None, "Val dataset is required for making predictions" + data_loader = trainer.train_loader + assert data_loader is not None, "Train set required to load batches" if ckpt_file.exists(): trainer.load_checkpoint(checkpoint_path=str(ckpt_file)) @@ -122,15 +113,8 @@ def main(*, num_batches: int = 16) -> None: sys.exit(-1) # endregion - # region get the output path - out_path = Path( - _prefilled_input( - "Enter output path for fitted scale factors: ", - prefill=str(ckpt_file), - ) - ) - if out_path.exists(): - logging.warning(f"Already found existing file: {out_path}") + if ckpt_file.exists(): + logging.warning(f"Already found existing file: {ckpt_file}") flag = input( "Do you want to continue and overwrite existing file (1), " "or exit (2)? " @@ -142,7 +126,7 @@ def main(*, num_batches: int = 16) -> None: sys.exit() logging.info( - f"Output path for fitted scale factors: {out_path}, {out_path.exists()=}" + f"Output path for fitted scale factors: {ckpt_file}, {ckpt_file.exists()=}" ) # endregion @@ -175,7 +159,7 @@ def index_fn(name: str = name) -> None: module.initialize_(index_fn=index_fn) # single pass through network - _train_batch(trainer, next(iter(val_loader))) + _train_batch(trainer, next(iter(data_loader))) # sort the scale factors by their computation order sorted_factors = sorted( @@ -200,7 +184,7 @@ def index_fn(name: str = name) -> None: logging.info(f"Fitting {name}...") with module.fit_context_(): - for batch in islice(val_loader, num_batches): + for batch in islice(data_loader, num_batches): _train_batch(trainer, batch) stats, ratio, value = module.fit_() @@ -216,19 +200,27 @@ def index_fn(name: str = name) -> None: assert module.fitted, f"{name} is not fitted" # region save the scale factors to the checkpoint file - trainer.config["cmd"]["checkpoint_dir"] = out_path.parent + trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent trainer.is_debug = False - out_file = trainer.save( - metrics=None, - checkpoint_file=out_path.name, - training_state=False, + + torch.save( + { + x[0].replace(".scale_factor", ""): x[1] + for x in trainer.model.to("cpu").named_parameters() + if ".scale_" in x[0] + }, + str(ckpt_file), ) - assert out_file is not None, "Failed to save checkpoint" - out_file = Path(out_file) - assert out_file.exists(), f"Failed to save checkpoint to {out_file}" - # endregion - logging.info(f"Saved results to: {out_file}") + logging.info(f"Saved results to: {ckpt_file}") if __name__ == "__main__": - main() + # region args/config setup + setup_logging() + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + assert not args.distributed, "This doesn't work with DDP" + config = build_config(args, override_args) + + compute_scaling_factors(config) diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py index ff3ea36342..ef2b860bff 100644 --- a/tests/core/e2e/test_e2e_commons.py +++ b/tests/core/e2e/test_e2e_commons.py @@ -93,6 +93,16 @@ def merge_dictionary(d, u): return d +def update_yaml_with_dict(input_yaml, output_yaml, update_dict_with): + with open(input_yaml) as yaml_file: + yaml_config = yaml.safe_load(yaml_file) + if update_dict_with is not None: + yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" + with open(str(output_yaml), "w") as yaml_file: + yaml.dump(yaml_config, yaml_file) + + def _run_main( rundir, input_yaml, @@ -103,14 +113,7 @@ def _run_main( world_size=0, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" - - with open(input_yaml) as yaml_file: - yaml_config = yaml.safe_load(yaml_file) - if update_dict_with is not None: - yaml_config = merge_dictionary(yaml_config, update_dict_with) - yaml_config["backend"] = "gloo" - with open(str(config_yaml), "w") as yaml_file: - yaml.dump(yaml_config, yaml_file) + update_yaml_with_dict(input_yaml, config_yaml, update_dict_with) run_args = { "run_dir": rundir, "logdir": f"{rundir}/logs", diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 2f7dfa3730..10e3203c91 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -8,11 +8,19 @@ import numpy as np import numpy.testing as npt import pytest -from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths +from fairchem.core._cli import Runner +from fairchem.core.modules.scaling.fit import compute_scaling_factors +from test_e2e_commons import ( + _run_main, + oc20_lmdb_train_and_val_from_paths, + update_yaml_with_dict, +) -from fairchem.core.common.utils import setup_logging +from fairchem.core.common.utils import build_config, setup_logging from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes +from fairchem.core.common.flags import flags + setup_logging() @@ -98,6 +106,61 @@ def smoke_test_train( energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6 ) + def test_gemnet_fit_scaling(self, configs, tutorial_val_src): + + with tempfile.TemporaryDirectory() as tempdirname: + # (1) generate scaling factors for gemnet config + config_yaml = f"{tempdirname}/train_and_val_on_val.yml" + scaling_pt = f"{tempdirname}/scaling.pt" + # run + parser = flags.get_parser() + args, override_args = parser.parse_known_args( + [ + "--mode", + "train", + "--seed", + "100", + "--config-yml", + config_yaml, + "--cpu", + "--checkpoint", + scaling_pt, + ] + ) + update_yaml_with_dict( + configs["gemnet_oc"], + config_yaml, + update_dict_with={ + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + ) + config = build_config(args, override_args) + + # (2) if existing scaling factors are present remove them + if "scale_file" in config["model"]: + config["model"].pop("scale_file") + + compute_scaling_factors(config) + + # (3) try to run the config with the newly generated scaling factors + _ = _run_main( + rundir=tempdirname, + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"use_pbc_single": True, "scale_file": scaling_pt}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + input_yaml=configs["gemnet_oc"], + ) + # not all models are tested with otf normalization estimation # only gemnet_oc, escn, equiformer, and their hydra versions @pytest.mark.parametrize(