You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
这段代码具体对应于文中的哪个模块呀?为什么会牵扯到相对位置编码呢?
class ConvRelPosEnc(nn.Module):
"""Convolutional relative position encoding."""
def init(self, Ch, h, window):
"""Initialization.
Ch: Channels per head.
h: Number of heads.
window: Window size(s) in convolutional relative positional encoding.
It can have two forms:
1. An integer of window size, which assigns all attention heads
with the same window size in ConvRelPosEnc.
2. A dict mapping window size to #attention head splits
(e.g. {window size 1: #attention head split 1, window size
2: #attention head split 2})
It will apply different window size to
the attention head splits.
"""
super().__init__()
if isinstance(window, int):
# Set the same window size for all attention heads.
window = {window: h}
self.window = window
elif isinstance(window, dict):
self.window = window
else:
raise ValueError()
self.conv_list = nn.ModuleList()
self.head_splits = []
for cur_window, cur_head_split in window.items():
dilation = 1 # Use dilation=1 at default.
padding_size = (cur_window + (cur_window - 1) *
(dilation - 1)) // 2
cur_conv = nn.Conv2d(
cur_head_split * Ch,
cur_head_split * Ch,
kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size),
dilation=(dilation, dilation),
groups=cur_head_split * Ch,
)
self.conv_list.append(cur_conv)
self.head_splits.append(cur_head_split)
self.channel_splits = [x * Ch for x in self.head_splits]
def forward(self, q, v, size):
"""foward function"""
B, h, N, Ch = q.shape
H, W = size
# We don't use CLS_TOKEN
q_img = q
v_img = v
# Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
v_img = rearrange(v_img, "B h (H W) Ch -> B (h Ch) H W", H=H, W=W)
# Split according to channels.
v_img_list = torch.split(v_img, self.channel_splits, dim=1)
conv_v_img_list = [
conv(x) for conv, x in zip(self.conv_list, v_img_list)
]
conv_v_img = torch.cat(conv_v_img_list, dim=1)
# Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
conv_v_img = rearrange(conv_v_img, "B (h Ch) H W -> B h (H W) Ch", h=h)
EV_hat_img = q_img * conv_v_img
EV_hat = EV_hat_img
return EV_hat
The text was updated successfully, but these errors were encountered:
这段代码具体对应于文中的哪个模块呀?为什么会牵扯到相对位置编码呢?
class ConvRelPosEnc(nn.Module):
"""Convolutional relative position encoding."""
def init(self, Ch, h, window):
"""Initialization.
The text was updated successfully, but these errors were encountered: