Skip to content

Commit

Permalink
FIx Bark batching feature (huggingface#27271)
Browse files Browse the repository at this point in the history
* fix bark batching

* make style

* add tests and make style
  • Loading branch information
ylacombe authored Nov 7, 2023
1 parent 8f840ed commit ac5d4cf
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 14 deletions.
65 changes: 51 additions & 14 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,9 @@ def generate(
coarse_generation_config: BarkCoarseGenerationConfig = None,
codebook_size: int = 1024,
history_prompt: Optional[Dict[str, torch.Tensor]] = None,
return_output_lengths: Optional[bool] = None,
**kwargs,
) -> torch.LongTensor:
) -> Union[torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]]:
"""
Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
prompt.
Expand All @@ -926,8 +927,14 @@ def generate(
Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
Optional `Bark` speaker prompt.
return_output_lengths (`bool`, *optional*):
Whether or not to return the output lengths. Useful when batching.
Returns:
torch.LongTensor: Output coarse acoustics tokens.
By default:
torch.LongTensor: Output coarse acoustics tokens.
If `return_output_lengths=True`:
`Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
of the batch.
"""

if semantic_generation_config is None:
Expand All @@ -954,13 +961,13 @@ def generate(
)
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))

# beware, depends on the seq_len of the longest sequence of the batch.
# Also, the seq_len might be one token too long because of an added
# pad_token as compared to Bark original implementation.
max_generated_len = np.floor(
semantic_output.shape[1] * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
output_lengths = torch.floor(
output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
)
max_generated_len = int(round(max_generated_len * coarse_generation_config.n_coarse_codebooks))
output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()

max_generated_len = torch.max(output_lengths).item()

batch_size = semantic_output.shape[0]

Expand Down Expand Up @@ -1026,6 +1033,9 @@ def generate(

coarse_output = x_coarse[:, len_coarse_history:]

if return_output_lengths:
return coarse_output, output_lengths

return coarse_output


Expand Down Expand Up @@ -1502,13 +1512,21 @@ def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
# We'll offload the last model manually.
self.codec_model_hook = hook

def codec_decode(self, fine_output):
def codec_decode(self, fine_output, output_lengths=None):
"""Turn quantized audio codes into audio array using encodec."""

fine_output = fine_output.transpose(0, 1)
emb = self.codec_model.quantizer.decode(fine_output)
out = self.codec_model.decoder(emb)
audio_arr = out.squeeze(1) # squeeze the codebook dimension

if output_lengths is not None:
# encodec uses LSTMs which behaves differently with appended padding
# decoding with encodec takes around 0.1% of the total generation time
# to keep generation quality, we break batching
out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
else:
out = self.codec_model.decoder(emb)
audio_arr = out.squeeze(1) # squeeze the codebook dimension

return audio_arr

Expand All @@ -1517,6 +1535,7 @@ def generate(
self,
input_ids: Optional[torch.Tensor] = None,
history_prompt: Optional[Dict[str, torch.Tensor]] = None,
return_output_lengths: Optional[bool] = None,
**kwargs,
) -> torch.LongTensor:
"""
Expand All @@ -1535,9 +1554,15 @@ def generate(
semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
This means you can, for example, specify a generation strategy for all sub-models except one.
return_output_lengths (`bool`, *optional*):
Whether or not to return the waveform lengths. Useful when batching.
Returns:
torch.LongTensor: Output generated audio.
By default:
- **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
When `return_output_lengths=True`:
Returns a tuple made of:
- **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
- **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
Example:
```python
Expand Down Expand Up @@ -1603,9 +1628,16 @@ def generate(
semantic_generation_config=semantic_generation_config,
coarse_generation_config=coarse_generation_config,
codebook_size=self.generation_config.codebook_size,
return_output_lengths=return_output_lengths,
**kwargs_coarse,
)

output_lengths = None
if return_output_lengths:
coarse_output, output_lengths = coarse_output
# (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks

# 3. "generate" from the fine model
output = self.fine_acoustics.generate(
coarse_output,
Expand All @@ -1625,10 +1657,15 @@ def generate(
self.codec_model = self.codec_model.to(self.device)

# 4. Decode the output and generate audio array
audio = self.codec_decode(output)
audio = self.codec_decode(output, output_lengths)

if getattr(self, "codec_model_hook", None) is not None:
# Offload codec_model to CPU
self.codec_model_hook.offload()

if return_output_lengths:
output_lengths = [len(sample) for sample in audio]
audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
return audio, output_lengths

return audio
31 changes: 31 additions & 0 deletions tests/models/bark/test_modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,37 @@ def test_generate_end_to_end_with_args(self):
self.model.generate(**input_ids, do_sample=True, temperature=0.6, penalty_alpha=0.6)
self.model.generate(**input_ids, do_sample=True, temperature=0.6, num_beams=4)

@slow
def test_generate_batching(self):
args = {"do_sample": False, "temperature": None}

s1 = "I love HuggingFace"
s2 = "In the light of the moon, a little egg lay on a leaf"
voice_preset = "en_speaker_6"
input_ids = self.processor([s1, s2], voice_preset=voice_preset).to(torch_device)

# generate in batch
outputs, audio_lengths = self.model.generate(**input_ids, **args, return_output_lengths=True)

# generate one-by-one
s1 = self.processor(s1, voice_preset=voice_preset).to(torch_device)
s2 = self.processor(s2, voice_preset=voice_preset).to(torch_device)
output1 = self.model.generate(**s1, **args)
output2 = self.model.generate(**s2, **args)

# up until the coarse acoustic model (included), results are the same
# the fine acoustic model introduces small differences
# first verify if same length (should be the same because it's decided in the coarse model)
self.assertEqual(tuple(audio_lengths), (output1.shape[1], output2.shape[1]))

# then assert almost equal
self.assertTrue(torch.allclose(outputs[0, : audio_lengths[0]], output1.squeeze(), atol=2e-3))
self.assertTrue(torch.allclose(outputs[1, : audio_lengths[1]], output2.squeeze(), atol=2e-3))

# now test single input with return_output_lengths = True
outputs, _ = self.model.generate(**s1, **args, return_output_lengths=True)
self.assertTrue((outputs == output1).all().item())

@slow
def test_generate_end_to_end_with_sub_models_args(self):
input_ids = self.inputs
Expand Down

0 comments on commit ac5d4cf

Please sign in to comment.