From 2fbbcf5007509c66b02924ce6dcff66f58e7f58c Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:00:13 +0200 Subject: [PATCH] Fix M4T for ASR pipeline (#32296) * tentative fix * do the same for M4T --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 1 + .../models/seamless_m4t_v2/modeling_seamless_m4t_v2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 0fe1e9f7efa..a79d1d4cf2b 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -3154,6 +3154,7 @@ def generate( """ text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + input_features = input_features if input_features is not None else kwargs.pop("inputs") if tgt_lang is not None: inputs = kwargs.get("input_embeds") if input_features is None else input_features inputs = ( diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 88a8ab466b2..a53f544bb34 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -3422,6 +3422,7 @@ def generate( """ text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. + input_features = input_features if input_features is not None else kwargs.pop("inputs") if tgt_lang is not None: inputs = kwargs.get("input_embeds") if input_features is None else input_features inputs = (