diff --git a/models/med.py b/models/med.py index 7b00a354..d960323e 100644 --- a/models/med.py +++ b/models/med.py @@ -929,7 +929,7 @@ def forward( cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **kwargs): input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -943,8 +943,8 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "encoder_hidden_states": kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": kwargs.get("encoder_attention_mask", None), "is_decoder": True, }