diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 9b3870e825050f..2c04f15c045696 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -909,8 +909,9 @@ def generate( coarse_generation_config: BarkCoarseGenerationConfig = None, codebook_size: int = 1024, history_prompt: Optional[Dict[str, torch.Tensor]] = None, + return_output_lengths: Optional[bool] = None, **kwargs, - ) -> torch.LongTensor: + ) -> Union[torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]]: """ Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker prompt. @@ -926,8 +927,14 @@ def generate( Codebook channel size, i.e. the size of the output vocabulary per codebook channel. history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*): Optional `Bark` speaker prompt. + return_output_lengths (`bool`, *optional*): + Whether or not to return the output lengths. Useful when batching. Returns: - torch.LongTensor: Output coarse acoustics tokens. + By default: + torch.LongTensor: Output coarse acoustics tokens. + If `return_output_lengths=True`: + `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample + of the batch. """ if semantic_generation_config is None: @@ -954,13 +961,13 @@ def generate( ) max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) - # beware, depends on the seq_len of the longest sequence of the batch. - # Also, the seq_len might be one token too long because of an added - # pad_token as compared to Bark original implementation. - max_generated_len = np.floor( - semantic_output.shape[1] * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks + output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1) + output_lengths = torch.floor( + output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks ) - max_generated_len = int(round(max_generated_len * coarse_generation_config.n_coarse_codebooks)) + output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int() + + max_generated_len = torch.max(output_lengths).item() batch_size = semantic_output.shape[0] @@ -1026,6 +1033,9 @@ def generate( coarse_output = x_coarse[:, len_coarse_history:] + if return_output_lengths: + return coarse_output, output_lengths + return coarse_output @@ -1502,13 +1512,21 @@ def enable_cpu_offload(self, gpu_id: Optional[int] = 0): # We'll offload the last model manually. self.codec_model_hook = hook - def codec_decode(self, fine_output): + def codec_decode(self, fine_output, output_lengths=None): """Turn quantized audio codes into audio array using encodec.""" fine_output = fine_output.transpose(0, 1) emb = self.codec_model.quantizer.decode(fine_output) - out = self.codec_model.decoder(emb) - audio_arr = out.squeeze(1) # squeeze the codebook dimension + + if output_lengths is not None: + # encodec uses LSTMs which behaves differently with appended padding + # decoding with encodec takes around 0.1% of the total generation time + # to keep generation quality, we break batching + out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)] + audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out] + else: + out = self.codec_model.decoder(emb) + audio_arr = out.squeeze(1) # squeeze the codebook dimension return audio_arr @@ -1517,6 +1535,7 @@ def generate( self, input_ids: Optional[torch.Tensor] = None, history_prompt: Optional[Dict[str, torch.Tensor]] = None, + return_output_lengths: Optional[bool] = None, **kwargs, ) -> torch.LongTensor: """ @@ -1535,9 +1554,15 @@ def generate( semantic, coarse and fine respectively. It has the priority over the keywords without a prefix. This means you can, for example, specify a generation strategy for all sub-models except one. + return_output_lengths (`bool`, *optional*): + Whether or not to return the waveform lengths. Useful when batching. Returns: - torch.LongTensor: Output generated audio. - + By default: + - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform. + When `return_output_lengths=True`: + Returns a tuple made of: + - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform. + - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch Example: ```python @@ -1603,9 +1628,16 @@ def generate( semantic_generation_config=semantic_generation_config, coarse_generation_config=coarse_generation_config, codebook_size=self.generation_config.codebook_size, + return_output_lengths=return_output_lengths, **kwargs_coarse, ) + output_lengths = None + if return_output_lengths: + coarse_output, output_lengths = coarse_output + # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len) + output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks + # 3. "generate" from the fine model output = self.fine_acoustics.generate( coarse_output, @@ -1625,10 +1657,15 @@ def generate( self.codec_model = self.codec_model.to(self.device) # 4. Decode the output and generate audio array - audio = self.codec_decode(output) + audio = self.codec_decode(output, output_lengths) if getattr(self, "codec_model_hook", None) is not None: # Offload codec_model to CPU self.codec_model_hook.offload() + if return_output_lengths: + output_lengths = [len(sample) for sample in audio] + audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0) + return audio, output_lengths + return audio diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 4186a72628e34e..bf13203ecd405a 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -1067,6 +1067,37 @@ def test_generate_end_to_end_with_args(self): self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6) self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4) + @slow + def test_generate_batching(self): + args = {"do_sample": False, "temperature": None} + + s1 = "I love HuggingFace" + s2 = "In the light of the moon, a little egg lay on a leaf" + voice_preset = "en_speaker_6" + input_ids = self.processor([s1, s2], voice_preset=voice_preset).to(torch_device) + + # generate in batch + outputs, audio_lengths = self.model.generate(**input_ids, **args, return_output_lengths=True) + + # generate one-by-one + s1 = self.processor(s1, voice_preset=voice_preset).to(torch_device) + s2 = self.processor(s2, voice_preset=voice_preset).to(torch_device) + output1 = self.model.generate(**s1, **args) + output2 = self.model.generate(**s2, **args) + + # up until the coarse acoustic model (included), results are the same + # the fine acoustic model introduces small differences + # first verify if same length (should be the same because it's decided in the coarse model) + self.assertEqual(tuple(audio_lengths), (output1.shape[1], output2.shape[1])) + + # then assert almost equal + self.assertTrue(torch.allclose(outputs[0, : audio_lengths[0]], output1.squeeze(), atol=2e-3)) + self.assertTrue(torch.allclose(outputs[1, : audio_lengths[1]], output2.squeeze(), atol=2e-3)) + + # now test single input with return_output_lengths = True + outputs, _ = self.model.generate(**s1, **args, return_output_lengths=True) + self.assertTrue((outputs == output1).all().item()) + @slow def test_generate_end_to_end_with_sub_models_args(self): input_ids = self.inputs