Skip to content

Commit

Permalink
Lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Dec 9, 2023
1 parent ecf8220 commit 2908077
Showing 1 changed file with 46 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
x = nn.gelu(x)

if self.use_glu:
y = nn.Dense(
self.mlp_dim,
**inits)(x)
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y

x = nn.Dropout(rate=self.dropout_rate)(x, train)
Expand All @@ -71,41 +69,45 @@ class Encoder1DBlock(nn.Module):
@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y

y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y

y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
else:
y = x
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)

y = x
y = MlpBlock(
mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)
y = x
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)

y = x
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)

return x

Expand Down Expand Up @@ -141,12 +143,13 @@ class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12

@nn.compact
def __call__(self, x):
n, _, d = x.shape
probe = self.param('probe',
nn.initializers.xavier_uniform(),
(1, 1, d), x.dtype)
nn.initializers.xavier_uniform(), (1, 1, d),
x.dtype)
probe = jnp.tile(probe, [n, 1, 1])

x = nn.MultiHeadDotProductAttention(
Expand All @@ -171,9 +174,9 @@ class ViT(nn.Module):
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False,
use_post_layer_norm: bool = False,
use_map: bool = False,
use_glu: bool = False
use_post_layer_norm: bool = False
use_map: bool = False

def get_posemb(self,
seqshape: tuple,
Expand Down Expand Up @@ -214,9 +217,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
x, train=not train)

if self.use_map:
x = MAPHead(num_heads=self.num_heads,
mlp_dim=self.mlp_dim
)(x)
x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
else:
x = jnp.mean(x, axis=1)

Expand Down

0 comments on commit 2908077

Please sign in to comment.