diff --git a/jiant/utils/torch_utils.py b/jiant/utils/torch_utils.py index 81404f4c1..56a007576 100644 --- a/jiant/utils/torch_utils.py +++ b/jiant/utils/torch_utils.py @@ -144,3 +144,24 @@ def get_model_for_saving(model: nn.Module) -> nn.Module: return model.module else: return model + + +def eq_state_dicts(state_dict_1, state_dict_2): + """Checks if the model weights in state_dict_1 and state_dict_2 are equal. + + Args: + state_dict_1 (dict): state_dict of a PyTorch model + state_dict_2 (dict): state_dict of a PyTorch model + + Requires: + state_dict_1 and state_dict_2 to be from the same model + + Returns: + bool: Returns True if all model weights are equal in state_dict_1 and state_dict_2 + """ + for key_item_1, key_item_2 in zip(state_dict_1.items(), state_dict_2.items()): + if torch.equal(key_item_1[1], key_item_2[1]): + pass + else: + return False + return True diff --git a/tests/proj/simple/test_runscript.py b/tests/proj/simple/test_runscript.py index 6dc65ca04..981b78e08 100644 --- a/tests/proj/simple/test_runscript.py +++ b/tests/proj/simple/test_runscript.py @@ -1,9 +1,11 @@ import os import pytest +import torch import jiant.utils.python.io as py_io from jiant.proj.simple import runscript as run import jiant.scripts.download_data.runscript as downloader +import jiant.utils.torch_utils as torch_utils @pytest.mark.parametrize("task_name", ["copa"]) @@ -29,3 +31,83 @@ def test_simple_runscript(tmpdir, task_name, model_type): val_metrics = py_io.read_json(os.path.join(exp_dir, "runs", RUN_NAME, "val_metrics.json")) assert val_metrics["aggregated"] > 0 + + +@pytest.mark.gpu +@pytest.mark.parametrize("task_name", ["copa"]) +@pytest.mark.parametrize("model_type", ["roberta-large"]) +def test_simple_runscript_save(tmpdir, task_name, model_type): + run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save" + data_dir = str(tmpdir.mkdir("data")) + exp_dir = str(tmpdir.mkdir("exp")) + + downloader.download_data([task_name], data_dir) + + args = run.RunConfiguration( + run_name=run_name, + exp_dir=exp_dir, + data_dir=data_dir, + model_type=model_type, + tasks=task_name, + max_steps=1, + train_batch_size=32, + do_save=True, + eval_every_steps=10, + learning_rate=0.01, + num_train_epochs=5, + ) + run.run_simple(args) + + # check best_model and last_model exist + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p")) + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json")) + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p")) + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json")) + + # assert best_model not equal to last_model + best_model_weights = torch.load( + os.path.join(exp_dir, "runs", run_name, "best_model.p"), map_location=torch.device("cpu") + ) + last_model_weights = torch.load( + os.path.join(exp_dir, "runs", run_name, "last_model.p"), map_location=torch.device("cpu") + ) + assert not torch_utils.eq_state_dicts(best_model_weights, last_model_weights) + + run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save_best" + args = run.RunConfiguration( + run_name=run_name, + exp_dir=exp_dir, + data_dir=data_dir, + model_type=model_type, + tasks=task_name, + max_steps=1, + train_batch_size=16, + do_save_best=True, + ) + run.run_simple(args) + + # check only best_model saved + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p")) + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json")) + assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p")) + assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json")) + + # check output last model + run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save_last" + args = run.RunConfiguration( + run_name=run_name, + exp_dir=exp_dir, + data_dir=data_dir, + model_type=model_type, + tasks=task_name, + max_steps=1, + train_batch_size=16, + do_save_last=True, + ) + run.run_simple(args) + + # check only last_model saved + assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p")) + assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json")) + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p")) + assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))