From 1268ee567c07ff98d2271eafeb7d5a3c2670e344 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 25 Oct 2024 20:49:54 -0700 Subject: [PATCH 1/2] Addition of booleans is currently wrong in iree-compile Addition of booleans is performing xor which breaks causal mapping. --- sharktank/sharktank/layers/causal_llm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 63c58e860..7213b76fc 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -127,9 +127,10 @@ def attention_mask( dtype = self.attention_dtype _, batch_seq_len = input_mask.shape causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] - boolean_mask = causal_mask + input_mask[:, None, None, :] - numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype) - numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype)) + boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) + numeric_mask = torch.where( + boolean_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.to(self.device) def extract_tokens_from_logits( From 1857ebe4b59c00a6ee44e11640cedce904fa67e9 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 25 Oct 2024 21:48:50 -0700 Subject: [PATCH 2/2] move change to decode_attention_mask too --- sharktank/sharktank/layers/causal_llm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 7213b76fc..7a09995a8 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -95,10 +95,9 @@ def input_mask( def decode_attention_mask(self, boolean_input_mask: torch.Tensor): dtype = self.attention_dtype - numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype) - numeric_mask.masked_fill_( - boolean_input_mask, self._maximally_negative_value(dtype) - ) + numeric_mask = torch.where( + boolean_input_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device) def attention_mask(