From 6e6622e907954f28d39b63d797178affb4a7023c Mon Sep 17 00:00:00 2001 From: Elvin-Ma Date: Sun, 2 Jul 2023 13:11:12 +0800 Subject: [PATCH] update minigpt --- 14-model_learning/miniGPT/README.md | 10 +++ 14-model_learning/miniGPT/chatgpt_demo.py | 100 ++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 14-model_learning/miniGPT/README.md create mode 100644 14-model_learning/miniGPT/chatgpt_demo.py diff --git a/14-model_learning/miniGPT/README.md b/14-model_learning/miniGPT/README.md new file mode 100644 index 0000000..54b10e8 --- /dev/null +++ b/14-model_learning/miniGPT/README.md @@ -0,0 +1,10 @@ +# install +```shell +git clone https://github.com/karpathy/minGPT.git +#commit id: 37baab71b9abea1b76ab957409a1cc2fbfba8a26 +cd minGPT +pip install -e . +``` + +# 执行 +python chatgpt_demo.py diff --git a/14-model_learning/miniGPT/chatgpt_demo.py b/14-model_learning/miniGPT/chatgpt_demo.py new file mode 100644 index 0000000..8cb0ed2 --- /dev/null +++ b/14-model_learning/miniGPT/chatgpt_demo.py @@ -0,0 +1,100 @@ +import torch +import pickle +from torch.utils.data import Dataset +from torch.utils.data.dataloader import DataLoader + +from mingpt.utils import set_seed +from mingpt.model import GPT +from mingpt.trainer import Trainer + +set_seed(3407) + +class SortDataset(Dataset): + """ + Dataset for the Sort problem. E.g. for problem length 6: + Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2 + Which will feed into the transformer concatenated as: + input: 0 0 2 1 0 1 0 0 0 1 1 + output: I I I I I 0 0 0 1 1 2 + where I is "ignore", as the transformer is reading the input sequence + """ + + def __init__(self, split, length=6, num_digits=3): + assert split in {'train', 'test'} + self.split = split + self.length = length + self.num_digits = num_digits + + def __len__(self): + return 10000 # ... + + def get_vocab_size(self): + return self.num_digits + + def get_block_size(self): + # the length of the sequence that will feed into transformer, + # containing concatenated input and the output, but -1 because + # the transformer starts making predictions at the last input element + return self.length * 2 - 1 + + def __getitem__(self, idx): + + # use rejection sampling to generate an input example from the desired split + while True: + # generate some random integers + inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long) + # half of the time let's try to boost the number of examples that + # have a large number of repeats, as this is what the model seems to struggle + # with later in training, and they are kind of rate + if torch.rand(1).item() < 0.5: + if inp.unique().nelement() > self.length // 2: + # too many unqiue digits, re-sample + continue + # figure out if this generated example is train or test based on its hash + h = hash(pickle.dumps(inp.tolist())) + inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test + if inp_split == self.split: + break # ok + + # solve the task: i.e. sort + sol = torch.sort(inp)[0] + + # concatenate the problem specification and the solution + cat = torch.cat((inp, sol), dim=0) + + # the inputs to the transformer will be the offset sequence + x = cat[:-1].clone() + y = cat[1:].clone() + # we only want to predict at output locations, mask out the loss at the input locations + y[:self.length-1] = -1 + return x, y + +if __name__ == "__main__": + + train_dataset = SortDataset('train') + test_dataset = SortDataset('test') + # x, y = train_dataset[0] + # for a, b in zip(x,y): + # print(int(a),int(b)) + + # create a GPT instance + model_config = GPT.get_default_config() + model_config.model_type = 'gpt-nano' + model_config.vocab_size = train_dataset.get_vocab_size() + model_config.block_size = train_dataset.get_block_size() + model = GPT(model_config) + + # train config + train_config = Trainer.get_default_config() + train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster + train_config.max_iters = 2000 + train_config.num_workers = 0 + trainer = Trainer(train_config, model, train_dataset) + + def batch_end_callback(trainer): + if trainer.iter_num % 100 == 0: + print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") + trainer.set_callback('on_batch_end', batch_end_callback) + + trainer.run() + print("chatgpt_demo") \ No newline at end of file