-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmain_training.py
958 lines (797 loc) · 31.4 KB
/
main_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
import argparse
import os
import time
import colorama
import torch
import torch
import torch.nn as nn
from colorama import Fore
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
StateDictType,
)
import model_checkpointing
import torch.distributed as dist
import environment
from contextlib import contextmanager
bf16_ready = environment.verify_bfloat_support
from torch.utils.data import DistributedSampler
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
colorama.init(autoreset=True) # reset after every line
import performance
import contextlib
_none_context = contextlib.nullcontext()
# add DDP support
from torch.nn.parallel import DistributedDataParallel as DDP
# import optimizers
@contextmanager
def init_empty_weights(include_buffers: bool = False):
"""
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
empty model. Useful when just initializing the model would blow the available RAM.
Args:
include_buffers (`bool`, *optional*, defaults to `False`):
Whether or not to also put all buffers on the meta device while initializing.
Example:
```pyton
import torch.nn as nn
from accelerate import init_empty_weights
# Initialize a model with 100 billions parameters in no time and without using any RAM.
with init_empty_weights():
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
```
<Tip warning={true}>
Any model created under this context manager has no weights. As such you can't do something like
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
</Tip>
"""
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)
def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
@torch.no_grad()
def my_init_fn(module: nn.Module):
for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(
torch.empty_like(param, device=torch.device("cuda"))
)
# nn.init.uniform_(materialized_param)
setattr(submodule, param_name, materialized_param)
def print_model(model, file_name, rank):
if rank != 0:
return
fn = file_name
with open(fn, "w") as external_file:
print(f"model wrapping = \n{model}\n\n", file=external_file)
external_file.close()
def print_memory_summary(prefix, device):
rank = int(os.getenv("RANK"))
if rank == 0:
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0)
print(
f"{prefix}, GPU peak memory allocation: {torch.cuda.max_memory_allocated(device) // 1e9}GB, "
f"GPU peak memory reserved: {torch.cuda.max_memory_reserved(device) // 1e9}GB, "
f"GPU peak memory active: {peak_memory_active // 1e9}GB"
)
torch.cuda.reset_peak_memory_stats(device)
def setup():
"""we use torchrun for init so no params needed here"""
dist.init_process_group("nccl")
def setup_environ_flags(cfg, rank):
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
if cfg.nccl_debug_handler:
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
if cfg.distributed_debug:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
if rank == 0:
print(f"--> running with torch dist debug set to detail")
def cleanup():
dist.barrier()
dist.destroy_process_group()
def clear_gpu_cache(rank=None):
if rank == 0:
print(f"clearing gpu cache for all ranks")
torch.cuda.empty_cache()
def setup_tasks(rank, world_size, cfg):
"""keep the basic setup list here"""
setup()
clear_gpu_cache(rank) # need to call torch set device first?
# set_printing()
setup_environ_flags(cfg, rank)
def zero_print(rank, x):
if rank == 0:
print(x)
# ------ main code loop -----------------
def fsdp_main():
"""main process, within each rank process"""
cfg = config.train_config() # loads from defaults
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# wrapper to avoid cluttering with if rank==0...
def rank_print(msg):
if rank == 0:
print(f"{msg}")
torch.cuda.manual_seed(cfg.seed + local_rank)
torch.manual_seed(cfg.seed + local_rank)
if rank == 0:
print(f"--> World Size = {world_size}\n")
print(f"--> Device_count = {torch.cuda.device_count()}")
print(f"--> running with these defaults {cfg}")
# time_of_run = get_date_of_run()
setup_tasks(rank, world_size, cfg)
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
from functools import partial
_zero_print = partial(zero_print, local_rank)
# setup memory tracking for perf
if local_rank == 0:
memmax = performance.Memory_Maximizer()
else:
memmax = None
# ==== use new transformer wrapper
my_auto_wrap_policy = config.get_policy()
if rank == 0:
print(f"wrapping policy is {my_auto_wrap_policy}")
use_pokemon = False
use_beans = False
use_food = False
use_label_singular = False
# todo - clean this up...temp bridge for testing pokemon dataset
if cfg.use_synthetic_data == False:
use_pokemon = False
use_beans = False
use_food = False
try:
use_pokemon = cfg.use_pokemon_dataset
use_beans = cfg.use_beans_dataset
use_food = cfg.use_food
except:
print(f"pokemon nor beans set not enabled")
pass
val_dataset = None
_stats = None
if use_pokemon:
dataset, val_dataset = config.get_pokemon_dataset()
elif use_beans:
assert not use_food and not use_pokemon, f"multiple datasets enabled."
dataset, val_dataset = config.get_beans_dataset()
elif use_food:
assert not use_beans and not use_pokemon, f"multiple datasets enabled."
dataset, val_dataset = config.get_universal_dataset()
use_label_singular = True
else:
dataset = config.get_dataset()
if not cfg.use_synthetic_data:
if rank == 0:
import collections
_stats = collections.defaultdict(list)
_stats["best_accuracy"] = 0.00
# samplers ----
train_sampler = DistributedSampler(
dataset, rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=True
)
if cfg.run_validation:
if not val_dataset:
val_dataset = config.get_dataset() # train=False)
val_sampler = DistributedSampler(
val_dataset, rank=dist.get_rank(), num_replicas=dist.get_world_size()
)
if local_rank == 0:
print(f"\n--> Prepping {cfg.model_name} model ...\n")
print(f"stats is ready....? {_stats=}, {local_rank=}, {rank=}")
# --- build model
use_timm = False
try:
use_timm = cfg.use_timm
except:
pass # means older config w/o timm support flag
if not use_timm:
print("***** building the model ******")
use_deferred_init = False
try:
use_deferred_init = cfg.use_deferred_init
except:
pass
with init_empty_weights() if cfg.use_deferred_init else _none_context:
_zero_print(f"using deferred? {use_deferred_init}")
use_parallel = False
use_upper_fusion = False
use_fused_attention = cfg.use_fused_attention
use_mqa = False
try:
use_parallel = cfg.use_parallel_attention
# use_upper_fusion = cfg.use_upper_fusion
use_mqa = cfg.use_multi_query_attention
print(f"**** Use MQA = {use_mqa}")
except:
# TODO - make this error appropriate per model ...print(f"failed to load pattn blocks params!")
pass
if use_parallel:
model = config.build_model(
cfg.model_name,
use_parallel_attention=use_parallel,
# use_upper_fusion=use_upper_fusion,
use_fused_attention=use_fused_attention,
use_multi_query_attention=use_mqa,
)
else:
model = config.build_model(
cfg.model_name,
use_parallel_attention=False,
use_fused_attention=use_fused_attention,
)
print_memory_summary("vit", "cuda")
time.sleep(2)
# TODO - we used to run HF checkpointing generically...adding this for now.
if cfg.hf_t5_checkpointing:
model.decoder.gradient_checkpointing = True
model.encoder.gradient_checkpointing = True
elif use_timm:
# if you are here and this import fails - run:
# git clone https://github.com/huggingface/pytorch-image-models.git
# and then in the cloned main dir, run 'python setup.py develop'
import timm
import torch.nn as nn
model = timm.create_model(
cfg.model_name,
# num_heads=cfg.model_num_heads,
pretrained=False,
act_layer=nn.GELU,
qk_norm=True,
num_classes=cfg.num_categories,
)
if local_rank == 0:
print(f"--> {cfg.model_name} built.")
num_params = (sum(p.numel() for p in model.parameters())) / 1e6
print(f"built model with {num_params}M params")
mp_policy = None
if cfg.use_mixed_precision and bf16_ready:
mp_policy = cfg.mp_policy
if rank == 0:
print(f"bf16 check passed")
print(f"\n--> Running with mixed precision {cfg.mp_policy} policy")
else:
if rank == 0:
print(f"--> Warning - bf16 support not available. Using fp32")
# if not using mixed precision, turn on TF32 for matmul?
if not cfg.use_mixed_precision and cfg.use_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if rank == 0:
print(f"--> TF32 support for matmul enabled. ")
if local_rank == 0:
init_start = time.perf_counter()
# preload checkpoint if desired
if cfg.load_model_checkpoint:
if cfg.checkpoint_type == StateDictType.FULL_STATE_DICT:
model_checkpointing.load_model_checkpoint(model, rank, cfg)
elif cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
model_checkpointing.load_distributed_model_checkpoint(model, rank, cfg)
prefetch_policy = cfg.backward_prefetch
if rank == 0:
print(f"backward prefetch set to {prefetch_policy}")
print(f"sharding set to {cfg.sharding_strategy}")
print(f"--> Batch Size = {cfg.batch_size_training}")
# model weights to BF16?
if cfg.model_weights_bf16:
model = model.to(torch.bfloat16)
mp_policy = None
if rank == 0:
print(f"--> Model converted to BF16.\nRunning in ** PURE ** BFloat mode")
# ----- Add 2D Tensor Parallel if activated (in config)
if cfg.use_tp:
print(f"Tensor Parallel activated - init start\n")
from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp
TP_AVAILABLE = False
try:
from torch.distributed._tensor import (
DeviceMesh,
)
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
ColwiseParallel,
RowwiseParallel,
# get_parallelization_fqn,
)
# need to setup hooks for TP
fsdp_is_available = enable_2d_with_fsdp()
TP_AVAILABLE = fsdp_is_available
except BaseException as e:
print(f"Exception during TP init - {e=}\n")
pass
assert TP_AVAILABLE, f"fsdp did not init"
print(f"tp_initialized - rank {rank}\n")
# Init TP
_tp = int(os.environ.get("RANK", -1)) != -1 # verify distributed run
assert (
_tp and TP_AVAILABLE
), "this config assumes setup for Tensor Parallel - distributed not ready here."
# rank_print(f"TP is available = {TP_AVAILABLE}\n")
model_parallel_size = 2
# 2-D mesh is [dp, tp]
twod_mesh = DeviceMesh(
device_type="cuda",
mesh=torch.arange(0, world_size).view(model_parallel_size, -1),
)
rank_print(rank, f"{twod_mesh=}")
# rank_print(rank, f"{model=}")
# this is for parallelized vit - need to dynamically locate blocks
# rank_print(rank, f"{model=}")
# assert False, "remove"
# tp parallelized block
# col
# in proj
# row
# mlp_out_proj
# attn_out_proj
blocks = model.get_submodule(f"blocks")
total_blocks = len(blocks)
# print(f"len block {total_blocks}")
for i, block in enumerate(blocks):
try:
rank_print(rank, f"\nparallelization of block {i}")
parallelized_block = parallelize_module(
module=block,
device_mesh=twod_mesh,
parallelize_plan={
"attn.qkv": ColwiseParallel(),
"attn.out_proj": RowwiseParallel(),
"mlp.linear1": ColwiseParallel(),
"mlp.linear2": RowwiseParallel(),
},
tp_mesh_dim=1,
)
# print(f"\nSuccess - {blocks[i]}\n")
block = parallelized_block
# rank_print(rank, f"{parallelized_block=}")
except e:
print(f"{e=}")
assert False, f"failed to TP"
# rank_print(rank, f"{blocks=}")
# rank_print(rank, f"{model=}")
"""
for i in range(12):
block = model.get_submodule(f"encoder.block_{i}")
parallelized_block = parallelize_module(
module=block,
device_mesh=twod_mesh,
parallelize_plan={
"self_attention": PairwiseParallel(),
"mlp_block": PairwiseParallel(),
},
tp_mesh_dim=1,
)
block = parallelized_block
"""
"""
if rank == 0:
print(f"&&&&&&&&&&&\n {model=}")
model = parallelize_module(
model,
twod_mesh,
{"self_attention": PairwiseParallel(), "mlp_block": PairwiseParallel()},
tp_mesh_dim=1,
)
"""
# print(f"{tp_model=}")
fsdp_pg = twod_mesh.get_dim_groups()[0]
# todo - add back main code later for resume
device = "cuda"
# model.to(device)
# model = FSDP(model, process_group=fsdp_pg)
process_group_fsdp = None
if cfg.use_tp:
fsdp_pg = twod_mesh.get_dim_groups()[0]
process_group_fsdp = fsdp_pg
# ----- main FSDP or DDP init -----------
if cfg.use_ddp:
model.to("cuda")
model = DDP(
model,
device_ids=[local_rank],
bucket_cap_mb=cfg.ddp_bucket_size,
gradient_as_bucket_view=cfg.ddp_use_gradient_view,
)
else:
# ----- FSDP Init --------------------
_global_device_mesh = None
# handle scaling groups
from torch.distributed.fsdp import ShardingStrategy as sharding
from sharding_groups_helper import create_device_mesh
if cfg.sharding_strategy in [
sharding.HYBRID_SHARD,
sharding._HYBRID_SHARD_ZERO2,
]:
sharding_group_size = None
replica_group_size = None
if cfg.sharding_group_size:
rank_print(f"{cfg.sharding_group_size=}")
sharding_group_size = cfg.sharding_group_size
if cfg.replica_group_size:
replica_group_size = replica_group_size
rank_print(f"{replica_group_size=}")
_global_device_mesh = create_device_mesh(
replica_group_size, sharding_group_size
)
rank_print(f"{_global_device_mesh=}")
model = FSDP(
model,
device_mesh=_global_device_mesh
if _global_device_mesh is not None
else None,
process_group=process_group_fsdp
if process_group_fsdp is not None
else None,
auto_wrap_policy=my_auto_wrap_policy,
mixed_precision=mp_policy,
backward_prefetch=prefetch_policy,
sharding_strategy=cfg.sharding_strategy,
device_id=torch.cuda.current_device(),
forward_prefetch=cfg.forward_prefetch,
use_orig_params=cfg.use_orig_params,
limit_all_gathers=cfg.limit_all_gathers,
param_init_fn=my_init_fn,
)
print_memory_summary("vit", "cuda")
time.sleep(2)
if (
cfg.load_model_checkpoint
and cfg.checkpoint_type == StateDictType.SHARDED_STATE_DICT
):
model_checkpointing.load_model_sharded(model, rank, cfg)
if cfg.fsdp_activation_checkpointing:
config.fsdp_checkpointing(model)
if rank == 0:
print(f"--> FSDP activation checkpointing in use")
if cfg.use_torch_compile:
# model = torch.compile(model)
model = torch._dynamo.optimize("inductor")(model)
if rank == 0:
print(f"--> Torch.compile in use")
# print sharding plan?
if rank == 0 and cfg.print_sharding_plan:
print(model)
# postload checkpoint if desired
if (
cfg.load_model_checkpoint
and cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT
):
model_checkpointing.load_distributed_model_checkpoint(model, rank, cfg)
if local_rank == 0:
init_time = time.perf_counter() - init_start
print(f"local rank {local_rank} init time = {init_time}")
# data loader -------------
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=cfg.batch_size_training,
num_workers=cfg.num_workers_dataloader,
pin_memory=False,
sampler=train_sampler,
)
if cfg.run_validation:
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.val_batch_size,
num_workers=cfg.num_workers_dataloader,
pin_memory=False,
sampler=val_sampler,
)
# memory and timing tracking
if local_rank == 0:
memmax.start()
# torch.cuda.reset_peak_memory_stats()
tracking_duration = []
else:
tracking_duration = None
# warmup, this is only used in the non-recursive ParamExecOrderPolicy
"""config.train(
model,
data_loader,
None,
None,
memmax,
local_rank,
tracking_duration,
1,
use_synthetic_data=cfg.use_synthetic_data,
)
if rank == 0:
print("Finish warm up")
model.zero_grad()
"""
# optimizer ----------
optimizer = None
lr = 9e-4
weight_decay = 0.002
if cfg.optimizer == "int8":
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(
model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=False
)
if rank == 0:
print(f"Running with 8 bit optimizer")
elif cfg.optimizer == "AnyPrecision":
import optimizers
optimizer = optimizers.AnyPrecisionAdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
momentum_dtype=cfg.ap_momentum_dtype,
variance_dtype=cfg.ap_variance_dtype,
use_kahan_summation=cfg.ap_use_kahan_summation,
)
if rank == 0:
print(
f"Running with AnyPrecision Optimizer, momo={cfg.ap_momentum_dtype}, var = {cfg.ap_variance_dtype}, kahan summation = {cfg.ap_use_kahan_summation}"
)
elif cfg.optimizer == "dadapt_adanip":
from adanip_exp import DAdaptAdanIP
optimizer = DAdaptAdanIP( # DAdaptAdam(
model.parameters(),
lr=1.0,
weight_decay=weight_decay,
# amsgrad=False,
# decouple=True,
# log_every=4,
)
if rank == 0:
print(f"Running with DAdapt AdanIP optimizer")
elif cfg.optimizer == "dadapt_adam":
from dadaptation import DAdaptAdam
# optimizer = torch.optim.AdamW(
optimizer = DAdaptAdanIP( # DAdaptAdam(
model.parameters(),
lr=1.0,
weight_decay=weight_decay,
# amsgrad=False,
# decouple=True,
# log_every=4,
)
if rank == 0:
print(f"Running with DAdapt optimizer")
elif cfg.optimizer == "AdamW":
use_fused_optimizer = cfg.use_fused_optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=0.0005,
weight_decay=weight_decay,
fused=use_fused_optimizer,
)
if rank == 0:
print(
f"Running with AdamW optimizer, with fusion set to {use_fused_optimizer}"
)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# linear warmup
from torch.optim.lr_scheduler import LinearLR
warmup_scheduler = LinearLR(optimizer, start_factor=0.1, total_iters=50)
# (optimizer, start_factor=0.3333333333333333, end_factor=1.0, total_iters=5, last_epoch=- 1, verbose=False)
# start adding in logged metrics...
_metric_logger = None
if cfg.run_validation:
from metric_logging.metric_logger import get_date_time
curr_time = get_date_time()
file_description = "stats_" + curr_time + ".txt"
_metric_logger = file_description
# load optimizer checkpoint
if cfg.load_optimizer:
model_checkpointing.load_optimizer_checkpoint(model, optimizer, rank, cfg)
torch_profiler = None
total_steps = None
if cfg.total_steps_to_run:
total_steps = cfg.total_steps_to_run - 1 # fix off by one for step count
@contextlib.contextmanager
def maybe_run_profiler(cfg, *args, **kwargs):
use_profiler: bool = cfg.run_profiler
if use_profiler:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
cfg.profile_folder
),
profile_memory=True,
with_stack=False,
record_shapes=True,
) as torch_profiler:
yield torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
if cfg.run_profiler:
print(f"Profiling active. Traces will be saved at {cfg.profile_folder}")
with maybe_run_profiler(cfg) as torch_profiler:
for i in range(cfg.num_epochs):
if rank == 0:
print(f"Epoch: {i} starting...")
if not cfg.use_synthetic_data:
assert _stats is not None, "missing stats in main"
config.train(
model,
data_loader,
torch_profiler,
optimizer,
memmax,
local_rank,
tracking_duration,
total_steps,
use_synthetic_data=cfg.use_synthetic_data,
use_label_singular=use_label_singular,
stats=_stats,
lr_scheduler=warmup_scheduler,
)
if cfg.total_steps_to_run is not None:
break
if cfg.run_validation:
if rank == 0:
assert _stats is not None, "no stats in main"
with torch.no_grad():
config.validation(
model,
local_rank,
rank,
val_loader,
world_size,
stats=_stats,
use_label_singular=use_label_singular,
metric_logger=_metric_logger,
)
# print(f"rank {local_rank} in front of barrier...")
# dist.barrier()
# print(f"rank {local_rank} past barrier...")
# checkpointing for model and optimizer
if cfg.save_model_checkpoint:
if cfg.checkpoint_type == StateDictType.FULL_STATE_DICT:
model_checkpointing.save_model_checkpoint(
model, optimizer, rank, cfg, epoch=1
)
elif cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
print(f"Saving Model via Distributed Checkpoint")
model_checkpointing.save_distributed_model_checkpoint(model, rank, cfg)
elif cfg.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
model_checkpointing.save_model_sharded(model, rank, cfg)
if cfg.save_optimizer:
model_checkpointing.save_optimizer_checkpoint(
model, optimizer, rank, cfg, epoch=1
)
# memory summary
print(f"** exit loop - rank {local_rank} reporting....")
if local_rank == 0:
# memory monitor
memmax.stop() # stop and display info
# print(f"{tracking_duration=}, {cfg.total_steps_to_run=}")
if _stats:
total_loss_curve = _stats["loss"]
total_acc_curve = _stats["accuracy"]
training_loss_curve = _stats["training_loss"]
if cfg.print_training_loss_data:
print(f"Training loss data")
for i, loss in enumerate(training_loss_curve):
print(f"{loss}")
print(f"\nValidation loss data")
for i, loss in enumerate(total_loss_curve):
print(f"{loss}")
print(f"\nAccuracy validation")
for i, accuracy in enumerate(total_acc_curve):
print(f"{accuracy}")
# print(f"Training time average iter")
total_training_iter_times = _stats["training_iter_time"]
denom = len(total_training_iter_times)
# total_times = sum(total_training_iter_times)
# average_iter = round(total_times / denom, 5)
# print(f"\nAverage iter = {average_iter}")
best_val_acc = 0
if total_acc_curve:
best_val_acc = 100 * float(max(total_acc_curve))
print(Fore.GREEN + f"\n--> Highest Val Accuracy = {best_val_acc}\n")
warmup_steps = cfg.warmup_steps
iters_to_avg = total_training_iter_times[warmup_steps:]
if cfg.total_steps_to_run is None:
stable_sum = sum(iters_to_avg)
# print(f"len iters_to_avg = {len(iters_to_avg)}")
total_steps_measured = denom - warmup_steps
stable_avg = stable_sum / total_steps_measured
stable_avg = round(stable_avg, 4)
print(
Fore.GREEN
+ f"\n--> Step avg speed (in seconds) based on {total_steps_measured} steps: {stable_avg}\nexcluding {warmup_steps} steps as warmup"
)
if cfg.total_steps_to_run is not None:
warmup_steps = cfg.warmup_steps
iters_to_avg = tracking_duration[warmup_steps:]
stable_sum = sum(iters_to_avg)
# print(f"len iters_to_avg = {len(iters_to_avg)}")
total_steps_measured = cfg.total_steps_to_run - warmup_steps
stable_avg = stable_sum / total_steps_measured
stable_avg = round(stable_avg, 4)
print(
Fore.GREEN
+ f"\n--> Step avg speed based on {total_steps_measured} steps: {stable_avg} seconds"
)
try:
if cfg.use_deferred_init:
print(
Fore.LIGHTBLUE_EX
+ f"\n ==>> This run used deferred init! \nIf you are training and seeing no/poor training results, \n pls set this to False in the config file.**\n"
)
except:
pass
training_framework = "DDP" if cfg.use_ddp else "FSDP"
print(Fore.GREEN + f"\nDist Training Framework used = {training_framework}\n")
if cfg.use_ddp:
print(
f"DDP settings: \nddp_bucket_size={cfg.ddp_bucket_size},\nddp_use_gradient_view={cfg.ddp_use_gradient_view}\n"
)
print(f"This was run with TensorParallel? = {cfg.use_tp}\n")
try:
print(f"Run with Parallel Attention? {cfg.use_parallel_attention}")
print(f"Run with MQA? {cfg.use_multi_query_attention}\n")
except:
pass
print(f"Batch size used = {cfg.batch_size_training}\n")
if not cfg.use_ddp:
print(
f"FSDP Activation Checkpointing? = {cfg.fsdp_activation_checkpointing}"
)
if cfg.hf_t5_checkpointing:
print(f"HF Activation Checkpointing? = {cfg.hf_t5_checkpointing}")
print(Fore.LIGHTBLUE_EX + f"\n--> Model Size = {num_params} M Params\n")
if cfg.print_memory_summary:
print(
f"\nCUDA Memory Summary After Training:\n {torch.cuda.memory_summary()}"
)
cleanup()
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch experiments with FSDP")
parser.add_argument(
"--model",
default="deepvit",
metavar="string",
choices=["deepvit", "t5", "regnet", "vitbase", "vitsmart"],
help="choose model to run, available: `deepvit`, `t5`, `regnet`, `vitbase`, 'vitsmart' (default: vitbase)",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
print(f"******* loading model {args.model=}")
assert args.model in ["deepvit", "t5", "regnet", "vitbase", "vitsmart"]
if args.model == "deepvit":
import config.deepvit_config as config
elif args.model == "t5":
import config.t5_config as config
elif args.model == "regnet":
import config.regnet_config as config
elif args.model == "vitbase":
import config.vit_base_config as config
elif args.model == "vitsmart":
import config.vit_smart_config as config
fsdp_main()