forked from Elvin-Ma/pytorch_guide
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# install | ||
```shell | ||
git clone https://github.com/karpathy/minGPT.git | ||
#commit id: 37baab71b9abea1b76ab957409a1cc2fbfba8a26 | ||
cd minGPT | ||
pip install -e . | ||
``` | ||
|
||
# 执行 | ||
python chatgpt_demo.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
import pickle | ||
from torch.utils.data import Dataset | ||
from torch.utils.data.dataloader import DataLoader | ||
|
||
from mingpt.utils import set_seed | ||
from mingpt.model import GPT | ||
from mingpt.trainer import Trainer | ||
|
||
set_seed(3407) | ||
|
||
class SortDataset(Dataset): | ||
""" | ||
Dataset for the Sort problem. E.g. for problem length 6: | ||
Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2 | ||
Which will feed into the transformer concatenated as: | ||
input: 0 0 2 1 0 1 0 0 0 1 1 | ||
output: I I I I I 0 0 0 1 1 2 | ||
where I is "ignore", as the transformer is reading the input sequence | ||
""" | ||
|
||
def __init__(self, split, length=6, num_digits=3): | ||
assert split in {'train', 'test'} | ||
self.split = split | ||
self.length = length | ||
self.num_digits = num_digits | ||
|
||
def __len__(self): | ||
return 10000 # ... | ||
|
||
def get_vocab_size(self): | ||
return self.num_digits | ||
|
||
def get_block_size(self): | ||
# the length of the sequence that will feed into transformer, | ||
# containing concatenated input and the output, but -1 because | ||
# the transformer starts making predictions at the last input element | ||
return self.length * 2 - 1 | ||
|
||
def __getitem__(self, idx): | ||
|
||
# use rejection sampling to generate an input example from the desired split | ||
while True: | ||
# generate some random integers | ||
inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long) | ||
# half of the time let's try to boost the number of examples that | ||
# have a large number of repeats, as this is what the model seems to struggle | ||
# with later in training, and they are kind of rate | ||
if torch.rand(1).item() < 0.5: | ||
if inp.unique().nelement() > self.length // 2: | ||
# too many unqiue digits, re-sample | ||
continue | ||
# figure out if this generated example is train or test based on its hash | ||
h = hash(pickle.dumps(inp.tolist())) | ||
inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test | ||
if inp_split == self.split: | ||
break # ok | ||
|
||
# solve the task: i.e. sort | ||
sol = torch.sort(inp)[0] | ||
|
||
# concatenate the problem specification and the solution | ||
cat = torch.cat((inp, sol), dim=0) | ||
|
||
# the inputs to the transformer will be the offset sequence | ||
x = cat[:-1].clone() | ||
y = cat[1:].clone() | ||
# we only want to predict at output locations, mask out the loss at the input locations | ||
y[:self.length-1] = -1 | ||
return x, y | ||
|
||
if __name__ == "__main__": | ||
|
||
train_dataset = SortDataset('train') | ||
test_dataset = SortDataset('test') | ||
# x, y = train_dataset[0] | ||
# for a, b in zip(x,y): | ||
# print(int(a),int(b)) | ||
|
||
# create a GPT instance | ||
model_config = GPT.get_default_config() | ||
model_config.model_type = 'gpt-nano' | ||
model_config.vocab_size = train_dataset.get_vocab_size() | ||
model_config.block_size = train_dataset.get_block_size() | ||
model = GPT(model_config) | ||
|
||
# train config | ||
train_config = Trainer.get_default_config() | ||
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster | ||
train_config.max_iters = 2000 | ||
train_config.num_workers = 0 | ||
trainer = Trainer(train_config, model, train_dataset) | ||
|
||
def batch_end_callback(trainer): | ||
if trainer.iter_num % 100 == 0: | ||
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") | ||
trainer.set_callback('on_batch_end', batch_end_callback) | ||
|
||
trainer.run() | ||
print("chatgpt_demo") |