Skip to content

Commit

Permalink
tests for adaptive param + others
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 9, 2023
1 parent 6b344cb commit 9f5e935
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/modules/adaptive_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import torch
from torch import nn
from zeta.nn.modules.adaptive_parameter_list import AdaptiveParameterList

def test_adaptiveparameterlist_initialization():
model = AdaptiveParameterList([nn.Parameter(torch.randn(10, 10))])
assert isinstance(model, AdaptiveParameterList)
assert len(model) == 1

def test_adaptiveparameterlist_adapt():
model = AdaptiveParameterList([nn.Parameter(torch.randn(10, 10))])
model.adapt({0: lambda x: x * 0.9})
assert torch.allclose(model[0], torch.randn(10, 10) * 0.9, atol=1e-4)

@pytest.mark.parametrize("adaptation_functions", [lambda x: x * 0.9])
def test_adaptiveparameterlist_adapt_edge_cases(adaptation_functions):
model = AdaptiveParameterList([nn.Parameter(torch.randn(10, 10))])
with pytest.raises(Exception):
model.adapt(adaptation_functions)

def test_adaptiveparameterlist_adapt_invalid_dimensions():
model = AdaptiveParameterList([nn.Parameter(torch.randn(10, 10))])
with pytest.raises(Exception):
model.adapt({0: lambda x: x.view(-1)})
37 changes: 37 additions & 0 deletions tests/modules/dynamic_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
from torch import nn
from zeta.nn.modules.dynamic_module import DynamicModule

def test_dynamicmodule_initialization():
model = DynamicModule()
assert isinstance(model, DynamicModule)
assert model.module_dict == nn.ModuleDict()
assert model.forward_method == None

def test_dynamicmodule_add_remove_module():
model = DynamicModule()
model.add('linear', nn.Linear(10, 10))
assert 'linear' in model.module_dict
model.remove('linear')
assert 'linear' not in model.module_dict

def test_dynamicmodule_forward():
model = DynamicModule()
model.add('linear', nn.Linear(10, 10))
x = torch.randn(1, 10)
output = model(x)
assert output.shape == (1, 10)

@pytest.mark.parametrize("name", ['linear'])
def test_dynamicmodule_add_module_edge_cases(name):
model = DynamicModule()
model.add(name, nn.Linear(10, 10))
with pytest.raises(Exception):
model.add(name, nn.Linear(10, 10))

@pytest.mark.parametrize("name", ['linear'])
def test_dynamicmodule_remove_module_edge_cases(name):
model = DynamicModule()
with pytest.raises(Exception):
model.remove(name)
Empty file added tests/modules/feedforward.py
Empty file.
Empty file added tests/modules/mbconv.py
Empty file.
30 changes: 30 additions & 0 deletions tests/modules/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import torch
from zeta.nn.modules.mlp import MLP

def test_mlp_initialization():
model = MLP(dim_in=256, dim_out=10)
assert isinstance(model, MLP)
assert len(model.net) == 3
assert isinstance(model.net[0], nn.Sequential)
assert isinstance(model.net[1], nn.Sequential)
assert isinstance(model.net[2], nn.Linear)

def test_mlp_forward():
model = MLP(dim_in=256, dim_out=10)
x = torch.randn(32, 256)
output = model(x)
assert output.shape == (32, 10)

@pytest.mark.parametrize("dim_in", [0])
def test_mlp_forward_edge_cases(dim_in):
model = MLP(dim_in=dim_in, dim_out=10)
x = torch.randn(32, dim_in)
with pytest.raises(Exception):
model(x)

def test_mlp_forward_invalid_dimensions():
model = MLP(dim_in=256, dim_out=10)
x = torch.randn(32, 128)
with pytest.raises(Exception):
model(x)
30 changes: 30 additions & 0 deletions tests/nn/embeddings/abc_pos_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import torch
from zeta.nn.embeddings.abc_pos_emb import AbsolutePositionalEmbedding

def test_absolutepositionalembedding_initialization():
model = AbsolutePositionalEmbedding(dim=512, max_seq_len=1000)
assert isinstance(model, AbsolutePositionalEmbedding)
assert model.scale == 512**-0.5
assert model.max_seq_len == 1000
assert model.l2norm_embed == False
assert model.emb.weight.shape == (1000, 512)

def test_absolutepositionalembedding_forward():
model = AbsolutePositionalEmbedding(dim=512, max_seq_len=1000)
x = torch.randn(1, 10, 512)
output = model(x)
assert output.shape == (10, 512)

@pytest.mark.parametrize("seq_len", [1001])
def test_absolutepositionalembedding_forward_edge_cases(seq_len):
model = AbsolutePositionalEmbedding(dim=512, max_seq_len=1000)
x = torch.randn(1, seq_len, 512)
with pytest.raises(Exception):
model(x)

def test_absolutepositionalembedding_forward_invalid_dimensions():
model = AbsolutePositionalEmbedding(dim=512, max_seq_len=1000)
x = torch.randn(1, 10, 256)
with pytest.raises(Exception):
model(x)

0 comments on commit 9f5e935

Please sign in to comment.