forked from Pawandeep-prog/finetuned-gpt2-convai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
62 lines (47 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from ChatData import ChatData
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm
import torch
def train(chatData, model, optim):
epochs = 12
for i in tqdm.tqdm(range(epochs)):
for X, a in chatData:
X = X.to(device)
a = a.to(device)
optim.zero_grad()
loss = model(X, attention_mask=a, labels=X).loss
loss.backward()
optim.step()
torch.save(model.state_dict(), "model_state.pt")
print(infer("hello how are you"))
def infer(inp):
inp = "<startofstring> "+inp+" <bot>: "
inp = tokenizer(inp, return_tensors="pt")
X = inp["input_ids"].to(device)
a = inp["attention_mask"].to(device)
output = model.generate(X, attention_mask=a )
output = tokenizer.decode(output[0])
return output
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "<pad>",
"bos_token": "<startofstring>",
"eos_token": "<endofstring>"})
tokenizer.add_tokens(["<bot>:"])
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
# print(tokenizer.decode(model.generate(**tokenizer("hey i was good at basketball but ",
# return_tensors="pt"))[0]))
chatData = ChatData("./chat_data.json", tokenizer)
chatData = DataLoader(chatData, batch_size=64)
model.train()
optim = Adam(model.parameters(), lr=1e-3)
print("training .... ")
train(chatData, model, optim)
print("infer from model : ")
while True:
inp = input()
print(infer(inp))