-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval.py
56 lines (45 loc) · 1.87 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import sys
from unittest.mock import patch
from slugify import slugify
import run_nb_flax_speech_recognition_seq2seq_streaming as cli
def get_run_name(model_name_or_path, dataset_name, dataset_config_name, num_beams, do_normalize_eval, **kwargs):
run_name_parts = [
slugify(model_name_or_path),
slugify(dataset_name),
slugify(dataset_config_name or "default"),
f"{num_beams or 0}beams",
]
if do_normalize_eval:
run_name_parts.append("normalize")
return "_".join(run_name_parts).lower()
def main():
parser = cli.HfArgumentParser(
(cli.ModelArguments, cli.DataTrainingArguments, cli.Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
sys_argv = []
for arg in sys.argv:
if "=" in arg:
sys_argv += arg.split("=")
else:
sys_argv += [arg]
run_name_params = {**model_args.__dict__, **data_args.__dict__, **training_args.__dict__}
if "--dataset_name" in sys_argv:
dataset_pos = sys_argv.index("--dataset_name") + 1
dataset = sys_argv[dataset_pos]
if ":" in dataset:
dataset_name, dataset_config_name = dataset.split(":")
sys_argv[dataset_pos] = dataset_name
sys_argv.extend(["--dataset_config_name", dataset_config_name])
run_name_params["dataset_name"] = dataset_name
run_name_params["dataset_config_name"] = dataset_config_name
if "--run_name" not in sys_argv:
sys_argv.extend(["--run_name", get_run_name(**run_name_params)])
with patch.object(sys, "argv", sys_argv):
cli.main()
if __name__ == "__main__":
main()