Skip to content

Commit

Permalink
Add simple save model test (#1227)
Browse files Browse the repository at this point in the history
* add save model test
  • Loading branch information
jeswan authored Jan 1, 2021
1 parent 1ab34a4 commit ce62495
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
21 changes: 21 additions & 0 deletions jiant/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,24 @@ def get_model_for_saving(model: nn.Module) -> nn.Module:
return model.module
else:
return model


def eq_state_dicts(state_dict_1, state_dict_2):
"""Checks if the model weights in state_dict_1 and state_dict_2 are equal.
Args:
state_dict_1 (dict): state_dict of a PyTorch model
state_dict_2 (dict): state_dict of a PyTorch model
Requires:
state_dict_1 and state_dict_2 to be from the same model
Returns:
bool: Returns True if all model weights are equal in state_dict_1 and state_dict_2
"""
for key_item_1, key_item_2 in zip(state_dict_1.items(), state_dict_2.items()):
if torch.equal(key_item_1[1], key_item_2[1]):
pass
else:
return False
return True
82 changes: 82 additions & 0 deletions tests/proj/simple/test_runscript.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pytest
import torch

import jiant.utils.python.io as py_io
from jiant.proj.simple import runscript as run
import jiant.scripts.download_data.runscript as downloader
import jiant.utils.torch_utils as torch_utils


@pytest.mark.parametrize("task_name", ["copa"])
Expand All @@ -29,3 +31,83 @@ def test_simple_runscript(tmpdir, task_name, model_type):

val_metrics = py_io.read_json(os.path.join(exp_dir, "runs", RUN_NAME, "val_metrics.json"))
assert val_metrics["aggregated"] > 0


@pytest.mark.gpu
@pytest.mark.parametrize("task_name", ["copa"])
@pytest.mark.parametrize("model_type", ["roberta-large"])
def test_simple_runscript_save(tmpdir, task_name, model_type):
run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save"
data_dir = str(tmpdir.mkdir("data"))
exp_dir = str(tmpdir.mkdir("exp"))

downloader.download_data([task_name], data_dir)

args = run.RunConfiguration(
run_name=run_name,
exp_dir=exp_dir,
data_dir=data_dir,
model_type=model_type,
tasks=task_name,
max_steps=1,
train_batch_size=32,
do_save=True,
eval_every_steps=10,
learning_rate=0.01,
num_train_epochs=5,
)
run.run_simple(args)

# check best_model and last_model exist
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p"))
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json"))
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p"))
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))

# assert best_model not equal to last_model
best_model_weights = torch.load(
os.path.join(exp_dir, "runs", run_name, "best_model.p"), map_location=torch.device("cpu")
)
last_model_weights = torch.load(
os.path.join(exp_dir, "runs", run_name, "last_model.p"), map_location=torch.device("cpu")
)
assert not torch_utils.eq_state_dicts(best_model_weights, last_model_weights)

run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save_best"
args = run.RunConfiguration(
run_name=run_name,
exp_dir=exp_dir,
data_dir=data_dir,
model_type=model_type,
tasks=task_name,
max_steps=1,
train_batch_size=16,
do_save_best=True,
)
run.run_simple(args)

# check only best_model saved
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p"))
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json"))
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p"))
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))

# check output last model
run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save_last"
args = run.RunConfiguration(
run_name=run_name,
exp_dir=exp_dir,
data_dir=data_dir,
model_type=model_type,
tasks=task_name,
max_steps=1,
train_batch_size=16,
do_save_last=True,
)
run.run_simple(args)

# check only last_model saved
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p"))
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json"))
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p"))
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))

0 comments on commit ce62495

Please sign in to comment.