Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added merging of transparent images and post-processing (ADtailer, Hires. fix, etc.). #99

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Binary file added 2024-04-16 000652.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 24 additions & 2 deletions lib_layerdiffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import cv2
import numpy as np

from PIL import Image
from tqdm import tqdm
from typing import Optional, Tuple
from diffusers.configuration_utils import ConfigMixin, register_to_config
Expand All @@ -11,6 +12,8 @@
import ldm_patched.modules.model_management as model_management
from ldm_patched.modules.model_patcher import ModelPatcher

from modules import images, processing
from modules.shared import opts

def zero_module(module):
"""
Expand Down Expand Up @@ -234,7 +237,7 @@ def estimate_augmented(self, pixel, latent):
median = torch.median(result, dim=0).values
return median

def patch(self, p, vae_patcher, output_origin):
def patch(self, p, vae_patcher, output_origin, transparentImages):
@torch.no_grad()
def wrapper(func, latent):
pixel = func(latent).movedim(-1, 1).to(device=self.load_device, dtype=self.dtype)
Expand All @@ -260,7 +263,7 @@ def wrapper(func, latent):
fg = y[..., 1:]

B, H, W, C = fg.shape
cb = checkerboard(shape=(H // 64, W // 64))
cb = checkerboard(shape=(H // 1, W // 1))
cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST)
cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None]
cb = torch.from_numpy(cb).to(fg)
Expand All @@ -270,6 +273,25 @@ def wrapper(func, latent):

png = torch.cat([fg, alpha], dim=3)[0]
png = (png * 255.0).detach().cpu().float().numpy().clip(0, 255).astype(np.uint8)

# Save transparent image code.
xpng = Image.fromarray(png)

infotext = processing.Processed(p, []).infotext(p, i)

transparentImages.append(xpng)
images.save_image(
image=xpng,
path=p.outpath_samples,
basename="",
seed=p.seeds[i],
prompt=p.prompts[i],
extension=getattr(opts, 'samples_format', 'png'),
info=infotext,
p=p,
suffix="-transparent"
)

p.extra_result_images.append(png)

vis_list = torch.cat(vis_list, dim=0)
Expand Down
95 changes: 91 additions & 4 deletions scripts/forge_layerdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
import copy

from PIL import Image
from modules import images
from modules import script_callbacks, shared
from modules import scripts
from modules.processing import StableDiffusionProcessing
from lib_layerdiffusion.enums import ResizeMode
Expand All @@ -19,6 +22,7 @@
from lib_layerdiffusion.attention_sharing import AttentionSharingPatcher
from ldm_patched.modules import model_management

from modules.shared import opts

def is_model_loaded(model):
return any(model == m.model for m in current_loaded_models)
Expand Down Expand Up @@ -50,6 +54,9 @@ def load_layer_model_state_dict(filename):


class LayerDiffusionForForge(scripts.Script):

transparentImages = []

def title(self):
return "LayerDiffuse"

Expand All @@ -59,6 +66,7 @@ def show(self, is_img2img):
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
enabled = gr.Checkbox(label='Enabled', value=False)
enabledSaveRebuild = gr.Checkbox(label='Save rebuild image', value=True)
method = gr.Dropdown(choices=[e.value for e in LayerMethod], value=LayerMethod.FG_ONLY_ATTN.value, label="Method", type='value')
gr.HTML('</br>') # some strange gradio problems

Expand Down Expand Up @@ -106,15 +114,15 @@ def method_changed(m):

method.change(method_changed, inputs=method, outputs=[fg_image, bg_image, blend_image, resize_mode, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt], show_progress=False, queue=False)

return enabled, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt
return enabled, enabledSaveRebuild, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt

def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_args, **kwargs):
global vae_transparent_decoder, vae_transparent_encoder

# This will be called before every sampling.
# If you use highres fix, this will be called twice.

enabled, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt = script_args
enabled, enabledSaveRebuild, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt = script_args

if not enabled:
return
Expand Down Expand Up @@ -151,6 +159,8 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar
vae = p.sd_model.forge_objects.vae.clone()
clip = p.sd_model.forge_objects.clip



if method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.FG_ONLY_CONV, LayerMethod.BG_BLEND_TO_FG]:
if vae_transparent_decoder is None:
model_path = load_file_from_url(
Expand All @@ -159,7 +169,7 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar
file_name='vae_transparent_decoder.safetensors'
)
vae_transparent_decoder = TransparentVAEDecoder(load_torch_file(model_path))
vae_transparent_decoder.patch(p, vae.patcher, output_origin)
vae_transparent_decoder.patch(p, vae.patcher, output_origin, self.transparentImages)

if vae_transparent_encoder is None:
model_path = load_file_from_url(
Expand All @@ -182,7 +192,7 @@ def process_before_every_sampling(self, p: StableDiffusionProcessing, *script_ar
vae_transparent_decoder.mod_number = 3
if method == LayerMethod.BG_TO_FG_SD15:
vae_transparent_decoder.mod_number = 2
vae_transparent_decoder.patch(p, vae.patcher, output_origin)
vae_transparent_decoder.patch(p, vae.patcher, output_origin, self.transparentImages)

if vae_transparent_encoder is None:
model_path = load_file_from_url(
Expand Down Expand Up @@ -357,3 +367,80 @@ def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_op
p.sd_model.forge_objects.unet = unet
p.sd_model.forge_objects.vae = vae
return


def postprocess(self, p, processed, *script_args):
"""
This function is called after processing ends for AlwaysVisible scripts.
args contains all values returned by components from ui()

p <modules.processing.StableDiffusionProcessingTxt2Img>
processed <modules.processing.Processed>
"""

enabled, enabledSaveRebuild, method, weight, ending_step, fg_image, bg_image, blend_image, resize_mode, output_origin, fg_additional_prompt, bg_additional_prompt, blend_additional_prompt = script_args

if not enabled or not enabledSaveRebuild:
return

# print( f"processed:{processed}" )
print( f"processed images :{processed.images}" )
print( f"self.transparentImages :{self.transparentImages}" )


if processed.images is not None and len(processed.images) > 0:
images_copy = processed.images[:]

if len(self.transparentImages) != len(processed.images) and len(images_copy) != 1 :
# Batch Process.
images_copy.pop(0)

# extra_images = processed.extra_images

# pil_images = []
# for img in extra_images:
# if img.shape[2] == 3: # RGB
# pil_image = Image.fromarray(img, 'RGB')
# pil_image = pil_image.convert("RGBA")
# elif img.shape[2] == 4: # RGBA
# pil_image = Image.fromarray(img, 'RGBA')
# pil_images.append(pil_image)


for image_a, image_b in zip(self.transparentImages, images_copy):
image_b = image_b.convert("RGBA")

# Create alpha mask with strict threshold
alpha_mask = image_a.getchannel('A').point(lambda x: 255 if x > 30 else 0)

# Extract RGB channels and Alpha from image_a
r_a, g_a, b_a, a_a = image_a.split()

# Extract RGB channels from image_b
r_b, g_b, b_b, a_b = image_b.split()

# Use the strict alpha mask to apply image_a's RGB only where alpha is 255
r_final = Image.composite(r_b, r_a, alpha_mask)
g_final = Image.composite(g_b, g_a, alpha_mask)
b_final = Image.composite(b_b, b_a, alpha_mask)

# Combine the new RGB channels with the original alpha channel of image_b
final_image = Image.merge("RGBA", (r_final, g_final, b_final, a_a))

print("save rebuild.")
# Save the result
images.save_image(
image=final_image,
path=p.outpath_samples,
basename="",
extension=getattr(opts, 'samples_format', 'png'),
p=p,
suffix="-rebuild"
)

self.transparentImages =[]
processed.images = []
else:
print("processed.images is null or zero index.")