-
Notifications
You must be signed in to change notification settings - Fork 117
/
config_train.py
72 lines (62 loc) · 1.77 KB
/
config_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Config file for GPT2 training.
"""
pickle_data_dir = "data/toy"
max_seq_length = 128
max_decoding_length = max_seq_length
train_batch_size = 32
max_train_epoch = 100
display_steps = 1 # Print training loss every display_steps; -1 to disable
eval_steps = 1 # Eval on the dev set every eval_steps; -1 to disable
eval_batch_size = 8
test_batch_size = 8
# Optimization configs
opt = {
'optimizer': {
'type': 'Adam',
'kwargs': {
'lr': 0.001
}
}
}
# Data configs
feature_types = {
# Reading features from pickle data file.
# E.g., Reading feature "text_ids" as dtype `int64`;
# "stacked_tensor" indicates its length is fixed for all data instances;
# and the sequence length is limited by `max_seq_length`.
"text_ids": ["int64", "stacked_tensor", max_seq_length],
"length": ["int64", "stacked_tensor"]
}
train_hparam = {
"allow_smaller_final_batch": False,
"batch_size": train_batch_size,
"dataset": {
"data_name": "data",
"feature_types": feature_types,
"files": "{}/train.pkl".format(pickle_data_dir)
},
"shuffle": True,
"shuffle_buffer_size": 10000
}
eval_hparam = {
"allow_smaller_final_batch": True,
"batch_size": eval_batch_size,
"dataset": {
"data_name": "data",
"feature_types": feature_types,
"files": "{}/dev.pkl".format(pickle_data_dir)
},
"shuffle": False
}
# Set to `test_hparam` to `None` if generating from scratch
# (instead of generating continuation) at test time
test_hparam = {
"allow_smaller_final_batch": True,
"batch_size": test_batch_size,
"dataset": {
"data_name": "data",
"feature_types": feature_types,
"files": "{}/test.pkl".format(pickle_data_dir)
},
"shuffle": False
}