diff --git a/modules/model/FluxModel.py b/modules/model/FluxModel.py index ac7cf1b1..42e24824 100644 --- a/modules/model/FluxModel.py +++ b/modules/model/FluxModel.py @@ -224,8 +224,8 @@ def encode_text( tokenizer_output = self.tokenizer_2( text, # padding='max_length', - truncation=True, - max_length=4096, + truncation=False, + max_length=99999999, return_tensors="pt", ) tokens_2 = tokenizer_output.input_ids.to(self.text_encoder_2.device) diff --git a/modules/model/util/clip_util.py b/modules/model/util/clip_util.py index c6695ba9..e2f5c3ac 100644 --- a/modules/model/util/clip_util.py +++ b/modules/model/util/clip_util.py @@ -24,8 +24,8 @@ def encode_clip( return None, None chunks = [tokens[:, i:i + chunk_length] for i in range(0, tokens.shape[1], chunk_length)] - chunk_embeddings = [] if add_output else None - pooled_outputs = [] if add_pooled_output else None + chunk_embeddings = [] + pooled_outputs = [] for i, chunk in enumerate(chunks): if chunk.numel() == 0: @@ -59,22 +59,27 @@ def encode_clip( if add_output: embedding = outputs.hidden_states[default_layer - layer_skip] chunk_embeddings.append(embedding) - if add_pooled_output: - if hasattr(text_encoder_output, "text_embeds"): - pooled_outputs.append(text_encoder_output.text_embeds) - if hasattr(text_encoder_output, "pooler_output"): - pooled_outputs.append(text_encoder_output.pooler_output) + if hasattr(outputs, "text_embeds"): + pooled_outputs.append(outputs.text_embeds) + elif hasattr(outputs, "pooler_output"): + pooled_outputs.append(outputs.pooler_output) - if chunk_embeddings is not None and len(chunk_embeddings) > max_embeddings_multiples: - chunk_embeddings = chunk_embeddings[:max_embeddings_multiples] - if pooled_outputs is not None and len(pooled_outputs) > max_embeddings_multiples: - pooled_outputs = pooled_outputs[:max_embeddings_multiples] - text_encoder_output = torch.cat(chunk_embeddings, dim=1) if chunk_embeddings is not None else None - pooled_text_encoder_output = pooled_outputs[0] if pooled_outputs else None + if add_output: + if chunk_embeddings and len(chunk_embeddings) > max_embeddings_multiples: + chunk_embeddings = chunk_embeddings[:max_embeddings_multiples] + text_encoder_output = torch.cat(chunk_embeddings, dim=1) + if add_layer_norm: + final_layer_norm = text_encoder.text_model.final_layer_norm + text_encoder_output = final_layer_norm(text_encoder_output) + else: + text_encoder_output = None - if add_layer_norm and text_encoder_output is not None: - final_layer_norm = text_encoder.text_model.final_layer_norm - text_encoder_output = final_layer_norm(text_encoder_output) + if add_pooled_output: + if pooled_outputs and len(pooled_outputs) > max_embeddings_multiples: + pooled_outputs = pooled_outputs[:max_embeddings_multiples] + pooled_text_encoder_output = pooled_outputs[0] if pooled_outputs else None + else: + pooled_text_encoder_output = None - return text_encoder_output, pooled_text_encoder_output + return text_encoder_output, pooled_text_encoder_output \ No newline at end of file