-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
76 lines (62 loc) · 2.55 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, Trainer
import torch
from torch.utils.data import Dataset
import os
class ArticlesDataset(Dataset):
def __init__(self, tokenizer, folder_path, max_length):
self.tokenizer = tokenizer
self.input_ids = []
self.attn_masks = []
self.labels = []
# Read each file in the folder
for filename in os.listdir(folder_path):
if filename.endswith('.txt'): # Check if the file is a text file
with open(os.path.join(folder_path, filename), 'r') as file:
lines = file.readlines()
for line in lines:
abstract = line.strip() # Remove leading/trailing whitespace
encodings_dict = tokenizer(abstract, truncation=True, max_length=max_length, padding="max_length")
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
# For language modeling, labels are usually the input IDs themselves
self.labels.append(torch.tensor(encodings_dict['input_ids']))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return {
'input_ids': self.input_ids[idx],
'attention_mask': self.attn_masks[idx],
'labels': self.labels[idx] # Adding labels
}
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
# Initialize dataset with the folder path
folder_path = 'data' # Path to the folder containing the text files
dataset = ArticlesDataset(tokenizer, folder_path, max_length=512)
# Define training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=10,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
# Train the model
trainer.train()
# Save the model
model_save_path = "./models/3_trained_gpt2_model"
model.save_pretrained(model_save_path)
# Save the tokenizer
tokenizer_save_path = "./models/3_trained_gpt2_tokenizer"
tokenizer.save_pretrained(tokenizer_save_path)
print(f"Model saved to {model_save_path}")
print(f"Tokenizer saved to {tokenizer_save_path}")