-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
54 lines (45 loc) · 1.46 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from torch import Tensor
from ..base import BaseModel
from .blocks import (
make_output_conv,
make_refinenets,
)
from .decoder import make_decoder
from .encoder import make_encoder
class Vit(BaseModel):
def __init__(
self,
embed_dim=192,
img_size=256,
in_chans=5,
blocks=4,
readout_op="ignore",
use_bn=True,
enable_attention_hooks=False,
encoder_attn_drop_rate=0.1,
encoder_proj_drop_rate=0.1,
decoder_drop_rate=0.1,
**kwargs,
):
super(Vit, self).__init__()
self.encoder = make_encoder(
embed_dim=embed_dim,
readout_op=readout_op,
enable_attention_hooks=enable_attention_hooks,
img_size=img_size,
in_chans=in_chans,
attn_drop_rate=encoder_attn_drop_rate,
proj_drop_rate=encoder_proj_drop_rate,
**kwargs,
)
self.decoder = make_decoder(embed_dim, blocks, decoder_drop_rate)
self.refinenets = make_refinenets(embed_dim, blocks, use_bn)
self.output_conv = make_output_conv(embed_dim)
def forward(self, x: Tensor) -> Tensor:
encoded = self.encoder(x)
decoded = self.decoder(encoded)
path = None
for layer, refine in zip(reversed(decoded), self.refinenets):
path = refine(layer) if path is None else refine(path, layer)
out = self.output_conv(path)
return out.squeeze()