diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 7213b76fc..7a09995a8 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -95,10 +95,9 @@ def input_mask( def decode_attention_mask(self, boolean_input_mask: torch.Tensor): dtype = self.attention_dtype - numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype) - numeric_mask.masked_fill_( - boolean_input_mask, self._maximally_negative_value(dtype) - ) + numeric_mask = torch.where( + boolean_input_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device) def attention_mask(