Skip to content

Commit

Permalink
Add concepts learning via textual inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
DoryanKaced committed Aug 31, 2023
1 parent 0f476ea commit 0d8dffa
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 9 deletions.
63 changes: 63 additions & 0 deletions configs/finetune-textual-inversion.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
script = "finetune-ldm-textual_inversion.py" # not used for now

[wandb]
mode = "offline" # "online", "offline", "disabled"
entity = "acme"
project = "test-textual-inversion"

[models]
unet = {checkpoint = "tests/weights/unet.safetensors"}
text_encoder = {checkpoint = "tests/weights/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "tests/weights/lda.safetensors"}

[latent_diffusion]
unconditional_sampling_probability = 0.05
offset_noise = 0.1

[textual_inversion]
placeholder_token = "<cat-toy>"
initializer_token = "toy"
# style_mode = true

[training]
duration = "2000:step"
seed = 0
gpu_index = 0
batch_size = 4
gradient_accumulation = "1:step"
evaluation_interval = "250:step"
evaluation_seed = 1

[optimizer]
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
learning_rate = 5e-4
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2

[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"

[dropout]
dropout_probability = 0
use_gyro_dropout = false

[dataset]
hf_repo = "acme/cat-toy"
revision = "main"
horizontal_flip = true
random_crop = true
resize_image_max_size = 512

[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "250:step"

[test_diffusion]
num_inference_steps = 30
use_short_prompts = false
prompts = [
"<cat-toy>",
# "green grass, <cat-toy>"
]
164 changes: 164 additions & 0 deletions scripts/training/finetune-ldm-textual-inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Any
from pydantic import BaseModel
from loguru import logger
from torch.utils.data import Dataset
from torch import randn, Tensor
import random

from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.fluxion.utils import save_to_safetensors
from refiners.training_utils.callback import Callback
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
TextEmbeddingLatentsBatch,
LatentDiffusionTrainer,
LatentDiffusionConfig,
TextEmbeddingLatentsDataset,
)


IMAGENET_TEMPLATES_SMALL = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]

IMAGENET_STYLE_TEMPLATES_SMALL = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]


class TextualInversionDataset(TextEmbeddingLatentsDataset):
templates: list[str] = []
placeholder_token: str = ""

def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
super().__init__(trainer)
self.templates = (
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
)
self.placeholder_token = self.config.textual_inversion.placeholder_token

def get_caption(self, index: int) -> str:
# Ignore the dataset caption, if any: use a template instead
return random.choice(self.templates).format(self.placeholder_token)


class TextualInversionConfig(BaseModel):
# The new token to be learned
placeholder_token: str = "*"
# The token to be used as initializer; if None, a random vector is used
initializer_token: str | None = None
style_mode: bool = False

def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
adapter = ConceptExtender(target=text_encoder)
tokenizer = text_encoder.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found in text encoder."
token_encoder = text_encoder.find(layer_type=TokenEncoder)
assert token_encoder is not None, "Token encoder not found in text encoder."
if self.initializer_token is not None:
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
assert " " not in bpe, "This initializer_token is not a single token."
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
init_embedding = token_encoder(token).squeeze(0)
else:
token_encoder = text_encoder.find(layer_type=TokenEncoder)
assert token_encoder is not None, "Token encoder not found in text encoder."
init_embedding = randn(token_encoder.embedding_dim)
adapter.add_concept(self.placeholder_token, init_embedding)
adapter.inject()


class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
latent_diffusion: LatentDiffusionConfig
textual_inversion: TextualInversionConfig

def model_post_init(self, __context: Any) -> None:
# Pydantic v2 does post init differently, so we need to override this method too.
logger.info("Freezing models to train only the new embedding.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False


class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
def __init__(
self,
config: TextualInversionLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))

def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TextualInversionDataset(trainer=self)


class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)


class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender)
assert embedding_extender is not None, "Embedding extender not found in text encoder."
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}

save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
)


if __name__ == "__main__":
import sys

config_path = sys.argv[1]
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
trainer = TextualInversionLatentDiffusionTrainer(config=config)
trainer.train()
11 changes: 5 additions & 6 deletions src/refiners/foundationals/clip/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def eject(self) -> None:
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
old_weight: Parameter
new_weight: Parameter
weight: Tensor

def __init__(
self,
Expand All @@ -83,22 +82,21 @@ def __init__(
self.new_weight = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
self.weight = cat([self.old_weight, self.new_weight])

# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)
# Concatenate old and new weights for dynamic embedding updates during training
return F.embedding(x, cat([self.old_weight, self.new_weight]))

def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
self.new_weight = Parameter(
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
)
self.weight = cat([self.old_weight, self.new_weight])

@property
def num_embeddings(self) -> int:
return self.weight.shape[0]
return self.old_weight.shape[0] + self.new_weight.shape[0]


class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
Expand All @@ -115,12 +113,13 @@ def __init__(self, target: CLIPTokenizer) -> None:
)

def add_token(self, token: str, token_id: int) -> None:
token = token.lower()
tokenizer = self.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
assert token_id not in tokenizer.token_to_id_mapping.values()
tokenizer.token_to_id_mapping[token] = token_id
current_pattern = tokenizer.token_pattern.pattern
new_pattern = token + "|" + current_pattern
new_pattern = re.escape(token) + "|" + current_pattern
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
# Define the keyword as its own smallest subtoken
tokenizer.byte_pair_encoding_cache[token] = token
10 changes: 8 additions & 2 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,15 @@ def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int =
def process_caption(self, caption: str) -> str:
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""

def get_caption(self, index: int) -> str:
return self.dataset[index]["caption"]

def get_image(self, index: int) -> Image.Image:
return self.dataset[index]["image"]

def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
item = self.dataset[index]
caption, image = item["caption"], item["image"]
caption = self.get_caption(index=index)
image = self.get_image(index=index)
resized_image = self.resize_image(
image=image,
min_size=self.config.dataset.resize_image_min_size,
Expand Down
25 changes: 24 additions & 1 deletion tests/foundationals/clip/test_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from warnings import warn
from pathlib import Path

from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.fluxion.utils import load_from_safetensors
import refiners.fluxion.layers as fl

from diffusers import StableDiffusionPipeline # type: ignore
import transformers # type: ignore
Expand Down Expand Up @@ -84,6 +85,28 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore


def test_tokenizer_with_special_character():
clip_tokenizer = fl.Chain(CLIPTokenizer())
token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer)
new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42
token_extender.add_token("*", new_token_id)
token_extender.inject(clip_tokenizer)

adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer)
assert adapted_clip_tokenizer is not None

assert torch.allclose(
adapted_clip_tokenizer.encode("*"),
torch.Tensor(
[
adapted_clip_tokenizer.start_of_text_token_id,
new_token_id,
adapted_clip_tokenizer.end_of_text_token_id,
]
).to(torch.int64),
)


def test_encoder(
prompt: str,
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
Expand Down

0 comments on commit 0d8dffa

Please sign in to comment.