forked from THUDM/GLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_utils.py
388 lines (342 loc) · 16.8 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import deepspeed
import torch
#from apex.optimizers import FusedAdam as Adam
from torch.optim import AdamW as Adam
from torch import distributed as dist
import mpu
from fp16 import FP16_Module, FP16_Optimizer, DynamicLossScaler
from learning_rates import AnnealingLR
from model import GLMModel, glm_get_params_for_weight_decay_optimization
from model import GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, GLMForSingleTokenCloze, GLMForSequenceClassification
from model import PyTorchDistributedDataParallel as TorchDDP, DistributedDataParallel as LocalDDP
from model.modeling_bert import BertForMultipleChoice, BertForSequenceClassification
from utils import print_rank_0, get_checkpoint_name, get_checkpoint_iteration
def load_pretrained(model, checkpoint_path, args, task_tokens=None):
load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path)
checkpoint_name = get_checkpoint_name(load_dir, tag, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading pretrained model {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu')
if args.deepspeed:
model = model.module
if isinstance(model, TorchDDP):
model = model.module
if isinstance(model, FP16_Module):
model = model.module
if hasattr(model, "model"):
model = model.model
# Model.
def extend_embedding_weights(state_weights, model_weights):
original_length = state_weights.shape[0]
assert original_length <= args.max_position_embeddings + 1
new_weights = model_weights.clone()
new_weights[:original_length] = state_weights
return new_weights
if args.block_lm:
if "transformer.block_position_embeddings.weight" in sd["module"]:
position_weights = sd['module']["transformer.position_embeddings.weight"]
if args.max_position_embeddings + 1 > position_weights.shape[0]:
sd['module']["transformer.position_embeddings.weight"] = extend_embedding_weights(
position_weights, model.state_dict()["transformer.position_embeddings.weight"].data)
print_rank_0(f"Extend position embedding to {args.max_position_embeddings + 1}")
if "transformer.block_position_embeddings.weight" in sd["module"]:
block_position_weights = sd['module']["transformer.block_position_embeddings.weight"]
if args.max_position_embeddings + 1 > block_position_weights.shape[0]:
sd['module']["transformer.block_position_embeddings.weight"] = extend_embedding_weights(
block_position_weights,
model.state_dict()["transformer.block_position_embeddings.weight"].data)
print_rank_0(f"Extend block position embedding to {args.max_position_embeddings + 1}")
missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
if missing_keys or unexpected_keys:
print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")
if args.continuous_prompt and args.prompt_init:
model.prompt_spell.init_embedding(model.word_embeddings.weight.data, task_tokens)
def get_model(args, model_type=None, multi_token=True, num_labels=None, spell_length=None):
"""Build the model."""
print_rank_0('building GLM model ...')
if args.pretrained_bert:
if model_type == "multiple_choice":
model = BertForMultipleChoice.from_pretrained(args.tokenizer_model_type,
cache_dir=args.cache_dir,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
layernorm_epsilon=args.layernorm_epsilon)
elif model_type == "classification":
model = BertForSequenceClassification.from_pretrained(args.tokenizer_model_type,
cache_dir=args.cache_dir,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
layernorm_epsilon=args.layernorm_epsilon,
num_labels=num_labels)
else:
raise NotImplementedError
else:
output_predict, paralle_output = True, True
if (model_type == "multiple_choice" or model_type == "classification") and not args.cloze_eval:
output_predict = False
if model_type is not None:
paralle_output = False
if spell_length is not None:
print_rank_0(f"Continuous spell length {spell_length}")
model = GLMModel(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
max_memory_length=args.mem_length,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=paralle_output,
relative_encoding=args.transformer_xl,
block_position_encoding=args.block_lm and not args.masked_lm,
output_predict=output_predict,
spell_length=spell_length,
spell_func=args.prompt_func,
attention_scale=args.attention_scale)
if args.freeze_transformer:
model.freeze_transformer(tune_prefix_layers=args.tune_prefix_layers)
if model_type is not None:
if model_type == 'multiple_choice':
if args.cloze_eval:
if multi_token:
if args.fast_decode:
model = GLMForMultiTokenClozeFast(model, length_penalty=args.length_penalty)
else:
model = GLMForMultiTokenCloze(model, length_penalty=args.length_penalty)
else:
model = GLMForSingleTokenCloze(model, take_softmax=args.adapet)
else:
model = GLMForSequenceClassification(model, args.hidden_size, args.output_dropout, args.pool_token,
num_class=num_labels)
elif model_type == 'classification':
model = GLMForSequenceClassification(model, args.hidden_size, args.output_dropout, args.pool_token,
num_class=num_labels)
elif model_type == 'generation':
pass
else:
raise NotImplementedError(model_type)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# To prevent OOM for model sizes that cannot fit in GPU memory in full precision
if args.fp16:
model.half()
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
if not args.deepspeed and (args.train_iters or args.epochs):
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = TorchDDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
elif args.DDP_impl == 'local':
model = LocalDDP(model)
else:
print_rank_0("Skip DDP model")
return model
def get_optimizer_param_groups(model):
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)):
model = model.module
param_groups = glm_get_params_for_weight_decay_optimization(model)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
# print('## param_group', len(param_group['params']))
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
return param_groups
def get_optimizer(param_groups, args):
"""Set up the optimizer."""
if args.cpu_optimizer:
# Apex FusedAdam uses decoupled weight decay so use the same here
if args.cpu_torch_adam:
cpu_adam_optimizer = torch.optim.AdamW
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
cpu_adam_optimizer = DeepSpeedCPUAdam
optimizer = cpu_adam_optimizer(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
else:
# Use FusedAdam.
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
elif args.optimizer == 'adafactor':
from transformers import Adafactor
optimizer = Adafactor(param_groups, lr=args.lr, relative_step=False, warmup_init=False)
else:
raise NotImplementedError
print(f'Optimizer = {optimizer.__class__.__name__}')
if hasattr(args, "deepspeed") and args.deepspeed:
raise NotImplementedError
# fp16 wrapper is not required for DeepSpeed.
# return optimizer
# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale': args.min_scale,
'delayed_shift': args.hysteresis})
return optimizer
def get_learning_rate_scheduler(optimizer, args):
"""Build the learning rate scheduler."""
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters
if args.finetune:
num_iters = num_iters // args.gradient_accumulation_steps
num_iters = max(1, num_iters)
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters - warmup_iter,
decay_style=args.lr_decay_style,
last_iter=init_step,
decay_ratio=args.lr_decay_ratio)
return lr_scheduler
def setup_model_and_optimizer(args, model_type=None, multi_token=True, num_labels=None, spell_length=None):
"""Setup model and optimizer."""
model = get_model(args, model_type=model_type, multi_token=multi_token, num_labels=num_labels,
spell_length=spell_length)
param_groups = get_optimizer_param_groups(model)
if args.train_data is not None or args.data_dir is not None and (args.epochs > 0 or args.train_iters > 0):
if args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
model, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=param_groups,
args=args,
mpu=mpu,
dist_init_required=False
)
else:
optimizer = get_optimizer(param_groups, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
else:
optimizer, lr_scheduler = None, None
return model, optimizer, lr_scheduler
def backward_step(optimizer, model, lm_loss, args, timers):
"""Backward step."""
# Total loss.
loss = lm_loss
# Backward pass.
if args.deepspeed:
model.backward(loss)
else:
# optimizer.zero_grad()
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
loss.backward()
if args.deepspeed or args.DDP_impl == 'torch':
# DeepSpeed backward propagation already addressed all reduce communication.
# Reset the timer to avoid breaking timer logs below.
timers('allreduce').reset()
else:
timers('allreduce').start()
model.allreduce_params(reduce_after=False, fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
# Update master gradients.
if not args.deepspeed:
if args.fp16:
optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
return lm_loss
def see_memory_usage(message, force=False):
if not force:
return
dist.barrier()
if dist.get_rank() == 0:
print(message)
print("Memory Allocated ", torch.cuda.memory_allocated() / (1024 * 1024 * 1024), "GigaBytes")
print("Max Memory Allocated ", torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), "GigaBytes")
print("Cache Allocated ", torch.cuda.memory_cached() / (1024 * 1024 * 1024), "GigaBytes")
print("Max cache Allocated ", torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), "GigaBytes")
print(" ")
# input("Press Any Key To Continue ..")
def train_step(data_iterator, model, optimizer, lr_scheduler, args, timers, forward_step_func, mems=None,
single_step=False):
"""Single training step."""
lm_loss_total, count = 0.0, 0
mems = [] if mems is None else mems
if not args.deepspeed:
optimizer.zero_grad()
while True:
skipped_iter, complete = 0, False
# Forward model for one step.
timers('forward').start()
lm_loss, mems, _ = forward_step_func(data_iterator, model, args, timers, mems)
timers('forward').stop()
# print_rank_0("Forward step")
if not args.deepspeed:
lm_loss /= args.gradient_accumulation_steps
reduced_loss = lm_loss.detach().clone().view(1)
torch.distributed.all_reduce(reduced_loss.data, group=mpu.get_data_parallel_group())
reduced_loss.data = reduced_loss.data / (args.world_size / args.model_parallel_size)
if not DynamicLossScaler._has_inf_or_nan(reduced_loss):
lm_loss_total += reduced_loss
count += 1
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, lm_loss, args, timers)
timers('backward').stop()
# print_rank_0("Backward step")
# Update parameters.
timers('optimizer').start()
if args.deepspeed:
if model.is_gradient_accumulation_boundary():
model.step()
complete = True
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
else:
model.step()
else:
if count == args.gradient_accumulation_steps:
optimizer.step()
complete = True
# Update learning rate.
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
# print_rank_0("Optimizer step")
timers('optimizer').stop()
if complete:
break
else:
print_rank_0("Found NaN loss, skip backward")
del lm_loss, reduced_loss
mems = []
if single_step:
break
if args.deepspeed:
lm_loss_total = lm_loss_total / count
return lm_loss_total, skipped_iter, mems