Skip to content

Commit

Permalink
move change to decode_attention_mask too
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 26, 2024
1 parent e16d483 commit a4b140a
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ def input_mask(

def decode_attention_mask(self, boolean_input_mask: torch.Tensor):
dtype = self.attention_dtype
numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype)
numeric_mask.masked_fill_(
boolean_input_mask, self._maximally_negative_value(dtype)
)
numeric_mask = torch.where(
boolean_input_mask, self._maximally_negative_value(dtype), 0
).to(dtype)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)

def attention_mask(
Expand Down

0 comments on commit a4b140a

Please sign in to comment.