Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Token Lengths Exceeding 75 Tokens in Text Encoder #450

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions modules/model/FluxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,18 @@ def encode_text(
if tokens_1 is None and text is not None and self.tokenizer_1 is not None:
tokenizer_output = self.tokenizer_1(
text,
padding='max_length',
truncation=True,
max_length=77,
# padding='max_length',
truncation=False,
return_tensors="pt",
)
tokens_1 = tokenizer_output.input_ids.to(self.text_encoder_1.device)

if tokens_2 is None and text is not None and self.tokenizer_2 is not None:
tokenizer_output = self.tokenizer_2(
text,
padding='max_length',
truncation=True,
max_length=77,
# padding='max_length',
truncation=False,
max_length=99999999,
return_tensors="pt",
)
tokens_2 = tokenizer_output.input_ids.to(self.text_encoder_2.device)
Expand All @@ -241,7 +240,7 @@ def encode_text(
text_encoder_output=None,
add_pooled_output=True,
pooled_text_encoder_output=pooled_text_encoder_1_output,
use_attention_mask=False,
use_attention_mask=True,
)
if pooled_text_encoder_1_output is None:
pooled_text_encoder_1_output = torch.zeros(
Expand Down
9 changes: 4 additions & 5 deletions modules/model/StableDiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,8 @@ def encode_text(
if tokens is None:
tokenizer_output = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=77,
padding="max_length",
truncation=False,
return_tensors="pt",
)
tokens = tokenizer_output.input_ids.to(self.text_encoder.device)
Expand All @@ -239,8 +238,8 @@ def encode_text(
layer_skip=text_encoder_layer_skip,
text_encoder_output=text_encoder_output,
add_pooled_output=False,
use_attention_mask=False,
use_attention_mask=True,
add_layer_norm=True,
)

return text_encoder_output
return text_encoder_output
12 changes: 5 additions & 7 deletions modules/model/StableDiffusionXLModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ def encode_text(
tokenizer_output = self.tokenizer_1(
text,
padding='max_length',
truncation=True,
max_length=77,
truncation=False,
return_tensors="pt",
)
tokens_1 = tokenizer_output.input_ids.to(self.text_encoder_1.device)
Expand All @@ -217,8 +216,7 @@ def encode_text(
tokenizer_output = self.tokenizer_2(
text,
padding='max_length',
truncation=True,
max_length=77,
truncation=False,
return_tensors="pt",
)
tokens_2 = tokenizer_output.input_ids.to(self.text_encoder_2.device)
Expand All @@ -230,7 +228,7 @@ def encode_text(
layer_skip=text_encoder_1_layer_skip,
text_encoder_output=text_encoder_1_output,
add_pooled_output=False,
use_attention_mask=False,
use_attention_mask=True,
add_layer_norm=False,
)

Expand All @@ -242,10 +240,10 @@ def encode_text(
text_encoder_output=text_encoder_2_output,
add_pooled_output=True,
pooled_text_encoder_output=pooled_text_encoder_2_output,
use_attention_mask=False,
use_attention_mask=True,
add_layer_norm=False,
)

text_encoder_output = torch.concat([text_encoder_1_output, text_encoder_2_output], dim=-1)

return text_encoder_output, pooled_text_encoder_2_output
return text_encoder_output, pooled_text_encoder_2_output
74 changes: 58 additions & 16 deletions modules/model/util/clip_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch import Tensor
import torch

from transformers import CLIPTextModel, CLIPTextModelWithProjection

Expand All @@ -16,28 +17,69 @@ def encode_clip(
attention_mask: Tensor | None = None,
add_layer_norm: bool = True,
) -> tuple[Tensor, Tensor]:
if (add_output and text_encoder_output is None) \
or (add_pooled_output and pooled_text_encoder_output is None) \
and text_encoder is not None:
chunk_length = 75
max_embeddings_multiples = 3

text_encoder_output = text_encoder(
tokens,
attention_mask=attention_mask if use_attention_mask else None,
if tokens is None or tokens.numel() == 0:
return None, None

chunks = [tokens[:, i:i + chunk_length] for i in range(0, tokens.shape[1], chunk_length)]
chunk_embeddings = []
pooled_outputs = []

for i, chunk in enumerate(chunks):
if chunk.numel() == 0:
continue

# Create attention mask (1 for non-masked, 0 for masked)
chunk_attention_mask = torch.ones_like(chunk, dtype=torch.bool, device=chunk.device)

# First, add BOS and EOS tokens
bos_tokens = torch.full((chunk.shape[0], 1), text_encoder.config.bos_token_id, dtype=chunk.dtype, device=chunk.device)
eos_tokens = torch.full((chunk.shape[0], 1), text_encoder.config.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([bos_tokens, chunk, eos_tokens], dim=1)
chunk_attention_mask = torch.cat([torch.zeros_like(bos_tokens, dtype=torch.bool) if i > 0 else torch.ones_like(bos_tokens, dtype=torch.bool),
chunk_attention_mask,
torch.zeros_like(eos_tokens, dtype=torch.bool) if i < len(chunks) - 1 else torch.ones_like(eos_tokens, dtype=torch.bool)],
dim=1)

# Fill with padding
if chunk.shape[1] < chunk_length + 2: # +2 is for BOS and EOS
padding = torch.full((chunk.shape[0], chunk_length + 2 - chunk.shape[1]), text_encoder.config.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([chunk, padding], dim=1)
chunk_attention_mask = torch.cat([chunk_attention_mask, torch.zeros_like(padding, dtype=torch.bool)], dim=1)

outputs = text_encoder(
chunk,
attention_mask=chunk_attention_mask if use_attention_mask else None,
return_dict=True,
output_hidden_states=True,
)

pooled_text_encoder_output = None

if add_output:
embedding = outputs.hidden_states[default_layer - layer_skip]
chunk_embeddings.append(embedding)
if add_pooled_output:
if hasattr(text_encoder_output, "text_embeds"):
pooled_text_encoder_output = text_encoder_output.text_embeds
if hasattr(text_encoder_output, "pooler_output"):
pooled_text_encoder_output = text_encoder_output.pooler_output

text_encoder_output = text_encoder_output.hidden_states[default_layer - layer_skip] if add_output else None
if hasattr(outputs, "text_embeds"):
pooled_outputs.append(outputs.text_embeds)
elif hasattr(outputs, "pooler_output"):
pooled_outputs.append(outputs.pooler_output)

if add_layer_norm and text_encoder_output is not None:
if add_output:
if chunk_embeddings and len(chunk_embeddings) > max_embeddings_multiples:
chunk_embeddings = chunk_embeddings[:max_embeddings_multiples]
text_encoder_output = torch.cat(chunk_embeddings, dim=1)
if add_layer_norm:
final_layer_norm = text_encoder.text_model.final_layer_norm
text_encoder_output = final_layer_norm(text_encoder_output)
else:
text_encoder_output = None

if add_pooled_output:
if pooled_outputs and len(pooled_outputs) > max_embeddings_multiples:
pooled_outputs = pooled_outputs[:max_embeddings_multiples]
pooled_text_encoder_output = pooled_outputs[0] if pooled_outputs else None
else:
pooled_text_encoder_output = None

return text_encoder_output, pooled_text_encoder_output
return text_encoder_output, pooled_text_encoder_output
3 changes: 2 additions & 1 deletion modules/modelLoader/FluxFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def _default_model_spec_name(
) -> str | None:
match model_type:
case ModelType.FLUX_DEV_1:
return "resources/sd_model_spec/flux_dev_1.0.json"
# return "resources/sd_model_spec/flux_dev_1.0.json"
return None
case _:
return None

Expand Down
2 changes: 1 addition & 1 deletion modules/modelSampler/FluxSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class FluxSampler(BaseModelSampler):
Expand Down
2 changes: 1 addition & 1 deletion modules/modelSampler/PixArtAlphaSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from PIL.Image import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class PixArtAlphaSampler(BaseModelSampler):
Expand Down
2 changes: 1 addition & 1 deletion modules/modelSampler/StableDiffusion3Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class StableDiffusion3Sampler(BaseModelSampler):
Expand Down
2 changes: 1 addition & 1 deletion modules/modelSampler/StableDiffusionSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.transforms import transforms

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class StableDiffusionSampler(BaseModelSampler):
Expand Down
2 changes: 1 addition & 1 deletion modules/modelSampler/StableDiffusionXLSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.transforms import transforms

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class StableDiffusionXLSampler(BaseModelSampler):
Expand Down
2 changes: 1 addition & 1 deletion modules/modelSampler/WuerstchenSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class WuerstchenSampler(BaseModelSampler):
Expand Down
2 changes: 1 addition & 1 deletion modules/module/BaseImageCaptionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from modules.util import path_util

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class CaptionSample:
Expand Down
2 changes: 1 addition & 1 deletion modules/module/BaseImageMaskModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.transforms import transforms

from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm


class MaskSample:
Expand Down
2 changes: 1 addition & 1 deletion modules/module/GenerateLossesModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from tqdm import tqdm
from tqdm.auto import tqdm


class GenerateLossesModel:
Expand Down
26 changes: 21 additions & 5 deletions modules/trainer/GenericTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from modules.util.torch_util import torch_gc
from modules.util.TrainProgress import TrainProgress

from accelerate import Accelerator, DistributedDataParallelKwargs

import torch
from torch import Tensor, nn
from torch.nn import Parameter
Expand All @@ -38,8 +40,7 @@
from torchvision.transforms.functional import pil_to_tensor

from PIL.Image import Image
from tqdm import tqdm

from tqdm.auto import tqdm

class GenericTrainer(BaseTrainer):
model_loader: BaseModelLoader
Expand All @@ -62,6 +63,16 @@ class GenericTrainer(BaseTrainer):
def __init__(self, config: TrainConfig, callbacks: TrainCallbacks, commands: TrainCommands):
super(GenericTrainer, self).__init__(config, callbacks, commands)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

if hasattr(self.accelerator, 'device') and self.accelerator.device:
print(f"Accelerator device: {self.accelerator.device.type}")
if hasattr(self.accelerator, 'distributed_type') and self.accelerator.distributed_type:
print(f"Distributed type: {self.accelerator.distributed_type}")

print(f"if accelerator is not activated, using {torch.device(self.config.train_device)}")

tensorboard_log_dir = os.path.join(config.workspace_dir, "tensorboard")
os.makedirs(Path(tensorboard_log_dir).absolute(), exist_ok=True)
self.tensorboard = SummaryWriter(os.path.join(tensorboard_log_dir, get_string_timestamp()))
Expand Down Expand Up @@ -138,6 +149,11 @@ def start(self):
self.data_loader = self.create_data_loader(
self.model, self.model.train_progress
)

self.model, self.data_loader = self.accelerator.prepare(
self.model, self.data_loader
)

self.model_saver = self.create_model_saver()

self.model_sampler = self.create_model_sampler(self.model)
Expand Down Expand Up @@ -468,7 +484,7 @@ def __before_eval(self):
self.model.optimizer.eval()

def train(self):
train_device = torch.device(self.config.train_device)
train_device = self.accelerator.device if self.accelerator.device else torch.device(self.config.train_device)

train_progress = self.model.train_progress

Expand Down Expand Up @@ -565,9 +581,9 @@ def sample_commands_fun():

loss = loss / self.config.gradient_accumulation_steps
if scaler:
scaler.scale(loss).backward()
self.accelerator.backward(scaler.scale(loss))
else:
loss.backward()
self.accelerator.backward(loss)

has_gradient = True
accumulated_loss += loss.item()
Expand Down
2 changes: 1 addition & 1 deletion modules/ui/OptimizerParamsWindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def create_dynamic_ui(
'r': {'title': 'R', 'tooltip': 'EMA factor.', 'type': 'float'},
'adanorm': {'title': 'AdaNorm', 'tooltip': 'Whether to use the AdaNorm variant', 'type': 'bool'},
'adam_debias': {'title': 'Adam Debias', 'tooltip': 'Only correct the denominator to avoid inflating step sizes early in training.', 'type': 'bool'},

'model_sharding': {'title': 'Model Sharding', 'tooltip': 'Whether to use model sharding for distributed training.', 'type': 'bool'},
}
# @formatter:on

Expand Down
1 change: 1 addition & 0 deletions modules/ui/TopBar.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __create_training_method(self):
]
elif self.train_config.model_type.is_flux():
values = [
("Fine Tune", TrainingMethod.FINE_TUNE),
("LoRA", TrainingMethod.LORA),
("Embedding", TrainingMethod.EMBEDDING),
]
Expand Down
2 changes: 2 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class TrainOptimizerConfig(BaseConfig):
r: float
adanorm: bool
adam_debias: bool
model_sharding: bool

def __init__(self, data: list[(str, Any, type, bool)]):
super(TrainOptimizerConfig, self).__init__(data)
Expand Down Expand Up @@ -156,6 +157,7 @@ def default_values():
data.append(("r", None, float, True))
data.append(("adanorm", False, bool, False))
data.append(("adam_debias", False, bool, False))
data.append(("model_sharding", False, bool, False))

return TrainOptimizerConfig(data)

Expand Down
Loading