Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jul 10, 2023
1 parent f11a2e1 commit 12cf070
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion LongNet/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self,
*,
dropout = 0.,
attention_dropout=0.,
causal = False,
heads = None,
talking_heads = False,
Expand All @@ -63,7 +64,8 @@ def __init__(
self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax

self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.attn_dropout = nn.Dropout(attention_dropout) # modify this line

# talking heads

Expand Down
2 changes: 1 addition & 1 deletion LongNet/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0,
self.relative_bias = RelativePositionBias(num_buckets=32, max_distance=128, n_heads=num_heads)

#head offsets
self.head_offsets = nn.Parameer(torch.randn(num_heads, d_model))
self.head_offsets = nn.Parameter(torch.randn(num_heads, d_model))

def get_mask(self, i, j):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 2)
Expand Down

0 comments on commit 12cf070

Please sign in to comment.