Skip to content

Commit

Permalink
Fix: clip_util.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
celll1 committed Sep 5, 2024
1 parent 0d72d61 commit e157d75
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
4 changes: 2 additions & 2 deletions modules/model/FluxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def encode_text(
tokenizer_output = self.tokenizer_2(
text,
# padding='max_length',
truncation=True,
max_length=4096,
truncation=False,
max_length=99999999,
return_tensors="pt",
)
tokens_2 = tokenizer_output.input_ids.to(self.text_encoder_2.device)
Expand Down
39 changes: 22 additions & 17 deletions modules/model/util/clip_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def encode_clip(
return None, None

chunks = [tokens[:, i:i + chunk_length] for i in range(0, tokens.shape[1], chunk_length)]
chunk_embeddings = [] if add_output else None
pooled_outputs = [] if add_pooled_output else None
chunk_embeddings = []
pooled_outputs = []

for i, chunk in enumerate(chunks):
if chunk.numel() == 0:
Expand Down Expand Up @@ -59,22 +59,27 @@ def encode_clip(
if add_output:
embedding = outputs.hidden_states[default_layer - layer_skip]
chunk_embeddings.append(embedding)

if add_pooled_output:
if hasattr(text_encoder_output, "text_embeds"):
pooled_outputs.append(text_encoder_output.text_embeds)
if hasattr(text_encoder_output, "pooler_output"):
pooled_outputs.append(text_encoder_output.pooler_output)
if hasattr(outputs, "text_embeds"):
pooled_outputs.append(outputs.text_embeds)
elif hasattr(outputs, "pooler_output"):
pooled_outputs.append(outputs.pooler_output)

if chunk_embeddings is not None and len(chunk_embeddings) > max_embeddings_multiples:
chunk_embeddings = chunk_embeddings[:max_embeddings_multiples]
if pooled_outputs is not None and len(pooled_outputs) > max_embeddings_multiples:
pooled_outputs = pooled_outputs[:max_embeddings_multiples]
text_encoder_output = torch.cat(chunk_embeddings, dim=1) if chunk_embeddings is not None else None
pooled_text_encoder_output = pooled_outputs[0] if pooled_outputs else None
if add_output:
if chunk_embeddings and len(chunk_embeddings) > max_embeddings_multiples:
chunk_embeddings = chunk_embeddings[:max_embeddings_multiples]
text_encoder_output = torch.cat(chunk_embeddings, dim=1)
if add_layer_norm:
final_layer_norm = text_encoder.text_model.final_layer_norm
text_encoder_output = final_layer_norm(text_encoder_output)
else:
text_encoder_output = None

if add_layer_norm and text_encoder_output is not None:
final_layer_norm = text_encoder.text_model.final_layer_norm
text_encoder_output = final_layer_norm(text_encoder_output)
if add_pooled_output:
if pooled_outputs and len(pooled_outputs) > max_embeddings_multiples:
pooled_outputs = pooled_outputs[:max_embeddings_multiples]
pooled_text_encoder_output = pooled_outputs[0] if pooled_outputs else None
else:
pooled_text_encoder_output = None

return text_encoder_output, pooled_text_encoder_output
return text_encoder_output, pooled_text_encoder_output

0 comments on commit e157d75

Please sign in to comment.