-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathaggregate_utils.py
94 lines (85 loc) · 2.66 KB
/
aggregate_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import time
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.nn import Parameter
import rasterio as rio
def aggregate(data, scale):
r, c = data.shape
nr, nc = int(r*scale), int(r*scale)
step = int(1/scale)
res = np.zeros((nr, nc))
data = data.astype('float')
k=0
for i in range(0, r, step):
for j in range(0, c,step):
m = int(i/step)
n = int(j/step)
patch = data[i:i+step, j:j+step]
patch_area = (patch>0).sum()
res[m, n] = patch.sum()/(patch_area+1e-6)
k=k+1
return res
def aggregate_torch(data, scale):
step = int(1/scale)
conv = torch.nn.Conv2d(1, 1,
kernel_size=step, stride=step, bias=False,
)
conv.weight.requires_grad = False
conv.weight = Parameter(torch.ones((1, 1, step, step)), requires_grad=False)
s1 = conv(data)
data_area = (data>=0).float() # changed from 1.0 to 0.0
s2 = conv(data_area)
res = s1/(s2+1e-10)
res = res.squeeze() #.numpy()
return res
def aggregate_torch_gpu(data, scale, device='cuda'):
# h, w = data.shape
# data = np.reshape(data, (1, 1, h, w))
# data = torch.from_numpy(data).float()
step = int(1/scale)
conv = torch.nn.Conv2d(1, 1,
kernel_size=step, stride=step, bias=False,
)
conv.weight.requires_grad = False
conv.weight = Parameter(torch.ones((1, 1, step, step)), requires_grad=False)
conv = conv.to(device)
s1 = conv(data)
data_area = (data>1.0).float()
s2 = conv(data_area)
res = s1/(s2+1e-10)
return res
if __name__=="__main__":
iname = 'Beijing_47.tif'
datapath = os.path.join(r'D:\data\Landcover\samples62\bh', iname)
respath = os.path.join('tmp', iname)
scale = 0.25
with rio.open(datapath, 'r') as src:
data = src.read(1)
t0 = time.time()
res1 = aggregate(data, scale)
t1 = time.time()
print('%.6f'%(t1-t0))
t0 = time.time()
res2 = aggregate_torch(data, scale)
t1 = time.time()
print('%.6f' % (t1 - t0))
# profile = src.profile
# profile.update(dtype=np.float32, count=2)
# with rio.open(respath, 'w', **profile) as dst:
# dst.write(res1, 1)
# dst.write(res2, 2)
# res1 = aggregate(data, scale)
# res2 = aggregate_torch(data, scale)
#
# diff = (res1-res2)
# print(diff.min(), diff.max())
# plt.subplot(1,3,1)
# plt.imshow(data)
# plt.subplot(1, 3, 2)
# plt.imshow(res1)
# plt.subplot(1, 3, 3)
# plt.imshow(res2)
# plt.show()