-
Notifications
You must be signed in to change notification settings - Fork 1
/
PSPModule.py
25 lines (21 loc) · 860 Bytes
/
PSPModule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from torch import nn
from torch.nn import functional as F
class PSPModule(nn.Module):
# (1, 2, 3, 6)
def __init__(self, sizes=(1, 3, 6, 8), dimension=2):
super(PSPModule, self).__init__()
self.stages = nn.ModuleList([self._make_stage(size, dimension) for size in sizes])
def _make_stage(self, size, dimension=2):
if dimension == 1:
prior = nn.AdaptiveAvgPool1d(output_size=size)
elif dimension == 2:
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
elif dimension == 3:
prior = nn.AdaptiveAvgPool3d(output_size=(size, size, size))
return prior
def forward(self, feats):
n, c, _, _ = feats.size()
priors = [stage(feats).view(n, c, -1) for stage in self.stages]
center = torch.cat(priors, -1)
return center