diff --git a/tests/nn/attentions/test_xc_attention.py b/tests/nn/attentions/test_xc_attention.py index d5558996..e6c4948f 100644 --- a/tests/nn/attentions/test_xc_attention.py +++ b/tests/nn/attentions/test_xc_attention.py @@ -1,25 +1,27 @@ +""" Test cases for the XCAttention class. """ import torch import pytest from torch import nn + from zeta.nn.attention.xc_attention import XCAttention -# Fixture to create an instance of the XCAttention class @pytest.fixture def xc_attention_model(): - model = XCAttention(dim=256, cond_dim=64, heads=8) + """ Fixture to create an instance of the XCAttention class. """ + model = XCAttention(dim=256, cond_dim=64, heads=8, dropout=0.1) return model -# Test case to check if XCAttention initializes correctly def test_xc_attention_initialization(xc_attention_model): + """ Test case to check if XCAttention initializes correctly. """ assert isinstance(xc_attention_model, XCAttention) assert isinstance(xc_attention_model.norm, nn.LayerNorm) assert isinstance(xc_attention_model.to_qkv, nn.Sequential) -# Test case to check if XCAttention handles forward pass correctly def test_xc_attention_forward_pass(xc_attention_model): + """ Test case to check if XCAttention handles forward pass correctly. """ x = torch.randn(1, 256, 16, 16) cond = torch.randn(1, 64) @@ -28,8 +30,8 @@ def test_xc_attention_forward_pass(xc_attention_model): assert isinstance(output, torch.Tensor) -# Test case to check if XCAttention handles forward pass without conditioning def test_xc_attention_forward_pass_without_cond(xc_attention_model): + """ Test case to check if XCAttention handles forward pass without conditioning. """ x = torch.randn(1, 256, 16, 16) output = xc_attention_model(x) @@ -37,16 +39,16 @@ def test_xc_attention_forward_pass_without_cond(xc_attention_model): assert isinstance(output, torch.Tensor) -# Test case to check if XCAttention raises an error when forwarding with invalid inputs def test_xc_attention_forward_with_invalid_inputs(xc_attention_model): + """ Test case to check if XCAttention raises an error when forwarding with invalid inputs. """ with pytest.raises(Exception): x = torch.randn(1, 256, 16, 16) cond = torch.randn(1, 128) # Mismatched conditioning dimension xc_attention_model(x, cond) -# Test case to check if XCAttention handles different head configurations correctly def test_xc_attention_with_different_heads(): + """ Test case to check if XCAttention handles different head configurations correctly. """ head_configs = [4, 8, 12] for heads in head_configs: @@ -58,8 +60,8 @@ def test_xc_attention_with_different_heads(): ) -# Test case to check if XCAttention handles different input dimensions correctly def test_xc_attention_with_different_input_dims(): + """ Test case to check if XCAttention handles different input dimensions correctly. """ input_dims = [128, 256, 512] for dim in input_dims: @@ -68,8 +70,8 @@ def test_xc_attention_with_different_input_dims(): assert model.to_qkv[0].in_features == dim -# Test case to check if XCAttention handles different conditioning dimensions correctly def test_xc_attention_with_different_cond_dims(): + """ Test case to check if XCAttention handles different conditioning dimensions correctly. """ cond_dims = [32, 64, 128] for cond_dim in cond_dims: @@ -78,13 +80,13 @@ def test_xc_attention_with_different_cond_dims(): assert model.film[0].in_features == cond_dim * 2 -# Test case to check if XCAttention handles negative input dimensions correctly def test_xc_attention_negative_input_dim(): + """ Test case to check if XCAttention handles negative input dimensions correctly. """ with pytest.raises(ValueError): XCAttention(dim=-256, cond_dim=64, heads=8) -# Test case to check if XCAttention handles negative conditioning dimensions correctly def test_xc_attention_negative_cond_dim(): + """ Test case to check if XCAttention handles negative conditioning dimensions correctly. """ with pytest.raises(ValueError): XCAttention(dim=256, cond_dim=-64, heads=8)