-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcollator.py
64 lines (50 loc) · 2.05 KB
/
collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from speaker_speech import dataset_creator
import torch
from model import model, processor
@dataclass
class TTSDataCollatorWithPadding:
processor: Any
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
label_features = [{"input_values": feature["labels"]} for feature in features]
speaker_features = [feature["speaker_embeddings"] for feature in features]
# collate the inputs and targets into a batch
batch = processor.pad(
input_ids=input_ids, labels=label_features, return_tensors="pt"
)
# replace padding with -100 to ignore loss correctly
batch["labels"] = batch["labels"].masked_fill(
batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100
)
# not used during fine-tuning
del batch["decoder_attention_mask"]
# round down target lengths to multiple of reduction factor
if model.config.reduction_factor > 1:
target_lengths = torch.tensor(
[len(feature["input_values"]) for feature in label_features]
)
target_lengths = target_lengths.new(
[
length - length % model.config.reduction_factor
for length in target_lengths
]
)
max_length = max(target_lengths)
batch["labels"] = batch["labels"][:, :max_length]
# also add in the speaker embeddings
batch["speaker_embeddings"] = torch.tensor(speaker_features)
return batch
data_collator = TTSDataCollatorWithPadding(processor=processor)
if __name__ == "__main__":
dataset = dataset_creator()
dataset = dataset.train_test_split(test_size=0.1)
features = [
dataset["train"][0],
dataset["train"][1],
dataset["train"][20],
]
batch = data_collator(features)