diff --git a/zeta/utils/attention/main.py b/zeta/utils/attention/main.py index 6734e031..59d90ce5 100644 --- a/zeta/utils/attention/main.py +++ b/zeta/utils/attention/main.py @@ -17,7 +17,7 @@ from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange -from zeta.utils.attention.attend import Attend +from zeta.utils.attention.attend import Attend, Intermediates from abc import ABC, abstractmethod import bitsandbytes as bnb @@ -293,6 +293,24 @@ def dropout_seq(seq, mask, dropout): +class GRUGating(nn.Module): + def __init__(self, dim, scale_residual = False, **kwargs): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + class Residual(nn.Module): def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.): super().__init__()