Skip to content

Commit

Permalink
smoke tests for linear autoencoders
Browse files Browse the repository at this point in the history
  • Loading branch information
SamChou05 committed Dec 31, 2024
1 parent a5ae6d1 commit e1840dd
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 12 deletions.
19 changes: 7 additions & 12 deletions afqinsight/nn/pt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,6 @@ def cnn_resnet_pt(input_shape, n_classes):
return cnn_resnet_Model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class VariationalEncoder(nn.Module):
def __init__(self, input_shape, latent_dims):
super(VariationalEncoder, self).__init__()
Expand All @@ -553,9 +550,11 @@ def __init__(self, input_shape, latent_dims):
self.linear3 = nn.Linear(500, latent_dims)
self.activation = nn.ReLU()

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.to(device)
self.N.scale = self.N.scale.to(device)
self.N.loc = self.N.loc.to(self.device)
self.N.scale = self.N.scale.to(self.device)
self.kl = 0

def forward(self, x):
Expand Down Expand Up @@ -607,8 +606,6 @@ def __init__(self, input_shape, latent_dims):
self.decoder = Decoder(input_shape, latent_dims)

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# is this useful?
self.to(self.device)

def forward(self, x):
z = self.encoder(x)
Expand All @@ -620,7 +617,7 @@ def fit(self, data, epochs=20, lr=0.001):
running_loss = 0
items = 0
for x, _ in data:
x = x.to(device) # GPU
x = x.to(self.device) # GPU
opt.zero_grad()
x_hat = self(x)
loss = ((x - x_hat) ** 2).sum() + self.encoder.kl
Expand All @@ -630,9 +627,8 @@ def fit(self, data, epochs=20, lr=0.001):
opt.step()
print(f"Epoch {epoch+1}, Loss: {running_loss/items:.2f}")

# what to do here
def transform(self, x):
return self.encoder(x)
self.forward(x)

def fit_transform(self, data, epochs=20):
self.fit(data, epochs)
Expand Down Expand Up @@ -669,9 +665,8 @@ def fit(self, data, epochs=20, lr=0.001):

return self

# what to do here
def transform(self, x):
return self.encoder(x)
self.forward(x)

def fit_transform(self, data, epochs=20):
self.fit(data, epochs)
Expand Down
128 changes: 128 additions & 0 deletions afqinsight/nn/tests/test_autoencoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import pytest
import torch

from afqinsight import AFQDataset
from afqinsight.nn.pt_models import Autoencoder, VariationalAutoencoder
from afqinsight.nn.utils import prep_pytorch_data


@pytest.fixture
def device():
"""Fixture to set up the computing device."""
if torch.backends.mps.is_available():
return torch.device("cpu")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.fixture
def dataset():
"""Fixture to load the AFQ dataset."""
return AFQDataset.from_study("hbn")


@pytest.fixture
def data_loaders(dataset):
"""Fixture to prepare PyTorch datasets and data loaders."""
torch_dataset, train_loader, test_loader, val_loader = prep_pytorch_data(dataset)
return torch_dataset, train_loader, test_loader, val_loader


@pytest.fixture
def data_shapes(data_loaders):
"""Fixture to compute shapes for input and target tensors."""
torch_dataset = data_loaders[0]
gt_shape = torch_dataset[0][1].size()[0]
sequence_length = torch_dataset[0][0].size()[0] # 48
in_channels = torch_dataset[0][0].size()[1] # 100
return gt_shape, sequence_length, in_channels


@pytest.mark.parametrize("latent_dims", [2, 10])
def test_autoencoder_forward(data_loaders, latent_dims, data_shapes):
"""
Smoke test to check if the linear Autoencoder forward pass works
without raising an exception and returns the expected shape.
"""
torch_dataset, train_loader, test_loader, val_loader = data_loaders
gt_shape, sequence_length, in_channels = data_shapes

# Define input_shape = 48 * 100 = 4800
model = Autoencoder(
input_shape=sequence_length * in_channels, latent_dims=latent_dims
)
model.eval() # We just do forward pass check, no training

# Retrieve a single batch
data_iter = iter(test_loader)
x, _ = next(data_iter)

# Forward pass
with torch.no_grad():
output = model(x)

# Check output shapeß
# The decoder expects to return shape: (batch_size, 48, 100)
expected_shape = (x.size(0), sequence_length, in_channels)
assert output.shape == expected_shape, (
f"Expected output shape {expected_shape}, " f"but got {output.shape}."
)


@pytest.mark.parametrize("latent_dims", [2, 10])
def test_variational_autoencoder_forward(data_loaders, latent_dims, data_shapes):
"""
Smoke test to check if the linear VariationalAutoencoder forward pass
works without throwing exceptions and returns the expected shape.
"""
torch_dataset, train_loader, test_loader, val_loader = data_loaders
gt_shape, sequence_length, in_channels = data_shapes

model = VariationalAutoencoder(
input_shape=sequence_length * in_channels, latent_dims=latent_dims
)
model.eval()

data_iter = iter(test_loader)
x, _ = next(data_iter)

with torch.no_grad():
output = model(x)

# Check if shape matches (batch_size, 48, 100)
expected_shape = (x.size(0), sequence_length, in_channels)
assert output.shape == expected_shape, (
f"Expected output shape {expected_shape}, " f"but got {output.shape}."
)


def test_autoencoder_train_loop(data_loaders, data_shapes):
"""
Simple smoke test for the training loop of the linear Autoencoder,
checking for any exceptions.
"""
torch_dataset, train_loader, test_loader, val_loader = data_loaders
gt_shape, sequence_length, in_channels = data_shapes

model = Autoencoder(input_shape=sequence_length * in_channels, latent_dims=10)
model.train()

# Fit the model on the random dataset for 1 epoch
# This doesn't guarantee correctness, just that it runs
model.fit(test_loader, epochs=1, lr=0.001)


def test_variational_autoencoder_train_loop(data_loaders, data_shapes):
"""
Simple smoke test for the training loop of the linear VariationalAutoencoder,
checking for any exceptions.
"""
torch_dataset, train_loader, test_loader, val_loader = data_loaders
gt_shape, sequence_length, in_channels = data_shapes

model = VariationalAutoencoder(
input_shape=sequence_length * in_channels, latent_dims=10
)
model.train()

# Fit the model on the random dataset for 1 epoch
model.fit(test_loader, epochs=1, lr=0.001)

0 comments on commit e1840dd

Please sign in to comment.