From 485fad90aa0fb444d08e952114c7f582cc002c3d Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Mon, 18 Sep 2023 14:47:09 +0000 Subject: [PATCH] fix: improved diffusion inference scripts, including video generation --- scripts/gen_single_image_diffusion.py | 63 +++++++++++++++------------ scripts/gen_video_diffusion.py | 20 +++++++-- 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/scripts/gen_single_image_diffusion.py b/scripts/gen_single_image_diffusion.py index cb710559d..3c637d4e8 100644 --- a/scripts/gen_single_image_diffusion.py +++ b/scripts/gen_single_image_diffusion.py @@ -148,6 +148,8 @@ def cond_augment(cond, rotation, persp_horizontal, persp_vertical): def generate( seed, model_in_file, + lmodel, + lopt, cpu, gpuid, sampling_steps, @@ -179,27 +181,33 @@ def generate( min_crop_bbox_ratio, ddim_num_steps, ddim_eta, + model_prior_321_backwardcompatibility, **unused_options, ): # seed if seed >= 0: torch.manual_seed(seed) - # loading model - modelpath = model_in_file.replace(os.path.basename(model_in_file), "") - if not cpu: device = torch.device("cuda:" + str(gpuid)) else: device = torch.device("cpu") - model, opt = load_model( - modelpath, - os.path.basename(model_in_file), - device, - sampling_steps, - sampling_method, - args.model_prior_321_backwardcompatibility, - ) + + # loading model + if lmodel is None: + modelpath = model_in_file.replace(os.path.basename(model_in_file), "") + + model, opt = load_model( + modelpath, + os.path.basename(model_in_file), + device, + sampling_steps, + sampling_method, + model_prior_321_backwardcompatibility, + ) + else: + model = lmodel + opt = lopt if alg_palette_cond_image_creation is not None: opt.alg_palette_cond_image_creation = alg_palette_cond_image_creation @@ -572,21 +580,18 @@ def generate( out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0 out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)""" - if img_width > 0 or img_height > 0 or crop_width > 0 or crop_height > 0: - # img_orig = cv2.cvtColor(img_orig, cv2.COLOR_RGB2BGR) - - if bbox_in: - out_img_resized = cv2.resize( - out_img, - ( - min(img_orig.shape[1], bbox_select[2] - bbox_select[0]), - min(img_orig.shape[0], bbox_select[3] - bbox_select[1]), - ), - ) + if bbox_in: + out_img_resized = cv2.resize( + out_img, + ( + min(img_orig.shape[1], bbox_select[2] - bbox_select[0]), + min(img_orig.shape[0], bbox_select[3] - bbox_select[1]), + ), + ) - out_img_real_size = img_orig.copy() - else: - out_img_real_size = out_img + out_img_real_size = img_orig.copy() + else: + out_img_real_size = out_img # fill out crop into original image if bbox_in: @@ -620,7 +625,7 @@ def generate( print("Successfully generated image ", name) - return out_img_real_size + return out_img_real_size, model, opt if __name__ == "__main__": @@ -796,6 +801,10 @@ def generate( real_name = args.name + args.lmodel = None + args.lopt = None for i in tqdm(range(args.nb_samples)): args.name = real_name + "_" + str(i).zfill(len(str(args.nb_samples))) - generate(**vars(args)) + frame, lmodel, lopt = generate(**vars(args)) + args.lmodel = lmodel + args.lopt = lopt diff --git a/scripts/gen_video_diffusion.py b/scripts/gen_video_diffusion.py index d8bdf48a9..57d76bc79 100644 --- a/scripts/gen_video_diffusion.py +++ b/scripts/gen_video_diffusion.py @@ -68,7 +68,7 @@ def natural_keys(text): args.video_width = args.img_width if args.video_height == -1: - args.video_width = args.img_width + args.video_height = args.img_height with open(args.dataroot, "r") as f: paths_list = f.read().split("\n") @@ -103,6 +103,8 @@ def natural_keys(text): args.previous_frame = None + lmodel = None + lopt = None for i, (image, label) in tqdm(enumerate(zip(images, labels)), total=len(images)): args.img_in = args.data_prefix + image @@ -125,8 +127,20 @@ def natural_keys(text): args.write = False""" args.bbox_ref_id = -1 + args.cond_in = None + args.cond_keep_ratio = True + args.alg_palette_guidance_scale = 0.0 + args.alg_palette_cond_image_creation = None + args.alg_palette_sketch_canny_thresholds = None + args.alg_palette_super_resolution_downsample = False + args.data_refined_mask = False + args.min_crop_bbox_ratio = 0.0 + args.model_prior_321_backwardcompatibility = False + + args.lmodel = lmodel + args.lopt = lopt - frame = generate(**vars(args)) + frame, lmodel, lopt = generate(**vars(args)) if args.cond == "previous": # use_real_previous: args.previous_frame = args.data_prefix + image @@ -140,8 +154,6 @@ def natural_keys(text): elif args.cond == "zero": args.previous_frame = None - colored_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2BGRA) - out.write(frame) # When everything done, release the video write objects