diff --git a/src/continuiti/networks/attention.py b/src/continuiti/networks/attention.py index 85421129..82e24dee 100644 --- a/src/continuiti/networks/attention.py +++ b/src/continuiti/networks/attention.py @@ -16,13 +16,10 @@ class Attention(nn.Module): components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g. scaled-dot-product attention). - Args: - dropout_p: dropout probability. """ - def __init__(self, dropout_p: float = 0.0): + def __init__(self): super().__init__() - self.dropout_p = dropout_p @abstractmethod def forward( diff --git a/src/continuiti/networks/scaled_dot_product_attention.py b/src/continuiti/networks/scaled_dot_product_attention.py index a1d77cde..64f60bd3 100644 --- a/src/continuiti/networks/scaled_dot_product_attention.py +++ b/src/continuiti/networks/scaled_dot_product_attention.py @@ -19,7 +19,8 @@ class ScaledDotProductAttention(Attention): """ def __init__(self, dropout_p: float = 0.0): - super().__init__(dropout_p) + super().__init__() + self.dropout_p = dropout_p def forward( self,