diff --git a/2024-04-16 000652.png b/2024-04-16 000652.png new file mode 100644 index 0000000..2e30ef1 Binary files /dev/null and b/2024-04-16 000652.png differ diff --git a/lib_layerdiffusion/models.py b/lib_layerdiffusion/models.py index 3d0cf82..59866ec 100644 --- a/lib_layerdiffusion/models.py +++ b/lib_layerdiffusion/models.py @@ -3,6 +3,7 @@ import cv2 import numpy as np +from PIL import Image from tqdm import tqdm from typing import Optional, Tuple from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -11,6 +12,8 @@ import ldm_patched.modules.model_management as model_management from ldm_patched.modules.model_patcher import ModelPatcher +from modules import images, processing +from modules.shared import opts def zero_module(module): """ @@ -234,7 +237,7 @@ def estimate_augmented(self, pixel, latent): median = torch.median(result, dim=0).values return median - def patch(self, p, vae_patcher, output_origin): + def patch(self, p, vae_patcher, output_origin, transparentImages): @torch.no_grad() def wrapper(func, latent): pixel = func(latent).movedim(-1, 1).to(device=self.load_device, dtype=self.dtype) @@ -260,7 +263,7 @@ def wrapper(func, latent): fg = y[..., 1:] B, H, W, C = fg.shape - cb = checkerboard(shape=(H // 64, W // 64)) + cb = checkerboard(shape=(H // 1, W // 1)) cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST) cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None] cb = torch.from_numpy(cb).to(fg) @@ -270,6 +273,25 @@ def wrapper(func, latent): png = torch.cat([fg, alpha], dim=3)[0] png = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8) + + # Save transparent image code. + xpng = Image.fromarray(png) + + infotext = processing.Processed(p, []).infotext(p, i) + + transparentImages.append(xpng) + images.save_image( + image=xpng, + path=p.outpath_samples, + basename="", + seed=p.seeds[i], + prompt=p.prompts[i], + extension=getattr(opts, 'samples_format', 'png'), + info=infotext, + p=p, + suffix="-transparent" + ) + p.extra_result_images.append(png) vis_list = torch.cat(vis_list, dim=0) diff --git a/scripts/forge_layerdiffusion.py b/scripts/forge_layerdiffusion.py index d08e71b..25df429 100644 --- a/scripts/forge_layerdiffusion.py +++ b/scripts/forge_layerdiffusion.py @@ -5,6 +5,9 @@ import numpy as np import copy +from PIL import Image +from modules import images +from modules import script_callbacks, shared from modules import scripts from modules.processing import StableDiffusionProcessing from lib_layerdiffusion.enums import ResizeMode @@ -19,6 +22,7 @@ from lib_layerdiffusion.attention_sharing import AttentionSharingPatcher from ldm_patched.modules import model_management +from modules.shared import opts def is_model_loaded(model): return any(model == m.model for m in current_loaded_models) @@ -50,6 +54,9 @@ def load_layer_model_state_dict(filename): class LayerDiffusionForForge(scripts.Script): + + transparentImages = [] + def title(self): return "LayerDiffuse" @@ -59,6 +66,7 @@ def show(self, is_img2img): def ui(self, *args, **kwargs): with gr.Accordion(open=False, label=self.title()): enabled = gr.Checkbox(label='Enabled', value=False) + enabledSaveRebuild = gr.Checkbox(label='Save rebuild image', value=True) method = gr.Dropdown(choices=[e.value for e in LayerMethod], value=LayerMethod.FG_ONLY_ATTN.value, label="Method", type='value') gr.HTML('
') # some strange gradio problems @@ -106,7 +114,7 @@ def method_changed(m): method.change(method_changed, inputs=method, outputs=[fg_image, bg_image, blend_image, resize_mode, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt], show_progress=False, queue=False) - return enabled, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt + return enabled, enabledSaveRebuild, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_args, **kwargs): global vae_transparent_decoder, vae_transparent_encoder @@ -114,7 +122,7 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar # This will be called before every sampling. # If you use highres fix, this will be called twice. - enabled, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt = script_args + enabled, enabledSaveRebuild, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt = script_args if not enabled: return @@ -151,6 +159,8 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar vae = p.sd_model.forge_objects.vae.clone() clip = p.sd_model.forge_objects.clip + + if method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.FG_ONLY_CONV, LayerMethod.BG_BLEND_TO_FG]: if vae_transparent_decoder is None: model_path = load_file_from_url( @@ -159,7 +169,7 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar file_name='vae_transparent_decoder.safetensors' ) vae_transparent_decoder = TransparentVAEDecoder(load_torch_file(model_path)) - vae_transparent_decoder.patch(p, vae.patcher, output_origin) + vae_transparent_decoder.patch(p, vae.patcher, output_origin, self.transparentImages) if vae_transparent_encoder is None: model_path = load_file_from_url( @@ -182,7 +192,7 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar vae_transparent_decoder.mod_number = 3 if method == LayerMethod.BG_TO_FG_SD15: vae_transparent_decoder.mod_number = 2 - vae_transparent_decoder.patch(p, vae.patcher, output_origin) + vae_transparent_decoder.patch(p, vae.patcher, output_origin, self.transparentImages) if vae_transparent_encoder is None: model_path = load_file_from_url( @@ -357,3 +367,80 @@ def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_op p.sd_model.forge_objects.unet = unet p.sd_model.forge_objects.vae = vae return + + + def postprocess(self, p, processed, *script_args): + """ + This function is called after processing ends for AlwaysVisible scripts. + args contains all values returned by components from ui() + + p + processed + """ + + enabled, enabledSaveRebuild, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt = script_args + + if not enabled or not enabledSaveRebuild: + return + + # print( f"processed:{processed}" ) + print( f"processed images :{processed.images}" ) + print( f"self.transparentImages :{self.transparentImages}" ) + + + if processed.images is not None and len(processed.images) > 0: + images_copy = processed.images[:] + + if len(self.transparentImages) != len(processed.images) and len(images_copy) != 1 : + # Batch Process. + images_copy.pop(0) + + # extra_images = processed.extra_images + + # pil_images = [] + # for img in extra_images: + # if img.shape[2] == 3: # RGB + # pil_image = Image.fromarray(img, 'RGB') + # pil_image = pil_image.convert("RGBA") + # elif img.shape[2] == 4: # RGBA + # pil_image = Image.fromarray(img, 'RGBA') + # pil_images.append(pil_image) + + + for image_a, image_b in zip(self.transparentImages, images_copy): + image_b = image_b.convert("RGBA") + + # Create alpha mask with strict threshold + alpha_mask = image_a.getchannel('A').point(lambda x: 255 if x > 30 else 0) + + # Extract RGB channels and Alpha from image_a + r_a, g_a, b_a, a_a = image_a.split() + + # Extract RGB channels from image_b + r_b, g_b, b_b, a_b = image_b.split() + + # Use the strict alpha mask to apply image_a's RGB only where alpha is 255 + r_final = Image.composite(r_b, r_a, alpha_mask) + g_final = Image.composite(g_b, g_a, alpha_mask) + b_final = Image.composite(b_b, b_a, alpha_mask) + + # Combine the new RGB channels with the original alpha channel of image_b + final_image = Image.merge("RGBA", (r_final, g_final, b_final, a_a)) + + print("save rebuild.") + # Save the result + images.save_image( + image=final_image, + path=p.outpath_samples, + basename="", + extension=getattr(opts, 'samples_format', 'png'), + p=p, + suffix="-rebuild" + ) + + self.transparentImages =[] + processed.images = [] + else: + print("processed.images is null or zero index.") + +