Skip to content

Commit

Permalink
Tokenizer code is moved to clip_util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
celll1 committed Sep 4, 2024
1 parent 4b18e79 commit 0d72d61
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 149 deletions.
11 changes: 5 additions & 6 deletions modules/model/FluxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,18 @@ def encode_text(
if tokens_1 is None and text is not None and self.tokenizer_1 is not None:
tokenizer_output = self.tokenizer_1(
text,
padding='max_length',
truncation=True,
max_length=77,
# padding='max_length',
truncation=False,
return_tensors="pt",
)
tokens_1 = tokenizer_output.input_ids.to(self.text_encoder_1.device)

if tokens_2 is None and text is not None and self.tokenizer_2 is not None:
tokenizer_output = self.tokenizer_2(
text,
padding='max_length',
# padding='max_length',
truncation=True,
max_length=77,
max_length=4096,
return_tensors="pt",
)
tokens_2 = tokenizer_output.input_ids.to(self.text_encoder_2.device)
Expand All @@ -241,7 +240,7 @@ def encode_text(
text_encoder_output=None,
add_pooled_output=True,
pooled_text_encoder_output=pooled_text_encoder_1_output,
use_attention_mask=False,
use_attention_mask=True,
)
if pooled_text_encoder_1_output is None:
pooled_text_encoder_1_output = torch.zeros(
Expand Down
64 changes: 12 additions & 52 deletions modules/model/StableDiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,64 +222,24 @@ def encode_text(
text_encoder_layer_skip: int = 0,
text_encoder_output: Tensor | None = None,
):
chunk_length = 75
max_embeddings_multiples = 3

def __process_tokens(tokens):
if tokens is None or tokens.numel() == 0:
return None

chunks = [tokens[:, i:i + chunk_length] for i in range(0, tokens.shape[1], chunk_length)]
chunk_embeddings = []

for chunk in chunks:
if chunk.numel() == 0:
continue

if chunk.shape[1] < chunk_length:
padding = torch.full((chunk.shape[0], chunk_length - chunk.shape[1]), self.tokenizer.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([chunk, padding], dim=1)

bos_tokens = torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id, dtype=chunk.dtype, device=chunk.device)
eos_tokens = torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([bos_tokens, chunk, eos_tokens], dim=1)

with self.autocast_context:
embedding, _ = encode_clip(
text_encoder=self.text_encoder,
tokens=chunk,
default_layer=-1,
layer_skip=text_encoder_layer_skip,
text_encoder_output=None,
add_pooled_output=False,
use_attention_mask=False,
add_layer_norm=True,
)

chunk_embeddings.append(embedding)

if not chunk_embeddings:
return None

if len(chunk_embeddings) > max_embeddings_multiples:
chunk_embeddings = chunk_embeddings[:max_embeddings_multiples]

combined_embedding = torch.cat(chunk_embeddings, dim=1)

return combined_embedding

if tokens is None:
tokenizer_output = self.tokenizer(
text,
padding='max_length',
padding="max_length",
truncation=False,
return_tensors="pt",
)
tokens = tokenizer_output.input_ids.to(self.text_encoder.device)

text_encoder_output = __process_tokens(tokens)

if text_encoder_output is None:
print("Text encoder output is None. Check your input text or tokens.")
text_encoder_output, _ = encode_clip(
text_encoder=self.text_encoder,
tokens=tokens,
default_layer=-1,
layer_skip=text_encoder_layer_skip,
text_encoder_output=text_encoder_output,
add_pooled_output=False,
use_attention_mask=True,
add_layer_norm=True,
)

return text_encoder_output
return text_encoder_output
106 changes: 29 additions & 77 deletions modules/model/StableDiffusionXLModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,95 +203,47 @@ def encode_text(
text_encoder_2_output: Tensor = None,
pooled_text_encoder_2_output: Tensor = None,
):
chunk_length = 75
max_embeddings_multiples = 3

def __process_tokens(tokens, tokenizer, text_encoder, layer_skip):
if tokens is None or tokens.numel() == 0:
return None, None

chunks = [tokens[:, i:i + chunk_length] for i in range(0, tokens.shape[1], chunk_length)]
chunk_embeddings = []
pooled_outputs = []
attention_masks = []

for i, chunk in enumerate(chunks):
if chunk.numel() == 0:
continue

# Create attention mask (1 for non-masked, 0 for masked)
attention_mask = torch.ones_like(chunk, dtype=torch.bool)

# First, add BOS and EOS tokens
bos_tokens = torch.full((chunk.shape[0], 1), tokenizer.bos_token_id, dtype=chunk.dtype, device=chunk.device)
eos_tokens = torch.full((chunk.shape[0], 1), tokenizer.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([bos_tokens, chunk, eos_tokens], dim=1)
attention_mask = torch.cat([torch.zeros_like(bos_tokens, dtype=torch.bool) if i > 0 else torch.ones_like(bos_tokens, dtype=torch.bool),
attention_mask,
torch.zeros_like(eos_tokens, dtype=torch.bool) if i < len(chunks) - 1 else torch.ones_like(eos_tokens, dtype=torch.bool)],
dim=1)

# Fill with padding
if chunk.shape[1] < chunk_length + 2: # +2 is for BOS and EOS
padding = torch.full((chunk.shape[0], chunk_length + 2 - chunk.shape[1]), tokenizer.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([chunk, padding], dim=1)
attention_mask = torch.cat([attention_mask, torch.zeros_like(padding, dtype=torch.bool)], dim=1)

attention_masks.append(attention_mask)

with self.autocast_context:
outputs = text_encoder(
chunk,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
embedding = outputs.hidden_states[-(2 + layer_skip)]
if hasattr(outputs, 'text_embeds'):
pooled_outputs.append(outputs.text_embeds)

chunk_embeddings.append(embedding)

if not chunk_embeddings:
return None, None

if len(chunk_embeddings) > max_embeddings_multiples:
chunk_embeddings = chunk_embeddings[:max_embeddings_multiples]
attention_masks = attention_masks[:max_embeddings_multiples]
if pooled_outputs:
pooled_outputs = pooled_outputs[:max_embeddings_multiples]

combined_embedding = torch.cat(chunk_embeddings, dim=1)
# combined_attention_mask = torch.cat(attention_masks, dim=1)
pooled_output = pooled_outputs[0] if pooled_outputs else None

return combined_embedding, pooled_output

if tokens_1 is None and text is not None:
tokens_1 = self.tokenizer_1(
tokenizer_output = self.tokenizer_1(
text,
padding='max_length',
truncation=False,
return_tensors="pt",
).input_ids.to(self.text_encoder_1.device)
)
tokens_1 = tokenizer_output.input_ids.to(self.text_encoder_1.device)

if tokens_2 is None and text is not None:
tokens_2 = self.tokenizer_2(
tokenizer_output = self.tokenizer_2(
text,
padding='max_length',
truncation=False,
return_tensors="pt",
).input_ids.to(self.text_encoder_2.device)

if text_encoder_1_output is None:
text_encoder_1_output, _ = __process_tokens(tokens_1, self.tokenizer_1, self.text_encoder_1, text_encoder_1_layer_skip)
)
tokens_2 = tokenizer_output.input_ids.to(self.text_encoder_2.device)

if text_encoder_2_output is None or pooled_text_encoder_2_output is None:
text_encoder_2_output, pooled_text_encoder_2_output = __process_tokens(tokens_2, self.tokenizer_2, self.text_encoder_2, text_encoder_2_layer_skip)
text_encoder_1_output, _ = encode_clip(
text_encoder=self.text_encoder_1,
tokens=tokens_1,
default_layer=-2,
layer_skip=text_encoder_1_layer_skip,
text_encoder_output=text_encoder_1_output,
add_pooled_output=False,
use_attention_mask=True,
add_layer_norm=False,
)

if text_encoder_1_output is None or text_encoder_2_output is None:
print("Both text encoder outputs are None. Check your input text or tokens.")
text_encoder_2_output, pooled_text_encoder_2_output = encode_clip(
text_encoder=self.text_encoder_2,
tokens=tokens_2,
default_layer=-2,
layer_skip=text_encoder_2_layer_skip,
text_encoder_output=text_encoder_2_output,
add_pooled_output=True,
pooled_text_encoder_output=pooled_text_encoder_2_output,
use_attention_mask=True,
add_layer_norm=False,
)

text_encoder_output = torch.cat([text_encoder_1_output, text_encoder_2_output], dim=-1)
text_encoder_output = torch.concat([text_encoder_1_output, text_encoder_2_output], dim=-1)

return text_encoder_output, pooled_text_encoder_2_output
return text_encoder_output, pooled_text_encoder_2_output
63 changes: 50 additions & 13 deletions modules/model/util/clip_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch import Tensor
import torch

from transformers import CLIPTextModel, CLIPTextModelWithProjection

Expand All @@ -16,28 +17,64 @@ def encode_clip(
attention_mask: Tensor | None = None,
add_layer_norm: bool = True,
) -> tuple[Tensor, Tensor]:
if (add_output and text_encoder_output is None) \
or (add_pooled_output and pooled_text_encoder_output is None) \
and text_encoder is not None:
chunk_length = 75
max_embeddings_multiples = 3

text_encoder_output = text_encoder(
tokens,
attention_mask=attention_mask if use_attention_mask else None,
if tokens is None or tokens.numel() == 0:
return None, None

chunks = [tokens[:, i:i + chunk_length] for i in range(0, tokens.shape[1], chunk_length)]
chunk_embeddings = [] if add_output else None
pooled_outputs = [] if add_pooled_output else None

for i, chunk in enumerate(chunks):
if chunk.numel() == 0:
continue

# Create attention mask (1 for non-masked, 0 for masked)
chunk_attention_mask = torch.ones_like(chunk, dtype=torch.bool)

# First, add BOS and EOS tokens
bos_tokens = torch.full((chunk.shape[0], 1), text_encoder.config.bos_token_id, dtype=chunk.dtype, device=chunk.device)
eos_tokens = torch.full((chunk.shape[0], 1), text_encoder.config.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([bos_tokens, chunk, eos_tokens], dim=1)
chunk_attention_mask = torch.cat([torch.zeros_like(bos_tokens, dtype=torch.bool) if i > 0 else torch.ones_like(bos_tokens, dtype=torch.bool),
chunk_attention_mask,
torch.zeros_like(eos_tokens, dtype=torch.bool) if i < len(chunks) - 1 else torch.ones_like(eos_tokens, dtype=torch.bool)],
dim=1)

# Fill with padding
if chunk.shape[1] < chunk_length + 2: # +2 is for BOS and EOS
padding = torch.full((chunk.shape[0], chunk_length + 2 - chunk.shape[1]), text_encoder.config.eos_token_id, dtype=chunk.dtype, device=chunk.device)
chunk = torch.cat([chunk, padding], dim=1)
chunk_attention_mask = torch.cat([chunk_attention_mask, torch.zeros_like(padding, dtype=torch.bool)], dim=1)

outputs = text_encoder(
chunk,
attention_mask=chunk_attention_mask if use_attention_mask else None,
return_dict=True,
output_hidden_states=True,
)

if add_output:
embedding = outputs.hidden_states[default_layer - layer_skip]
chunk_embeddings.append(embedding)

pooled_text_encoder_output = None
if add_pooled_output:
if hasattr(text_encoder_output, "text_embeds"):
pooled_text_encoder_output = text_encoder_output.text_embeds
pooled_outputs.append(text_encoder_output.text_embeds)
if hasattr(text_encoder_output, "pooler_output"):
pooled_text_encoder_output = text_encoder_output.pooler_output
pooled_outputs.append(text_encoder_output.pooler_output)

text_encoder_output = text_encoder_output.hidden_states[default_layer - layer_skip] if add_output else None
if chunk_embeddings is not None and len(chunk_embeddings) > max_embeddings_multiples:
chunk_embeddings = chunk_embeddings[:max_embeddings_multiples]
if pooled_outputs is not None and len(pooled_outputs) > max_embeddings_multiples:
pooled_outputs = pooled_outputs[:max_embeddings_multiples]
text_encoder_output = torch.cat(chunk_embeddings, dim=1) if chunk_embeddings is not None else None
pooled_text_encoder_output = pooled_outputs[0] if pooled_outputs else None

if add_layer_norm and text_encoder_output is not None:
final_layer_norm = text_encoder.text_model.final_layer_norm
text_encoder_output = final_layer_norm(text_encoder_output)
if add_layer_norm and text_encoder_output is not None:
final_layer_norm = text_encoder.text_model.final_layer_norm
text_encoder_output = final_layer_norm(text_encoder_output)

return text_encoder_output, pooled_text_encoder_output
3 changes: 2 additions & 1 deletion modules/modelLoader/FluxFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def _default_model_spec_name(
) -> str | None:
match model_type:
case ModelType.FLUX_DEV_1:
return "resources/sd_model_spec/flux_dev_1.0.json"
# return "resources/sd_model_spec/flux_dev_1.0.json"
return None
case _:
return None

Expand Down

0 comments on commit 0d72d61

Please sign in to comment.