forked from ppwwyyxx/RAM-multiprocess-dataloader
-
Notifications
You must be signed in to change notification settings - Fork 0
/
common.py
107 lines (90 loc) · 2.87 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations
from collections import defaultdict
import pickle
import sys
import torch
import json
from typing import Any
from tabulate import tabulate
import os
import time
import psutil
def get_mem_info(pid: int) -> dict[str, int]:
res = defaultdict(int)
for mmap in psutil.Process(pid).memory_maps():
res['rss'] += mmap.rss
res['pss'] += mmap.pss
res['uss'] += mmap.private_clean + mmap.private_dirty
res['shared'] += mmap.shared_clean + mmap.shared_dirty
if mmap.path.startswith('/'):
res['shared_file'] += mmap.shared_clean + mmap.shared_dirty
return res
class MemoryMonitor():
def __init__(self, pids: list[int] = None):
if pids is None:
pids = [os.getpid()]
self.pids = pids
def add_pid(self, pid: int):
assert pid not in self.pids
self.pids.append(pid)
def _refresh(self):
self.data = {pid: get_mem_info(pid) for pid in self.pids}
return self.data
def table(self) -> str:
self._refresh()
table = []
keys = list(list(self.data.values())[0].keys())
now = str(int(time.perf_counter() % 1e5))
for pid, data in self.data.items():
table.append((now, str(pid)) + tuple(self.format(data[k]) for k in keys))
return tabulate(table, headers=["time", "PID"] + keys)
def str(self):
self._refresh()
keys = list(list(self.data.values())[0].keys())
res = []
for pid in self.pids:
s = f"PID={pid}"
for k in keys:
v = self.format(self.data[pid][k])
s += f", {k}={v}"
res.append(s)
return "\n".join(res)
@staticmethod
def format(size: int) -> str:
for unit in ('', 'K', 'M', 'G'):
if size < 1024:
break
size /= 1024.0
return "%.1f%s" % (size, unit)
def create_coco() -> list[Any]:
# Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
with open("instances_train2017.json") as f:
obj = json.load(f)
return obj["annotations"]
def read_sample(x):
# A function that is supposed to read object x, incrementing its refcount.
# This mimics what a real dataloader would do.
if sys.version_info >= (3, 10, 6):
# Before this version, pickle does not increment refcount. This is a bug that's
# fixed in https://github.com/python/cpython/pull/92931.
return pickle.dumps(x)
else:
import msgpack
return msgpack.dumps(x)
class DatasetFromList(torch.utils.data.Dataset):
def __init__(self, lst):
self.lst = lst
def __len__(self):
return len(self.lst)
def __getitem__(self, idx: int):
return self.lst[idx]
if __name__ == "__main__":
from serialize import NumpySerializedList
monitor = MemoryMonitor()
print("Initial", monitor.str())
lst = create_coco()
print("JSON", monitor.str())
lst = NumpySerializedList(lst)
print("Serialized", monitor.str())
del lst; import gc; gc.collect()
print("End", monitor.str())