-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
121 lines (98 loc) · 3.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import sys
import logging
import csv
from pathlib import Path
import rasterio
rasterio.log.setLevel(logging.ERROR)
import torch
import torch.nn as nn
def make_tuple(x):
if isinstance(x, int):
return x, x
if isinstance(x, list) and len(x) == 1:
return x[0], x[0]
return x
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_logger(logpath=None):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
if logpath is not None:
file_handler = logging.FileHandler(logpath)
file_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(file_handler)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(stream_handler)
return logger
def save_checkpoint(model, optimizer, path):
if path.exists():
path.unlink()
model = model.module if isinstance(model, nn.DataParallel) else model
state = {'state_dict': model.state_dict()}
if optimizer:
state = {'state_dict': model.state_dict(),
'optim_dict': optimizer.state_dict()}
if isinstance(path, Path):
torch.save(state, str(path.resolve()))
else:
torch.save(state, str(path.resolve()))
def load_checkpoint(checkpoint, model, optimizer=None, map_location=None):
if not checkpoint.exists():
raise FileNotFoundError(f"File doesn't exist {checkpoint}")
state = torch.load(checkpoint, map_location=map_location)
if isinstance(model, nn.DataParallel):
model = model.module
model.load_state_dict(state['state_dict'])
if optimizer:
optimizer.load_state_dict(state['optim_dict'])
return state
def log_csv(filepath, values, header=None, multirows=False):
empty = False
if not filepath.exists():
filepath.touch()
empty = True
with open(filepath, 'a') as file:
writer = csv.writer(file)
if empty and header:
writer.writerow(header)
if multirows:
writer.writerows(values)
else:
writer.writerow(values)
def load_pretrained(model, pretrained, requires_grad=False):
if isinstance(model, nn.DataParallel):
model = model.module
model_dict = model.state_dict()
pretrained_dict = torch.load(pretrained)['state_dict']
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
if not requires_grad:
for param in model.parameters():
param.requires_grad = False
def save_array_as_tif(matrix, path, profile=None, prototype=None):
assert matrix.ndim == 2 or matrix.ndim == 3
if prototype:
with rasterio.open(str(prototype)) as src:
profile = src.profile
with rasterio.open(path, mode='w', **profile) as dst:
if matrix.ndim == 3:
for i in range(matrix.shape[0]):
dst.write(matrix[i], i + 1)
else:
dst.write(matrix, 1)