-
Notifications
You must be signed in to change notification settings - Fork 1
/
pretrain.py
87 lines (69 loc) · 2.91 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import os
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling, EsmConfig, EsmForMaskedLM, Trainer, TrainingArguments
from src.tokenization import train_tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "true"
parser = argparse.ArgumentParser(description="Pretrain a model")
parser.add_argument("--tokenizer", default="bpe", type=str, choices=["char", "bpe"], help="Tokenizer to use")
parser.add_argument("--vocab_size", default=5000, type=int, help="Vocabulary size")
parser.add_argument("--n_layers", type=int, default=6, help="Number of layers in the model")
parser.add_argument("--n_dims", type=int, default=320, help="Dimensions of the model")
parser.add_argument("--n_heads", type=int, default=16, help="Number of heads in the model")
parser.add_argument("--epochs", type=int, default=30, help="Number of epochs to train")
parser.add_argument("--batch_size", type=int, default=200, help="Batch size")
parser.add_argument("--output_dir", type=str, default="output", help="Output directory for model and checkpoints")
parser.add_argument(
"--token_output_file",
type=str,
default="output/tokenizer/tokenizer.json",
help="Where the tokenizer json file will be saved",
)
config = parser.parse_args()
## Load the dataset ##
dataset = load_dataset("text", data_files={"train": "/scratch/pretrain/swissprot.txt"})
### Train the tokenizer & Load the pretrained tokenizer ###
# You can choose the tokenizer type, default is bpe
tokenizer = train_tokenizer(
dataset=dataset,
tokenization_type=config.tokenizer,
vocab_size=config.vocab_size,
output_file_directory=config.token_output_file,
)
# Add padding token to the tokenizer if not already present
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
### Tokenize the dataset
def tokenize_function(examples: dict) -> dict:
max_length = 250 # Define the maximum length for the sequences
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
return tokens
dataset = dataset.map(tokenize_function, batched=False, num_proc=16, remove_columns=["text"])
print(dataset)
### Setup Model ###
esm_config = EsmConfig(
vocab_size=config.vocab_size,
num_hidden_layers=config.n_layers,
hidden_size=config.n_dims,
num_attention_heads=config.n_heads,
pad_token_id=tokenizer.pad_token_id,
max_position_embeddings=256,
)
model = EsmForMaskedLM(
config=esm_config,
).to("cuda")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(
output_dir=config.output_dir,
num_train_epochs=config.epochs,
per_device_train_batch_size=config.batch_size,
bf16=True,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
tokenizer=tokenizer,
train_dataset=dataset["train"],
)
trainer.train()