diff --git a/tensorlayer/files.py b/tensorlayer/files.py index 0e9253baa..7b9568a7d 100644 --- a/tensorlayer/files.py +++ b/tensorlayer/files.py @@ -646,14 +646,14 @@ def save_npz(save_list=[], name='model.npz', sess=None): """ ## save params into a list save_list_var = [] - for k, value in enumerate(save_list): - if sess: - save_list_var.append( sess.run(value) ) - else: - try: - save_list_var.append( value.eval() ) - except: - print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)") + if sess: + save_list_var = sess.run(save_list) + else: + try: + for k, value in enumerate(save_list): + save_list_var.append(value.eval()) + except: + print(" Fail to save model, Hint: pass the session into this function, save_npz(network.all_params, name='model.npz', sess=sess)") np.savez(name, params=save_list_var) save_list_var = None del save_list_var @@ -734,9 +734,10 @@ def assign_params(sess, params, network): ---------- - `Assign value to a TensorFlow variable `_ """ + ops = [] for idx, param in enumerate(params): - assign_op = network.all_params[idx].assign(param) - sess.run(assign_op) + ops.append(network.all_params[idx].assign(param)) + sess.run(ops)