diff --git a/tools/restore_model.py b/tools/restore_model.py index 9805b12..f3b30fa 100644 --- a/tools/restore_model.py +++ b/tools/restore_model.py @@ -25,7 +25,11 @@ def get_restorer(): if RESTORE_FROM_RPN: print('___restore from rpn___') model_variables = slim.get_model_variables() - restore_variables = [var for var in model_variables if not var.name.startswith('Fast_Rcnn')] + [slim.get_or_create_global_step()] + if(tf.__version__.startswith("1.") and int(tf.__version__.split(".")[1])<=3) or tf.__version__.startswith("0."): + ### for tf version <=1.3.0 + restore_variables = [var for var in model_variables if not var.name.startswith('Fast_Rcnn')] + [slim.get_or_create_global_step()] + else: ### for tf version >=1.4.0 + restore_variables = [var for var in model_variables if not var.name.startswith('Fast_Rcnn')] + [tf.train.get_or_create_global_step()] for var in restore_variables: print(var.name) restorer = tf.train.Saver(restore_variables) diff --git a/tools/train.py b/tools/train.py index 49f1c0c..ffa7206 100644 --- a/tools/train.py +++ b/tools/train.py @@ -137,7 +137,11 @@ def train(): # train total_loss = slim.losses.get_total_loss() - global_step = slim.get_or_create_global_step() + if(tf.__version__.startswith("1.") and int(tf.__version__.split(".")[1])<=3) or tf.__version__.startswith("0."): + ### for tf version <=1.3.0 + global_step = slim.get_or_create_global_step() + else: ### for tf version >=1.4.0 + global_step = tf.train.get_or_create_global_step() lr = tf.train.piecewise_constant(global_step, boundaries=[np.int64(20000), np.int64(40000)], diff --git a/tools/train1.py b/tools/train1.py index 32cc16c..1eab9b6 100644 --- a/tools/train1.py +++ b/tools/train1.py @@ -148,8 +148,11 @@ def train(): # train total_loss = slim.losses.get_total_loss() - - global_step = slim.get_or_create_global_step() + if(tf.__version__.startswith("1.") and int(tf.__version__.split(".")[1])<=3) or tf.__version__.startswith("0."): + ### for tf version <=1.3.0 + global_step = slim.get_or_create_global_step() + else: ### for tf version >=1.4.0 + global_step = tf.train.get_or_create_global_step() lr = tf.train.piecewise_constant(global_step, boundaries=[np.int64(20000), np.int64(40000)],