Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
israfelsr committed Jan 9, 2024
2 parents 3f823cf + 60981b7 commit 1d6f480
Show file tree
Hide file tree
Showing 42 changed files with 1,975 additions and 318 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
lint_and_typecheck:
if: ${{ github.event.name == 'push' || github.event.label.name == 'run-ci' }}
if: ${{ github.event_name == 'push' || github.event.label.name == 'run-ci' }}
runs-on: ubuntu-latest

steps:
Expand Down
20 changes: 20 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Deploy docs to GitHub Pages

on:
push:
branches:
- main

jobs:
deploy:
name: Deploy docs
runs-on: ubuntu-latest
steps:
- name: Checkout main
uses: actions/checkout@v2

- name: Deploy MkDocs
uses: mhausenblas/mkdocs-deploy-gh-pages@master
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REQUIREMENTS: ./requirements.docs.txt
15 changes: 3 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ from PIL import Image

from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors
from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors

# Load inputs
init_image = Image.open("dropy_logo.png")
Expand Down Expand Up @@ -122,22 +122,13 @@ t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)

with torch.no_grad():
with no_grad():
# Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids

condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
Expand Down
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Refiners - Docs

WIP
4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
site_name: Refiners

theme:
name: material
1,574 changes: 1,574 additions & 0 deletions notebooks/basics.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"pyright == 1.1.333",
"pyright == 1.1.342",
"ruff>=0.0.292",
"docformatter>=1.7.5",
"pytest>=7.4.2",
"mkdocs-material>=9.5.3",
]


Expand All @@ -66,6 +67,7 @@ allow-direct-references = true

[tool.rye.scripts]
lint = { chain = ["ruff format .", "ruff --fix ."] }
serve-docs = "mkdocs serve"

[tool.black]
line-length = 120
Expand Down
1 change: 1 addition & 0 deletions requirements.docs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mkdocs-material==9.5.3
4 changes: 2 additions & 2 deletions scripts/conversion/convert_diffusers_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn

from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import (
DPMSolver,
SD1ControlnetAdapter,
Expand All @@ -20,7 +20,7 @@ class Args(argparse.Namespace):
output_path: str | None


@torch.no_grad()
@no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
Expand Down
18 changes: 4 additions & 14 deletions scripts/conversion/convert_diffusers_ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,14 @@ def main() -> None:
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2

for i, cross_attn in enumerate(ip_adapter.sub_adapters):
for i, _ in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight"
v_ip = f"{cross_attn_index}.to_v_ip.weight"

# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights

names = [k for k, _ in cross_attn.named_parameters()]
assert len(names) == 2

cross_attn_state_dict: dict[str, Any] = {
names[0]: ip_adapter_weights[k_ip],
names[1]: ip_adapter_weights[v_ip],
}
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)

for k, v in cross_attn_state_dict.items():
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
# the name of the key is not checked at runtime, so we keep the original name
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]

if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
Expand Down
4 changes: 2 additions & 2 deletions scripts/conversion/convert_diffusers_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets

Expand All @@ -37,7 +37,7 @@ class Args(argparse.Namespace):
verbose: bool


@torch.no_grad()
@no_grad()
def process(args: Args) -> None:
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
Expand Down
29 changes: 26 additions & 3 deletions scripts/conversion/convert_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,36 @@ class Args(argparse.Namespace):


def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_mask_encoder = MaskEncoder()

converter = ModelConverter(
source_model=prompt_encoder.mask_downscaling,
target_model=refiners_mask_encoder,
custom_layer_mapping=custom_layers, # type: ignore
)

x = torch.randn(1, 256, 256)
mapping = converter.map_state_dicts(source_args=(x,))
assert mapping

source_state_dict = prompt_encoder.mask_downscaling.state_dict()
target_state_dict = refiners_mask_encoder.state_dict()

# Mapping handled manually (see below) because nn.Parameter is a special case
del target_state_dict["no_mask_embedding"]

converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)

state_dict: dict[str, Tensor] = {
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
}

refiners_mask_encoder = MaskEncoder()
# TODO: handle other weights
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False)
state_dict.update(converted_source)

refiners_mask_encoder.load_state_dict(state_dict=state_dict)

return state_dict

Expand Down
5 changes: 4 additions & 1 deletion src/refiners/fluxion/layers/sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn.functional import pad

Expand Down Expand Up @@ -40,7 +42,8 @@ def __init__(
),
)
if padding == 0:
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad))
if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))

Expand Down
6 changes: 3 additions & 3 deletions src/refiners/fluxion/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle

from refiners.fluxion.utils import norm, save_to_safetensors
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors

TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d,
Expand Down Expand Up @@ -512,7 +512,7 @@ def _verify_missing_basic_layers(self) -> bool:

return True

@torch.no_grad()
@no_grad()
def _trace_module_execution_order(
self,
module: nn.Module,
Expand Down Expand Up @@ -603,7 +603,7 @@ def _convert_state_dict(

return converted_state_dict

@torch.no_grad()
@no_grad()
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:
Expand Down
16 changes: 14 additions & 2 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from pathlib import Path
from typing import Iterable, Literal, TypeVar
from typing import Any, Iterable, Literal, TypeVar

import torch
from jaxtyping import Float
from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore
from torch import (
Tensor,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore

T = TypeVar("T")
Expand All @@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None:
_manual_seed(seed)


class no_grad(_no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)


def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore

Expand Down
7 changes: 5 additions & 2 deletions src/refiners/foundationals/clip/image_encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from torch import device as Device, dtype as DType
from typing import Callable

from torch import Tensor, device as Device, dtype as DType

import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
Expand Down Expand Up @@ -126,6 +128,7 @@ def __init__(
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
super().__init__(
ViTEmbeddings(
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
Expand All @@ -142,7 +145,7 @@ def __init__(
)
for _ in range(num_layers)
),
fl.Lambda(func=lambda x: x[:, 0, :]),
fl.Lambda(func=cls_token_pooling),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
)
Expand Down
5 changes: 3 additions & 2 deletions src/refiners/foundationals/latent_diffusion/freeu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -54,9 +54,10 @@ def forward(self, x: Tensor) -> Tensor:

class FreeUSkipFeatures(fl.Chain):
def __init__(self, n: int, skip_scale: float) -> None:
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
super().__init__(
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)),
fl.Lambda(apply_filter),
)


Expand Down
Loading

0 comments on commit 1d6f480

Please sign in to comment.