Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner authored and samuelburbulla committed Jun 20, 2024
1 parent eb143d3 commit 4ab4258
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
import torch
import torch.nn as nn

from continuiti.networks import MultiHead, ScaledDotProduct
from continuiti.networks import MultiHeadAttention, ScaledDotProductAttention


@pytest.fixture(scope="session")
def some_multi_head_attn():
return MultiHead(
return MultiHeadAttention(
hidden_dim=32,
n_heads=4,
attention=ScaledDotProduct(dropout_p=0.25),
attention=ScaledDotProductAttention(dropout_p=0.25),
bias=True,
)


class TestMultiHeadAttention:
def test_can_initialize(self, some_multi_head_attn):
assert isinstance(some_multi_head_attn, MultiHead)
assert isinstance(some_multi_head_attn, MultiHeadAttention)

def test_output_shape(self, some_multi_head_attn):
batch_size = 3
Expand All @@ -42,7 +42,7 @@ def test_output_shape(self, some_multi_head_attn):

def test_attention_correct(self):
"""Edge case testing for correctness."""
m_attn = MultiHead(4, 4, bias=False)
m_attn = MultiHeadAttention(4, 4, bias=False)

batch_size = 3
hidden_dim = 4
Expand All @@ -51,10 +51,10 @@ def test_attention_correct(self):

query = torch.rand(batch_size, query_size, hidden_dim)
key = torch.rand(batch_size, key_val_size, hidden_dim)
torch.rand(batch_size, key_val_size, hidden_dim)
value = torch.zeros(batch_size, key_val_size, hidden_dim)

# V = 0 -> attn score == 0
out = m_attn(query, key, torch.zeros(batch_size, key_val_size, hidden_dim))
out = m_attn(query, key, value)
assert torch.allclose(out, torch.zeros(out.shape))

def test_gradient_flow(self, some_multi_head_attn):
Expand Down Expand Up @@ -84,10 +84,10 @@ def test_equal_to_torch(self):
v = torch.rand(batch_size, source_length, embedding_dim)

gt_attn = nn.MultiheadAttention(embedding_dim, heads, batch_first=True)
attn = MultiHead(
attn = MultiHeadAttention(
hidden_dim=embedding_dim,
n_heads=heads,
attention=ScaledDotProduct(dropout_p=0.0),
attention=ScaledDotProductAttention(dropout_p=0.0),
bias=True,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.nn.functional import scaled_dot_product_attention

from continuiti.networks.attention import ScaledDotProduct
from continuiti.networks import ScaledDotProductAttention


def test_forward_correct():
Expand All @@ -14,7 +14,7 @@ def test_forward_correct():
key = torch.rand(batch_size, key_val_size, hidden_dim)
value = torch.rand(batch_size, key_val_size, hidden_dim)

attn = ScaledDotProduct()
attn = ScaledDotProductAttention()

out = attn(query, key, value)
gt_out = scaled_dot_product_attention(query, key, value)
Expand Down

0 comments on commit 4ab4258

Please sign in to comment.