Skip to content

Commit

Permalink
tests: add test for lightning
Browse files Browse the repository at this point in the history
Add test for `lightning`, based on their "Getting started" example
with autoencoder trained on MNIST dataset. Requires `torchivsion`
for dataset.
  • Loading branch information
rokm committed Dec 23, 2023
1 parent bb294d3 commit 2afbbcf
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions src/_pyinstaller_hooks_contrib/tests/test_deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,73 @@ def test_timm_model_creation(pyi_builder):
model = timm.create_model("resnet50d", pretrained=False)
print(model)
""")


@importorskip('lightning')
@importorskip('torchvision')
@importorskip('torch')
@onedir_only
def test_lightning_mnist_autoencoder(pyi_builder):
pyi_builder.test_source("""
import os
import torch
import torchvision
import lightning
class LitAutoEncoder(lightning.LightningModule):
def __init__(self):
super().__init__()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28 * 28, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 3),
)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 28 * 28),
)
def forward(self, x):
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = torch.nn.functional.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=1e-3,
)
return optimizer
# Dataset
dataset = torchvision.datasets.MNIST(
os.path.dirname(__file__),
download=True,
transform=torchvision.transforms.ToTensor(),
)
dataset_size = len(dataset)
num_samples = 100
train, val = torch.utils.data.random_split(
dataset,
[num_samples, dataset_size - num_samples],
)
# Train
autoencoder = LitAutoEncoder()
trainer = lightning.Trainer(max_epochs=1, logger=False)
trainer.fit(
autoencoder,
torch.utils.data.DataLoader(train),
)
""")

0 comments on commit 2afbbcf

Please sign in to comment.