From 1268ee567c07ff98d2271eafeb7d5a3c2670e344 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 25 Oct 2024 20:49:54 -0700 Subject: [PATCH] Addition of booleans is currently wrong in iree-compile Addition of booleans is performing xor which breaks causal mapping. --- sharktank/sharktank/layers/causal_llm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 63c58e860..7213b76fc 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -127,9 +127,10 @@ def attention_mask( dtype = self.attention_dtype _, batch_seq_len = input_mask.shape causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] - boolean_mask = causal_mask + input_mask[:, None, None, :] - numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype) - numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype)) + boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) + numeric_mask = torch.where( + boolean_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.to(self.device) def extract_tokens_from_logits(