diff --git a/test_code/cartoonize.py b/test_code/cartoonize.py index f32e962..4a06f71 100755 --- a/test_code/cartoonize.py +++ b/test_code/cartoonize.py @@ -1,7 +1,11 @@ +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf + import os import cv2 import numpy as np -import tensorflow as tf import network import guided_filter from tqdm import tqdm @@ -23,6 +27,13 @@ def resize_crop(image): def cartoonize(load_folder, save_folder, model_path): + try: + tf.disable_eager_execution() + except: + None + + tf.reset_default_graph() + input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) network_out = network.unet_generator(input_photo) final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3) @@ -63,6 +74,3 @@ def cartoonize(load_folder, save_folder, model_path): if not os.path.exists(save_folder): os.mkdir(save_folder) cartoonize(load_folder, save_folder, model_path) - - - \ No newline at end of file diff --git a/test_code/guided_filter.py b/test_code/guided_filter.py index fd019d1..a13897f 100755 --- a/test_code/guided_filter.py +++ b/test_code/guided_filter.py @@ -1,4 +1,8 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf + import numpy as np diff --git a/test_code/network.py b/test_code/network.py index 6f16cee..a53077c 100755 --- a/test_code/network.py +++ b/test_code/network.py @@ -1,6 +1,11 @@ -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf + import tf_slim as slim +except ImportError: + import tensorflow as tf + import tensorflow.contrib.slim as slim + import numpy as np -import tensorflow.contrib.slim as slim @@ -59,4 +64,4 @@ def unet_generator(inputs, channel=32, num_blocks=4, name='generator', reuse=Fal if __name__ == '__main__': - pass \ No newline at end of file + pass