Skip to content

Commit

Permalink
Fix a few things in text enc code for models with no eos token.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 11, 2024
1 parent 1c8d11e commit 44db978
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,18 @@ def forward(self, tokens):
attention_mask = None
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1)
end_token = self.special_tokens.get("end", None)
if end_token is None:
cmp_token = self.special_tokens.get("pad", -1)
else:
cmp_token = end_token

for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == end_token:
if tokens[x, y] == cmp_token:
if end_token is None:
attention_mask[x, y] = 0
break

attention_mask_model = None
Expand Down Expand Up @@ -522,10 +529,14 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
if self.end_token is not None:
has_end_token = 1
else:
has_end_token = 0

while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
if len(t_group) + len(batch) > self.max_length - has_end_token:
remaining_length = self.max_length - len(batch) - has_end_token
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
Expand Down

0 comments on commit 44db978

Please sign in to comment.