-
Notifications
You must be signed in to change notification settings - Fork 7
/
update_mem.py
68 lines (53 loc) · 2.28 KB
/
update_mem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import time
import torch
import argparse
### from Detectron2 ###
from configs.defaults import _C
### from MiB/PLOP ###
import utils.tasks as tasks
from models.cayley_rot import Cayley_Rot
def main(args):
device = torch.device(f"cuda:{args.gpu_id}")
cfg = _C.clone()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.dataset = args.config_file.split("/")[1]
cfg.mem_size = args.mem_size
cfg.mem_name = os.path.join("./checkpoints", cfg.MODEL.WEIGHTS.replace(".pt", f"_M{args.mem_size}.pt"))
cfg.matrices = f"ROT_{cfg.SEED}_"
if cfg.OVERLAP:
cfg.matrices += "ov_"
else:
cfg.matrices += "dis_"
cfg.matrices += f"{cfg.TASK}_{cfg.STEP}_last.pt"
num_classes = tasks.get_per_task_classes(cfg.dataset, cfg.TASK, cfg.STEP)
num_cls = sum(num_classes[:cfg.STEP])
cfg.save_name = os.path.join("./checkpoints", cfg.matrices.replace(".pt", f"_C{num_cls}M{cfg.mem_size}.pt"))
cfg.freeze()
print(f"Loading Memory from {cfg.mem_name} ...")
memory = torch.load(cfg.mem_name, map_location='cpu').to(device)
print(f"Memory Size: {memory.shape} (num_cls, num_mem, num_dim)\n")
print(f"Loading Roation Matrices from {cfg.matrices} ...\n")
model = Cayley_Rot(num_cls).to(device)
chkpt = torch.load(f"./checkpoints/{cfg.matrices}", map_location='cpu')
model.load_state_dict(chkpt)
new = []
with torch.no_grad():
for ind in range(num_cls):
new += [torch.matmul(model.get_matrix(ind), memory[ind].t()).t()] # (num_mem, num_dim)
new = torch.stack(new, dim=0).detach().cpu() # (num_cls, num_mem, num_dim)
print(f"Saving New Memory @ {cfg.save_name} ...\n")
torch.save(new, cfg.save_name)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file")
parser.add_argument("--mem-size", type=int, help="size of memory")
parser.add_argument("--gpu-id", type=int, default=0, help="GPU Index (0 or 1)")
parser.add_argument("--opts", help="Modify config options using the command-line 'KEY VALUE' pairs", default=[], nargs=argparse.REMAINDER)
return parser.parse_args()
if __name__ == '__main__':
start_time = time.time()
args = get_args()
main(args)
print('TOTAL TIME (sec): ', time.time() - start_time)