Skip to content

Commit

Permalink
Addition of booleans is currently wrong in iree-compile
Browse files Browse the repository at this point in the history
Addition of booleans is performing xor which breaks causal mapping.
  • Loading branch information
rsuderman committed Oct 28, 2024
1 parent ffb146b commit 1268ee5
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def attention_mask(
dtype = self.attention_dtype
_, batch_seq_len = input_mask.shape
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
boolean_mask = causal_mask + input_mask[:, None, None, :]
numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
numeric_mask = torch.where(
boolean_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
return numeric_mask.to(self.device)

def extract_tokens_from_logits(
Expand Down

0 comments on commit 1268ee5

Please sign in to comment.