This repository has been archived by the owner on Jan 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 538
/
test_models_mt5.py
44 lines (38 loc) · 1.63 KB
/
test_models_mt5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import pytest
import tempfile
from gluonnlp.models.mt5 import (
MT5Model, MT5Inference, mt5_cfg_reg, list_pretrained_mt5, get_pretrained_mt5
)
def test_list_pretrained_mt5():
assert len(list_pretrained_mt5()) > 0
@pytest.mark.parametrize('cfg_key', mt5_cfg_reg.list_keys())
def test_mt5_model_and_inference(cfg_key, ctx):
# since MT5Model, MT5Inference simply inherits the T5Model, T5Inference,
# we just want to make sure the model can be properly loaded, and leave
# the correctness tests to test_model_t5.py
with ctx:
cfg = MT5Model.get_cfg(cfg_key)
if cfg_key != 'google_mt5_small':
cfg.defrost()
cfg.MODEL.vocab_size = 256
cfg.MODEL.d_model = 128
cfg.MODEL.d_ff = 512
cfg.MODEL.num_layers = 2
cfg.MODEL.num_heads = 4
cfg.freeze()
mt5_model = MT5Model.from_cfg(cfg)
mt5_model.initialize()
mt5_model.hybridize()
if cfg_key == 'google_mt5_small':
inference_model = MT5Inference(mt5_model)
inference_model.hybridize()
def test_mt5_get_pretrained(ctx):
with tempfile.TemporaryDirectory() as root, ctx:
cfg, tokenizer, backbone_params_path, _ = get_pretrained_mt5('google_mt5_small')
# we exclude <extra_id>s in the comparison below by avoiding len(tokenizer.vocab)
assert cfg.MODEL.vocab_size >= len(tokenizer._sp_model)
mt5_model = MT5Model.from_cfg(cfg)
mt5_model.load_parameters(backbone_params_path)
mt5_model.hybridize()
mt5_inference_model = MT5Inference(mt5_model)
mt5_inference_model.hybridize()