-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathconfig.py
41 lines (33 loc) · 1.17 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# parameters for processing the dataset
TERMINAL_OUTPUT = 'terminal_conversation.txt'
WECHAT_OUTPUT = 'wechat_conversation.txt'
DATA_PATH = 'data/simisimi'
CPT_PATH = 'checkpoints'
THRESHOLD = 2
PAD_ID = 0
UNK_ID = 1
START_ID = 2
EOS_ID = 3
# model parameters
""" Train encoder length distribution:
[175, 92, 11883, 8387, 10656, 13613, 13480, 12850, 11802, 10165,
8973, 7731, 7005, 6073, 5521, 5020, 4530, 4421, 3746, 3474, 3192,
2724, 2587, 2413, 2252, 2015, 1816, 1728, 1555, 1392, 1327, 1248,
1128, 1084, 1010, 884, 843, 755, 705, 660, 649, 594, 558, 517, 475,
426, 444, 388, 349, 337]
These buckets size seem to work the best
"""
# [19530, 17449, 17585, 23444, 22884, 16435, 17085, 18291, 18931]
# BUCKETS = [(6, 8), (8, 10), (10, 12), (13, 15), (16, 19), (19, 22), (23, 26), (29, 32), (39, 44)]
# [37049, 33519, 30223, 33513, 37371]
# BUCKETS = [(8, 10), (12, 14), (16, 19), (23, 26), (39, 43)]
# BUCKETS = [(8, 10), (12, 14), (16, 19)]
# BUCKETS = [(16, 19)]
BUCKETS = [(10, 13), (13, 16), (16, 20)]
# BUCKETS = [(6, 8), (10, 13), (14, 18), (18, 23), (22, 27), (26, 31), (30, 35)]
NUM_LAYERS = 3
HIDDEN_SIZE = 256
BATCH_SIZE = 64
LR = 0.5
MAX_GRAD_NORM = 5.0
NUM_SAMPLES = 512