Skip to content

Commit

Permalink
update minigpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Elvin-Ma committed Jul 2, 2023
1 parent 5d804b6 commit 6e6622e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
10 changes: 10 additions & 0 deletions 14-model_learning/miniGPT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# install
```shell
git clone https://github.com/karpathy/minGPT.git
#commit id: 37baab71b9abea1b76ab957409a1cc2fbfba8a26
cd minGPT
pip install -e .
```

# 执行
python chatgpt_demo.py
100 changes: 100 additions & 0 deletions 14-model_learning/miniGPT/chatgpt_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
import pickle
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

from mingpt.utils import set_seed
from mingpt.model import GPT
from mingpt.trainer import Trainer

set_seed(3407)

class SortDataset(Dataset):
"""
Dataset for the Sort problem. E.g. for problem length 6:
Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
Which will feed into the transformer concatenated as:
input: 0 0 2 1 0 1 0 0 0 1 1
output: I I I I I 0 0 0 1 1 2
where I is "ignore", as the transformer is reading the input sequence
"""

def __init__(self, split, length=6, num_digits=3):
assert split in {'train', 'test'}
self.split = split
self.length = length
self.num_digits = num_digits

def __len__(self):
return 10000 # ...

def get_vocab_size(self):
return self.num_digits

def get_block_size(self):
# the length of the sequence that will feed into transformer,
# containing concatenated input and the output, but -1 because
# the transformer starts making predictions at the last input element
return self.length * 2 - 1

def __getitem__(self, idx):

# use rejection sampling to generate an input example from the desired split
while True:
# generate some random integers
inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
# half of the time let's try to boost the number of examples that
# have a large number of repeats, as this is what the model seems to struggle
# with later in training, and they are kind of rate
if torch.rand(1).item() < 0.5:
if inp.unique().nelement() > self.length // 2:
# too many unqiue digits, re-sample
continue
# figure out if this generated example is train or test based on its hash
h = hash(pickle.dumps(inp.tolist()))
inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
if inp_split == self.split:
break # ok

# solve the task: i.e. sort
sol = torch.sort(inp)[0]

# concatenate the problem specification and the solution
cat = torch.cat((inp, sol), dim=0)

# the inputs to the transformer will be the offset sequence
x = cat[:-1].clone()
y = cat[1:].clone()
# we only want to predict at output locations, mask out the loss at the input locations
y[:self.length-1] = -1
return x, y

if __name__ == "__main__":

train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
# x, y = train_dataset[0]
# for a, b in zip(x,y):
# print(int(a),int(b))

# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT(model_config)

# train config
train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

def batch_end_callback(trainer):
if trainer.iter_num % 100 == 0:
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)

trainer.run()
print("chatgpt_demo")

0 comments on commit 6e6622e

Please sign in to comment.