forked from karpathy/minGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
116 lines (97 loc) · 3.91 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""
import time
from collections import defaultdict
import torch
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import CfgNode as CN
class Trainer:
@staticmethod
def get_default_config():
C = CN()
# device to train on
C.device = 'auto'
# dataloder parameters
C.num_workers = 4
# optimizer parameters
C.max_iters = None
C.batch_size = 64
C.learning_rate = 3e-4
C.betas = (0.9, 0.95)
C.weight_decay = 0.1 # only applied on matmul weights
C.grad_norm_clip = 1.0
return C
def __init__(self, config, model, train_dataset):
self.config = config
self.model = model
self.optimizer = None
self.train_dataset = train_dataset
self.callbacks = defaultdict(list)
# determine the device we'll train on
if config.device == 'auto':
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = config.device
self.model = self.model.to(self.device)
print("running on device", self.device)
# variables that will be assigned to trainer class later for logging and etc
self.iter_num = 0
self.iter_time = 0.0
self.iter_dt = 0.0
def add_callback(self, onevent: str, callback):
# Add a new callback function to be triggered on a specific event.
# 'onevent' is a string identifying the event (e.g., 'on_batch_end').
# 'callback' is the function to be called when the event occurs.
self.callbacks[onevent].append(callback)
def set_callback(self, onevent: str, callback):
# Set (replace) the callback function for a specific event.
# This replaces any existing callbacks for the event with the new callback.
self.callbacks[onevent] = [callback]
def trigger_callbacks(self, onevent: str):
# Trigger all callback functions associated with a specific event.
# This method is called internally to execute callbacks at appropriate times.
for callback in self.callbacks.get(onevent, []):
callback(self)
def run(self):
model, config = self.model, self.config
# setup the optimizer
self.optimizer = model.configure_optimizers(config)
# setup the dataloader
train_loader = DataLoader(
self.train_dataset,
sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
shuffle=False,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers,
)
model.train()
self.iter_num = 0
self.iter_time = time.time()
data_iter = iter(train_loader)
while True:
# fetch the next batch (x, y) and re-init iterator if needed
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
batch = next(data_iter)
batch = [t.to(self.device) for t in batch]
x, y = batch
# forward the model
logits, self.loss = model(x, y)
# backprop and update the parameters
model.zero_grad(set_to_none=True)
self.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
self.optimizer.step()
self.trigger_callbacks('on_batch_end')
self.iter_num += 1
tnow = time.time()
self.iter_dt = tnow - self.iter_time
self.iter_time = tnow
# termination conditions
if config.max_iters is not None and self.iter_num >= config.max_iters:
break