Skip to content

Commit

Permalink
fix: improved diffusion inference scripts, including video generation
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 19, 2023
1 parent a4e1e04 commit 485fad9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 31 deletions.
63 changes: 36 additions & 27 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def cond_augment(cond, rotation, persp_horizontal, persp_vertical):
def generate(
seed,
model_in_file,
lmodel,
lopt,
cpu,
gpuid,
sampling_steps,
Expand Down Expand Up @@ -179,27 +181,33 @@ def generate(
min_crop_bbox_ratio,
ddim_num_steps,
ddim_eta,
model_prior_321_backwardcompatibility,
**unused_options,
):
# seed
if seed >= 0:
torch.manual_seed(seed)

# loading model
modelpath = model_in_file.replace(os.path.basename(model_in_file), "")

if not cpu:
device = torch.device("cuda:" + str(gpuid))
else:
device = torch.device("cpu")
model, opt = load_model(
modelpath,
os.path.basename(model_in_file),
device,
sampling_steps,
sampling_method,
args.model_prior_321_backwardcompatibility,
)

# loading model
if lmodel is None:
modelpath = model_in_file.replace(os.path.basename(model_in_file), "")

model, opt = load_model(
modelpath,
os.path.basename(model_in_file),
device,
sampling_steps,
sampling_method,
model_prior_321_backwardcompatibility,
)
else:
model = lmodel
opt = lopt

if alg_palette_cond_image_creation is not None:
opt.alg_palette_cond_image_creation = alg_palette_cond_image_creation
Expand Down Expand Up @@ -572,21 +580,18 @@ def generate(
out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0
out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)"""

if img_width > 0 or img_height > 0 or crop_width > 0 or crop_height > 0:
# img_orig = cv2.cvtColor(img_orig, cv2.COLOR_RGB2BGR)

if bbox_in:
out_img_resized = cv2.resize(
out_img,
(
min(img_orig.shape[1], bbox_select[2] - bbox_select[0]),
min(img_orig.shape[0], bbox_select[3] - bbox_select[1]),
),
)
if bbox_in:
out_img_resized = cv2.resize(
out_img,
(
min(img_orig.shape[1], bbox_select[2] - bbox_select[0]),
min(img_orig.shape[0], bbox_select[3] - bbox_select[1]),
),
)

out_img_real_size = img_orig.copy()
else:
out_img_real_size = out_img
out_img_real_size = img_orig.copy()
else:
out_img_real_size = out_img

# fill out crop into original image
if bbox_in:
Expand Down Expand Up @@ -620,7 +625,7 @@ def generate(

print("Successfully generated image ", name)

return out_img_real_size
return out_img_real_size, model, opt


if __name__ == "__main__":
Expand Down Expand Up @@ -796,6 +801,10 @@ def generate(

real_name = args.name

args.lmodel = None
args.lopt = None
for i in tqdm(range(args.nb_samples)):
args.name = real_name + "_" + str(i).zfill(len(str(args.nb_samples)))
generate(**vars(args))
frame, lmodel, lopt = generate(**vars(args))
args.lmodel = lmodel
args.lopt = lopt
20 changes: 16 additions & 4 deletions scripts/gen_video_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def natural_keys(text):
args.video_width = args.img_width

if args.video_height == -1:
args.video_width = args.img_width
args.video_height = args.img_height

with open(args.dataroot, "r") as f:
paths_list = f.read().split("\n")
Expand Down Expand Up @@ -103,6 +103,8 @@ def natural_keys(text):

args.previous_frame = None

lmodel = None
lopt = None
for i, (image, label) in tqdm(enumerate(zip(images, labels)), total=len(images)):

args.img_in = args.data_prefix + image
Expand All @@ -125,8 +127,20 @@ def natural_keys(text):
args.write = False"""

args.bbox_ref_id = -1
args.cond_in = None
args.cond_keep_ratio = True
args.alg_palette_guidance_scale = 0.0
args.alg_palette_cond_image_creation = None
args.alg_palette_sketch_canny_thresholds = None
args.alg_palette_super_resolution_downsample = False
args.data_refined_mask = False
args.min_crop_bbox_ratio = 0.0
args.model_prior_321_backwardcompatibility = False

args.lmodel = lmodel
args.lopt = lopt

frame = generate(**vars(args))
frame, lmodel, lopt = generate(**vars(args))

if args.cond == "previous": # use_real_previous:
args.previous_frame = args.data_prefix + image
Expand All @@ -140,8 +154,6 @@ def natural_keys(text):
elif args.cond == "zero":
args.previous_frame = None

colored_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2BGRA)

out.write(frame)

# When everything done, release the video write objects
Expand Down

0 comments on commit 485fad9

Please sign in to comment.