Skip to content

Commit

Permalink
Merge pull request #43 from sczhengyabin/sczhengyabin-patch-model_res…
Browse files Browse the repository at this point in the history
…tore_speedup

Update files.py, speed up model saving and restoring process.
  • Loading branch information
zsdonghao authored Dec 22, 2016
2 parents 400d717 + 9eff9a6 commit 8aa4ffd
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions tensorlayer/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,14 +646,14 @@ def save_npz(save_list=[], name='model.npz', sess=None):
"""
## save params into a list
save_list_var = []
for k, value in enumerate(save_list):
if sess:
save_list_var.append( sess.run(value) )
else:
try:
save_list_var.append( value.eval() )
except:
print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)")
if sess:
save_list_var = sess.run(save_list)
else:
try:
for k, value in enumerate(save_list):
save_list_var.append(value.eval())
except:
print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)")
np.savez(name, params=save_list_var)
save_list_var = None
del save_list_var
Expand Down Expand Up @@ -734,9 +734,10 @@ def assign_params(sess, params, network):
----------
- `Assign value to a TensorFlow variable <http://stackoverflow.com/questions/34220532/how-to-assign-value-to-a-tensorflow-variable>`_
"""
ops = []
for idx, param in enumerate(params):
assign_op = network.all_params[idx].assign(param)
sess.run(assign_op)
ops.append(network.all_params[idx].assign(param))
sess.run(ops)



Expand Down

0 comments on commit 8aa4ffd

Please sign in to comment.