Skip to content

Commit

Permalink
Pre-tokenization option is added to Tatoeba dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
bgunyel committed Dec 15, 2024
1 parent 6779531 commit 3d5e496
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 55 deletions.
168 changes: 115 additions & 53 deletions source/ml/datasets/tatoeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time

import numpy as np
import torch
from torch.utils.data import Dataset
from transformers.tokenization_utils import PreTrainedTokenizer
Expand All @@ -16,19 +17,73 @@ def __init__(self,
source_tokenizer: [PreTrainedTokenizer, PreTrainedTokenizerFast],
target_tokenizer: [PreTrainedTokenizer, PreTrainedTokenizerFast],
bos_token: str,
eos_token: str) -> None:
eos_token: str,
is_pre_tokenized: bool = False) -> None:
super().__init__()
self.is_pre_tokenized = is_pre_tokenized
self.source_sentences = source_sentences
self.target_sentences = target_sentences
self.source_language = source_language
self.target_language = target_language
self.source_tokenizer = source_tokenizer
self.target_tokenizer = target_tokenizer

(self.source_tokenizer_bos_token_id,
self.source_tokenizer_eos_token_id) = self.source_tokenizer.convert_tokens_to_ids([bos_token, eos_token])
(self.target_tokenizer_bos_token_id,
self.target_tokenizer_eos_token_id) = self.target_tokenizer.convert_tokens_to_ids([bos_token, eos_token])
(
self.source_tokenizer_bos_token_id,
self.source_tokenizer_eos_token_id,
self.source_tokenizer_pad_token_id
) = self.source_tokenizer.convert_tokens_to_ids([bos_token,
eos_token,
source_tokenizer.special_tokens_map['pad_token']])

(
self.target_tokenizer_bos_token_id,
self.target_tokenizer_eos_token_id,
self.target_tokenizer_pad_token_id
) = self.target_tokenizer.convert_tokens_to_ids([bos_token,
eos_token,
target_tokenizer.special_tokens_map['pad_token']])

self.source_ids = None
self.target_ids = None

if self.is_pre_tokenized:

source_encodings = self.source_tokenizer.batch_encode_plus(
source_sentences,
max_length=self.source_tokenizer.model_max_length,
add_special_tokens=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='np'
)

target_encodings = self.target_tokenizer.batch_encode_plus(
target_sentences,
max_length=self.target_tokenizer.model_max_length,
add_special_tokens=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='np'
)

self.source_ids = np.ones_like(source_encodings['input_ids']) * self.source_tokenizer_pad_token_id
self.source_ids[:, 0] = self.source_tokenizer_bos_token_id
self.source_ids[:, 1:] = source_encodings['input_ids'][:, :-1]
source_attention_counts = np.sum(source_encodings['attention_mask'], axis=1) + 1
self.source_ids[np.arange(len(source_attention_counts)), source_attention_counts] = self.source_tokenizer_eos_token_id

self.target_ids = np.ones_like(target_encodings['input_ids']) * self.target_tokenizer_pad_token_id
self.target_ids[:, 0] = self.target_tokenizer_bos_token_id
self.target_ids[:, 1:] = target_encodings['input_ids'][:, :-1]
target_attention_counts = np.sum(target_encodings['attention_mask'], axis=1) + 1
self.target_ids[np.arange(len(target_attention_counts)), target_attention_counts] = self.target_tokenizer_eos_token_id


dummy = -32



def get_info(self):
Expand All @@ -39,53 +94,58 @@ def __len__(self) -> int:

def __getitem__(self, idx):

source_sentence = self.source_sentences[idx]
target_sentence = self.target_sentences[idx]

source_encoding = self.source_tokenizer.encode_plus(
text=source_sentence,
max_length=self.source_tokenizer.model_max_length,
add_special_tokens=False,
padding='max_length',
truncation=True,
return_attention_mask=True
)
target_encoding = self.target_tokenizer.encode_plus(
text=target_sentence,
max_length=self.target_tokenizer.model_max_length,
add_special_tokens=False,
padding='max_length',
truncation=True,
return_attention_mask=True
)

source_attention_count = sum(source_encoding['attention_mask'])
target_attention_count = sum(target_encoding['attention_mask'])

source_ids = [0] * self.source_tokenizer.model_max_length
source_ids[0] = self.source_tokenizer_bos_token_id
source_ids[1:source_attention_count+1] = source_encoding['input_ids'][:source_attention_count]
source_ids[source_attention_count + 1] = self.source_tokenizer_eos_token_id

target_ids = [0] * self.target_tokenizer.model_max_length
target_ids[0] = self.target_tokenizer_bos_token_id
target_ids[1:target_attention_count + 1] = target_encoding['input_ids'][:target_attention_count]
target_ids[target_attention_count + 1] = self.target_tokenizer_eos_token_id

out = {
'source': {
'language': self.source_language,
'text': source_sentence,
'input_ids': torch.tensor(data = source_ids, dtype=torch.long),
'attention_count': torch.tensor(data = source_attention_count + 2, dtype=torch.long)
},
'target': {
'language': self.target_language,
'text': target_sentence,
'input_ids': torch.tensor(data = target_ids, dtype=torch.long ),
'attention_count': torch.tensor(data = target_attention_count + 2, dtype=torch.long)
if self.is_pre_tokenized:
out = {
'source': {
'input_ids': torch.tensor(data=self.source_ids[idx, :], dtype=torch.long),
},
'target': {
'input_ids': torch.tensor(data=self.target_ids[idx, :], dtype=torch.long),
}
}
else:

source_sentence = self.source_sentences[idx]
target_sentence = self.target_sentences[idx]

source_encoding = self.source_tokenizer.encode_plus(
text=source_sentence,
max_length=self.source_tokenizer.model_max_length,
add_special_tokens=False,
padding='max_length',
truncation=True,
return_attention_mask=True
)
target_encoding = self.target_tokenizer.encode_plus(
text=target_sentence,
max_length=self.target_tokenizer.model_max_length,
add_special_tokens=False,
padding='max_length',
truncation=True,
return_attention_mask=True
)

source_attention_count = sum(source_encoding['attention_mask'])
target_attention_count = sum(target_encoding['attention_mask'])

source_ids = [0] * self.source_tokenizer.model_max_length
source_ids[0] = self.source_tokenizer_bos_token_id
source_ids[1:source_attention_count+1] = source_encoding['input_ids'][:source_attention_count]
source_ids[source_attention_count + 1] = self.source_tokenizer_eos_token_id

target_ids = [0] * self.target_tokenizer.model_max_length
target_ids[0] = self.target_tokenizer_bos_token_id
target_ids[1:target_attention_count + 1] = target_encoding['input_ids'][:target_attention_count]
target_ids[target_attention_count + 1] = self.target_tokenizer_eos_token_id

out = {
'source': {
'input_ids': torch.tensor(data = source_ids, dtype=torch.long),
},
'target': {
'input_ids': torch.tensor(data = target_ids, dtype=torch.long ),
}
}
}

return out

Expand All @@ -99,7 +159,8 @@ def build_dataset(cls,
source_tokenizer: [PreTrainedTokenizer, PreTrainedTokenizerFast],
target_tokenizer: [PreTrainedTokenizer, PreTrainedTokenizerFast],
bos_token: str,
eos_token: str) -> Tatoeba:
eos_token: str,
is_pre_tokenized: bool = False) -> Tatoeba:

if dataset_split in ['valid', 'validation']:
dataset_split = 'dev'
Expand All @@ -123,4 +184,5 @@ def build_dataset(cls,
source_tokenizer=source_tokenizer,
target_tokenizer=target_tokenizer,
bos_token=bos_token,
eos_token=eos_token)
eos_token=eos_token,
is_pre_tokenized=is_pre_tokenized)
6 changes: 4 additions & 2 deletions source/ml/models/machine_translation/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def prepare_data(self):
source_tokenizer=self.source_tokenizer,
target_tokenizer=self.target_tokenizer,
bos_token=self.model_config.bos_token,
eos_token=self.model_config.eos_token)
eos_token=self.model_config.eos_token,
is_pre_tokenized=True)

valid_data = Tatoeba.build_dataset(dataset_folder=os.path.join(settings.DATA_FOLDER, 'tatoeba'),
source_language=self.train_config.source_language,
Expand All @@ -70,7 +71,8 @@ def prepare_data(self):
source_tokenizer=self.source_tokenizer,
target_tokenizer=self.target_tokenizer,
bos_token=self.model_config.bos_token,
eos_token=self.model_config.eos_token)
eos_token=self.model_config.eos_token,
is_pre_tokenized=True)

return train_data, valid_data

Expand Down

0 comments on commit 3d5e496

Please sign in to comment.