forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
116 lines (94 loc) · 3.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import fire
from config import GPTConfig
from dataset import load_data
def train(**kwargs):
# --- CONFIG ---
config = GPTConfig(**kwargs)
# Imports here to allow command args to set the backend
os.environ["KERAS_BACKEND"] = config.backend
import keras as K
from model import GPT
from callback import AddLRCallback, EvaluateCallback, WandbCallback
# --- TRAINING GLOBALS ---
if config.do_mixed_precision:
dtype_to_mixed = {
"float16": "mixed_float16",
"bfloat16": "mixed_bfloat16",
}
K.mixed_precision.set_global_policy(dtype_to_mixed[config.mixed_precision_dtype])
if config.fixed_seed:
import tensorflow as tf
K.utils.set_random_seed(1337)
tf.config.experimental.enable_op_determinism()
# --- WANDB ---
if config.do_wandb:
import wandb
wandb.init(project=config.wandb_project, name=config.wandb_run_name, config=config)
# --- LOAD DATA ---
train_dataset, val_dataset, n_step_train, n_step_val = load_data(config)
# --- LOAD MODEL ---
model = GPT(config)
# --- PREPARE TRAINING ---
total_steps = n_step_train * config.n_epoch
warmup_steps = max(int(total_steps * config.warmup_ratio), 1)
decay_steps = total_steps - warmup_steps
print(f"Epoch steps: {n_step_train}. Total steps: {n_step_train * config.n_epoch}. "
f"Warmup steps: {warmup_steps}. Decay steps: {decay_steps}.")
tok_per_step = config.batch_size * config.block_size
print(f"Step tokens: {tok_per_step}. Epoch tokens: {tok_per_step * n_step_train}. "
f"Total tokens {tok_per_step * total_steps}")
if config.do_lr_decay:
init_lr = config.lr / warmup_steps
learning_rate = K.optimizers.schedules.CosineDecay(
initial_learning_rate=init_lr,
warmup_target=config.lr,
warmup_steps=warmup_steps,
alpha=config.min_lr,
decay_steps=decay_steps
)
else:
learning_rate = config.lr
optimizer = K.optimizers.AdamW(learning_rate=learning_rate,
weight_decay=config.weight_decay,
beta_1=config.beta1,
beta_2=config.beta2,
global_clipnorm=config.grad_clip)
optimizer.exclude_from_weight_decay(model.get_list_exclude_from_weight_decay())
model.build(input_shape=(config.batch_size, config.block_size))
model.compile(
optimizer=optimizer,
loss=K.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[K.metrics.SparseCategoricalAccuracy(name='acc')],
jit_compile=True
)
if config.verbose > 10:
model.summary()
my_callbacks = []
if config.backend != "jax":
my_callbacks.append(AddLRCallback(optimizer)) # Workaround. Always 0 for jax
if config.do_eval_every > 0:
my_callbacks.append(EvaluateCallback(config, val_dataset, n_step_val))
if config.do_wandb:
my_callbacks.append(WandbCallback(n_step_train))
# --- TRAIN ---
history = model.fit(
train_dataset,
steps_per_epoch=n_step_train,
epochs=config.n_epoch,
validation_data=val_dataset if config.do_eval_epoch else None,
validation_steps=n_step_val if config.do_eval_epoch else None,
callbacks=[my_callbacks],
verbose=1
)
# print(history.history)
# print(model.evaluate(val_dataset, steps=n_batch_val))
# --- SAVE ---
if config.do_save_model:
os.makedirs(config.out_dir, exist_ok=True)
model.save(os.path.join(config.out_dir, f"{config.out_name}.keras"))
return model, history, config
def main(**kwargs): # Fire function cannot return anything.
train(**kwargs) # I do this to make train() return the model (eg. for when it's run in a notebook)
if __name__ == "__main__":
fire.Fire(main)