-
Notifications
You must be signed in to change notification settings - Fork 428
/
basic_var.py
174 lines (141 loc) · 8.94 KB
/
basic_var.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.helpers import DropPath, drop_path
# this file only provides the 3 blocks used in VAR transformer
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
# automatically import fused operators
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.fused_dense import fused_mlp_func
except ImportError: pass
# automatically import faster attention implementations
try: from xformers.ops import memory_efficient_attention
except ImportError: pass
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
except ImportError: pass
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
except ImportError:
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
if attn_mask is not None: attn.add_(attn_mask)
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
class FFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
super().__init__()
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU(approximate='tanh')
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
def forward(self, x):
if self.fused_mlp_func is not None:
return self.drop(self.fused_mlp_func(
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
heuristic=0, process_group=None,
))
else:
return self.drop(self.fc2( self.act(self.fc1(x)) ))
def extra_repr(self) -> str:
return f'fused_mlp_func={self.fused_mlp_func is not None}'
class SelfAttention(nn.Module):
def __init__(
self, block_idx, embed_dim=768, num_heads=12,
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
):
super().__init__()
assert embed_dim % num_heads == 0
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
self.attn_l2_norm = attn_l2_norm
if self.attn_l2_norm:
self.scale = 1
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
self.max_scale_mul = torch.log(torch.tensor(100)).item()
else:
self.scale = 0.25 / math.sqrt(self.head_dim)
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
self.attn_drop: float = attn_drop
self.using_flash = flash_if_available and flash_attn_func is not None
self.using_xform = flash_if_available and memory_efficient_attention is not None
# only used during inference
self.caching, self.cached_k, self.cached_v = False, None, None
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
# NOTE: attn_bias is None during inference because kv cache is enabled
def forward(self, x, attn_bias):
B, L, C = x.shape
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
main_type = qkv.dtype
# qkv: BL3Hc
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
if self.attn_l2_norm:
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
q = F.normalize(q, dim=-1).mul(scale_mul)
k = F.normalize(k, dim=-1)
if self.caching:
if self.cached_k is None: self.cached_k = k; self.cached_v = v
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
dropout_p = self.attn_drop if self.training else 0.0
if using_flash:
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
elif self.using_xform:
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
else:
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
return self.proj_drop(self.proj(oup))
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
# attn = self.attn_drop(attn.softmax(dim=-1))
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
def extra_repr(self) -> str:
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
class AdaLNSelfAttn(nn.Module):
def __init__(
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
flash_if_available=False, fused_if_available=True,
):
super(AdaLNSelfAttn, self).__init__()
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
self.C, self.D = embed_dim, cond_dim
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
self.shared_aln = shared_aln
if self.shared_aln:
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
else:
lin = nn.Linear(cond_dim, 6*embed_dim)
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
self.fused_add_norm_fn = None
# NOTE: attn_bias is None during inference because kv cache is enabled
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
if self.shared_aln:
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
else:
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
return x
def extra_repr(self) -> str:
return f'shared_aln={self.shared_aln}'
class AdaLNBeforeHead(nn.Module):
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
super().__init__()
self.C, self.D = C, D
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)