forked from norybaby/poet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config_poem.py
138 lines (109 loc) · 6.07 KB
/
config_poem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import numpy as np
def config_poem_train(args=''):
parser = argparse.ArgumentParser()
# Data and vocabulary file
# parser.add_argument('--data_file', type=str,
# default='../data/poem/poems_space.txt',
# help='data file')
parser.add_argument('--data_path', type=str,
default='./data/poem/',
help='data path')
parser.add_argument('--encoding', type=str,
default='utf-8',
help='the encoding of the data file.')
# Parameters for saving models.
parser.add_argument('--output_dir', type=str, default='output_model',
help=('directory to store final and'
' intermediate results and models.'))
# Parameters for using saved best models.
parser.add_argument('--init_dir', type=str, default='',
help='continue from the outputs in the given directory')
# Parameters to configure the neural network.
parser.add_argument('--hidden_size', type=int, default=128,#128,
help='size of RNN hidden state vector')
parser.add_argument('--embedding_size', type=int, default=128,#0,
help='size of character embeddings, 0 for one-hot')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
parser.add_argument('--num_unrollings', type=int, default=64,#10,
help='number of unrolling steps.')
parser.add_argument('--cell_type', type=str, default='lstm',
help='which model to use (rnn, lstm or gru).')
# Parameters to control the training.
parser.add_argument('--num_epochs', type=int, default=5,
help='number of epochs')
parser.add_argument('--batch_size', type=int, default=16,
help='minibatch size')
parser.add_argument('--train_frac', type=float, default=0.9,
help='fraction of data used for training.')
parser.add_argument('--valid_frac', type=float, default=0.05,
help='fraction of data used for validation.')
# test_frac is computed as (1 - train_frac - valid_frac).
parser.add_argument('--dropout', type=float, default=0.0,
help='dropout rate, default to 0 (no dropout).')
parser.add_argument('--input_dropout', type=float, default=0.0,
help=('dropout rate on input layer, default to 0 (no dropout),'
'and no dropout if using one-hot representation.'))
# Parameters for gradient descent.
parser.add_argument('--max_grad_norm', type=float, default=5.,
help='clip global grad norm')
parser.add_argument('--learning_rate', type=float, default=5e-3,
help='initial learning rate')
# Parameters for logging.
parser.add_argument('--progress_freq', type=int, default=100,
help=('frequency for progress report in training and evalution.'))
parser.add_argument('--verbose', type=int, default=0,
help=('whether to show progress report in training and evalution.'))
# Parameters to feed in the initial model and current best model.
parser.add_argument('--init_model', type=str,
default='', help=('initial model'))
parser.add_argument('--best_model', type=str,
default='', help=('current best model'))
parser.add_argument('--best_valid_ppl', type=float,
default=np.Inf, help=('current valid perplexity'))
# # Parameters for using saved best models.
# parser.add_argument('--model_dir', type=str, default='',
# help='continue from the outputs in the given directory')
# Parameters for debugging.
parser.add_argument('--debug', dest='debug', action='store_true',
help='show debug information')
parser.set_defaults(debug=False)
# Parameters for unittesting the implementation.
parser.add_argument('--test', dest='test', action='store_true',
help=('use the first 1000 character to as data to test the implementation'))
parser.set_defaults(test=False)
# input_args = '--data_path ./data/poem --output_dir output_poem --hidden_size 256 --embedding_size 128 --num_unrollings 128 --debug --encoding utf-8'
args = parser.parse_args(args.split())
return args
def config_sample(args=''):
parser = argparse.ArgumentParser()
# hyper-parameters for using saved best models.
# 学习日志和结果相关的超参数
logging_args = parser.add_argument_group('Logging_Options')
logging_args.add_argument('--model_dir', type=str,
default='demo_model/',
help='continue from the outputs in the given directory')
logging_args.add_argument('--data_dir', type=str,
default='./data/poem',
help='data file path')
logging_args.add_argument('--best_model', type=str,
default='', help=('current best model'))
# hyper-parameters for sampling.
# 设置sampling相关的超参数
testing_args = parser.add_argument_group('Sampling Options')
testing_args.add_argument('--max_prob', dest='max_prob', action='store_true',
help='always pick the most probable next character in sampling')
testing_args.set_defaults(max_prob=False)
testing_args.add_argument('--start_text', type=str,
default='The meaning of life is ',
help='the text to start with')
testing_args.add_argument('--length', type=int,
default=100,
help='length of sampled sequence')
testing_args.add_argument('--seed', type=int,
default=-1,
help=('seed for sampling to replicate results, '
'an integer between 0 and 4294967295.'))
args = parser.parse_args(args.split())
return args