forked from generative-ai-on-aws/generative-ai-on-aws
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsdp_utils.py
38 lines (32 loc) · 1.44 KB
/
sdp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
def build_param_id_to_offset(param_groups):
param_id_to_offset = []
for i, group in enumerate(param_groups):
offset = 0
group_offsets = {}
for p in group["params"]:
size = p.ds_tensor.ds_numel
group_offsets[id(p)] = (offset, size)
offset += size
param_id_to_offset.append(group_offsets)
return param_id_to_offset
def build_param_id_to_buffer(optimizer, param_id_to_offset):
param_id_to_buffer = {}
for i, group in enumerate(optimizer.param_groups):
for _id, (offset, sz) in param_id_to_offset[i].items():
buf = optimizer.fp32_partitioned_groups_flat[i].narrow(0, offset, sz)
param_id_to_buffer[_id] = buf
return param_id_to_buffer
def log_param_norms(model, optimizer, param_id_to_buffer):
weight_norms = {}
other_norms = {}
for name, param in model.named_parameters():
buf = param_id_to_buffer[id(param)]
param_norm = torch.linalg.norm(buf) ** 2
other_norm = torch.linalg.norm(param.ds_tensor.data) ** 2
torch.distributed.all_reduce(param_norm, group=optimizer.ds_param_shard_group)
torch.distributed.all_reduce(other_norm, group=optimizer.ds_param_shard_group)
weight_norms[name] = torch.sqrt(param_norm).item()
other_norms[name] = torch.sqrt(other_norm).item()
if smp.rank() == 0:
print(f"{name}: {weight_norms[name]} {other_norms[name]}")