-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
316 lines (290 loc) · 11.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
Basic Usage:
torchrun --nproc_per_node=<N> train.py
"""
import argparse
import copy
import os
import time
from typing import Optional, Union
import datasets
import omegaconf as oc
import torch
import torch.distributed as dist
import tqdm
import transformers
import wandb
from text_sed import diffusion, layers, slurm, utils
def train(
config: oc.DictConfig,
model: torch.nn.Module,
model_ema: torch.nn.Module,
optimizer: torch.optim.Optimizer,
*,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
scaler: torch.cuda.amp.GradScaler,
tokenizer: transformers.PreTrainedTokenizer,
step_state: int = 0,
tracker: Optional["wandb.Run"] = None,
device: Union[torch.device, str] = "cuda:0",
):
# Initialize datasets
logger.info("📦 Loading dataset...")
text_datasets = {"train": datasets.load_dataset(**config.data.train_kwargs)}
if config.data.valid_kwargs:
text_datasets["valid"] = datasets.load_dataset(**config.data.valid_kwargs)
# Initialize data loaders
logger.info("📦 Loading dataloaders...")
dataloaders = {
"train": utils.text_dataloader(
dataset=text_datasets["train"],
tokenizer=tokenizer,
per_gpu_batch_size=config.train.batch_size,
max_seq_len=config.model.seq_len,
num_workers=config.data.num_preprocess_workers,
use_infinite_sampler=True,
text_attr=config.data.text_attr,
),
}
if config.data.valid_kwargs:
dataloaders["valid"] = utils.text_dataloader(
dataset=text_datasets["valid"],
tokenizer=tokenizer,
per_gpu_batch_size=config.valid.batch_size,
max_seq_len=config.model.seq_len,
num_workers=config.data.num_preprocess_workers,
use_infinite_sampler=True,
text_attr=config.data.text_attr,
)
# Initialize data iterators
train_iter = iter(dataloaders["train"])
if config.data.valid_kwargs:
valid_iter = iter(dataloaders["valid"])
logger.info("⏳ Begin model training...")
model.train()
for step in tqdm.trange(
step_state,
config.train.total_steps,
initial=step_state,
disable=not utils.is_main_process(),
):
step += 1
batch = next(train_iter)
# TODO: The `BatchSampler` + `DataLoader` prepends an extra dimension to the data.
input_ids = batch["input_ids"][0].to(device)
attention_mask = None # batch["attention_mask"][0].to(device)
with torch.amp.autocast(
device_type="cuda", dtype=utils.get_dtype(config.train.dtype)
):
loss, stats = model(input_ids, attention_mask=attention_mask)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=config.optimizer.max_grad_norm
)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
if step % config.model.ema_every == 0:
utils.ema_update(model, model_ema, config.model.ema_decay)
optimizer.zero_grad(set_to_none=True)
# Log training stats
if step % config.train.log_every == 0 and utils.is_main_process():
# Log learning across all param groups
stats["learning_rate"] = lr_scheduler.get_last_lr()[0]
tracker.log({f"train/{k}": v for k, v in stats.items()}, step=step)
info = f"🎛 Step: {step}/{config.train.total_steps} "
info += f"𑗔 Loss: {loss:.5f} "
info += f"𑗔 MSE Loss: {stats['loss_mse']:.5f} "
info += f"𑗔 Recon Loss: {stats['loss_recon']:.5f} "
info += f"𑗔 LR: {stats['learning_rate']:.6f}"
logger.info(info)
# Evaluate and log the validation stats
is_eval_step = step % config.train.eval_every == 0 and step > 0 and config.data.valid_kwargs
if is_eval_step and utils.is_main_process():
logger.info(
"📊 Evaluating... "
"WARNING: Evaluation is slow! Run evaluations on checkpoints instead."
)
model.eval()
# TODO: The `BatchSampler` + `DataLoader` prepends an extra dimension to the data.
valid_inputs = next(valid_iter)["input_ids"].to(device)[0]
with torch.no_grad():
_, valid_stats = model(valid_inputs)
tracker.log({f"valid/{k}": v for k, v in valid_stats.items()}, step=step)
model.train()
# Generate samples
is_sample_step = step % config.train.sample_every == 0
if is_sample_step and utils.is_main_process():
logger.info("💬 Generating samples...")
model_ema.eval()
shape = (
config.train.num_samples,
config.model.seq_len,
utils.default(config.model.bottleneck_dim, embed_dim),
)
start_time = time.perf_counter()
batched_tokens = model_ema.module.generate(
shape=shape,
num_steps=config.model.num_gen_steps,
sampler=diffusion.get_sampler(config.model.sampler),
time_delta=config.model.time_delta,
guide_scale=config.model.guide_scale,
use_clamp=False,
device=input_ids.device,
)
end_time = time.perf_counter()
sample_log = "💬 Generating samples..."
samples = tokenizer.batch_decode(batched_tokens, skip_special_tokens=True)
for sample in samples:
sample_log += f"\n➜ {sample}"
logger.info(sample_log)
logger.info(f"🕒 Generation took {end_time - start_time:.2f} seconds.")
model_ema.train()
# Save checkpoints
is_save_step = step % config.train.save_every == 0 and step != 0
if is_save_step and utils.is_main_process():
logger.info(f"💾 Saving checkpoint for step {step}")
checkpoint = {
"model": model.module.state_dict(),
"model_ema": model_ema.module.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"scaler": scaler.state_dict(),
"step": step,
"config": config,
}
path = os.path.join(config.output_dir, f"step_{step}.pth")
torch.save(checkpoint, path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/default.yaml")
parser.add_argument("--name", type=str)
parser.add_argument("--global_rank", type=int)
parser.add_argument("--local_rank", type=int)
parser.add_argument("--world_size", type=int)
parser.add_argument("--master_addr", type=int, default=-1)
parser.add_argument("--master_port", type=int, default=-1)
args = parser.parse_args()
slurm.init_distributed_mode(args)
slurm.init_signal_handler()
config = oc.OmegaConf.load(args.config)
if args.name is not None:
oc.OmegaConf.update(config, "name", args.name)
else:
# Add timestamp to checkpoint dir name
# TODO: There's probably a better way to do this...
oc.OmegaConf.update(config, "output_dir", f"{config.output_dir}-{utils.get_timestamp()}")
os.makedirs(config.output_dir, exist_ok=True)
if dist.is_initialized():
dist.barrier()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
# Set up logging
logger = utils.init_logger(config.output_dir)
logger.info(f"🖥 Device Count: {dist.get_world_size()}")
logger.info(f"🎚 Config: {config}")
tracker = None
if utils.is_main_process():
wandb.finish()
wandb_id = (
wandb.util.generate_id()
if config.logging.wandb_id is None
else config.logging.wandb_id
)
tracker = wandb.init(
project=config.logging.wandb_project,
entity=config.logging.wandb_entity,
name=f"{config.name}-{wandb_id}",
config=utils.flatten_dict(oc.OmegaConf.to_container(config)),
mode=utils.default(config.logging.get("wandb_mode", None), "disabled"),
id=wandb_id,
)
utils.set_seed(config.seed, use_device_specific_seeds=True)
# Initialize tokenizer
logger.info("⏳ Loading tokenizer...")
# Turn turn off HuggingFace parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.model.embed_model_name,
use_fast=config.data.use_fast_tokenizer,
use_auth_token=config.data.use_auth_token,
)
# Initialize model and optimizer
embed_mat, embed_dim = layers.auto_extract_embed_mat(config.model.embed_model_name)
inner_model = layers.MaskConditionalTransformer(
embed_dim=utils.default(config.model.bottleneck_dim, embed_dim),
model_dim=config.model.model_dim,
max_seq_len=config.model.seq_len,
head_dim=config.model.head_dim,
num_heads=config.model.num_heads,
use_abs_pos=config.model.use_abs_pos,
use_rotary=config.model.use_rotary,
)
model = diffusion.TextSed(
model=inner_model,
embed_mat=embed_mat,
noise_schedule=diffusion.get_noise_schedule(config.model.noise_schedule),
bottleneck_dim=config.model.bottleneck_dim,
mask_type=config.model.mask_type,
max_num_spans=config.model.max_num_spans,
)
optimizer = torch.optim.AdamW(
utils.get_grouped_params(
model,
config.optimizer.weight_decay,
exlcuded_modules=(
torch.nn.LayerNorm,
torch.nn.Embedding,
),
),
lr=config.optimizer.lr,
weight_decay=config.optimizer.weight_decay,
betas=tuple(config.optimizer.betas),
eps=config.optimizer.eps,
)
lr_scheduler = transformers.get_scheduler(
name=config.optimizer.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.optimizer.warmup_steps,
num_training_steps=config.train.total_steps,
)
scaler = torch.cuda.amp.GradScaler(enabled=config.train.use_amp)
logger.info(f"🏘 Inner Model: {inner_model}")
logger.info(f"👾 Parameter Count: ~{format(utils.param_count(model), ',')}")
if torch.cuda.is_available():
model.cuda()
if dist.is_initialized():
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True,
)
dist.barrier()
# Init the model EMA after DDP to avoid state dict key mismatches during updates
model_ema = copy.deepcopy(model)
# Load checkpoints if resuming training
if config.train.checkpoint_path is not None:
logger.info(f"⏳ Loading checkpoint from {config.train.checkpoint_path}")
checkpoint = torch.load(config.train.checkpoint_path)
model.module.load_state_dict(checkpoint["model"], strict=True)
model_ema.module.load_state_dict(checkpoint["model_ema"], strict=True)
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
step_state = checkpoint["step"]
else:
step_state = 0
logger.info("🏁 Starting training...")
train(
config,
model,
model_ema,
optimizer,
lr_scheduler=lr_scheduler,
scaler=scaler,
tokenizer=tokenizer,
step_state=step_state,
tracker=tracker,
)