-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
117 lines (102 loc) · 5.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# encoding: utf-8
# author: liaochangzeng
# github: https://github.com/changzeng
import os
import time
import argparse
import numpy as np
import tensorflow as tf
from model import LFM
from random import shuffle
def get_div(a, b):
if b == 0:
return 0.0
return a * 1.0 / b
def shuffle_file(origin, destination):
with open("data/"+origin, encoding="utf-8") as fd:
txt = fd.read().strip()
shuffle_result = txt.split("\n")
shuffle(shuffle_result)
with open(destination, "w", encoding="utf-8") as fd:
fd.write("\n".join(shuffle_result))
class Trainer(object):
def __init__(self, args):
self.args = args
self.model_time = int(time.time()) if args.model_time == None else args.model_time
self.model_path = "model/" + str(self.model_time) + "/"
self.model_data_path = self.model_path + "data/"
self.checkpoint_path = self.model_path + "checkpoint/"
self.summary_path = self.model_path + "summary/"
self.max_epoch = args.max_epoch
self.train_data = args.train_data
self.test_data = args.test_data
self.user_num = args.user_num
self.item_num = args.item_num
self.batch_size = args.batch_size
self.hidden_dim = args.hidden_dim
self.validate_epoch = args.validate_epoch
self.checkpoint_epoch = args.checkpoint_epoch
self.lfm = LFM(user_num=args.user_num, item_num=args.item_num, batch_size=args.batch_size, hidden_dim=args.hidden_dim, learning_rate=args.learning_rate)
self.check_model_path()
def model_parameter(self):
return "hidden_dim:{}-batch_size:{}".format(self.hidden_dim, self.batch_size)
def check_model_path(self):
for _path in ["model/", self.model_path, self.model_data_path, self.checkpoint_path, self.summary_path]:
if not os.path.exists(_path):
os.mkdir(_path)
def train(self):
tf_config = tf.ConfigProto()
tf_config.gpu_options.per_process_gpu_memory_fraction = args.gpu_use_rate
with tf.Session(config=tf_config) as sess:
sess.run(tf.global_variables_initializer())
file_writer = tf.summary.FileWriter(self.summary_path, sess.graph)
merged = tf.summary.merge_all()
for epoch in range(self.max_epoch):
for feed_dict in self.gen_batch(self.train_data):
_, global_step, summary = sess.run([self.lfm.train, self.lfm.global_step, merged], feed_dict=feed_dict)
if global_step % self.validate_epoch == 0:
self.validate(sess)
if global_step % self.checkpoint_epoch == 0:
self.lfm.save(sess, self.checkpoint_path+self.model_parameter(), global_step)
file_writer.add_summary(summary, global_step)
print("cur_epoch/total_epoch: ({:3d}/{:3d}), global_step: {:4d}".format(epoch+1, self.max_epoch, global_step))
def gen_batch(self, file_name):
shuffle_file(file_name, self.model_data_path+file_name)
with open(self.model_data_path+file_name, encoding="utf-8") as fd:
while True:
input_list = []
score_list = []
for i in range(self.batch_size):
line = fd.readline().strip()
if len(line) == 0:
return
_user, _item, _score, _ = line.split(",")
input_list.append([int(_user)-1, int(_item)-1])
score_list.append(_score)
yield {"input:0": np.array(input_list, dtype=np.int32), "score:0": np.array(score_list, dtype=np.float32)}
def validate(self, sess):
total_loss = 0
batch_num = 0
for feed_dict in self.gen_batch(self.test_data):
loss = sess.run(self.lfm.loss, feed_dict=feed_dict)
total_loss += loss
batch_num += 1
print("Total loss is: {:4.2f}, Batch num is: {:4.2f}, Average loss is {:4.2f}".format(total_loss, batch_num, get_div(total_loss, batch_num)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='LFM Model Training')
parser.add_argument('--train_data', type=str, default="data.train", help='training data')
parser.add_argument('--test_data', type=str, default="data.test", help='testing data')
parser.add_argument('--user_num', type=int, default=6040, help='user num')
parser.add_argument('--item_num', type=int, default=3952, help='item num')
parser.add_argument('--batch_size', type=int, default=500, help='batch size')
parser.add_argument('--hidden_dim', type=int, default=20, help='hidden dim')
parser.add_argument('--max_epoch', type=int, default=20, help='opoch num')
parser.add_argument('--validate_epoch', type=int, default=2000, help='validate opoch num')
parser.add_argument('--checkpoint_epoch', type=int, default=500, help='checkpoint opoch num')
parser.add_argument('--model_time', type=str, default=None, help='time when training new model')
parser.add_argument('--gpu_use_rate', type=float, default=0.38, help='GPU memory using rate per process')
parser.add_argument('--learning_rate', type=float, default=0.1, help='learning rate')
parser.add_argument('--decay', type=float, default=0.5, help='learning decay rate per 2000 batch')
args = parser.parse_args()
trainer = Trainer(args)
trainer.train()