Skip to content

Commit

Permalink
Update files.py, speed up model saving and restoring process.
Browse files Browse the repository at this point in the history
The run() and eval() will run the whole graph from scratch, so combining ops to an array and executing together will result in significant speed-up.
  • Loading branch information
sczhengyabin authored Dec 22, 2016
1 parent 400d717 commit 9eff9a6
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions tensorlayer/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,14 +646,14 @@ def save_npz(save_list=[], name='model.npz', sess=None):
"""
## save params into a list
save_list_var = []
for k, value in enumerate(save_list):
if sess:
save_list_var.append( sess.run(value) )
else:
try:
save_list_var.append( value.eval() )
except:
print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)")
if sess:
save_list_var = sess.run(save_list)
else:
try:
for k, value in enumerate(save_list):
save_list_var.append(value.eval())
except:
print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)")
np.savez(name, params=save_list_var)
save_list_var = None
del save_list_var
Expand Down Expand Up @@ -734,9 +734,10 @@ def assign_params(sess, params, network):
----------
- `Assign value to a TensorFlow variable <http://stackoverflow.com/questions/34220532/how-to-assign-value-to-a-tensorflow-variable>`_
"""
ops = []
for idx, param in enumerate(params):
assign_op = network.all_params[idx].assign(param)
sess.run(assign_op)
ops.append(network.all_params[idx].assign(param))
sess.run(ops)



Expand Down

0 comments on commit 9eff9a6

Please sign in to comment.