forked from yongqyu/DeepFM-tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
25 lines (18 loc) · 1.2 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import argparse
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument('--lr', default=0.01, help='learning rate', type=float)
parser.add_argument('--train_batch_size', default=128, help='batch size', type=int)
parser.add_argument('--test_batch_size', default=512, help='batch size', type=int)
parser.add_argument('--epochs', default=50, help='number of epochs', type=int)
parser.add_argument('--print_step', default=100, help='step size for print log', type=int)
parser.add_argument('--dropout_rate', default=0.5, help='dropout rate', type=float)
parser.add_argument('--dataset_dir', default='/data/private/Ad/ml-20m/np_prepro/', help='dataset path')
parser.add_argument('--model_path', default='./models/', help='model load path', type=str)
parser.add_argument('--log_path', default='./logs/', help='log path fot tensorboard', type=str)
parser.add_argument('--is_reuse', default=False)
parser.add_argument('--multi_gpu', default=False)
parser.add_argument('--sparse_emb_dim', default=8, help='dimension for sparse feature', type=int)
parser.add_argument('--dnn_layers', default=[256,128], type=int)
args = parser.parse_args()
return args