diff --git a/onmt/modules/multi_headed_attn.py b/onmt/modules/multi_headed_attn.py index 2510afef83..9ff2ad2277 100644 --- a/onmt/modules/multi_headed_attn.py +++ b/onmt/modules/multi_headed_attn.py @@ -483,6 +483,12 @@ def forward( if sliding_window > 0 and key.size(2) > sliding_window: key = key[:, :, 1:, :] value = value[:, :, 1:, :] + if key_pad_mask is not None and step == 0: + x = key_pad_mask + x = x.expand(-1, self.head_count // self.parallel_gpu, -1) + x = x.unsqueeze(3) + x = x.expand(-1, -1, -1, 128) + value = value.masked_fill(x, 0) self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value