Skip to content

Commit

Permalink
rm dropout p from attention base class
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner authored and samuelburbulla committed Jun 20, 2024
1 parent 93af51c commit eb143d3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
5 changes: 1 addition & 4 deletions src/continuiti/networks/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@ class Attention(nn.Module):
components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g.
scaled-dot-product attention).
Args:
dropout_p: dropout probability.
"""

def __init__(self, dropout_p: float = 0.0):
def __init__(self):
super().__init__()
self.dropout_p = dropout_p

@abstractmethod
def forward(
Expand Down
3 changes: 2 additions & 1 deletion src/continuiti/networks/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class ScaledDotProductAttention(Attention):
"""

def __init__(self, dropout_p: float = 0.0):
super().__init__(dropout_p)
super().__init__()
self.dropout_p = dropout_p

def forward(
self,
Expand Down

0 comments on commit eb143d3

Please sign in to comment.