diff --git a/Readme.md b/Readme.md index 84f1e34..1c655fa 100644 --- a/Readme.md +++ b/Readme.md @@ -117,6 +117,7 @@ If we enable Tiny decoder(TAESD) we can save some memory(2GB approx) for example - 1 step fast inference support for SDXL and SD1.5 - Experimental support for single file Safetensors SD 1.5 models(Civitai models), simply add local model path to configs/stable-diffusion-models.txt file. - Add REST API support +- Add Aura SR (4x)/GigaGAN based upscaler support diff --git a/requirements.txt b/requirements.txt index 3f7ec90..ebb6831 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ onnxruntime==1.17.3 pydantic==2.4.2 typing-extensions==4.8.0 pyyaml==6.0.1 -gradio==4.21.0 +gradio==4.23.0 peft==0.6.1 opencv-python==4.8.1.78 omegaconf==2.3.0 diff --git a/src/backend/lcm_text_to_image.py b/src/backend/lcm_text_to_image.py index 580edd5..94d035c 100644 --- a/src/backend/lcm_text_to_image.py +++ b/src/backend/lcm_text_to_image.py @@ -83,6 +83,9 @@ def _add_freeu(self): b2=1.2, ) + def _enable_vae_tiling(self): + self.pipeline.vae.enable_tiling() + def _update_lcm_scheduler_params(self): if isinstance(self.pipeline.scheduler, LCMScheduler): self.pipeline.scheduler = LCMScheduler.from_config( diff --git a/src/backend/models/upscale.py b/src/backend/models/upscale.py index 5b9072f..e065fed 100644 --- a/src/backend/models/upscale.py +++ b/src/backend/models/upscale.py @@ -6,3 +6,4 @@ class UpscaleMode(str, Enum): normal = "normal" sd_upscale = "sd_upscale" + aura_sr = "aura_sr" diff --git a/src/backend/upscale/aura_sr.py b/src/backend/upscale/aura_sr.py new file mode 100644 index 0000000..be6efa3 --- /dev/null +++ b/src/backend/upscale/aura_sr.py @@ -0,0 +1,834 @@ +# AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is +# based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there. +# +# https://mingukkang.github.io/GigaGAN/ +from math import log2, ceil +from functools import partial +from typing import Any, Optional, List, Iterable + +import torch +from torchvision import transforms +from PIL import Image +from torch import nn, einsum, Tensor +import torch.nn.functional as F + +from einops import rearrange, repeat, reduce +from einops.layers.torch import Rearrange + + +def get_same_padding(size, kernel, dilation, stride): + return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 + + +class AdaptiveConv2DMod(nn.Module): + def __init__( + self, + dim, + dim_out, + kernel, + *, + demod=True, + stride=1, + dilation=1, + eps=1e-8, + num_conv_kernels=1, # set this to be greater than 1 for adaptive + ): + super().__init__() + self.eps = eps + + self.dim_out = dim_out + + self.kernel = kernel + self.stride = stride + self.dilation = dilation + self.adaptive = num_conv_kernels > 1 + + self.weights = nn.Parameter( + torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)) + ) + + self.demod = demod + + nn.init.kaiming_normal_( + self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu" + ) + + def forward( + self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None + ): + """ + notation + + b - batch + n - convs + o - output + i - input + k - kernel + """ + + b, h = fmap.shape[0], fmap.shape[-2] + + # account for feature map that has been expanded by the scale in the first dimension + # due to multiscale inputs and outputs + + if mod.shape[0] != b: + mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0]) + + if exists(kernel_mod): + kernel_mod_has_el = kernel_mod.numel() > 0 + + assert self.adaptive or not kernel_mod_has_el + + if kernel_mod_has_el and kernel_mod.shape[0] != b: + kernel_mod = repeat( + kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0] + ) + + # prepare weights for modulation + + weights = self.weights + + if self.adaptive: + weights = repeat(weights, "... -> b ...", b=b) + + # determine an adaptive weight and 'select' the kernel to use with softmax + + assert exists(kernel_mod) and kernel_mod.numel() > 0 + + kernel_attn = kernel_mod.softmax(dim=-1) + kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1") + + weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum") + + # do the modulation, demodulation, as done in stylegan2 + + mod = rearrange(mod, "b i -> b 1 i 1 1") + + weights = weights * (mod + 1) + + if self.demod: + inv_norm = ( + reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum") + .clamp(min=self.eps) + .rsqrt() + ) + weights = weights * inv_norm + + fmap = rearrange(fmap, "b c h w -> 1 (b c) h w") + + weights = rearrange(weights, "b o ... -> (b o) ...") + + padding = get_same_padding(h, self.kernel, self.dilation, self.stride) + fmap = F.conv2d(fmap, weights, padding=padding, groups=b) + + return rearrange(fmap, "1 (b o) ... -> b o ...", b=b) + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, flash=False): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + self.scale = nn.Parameter(torch.randn(1)) + self.flash = flash + + def flash_attn(self, q, k, v): + q, k, v = map(lambda t: t.contiguous(), (q, k, v)) + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0 + ) + return out + + def forward(self, q, k, v): + if self.flash: + return self.flash_attn(q, k, v) + + scale = q.shape[-1] ** -0.5 + + # similarity + sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale + + # attention + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + return out + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cast_tuple(t, length=1): + if isinstance(t, tuple): + return t + return (t,) * length + + +def identity(t, *args, **kwargs): + return t + + +def is_power_of_two(n): + return log2(n).is_integer() + + +def null_iterator(): + while True: + yield None + +def Downsample(dim, dim_out=None): + return nn.Sequential( + Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), + nn.Conv2d(dim * 4, default(dim_out, dim), 1), + ) + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.eps = 1e-4 + + def forward(self, x): + return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5) + + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0): + super().__init__() + self.proj = AdaptiveConv2DMod( + dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels + ) + self.kernel = 3 + self.dilation = 1 + self.stride = 1 + + self.act = nn.SiLU() + + def forward(self, x, conv_mods_iter: Optional[Iterable] = None): + conv_mods_iter = default(conv_mods_iter, null_iterator()) + + x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter)) + + x = self.act(x) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = [] + ): + super().__init__() + style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels]) + + self.block1 = Block( + dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels + ) + self.block2 = Block( + dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels + ) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, conv_mods_iter: Optional[Iterable] = None): + h = self.block1(x, conv_mods_iter=conv_mods_iter) + h = self.block2(h, conv_mods_iter=conv_mods_iter) + + return h + self.res_conv(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim)) + + def forward(self, x): + b, c, h, w = x.shape + + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv + ) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + +class Attention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32, flash=False): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = RMSNorm(dim) + + self.attend = Attend(flash=flash) + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv + ) + + out = self.attend(q, k, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + + return self.to_out(out) + + +# feedforward +def FeedForward(dim, mult=4): + return nn.Sequential( + RMSNorm(dim), + nn.Conv2d(dim, dim * mult, 1), + nn.GELU(), + nn.Conv2d(dim * mult, dim, 1), + ) + + +# transformers +class Transformer(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4): + super().__init__() + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return x + + +class LinearTransformer(nn.Module): + def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4): + super().__init__() + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + LinearAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return x + + +class NearestNeighborhoodUpsample(nn.Module): + def __init__(self, dim, dim_out=None): + super().__init__() + dim_out = default(dim_out, dim) + self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + + if x.shape[0] >= 64: + x = x.contiguous() + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + +class EqualLinear(nn.Module): + def __init__(self, dim, dim_out, lr_mul=1, bias=True): + super().__init__() + self.weight = nn.Parameter(torch.randn(dim_out, dim)) + if bias: + self.bias = nn.Parameter(torch.zeros(dim_out)) + + self.lr_mul = lr_mul + + def forward(self, input): + return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) + + +class StyleGanNetwork(nn.Module): + def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_text_latent = dim_text_latent + + layers = [] + for i in range(depth): + is_first = i == 0 + + if is_first: + dim_in_layer = dim_in + dim_text_latent + else: + dim_in_layer = dim_out + + dim_out_layer = dim_out + + layers.extend( + [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)] + ) + + self.net = nn.Sequential(*layers) + + def forward(self, x, text_latent=None): + x = F.normalize(x, dim=1) + if self.dim_text_latent > 0: + assert exists(text_latent) + x = torch.cat((x, text_latent), dim=-1) + return self.net(x) + + +class UnetUpsampler(torch.nn.Module): + + def __init__( + self, + dim: int, + *, + image_size: int, + input_image_size: int, + init_dim: Optional[int] = None, + out_dim: Optional[int] = None, + style_network: Optional[dict] = None, + up_dim_mults: tuple = (1, 2, 4, 8, 16), + down_dim_mults: tuple = (4, 8, 16), + channels: int = 3, + resnet_block_groups: int = 8, + full_attn: tuple = (False, False, False, True, True), + flash_attn: bool = True, + self_attn_dim_head: int = 64, + self_attn_heads: int = 8, + attn_depths: tuple = (2, 2, 2, 2, 4), + mid_attn_depth: int = 4, + num_conv_kernels: int = 4, + resize_mode: str = "bilinear", + unconditional: bool = True, + skip_connect_scale: Optional[float] = None, + ): + super().__init__() + self.style_network = style_network = StyleGanNetwork(**style_network) + self.unconditional = unconditional + assert not ( + unconditional + and exists(style_network) + and style_network.dim_text_latent > 0 + ) + + assert is_power_of_two(image_size) and is_power_of_two( + input_image_size + ), "both output image size and input image size must be power of 2" + assert ( + input_image_size < image_size + ), "input image size must be smaller than the output image size, thus upsampling" + + self.image_size = image_size + self.input_image_size = input_image_size + + style_embed_split_dims = [] + + self.channels = channels + input_channels = channels + + init_dim = default(init_dim, dim) + + up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)] + init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)] + down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)] + self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3) + + up_in_out = list(zip(up_dims[:-1], up_dims[1:])) + down_in_out = list(zip(down_dims[:-1], down_dims[1:])) + + block_klass = partial( + ResnetBlock, + groups=resnet_block_groups, + num_conv_kernels=num_conv_kernels, + style_dims=style_embed_split_dims, + ) + + FullAttention = partial(Transformer, flash_attn=flash_attn) + *_, mid_dim = up_dims + + self.skip_connect_scale = default(skip_connect_scale, 2**-0.5) + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + block_count = 6 + + for ind, ( + (dim_in, dim_out), + layer_full_attn, + layer_attn_depth, + ) in enumerate(zip(down_in_out, full_attn, attn_depths)): + attn_klass = FullAttention if layer_full_attn else LinearTransformer + + blocks = [] + for i in range(block_count): + blocks.append(block_klass(dim_in, dim_in)) + + self.downs.append( + nn.ModuleList( + [ + nn.ModuleList(blocks), + nn.ModuleList( + [ + ( + attn_klass( + dim_in, + dim_head=self_attn_dim_head, + heads=self_attn_heads, + depth=layer_attn_depth, + ) + if layer_full_attn + else None + ), + nn.Conv2d( + dim_in, dim_out, kernel_size=3, stride=2, padding=1 + ), + ] + ), + ] + ) + ) + + self.mid_block1 = block_klass(mid_dim, mid_dim) + self.mid_attn = FullAttention( + mid_dim, + dim_head=self_attn_dim_head, + heads=self_attn_heads, + depth=mid_attn_depth, + ) + self.mid_block2 = block_klass(mid_dim, mid_dim) + + *_, last_dim = up_dims + + for ind, ( + (dim_in, dim_out), + layer_full_attn, + layer_attn_depth, + ) in enumerate( + zip( + reversed(up_in_out), + reversed(full_attn), + reversed(attn_depths), + ) + ): + attn_klass = FullAttention if layer_full_attn else LinearTransformer + + blocks = [] + input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in + for i in range(block_count): + blocks.append(block_klass(input_dim, dim_in)) + + self.ups.append( + nn.ModuleList( + [ + nn.ModuleList(blocks), + nn.ModuleList( + [ + NearestNeighborhoodUpsample( + last_dim if ind == 0 else dim_out, + dim_in, + ), + ( + attn_klass( + dim_in, + dim_head=self_attn_dim_head, + heads=self_attn_heads, + depth=layer_attn_depth, + ) + if layer_full_attn + else None + ), + ] + ), + ] + ) + ) + + self.out_dim = default(out_dim, channels) + self.final_res_block = block_klass(dim, dim) + self.final_to_rgb = nn.Conv2d(dim, channels, 1) + self.resize_mode = resize_mode + self.style_to_conv_modulations = nn.Linear( + style_network.dim_out, sum(style_embed_split_dims) + ) + self.style_embed_split_dims = style_embed_split_dims + + @property + def allowable_rgb_resolutions(self): + input_res_base = int(log2(self.input_image_size)) + output_res_base = int(log2(self.image_size)) + allowed_rgb_res_base = list(range(input_res_base, output_res_base)) + return [*map(lambda p: 2**p, allowed_rgb_res_base)] + + @property + def device(self): + return next(self.parameters()).device + + @property + def total_params(self): + return sum([p.numel() for p in self.parameters()]) + + def resize_image_to(self, x, size): + return F.interpolate(x, (size, size), mode=self.resize_mode) + + def forward( + self, + lowres_image: torch.Tensor, + styles: Optional[torch.Tensor] = None, + noise: Optional[torch.Tensor] = None, + global_text_tokens: Optional[torch.Tensor] = None, + return_all_rgbs: bool = False, + ): + x = lowres_image + + noise_scale = 0.001 # Adjust the scale of the noise as needed + noise_aug = torch.randn_like(x) * noise_scale + x = x + noise_aug + x = x.clamp(0, 1) + + shape = x.shape + batch_size = shape[0] + + assert shape[-2:] == ((self.input_image_size,) * 2) + + # styles + if not exists(styles): + assert exists(self.style_network) + + noise = default( + noise, + torch.randn( + (batch_size, self.style_network.dim_in), device=self.device + ), + ) + styles = self.style_network(noise, global_text_tokens) + + # project styles to conv modulations + conv_mods = self.style_to_conv_modulations(styles) + conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1) + conv_mods = iter(conv_mods) + + x = self.init_conv(x) + + h = [] + for blocks, (attn, downsample) in self.downs: + for block in blocks: + x = block(x, conv_mods_iter=conv_mods) + h.append(x) + + if attn is not None: + x = attn(x) + + x = downsample(x) + + x = self.mid_block1(x, conv_mods_iter=conv_mods) + x = self.mid_attn(x) + x = self.mid_block2(x, conv_mods_iter=conv_mods) + + for ( + blocks, + ( + upsample, + attn, + ), + ) in self.ups: + x = upsample(x) + for block in blocks: + if h != []: + res = h.pop() + res = res * self.skip_connect_scale + x = torch.cat((x, res), dim=1) + + x = block(x, conv_mods_iter=conv_mods) + + if attn is not None: + x = attn(x) + + x = self.final_res_block(x, conv_mods_iter=conv_mods) + rgb = self.final_to_rgb(x) + + if not return_all_rgbs: + return rgb + + return rgb, [] + + +def tile_image(image, chunk_size=64): + c, h, w = image.shape + h_chunks = ceil(h / chunk_size) + w_chunks = ceil(w / chunk_size) + tiles = [] + for i in range(h_chunks): + for j in range(w_chunks): + tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size] + tiles.append(tile) + return tiles, h_chunks, w_chunks + + +def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64): + # Determine the shape of the output tensor + c = tiles[0].shape[0] + h = h_chunks * chunk_size + w = w_chunks * chunk_size + + # Create an empty tensor to hold the merged image + merged = torch.zeros((c, h, w), dtype=tiles[0].dtype) + + # Iterate over the tiles and place them in the correct position + for idx, tile in enumerate(tiles): + i = idx // w_chunks + j = idx % w_chunks + + h_start = i * chunk_size + w_start = j * chunk_size + + tile_h, tile_w = tile.shape[1:] + merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile + + return merged + + +class AuraSR: + def __init__(self, config: dict[str, Any], device: str = "cuda"): + self.upsampler = UnetUpsampler(**config).to(device) + self.input_image_size = config["input_image_size"] + + @classmethod + def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_safetensors: bool = True): + import json + import torch + from pathlib import Path + from huggingface_hub import snapshot_download + + # Check if model_id is a local file + if Path(model_id).is_file(): + local_file = Path(model_id) + if local_file.suffix == '.safetensors': + use_safetensors = True + elif local_file.suffix == '.ckpt': + use_safetensors = False + else: + raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") + + # For local files, we need to provide the config separately + config_path = local_file.with_name('config.json') + if not config_path.exists(): + raise FileNotFoundError( + f"Config file not found: {config_path}. " + f"When loading from a local file, ensure that 'config.json' " + f"is present in the same directory as '{local_file.name}'. " + f"If you're trying to load a model from Hugging Face, " + f"please provide the model ID instead of a file path." + ) + + config = json.loads(config_path.read_text()) + hf_model_path = local_file.parent + else: + hf_model_path = Path(snapshot_download(model_id)) + config = json.loads((hf_model_path / "config.json").read_text()) + + model = cls(config,device) + + if use_safetensors: + try: + from safetensors.torch import load_file + checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) + except ImportError: + raise ImportError( + "The safetensors library is not installed. " + "Please install it with `pip install safetensors` " + "or use `use_safetensors=False` to load the model with PyTorch." + ) + else: + checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) + + model.upsampler.load_state_dict(checkpoint, strict=True) + return model + + @torch.no_grad() + def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image: + tensor_transform = transforms.ToTensor() + device = self.upsampler.device + + image_tensor = tensor_transform(image).unsqueeze(0) + _, _, h, w = image_tensor.shape + pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size + pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size + + # Pad the image + image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0) + tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size) + + # Batch processing of tiles + num_tiles = len(tiles) + batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)] + reconstructed_tiles = [] + + for batch in batches: + model_input = torch.stack(batch).to(device) + generator_output = self.upsampler( + lowres_image=model_input, + noise=torch.randn(model_input.shape[0], 128, device=device) + ) + reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu())) + + merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4) + unpadded = merged_tensor[:, :h * 4, :w * 4] + + to_pil = transforms.ToPILImage() + return to_pil(unpadded) + diff --git a/src/backend/upscale/aura_sr_upscale.py b/src/backend/upscale/aura_sr_upscale.py new file mode 100644 index 0000000..932487c --- /dev/null +++ b/src/backend/upscale/aura_sr_upscale.py @@ -0,0 +1,9 @@ +from backend.upscale.aura_sr import AuraSR +from PIL import Image + + +def upscale_aura_sr(image_path: str): + + aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device="cpu") + image_in = Image.open(image_path) # .resize((256, 256)) + return aura_sr.upscale_4x(image_in) diff --git a/src/backend/upscale/upscaler.py b/src/backend/upscale/upscaler.py index a923dfd..563e588 100644 --- a/src/backend/upscale/upscaler.py +++ b/src/backend/upscale/upscaler.py @@ -1,6 +1,7 @@ from backend.models.lcmdiffusion_setting import DiffusionTask from backend.models.upscale import UpscaleMode from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x +from backend.upscale.aura_sr_upscale import upscale_aura_sr from backend.upscale.tiled_upscale import generate_upscaled_image from context import Context from PIL import Image @@ -22,6 +23,10 @@ def upscale_image( upscaled_img = upscale_edsr_2x(src_image_path) upscaled_img.save(dst_image_path) print(f"Upscaled image saved {dst_image_path}") + elif upscale_mode == UpscaleMode.aura_sr.value: + upscaled_img = upscale_aura_sr(src_image_path) + upscaled_img.save(dst_image_path) + print(f"Upscaled image saved {dst_image_path}") else: config.settings.lcm_diffusion_setting.strength = ( 0.3 if config.settings.lcm_diffusion_setting.use_openvino else 0.1 diff --git a/src/constants.py b/src/constants.py index 4b23e5f..6074748 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,6 +1,6 @@ from os import environ -APP_VERSION = "v1.0.0 beta 32" +APP_VERSION = "v1.0.0 beta 33" LCM_DEFAULT_MODEL = "stabilityai/sd-turbo" LCM_DEFAULT_MODEL_OPENVINO = "rupeshs/sd-turbo-openvino" APP_NAME = "FastSD CPU" diff --git a/src/frontend/webui/generation_settings_ui.py b/src/frontend/webui/generation_settings_ui.py index 455fbd3..be18e8d 100644 --- a/src/frontend/webui/generation_settings_ui.py +++ b/src/frontend/webui/generation_settings_ui.py @@ -97,7 +97,7 @@ def get_generation_settings_ui() -> None: ) guidance_scale = gr.Slider( 1.0, - 2.0, + 10.0, value=app_settings.settings.lcm_diffusion_setting.guidance_scale, step=0.1, label="Guidance Scale", diff --git a/src/frontend/webui/upscaler_ui.py b/src/frontend/webui/upscaler_ui.py index 41d7cde..ec58312 100644 --- a/src/frontend/webui/upscaler_ui.py +++ b/src/frontend/webui/upscaler_ui.py @@ -24,6 +24,9 @@ def create_upscaled_image( scale_factor = 2 if upscale_mode == "SD": mode = UpscaleMode.sd_upscale.value + elif upscale_mode == "AURA-SR": + mode = UpscaleMode.aura_sr.value + scale_factor = 4 else: mode = UpscaleMode.normal.value @@ -48,8 +51,8 @@ def get_upscaler_ui() -> None: input_image = gr.Image(label="Image", type="filepath") with gr.Row(): upscale_mode = gr.Radio( - ["EDSR", "SD"], - label="Upscale Mode (2x)", + ["EDSR", "SD", "AURA-SR"], + label="Upscale Mode (2x) | AURA-SR (4x)", info="Select upscale method, SD Upscale is experimental", value="EDSR", )