forked from generative-ai-on-aws/generative-ai-on-aws
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory_tracker.py
115 lines (96 loc) · 3.19 KB
/
memory_tracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import psutil
import smdistributed.modelparallel.torch as smp
import torch
try:
from py3nvml import py3nvml
except ImportError:
py3nvml = None
dtype_to_bit = {
torch.float32: 32,
torch.float64: 64,
torch.float16: 16,
torch.bfloat16: 16,
torch.uint8: 8,
torch.int8: 8,
torch.int16: 16,
torch.int32: 32,
torch.int64: 64,
torch.bool: 1,
}
process = psutil.Process(os.getpid())
base_mem_usage = process.memory_info().data
last_mem_usage = base_mem_usage
def memory_status(msg="", reset_max=True, sync=True):
rank = smp.rank()
tp_rank = smp.tp_rank()
pp_rank = smp.pp_rank()
rdp_rank = smp.rdp_rank()
local_rank = smp.local_rank()
if sync:
torch.cuda.synchronize()
if rdp_rank != 0:
return
if py3nvml != None:
py3nvml.nvmlInit()
handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank)
info = py3nvml.nvmlDeviceGetMemoryInfo(handle)
total_used = info.used / 1024**3
total_used_str = f"Totally used GPU memory: {total_used}"
else:
total_used_str = ""
alloced = torch.cuda.memory_allocated(device=local_rank)
max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
cached = torch.cuda.memory_reserved(device=local_rank)
max_cached = torch.cuda.max_memory_reserved(device=local_rank)
# convert to GB for printing
alloced /= 1024**3
cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3
print(
f"[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}",
f"device={local_rank} "
f"alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} "
f"cache {cached:0.4f} max_cached {max_cached:0.4f} "
f"{total_used_str}",
)
if reset_max:
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
if py3nvml != None:
py3nvml.nvmlShutdown()
def memory_status_cpu(msg=""):
import gc
global last_mem_usage
global base_mem_usage
rdp_rank = smp.rdp_rank()
gc.collect()
gc.collect()
gc.collect()
objects = gc.get_objects()
tensors = [obj for obj in objects if isinstance(obj, torch.Tensor) and not obj.is_cuda]
torch_usage = 0
for t in tensors:
torch_usage += t.numel() * dtype_to_bit[t.dtype]
# total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes
current_usage = process.memory_info().data
total_usage = current_usage - base_mem_usage
usage_change = current_usage - last_mem_usage
last_mem_usage = current_usage
torch_usage /= 1024**3
total_usage /= 1024**3
usage_change /= 1024**3
base_usage = base_mem_usage / 1024**3
rank = smp.rank()
tp_rank = smp.tp_rank()
pp_rank = smp.pp_rank()
rdp_rank = smp.rdp_rank()
local_rank = smp.local_rank()
if rdp_rank != 0:
return
print(
f"[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}",
f"device={local_rank} "
f"torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}",
)