Skip to content

Commit

Permalink
push attention config
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jul 10, 2023
1 parent 08bce75 commit fdf5a5a
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion zeta/utils/attention/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from zeta.utils.attention.attend import Attend
from zeta.utils.attention.attend import Attend, Intermediates

from abc import ABC, abstractmethod
import bitsandbytes as bnb
Expand Down Expand Up @@ -293,6 +293,24 @@ def dropout_seq(seq, mask, dropout):



class GRUGating(nn.Module):
def __init__(self, dim, scale_residual = False, **kwargs):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None

def forward(self, x, residual):
if exists(self.residual_scale):
residual = residual * self.residual_scale

gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)

return gated_output.reshape_as(x)


class Residual(nn.Module):
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
super().__init__()
Expand Down

0 comments on commit fdf5a5a

Please sign in to comment.