Skip to content

Commit

Permalink
Merge pull request #135 from evelynmitchell/master
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez authored Feb 5, 2024
2 parents b496122 + 8749c3d commit dfb0914
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions tests/nn/attentions/test_xc_attention.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
""" Test cases for the XCAttention class. """
import torch
import pytest
from torch import nn

from zeta.nn.attention.xc_attention import XCAttention


# Fixture to create an instance of the XCAttention class
@pytest.fixture
def xc_attention_model():
model = XCAttention(dim=256, cond_dim=64, heads=8)
""" Fixture to create an instance of the XCAttention class. """
model = XCAttention(dim=256, cond_dim=64, heads=8, dropout=0.1)
return model


# Test case to check if XCAttention initializes correctly
def test_xc_attention_initialization(xc_attention_model):
""" Test case to check if XCAttention initializes correctly. """
assert isinstance(xc_attention_model, XCAttention)
assert isinstance(xc_attention_model.norm, nn.LayerNorm)
assert isinstance(xc_attention_model.to_qkv, nn.Sequential)


# Test case to check if XCAttention handles forward pass correctly
def test_xc_attention_forward_pass(xc_attention_model):
""" Test case to check if XCAttention handles forward pass correctly. """
x = torch.randn(1, 256, 16, 16)
cond = torch.randn(1, 64)

Expand All @@ -28,25 +30,25 @@ def test_xc_attention_forward_pass(xc_attention_model):
assert isinstance(output, torch.Tensor)


# Test case to check if XCAttention handles forward pass without conditioning
def test_xc_attention_forward_pass_without_cond(xc_attention_model):
""" Test case to check if XCAttention handles forward pass without conditioning. """
x = torch.randn(1, 256, 16, 16)

output = xc_attention_model(x)

assert isinstance(output, torch.Tensor)


# Test case to check if XCAttention raises an error when forwarding with invalid inputs
def test_xc_attention_forward_with_invalid_inputs(xc_attention_model):
""" Test case to check if XCAttention raises an error when forwarding with invalid inputs. """
with pytest.raises(Exception):
x = torch.randn(1, 256, 16, 16)
cond = torch.randn(1, 128) # Mismatched conditioning dimension
xc_attention_model(x, cond)


# Test case to check if XCAttention handles different head configurations correctly
def test_xc_attention_with_different_heads():
""" Test case to check if XCAttention handles different head configurations correctly. """
head_configs = [4, 8, 12]

for heads in head_configs:
Expand All @@ -58,8 +60,8 @@ def test_xc_attention_with_different_heads():
)


# Test case to check if XCAttention handles different input dimensions correctly
def test_xc_attention_with_different_input_dims():
""" Test case to check if XCAttention handles different input dimensions correctly. """
input_dims = [128, 256, 512]

for dim in input_dims:
Expand All @@ -68,8 +70,8 @@ def test_xc_attention_with_different_input_dims():
assert model.to_qkv[0].in_features == dim


# Test case to check if XCAttention handles different conditioning dimensions correctly
def test_xc_attention_with_different_cond_dims():
""" Test case to check if XCAttention handles different conditioning dimensions correctly. """
cond_dims = [32, 64, 128]

for cond_dim in cond_dims:
Expand All @@ -78,13 +80,13 @@ def test_xc_attention_with_different_cond_dims():
assert model.film[0].in_features == cond_dim * 2


# Test case to check if XCAttention handles negative input dimensions correctly
def test_xc_attention_negative_input_dim():
""" Test case to check if XCAttention handles negative input dimensions correctly. """
with pytest.raises(ValueError):
XCAttention(dim=-256, cond_dim=64, heads=8)


# Test case to check if XCAttention handles negative conditioning dimensions correctly
def test_xc_attention_negative_cond_dim():
""" Test case to check if XCAttention handles negative conditioning dimensions correctly. """
with pytest.raises(ValueError):
XCAttention(dim=256, cond_dim=-64, heads=8)

0 comments on commit dfb0914

Please sign in to comment.