From 1857ebe4b59c00a6ee44e11640cedce904fa67e9 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 25 Oct 2024 21:48:50 -0700 Subject: [PATCH] move change to decode_attention_mask too --- sharktank/sharktank/layers/causal_llm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 7213b76fc..7a09995a8 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -95,10 +95,9 @@ def input_mask( def decode_attention_mask(self, boolean_input_mask: torch.Tensor): dtype = self.attention_dtype - numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype) - numeric_mask.masked_fill_( - boolean_input_mask, self._maximally_negative_value(dtype) - ) + numeric_mask = torch.where( + boolean_input_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device) def attention_mask(