diff --git a/README.md b/README.md index c0ac8c60..e69de29b 100644 --- a/README.md +++ b/README.md @@ -1,84 +0,0 @@ -# bioimage_embed: Autoencoders for Biological Image Data - -bioimage_embed is an all-in-one Python package designed to cater to the needs of computational biologists, data scientists, and researchers working on biological image data. With specialized functions to handle, preprocess, and visualize microscopy datasets, this tool is tailored to streamline the embedding process for biological imagery. - -[![Build Status](https://img.shields.io/badge/build-passing-green.svg)](https://github.com/ctr26/bioimage_embed) -[![Python Version](https://img.shields.io/badge/python-3.7+-blue.svg)](https://github.com/ctr26/bioimage_embed) -[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/ctr26/bioimage_embed) - ---- - -## Features - -- Seamless loading of microscopy datasets, compatible with the BioImage Data Resource and Cell Image Library. -- Built-in preprocessing functions to ensure your images are primed for encoding. -- Visual tools to dive deep into the encoding and decoding processes of your autoencoders. - ---- - -## Installation - -To get started with bioimage_embed, you can install it directly via pip or from the GitHub repository. - -### From PyPI: - -```bash -pip install bioimage_embed -``` - -### From GitHub: - -```bash -pip install git+https://github.com/ctr26/bioimage_embed -``` - ---- - -## Usage - -### 1. Basic Installation: - -```bash -pip install -e . -``` - -### 2. Command Line Interface (CLI): - -To get a list of all commands and functions: - -```bash -bioimage_embed --help -``` - -OR - -```bash -bie --help -``` - - -### 3. Developer Installation: - -For those intending to contribute or looking for a deeper dive into the codebase, we use `poetry` to manage our dependencies and virtual environments: - -```bash -poetry env use python -poetry install -poetry shell -``` - ---- - -## Support & Contribution - -For any issues, please refer to our [issues page](https://github.com/ctr26/bioimage_embed/issues). Contributions are more than welcome! Please submit pull requests to the master branch. - ---- - -## License - -bioimage_embed is licensed under the MIT License. Please refer to the [LICENSE](https://github.com/ctr26/bioimage_embed/LICENSE) for more details. - ---- - -Happy Embedding! 🧬🔬 diff --git a/bioimage_embed/lightning/dataloader.py b/bioimage_embed/lightning/dataloader.py deleted file mode 100644 index 8b596b65..00000000 --- a/bioimage_embed/lightning/dataloader.py +++ /dev/null @@ -1,209 +0,0 @@ -from torch.utils.data import DataLoader, WeightedRandomSampler -import pytorch_lightning as pl -import torch -from torch.utils.data import Dataset, random_split -from typing import Tuple -from functools import partial -import numpy as np - - -class StratifiedSampler(WeightedRandomSampler): - def __init__(self, dataset, replacement=True): - # Get the labels (targets) from the dataset - self.targets = np.array([dataset[i][1] for i in range(len(dataset))]) - - # Count the occurrences of each class - class_counts = np.bincount(self.targets) - - # Calculate the weight for each class (inverse of frequency) - class_weights = 1.0 / class_counts - - # Assign weights to each sample based on its class - sample_weights = class_weights[self.targets] - - # Initialize the parent class (WeightedRandomSampler) with sample weights - super().__init__( - weights=sample_weights, - num_samples=len(self.targets), - replacement=replacement, - ) - - -# https://stackoverflow.com/questions/74931838/cant-pickle-local-object-evaluationloop-advance-locals-batch-to-device-pyto -class Collator: - def collate_filter_for_none(self, batch): - """ - Collate function that filters out None values from the batch. - - Args: - batch: The batch to be filtered. - - Returns: - The filtered batch. - """ - batch = list(filter(lambda x: x is not None, batch)) - return torch.utils.data.dataloader.default_collate(batch) - - def __call__(self, incoming): - # do stuff with incoming - return self.collate_filter_for_none(incoming) - - -class DataModule(pl.LightningDataModule): - """ - A PyTorch Lightning DataModule for handling dataset loading and splitting. - - Attributes: - dataset: The dataset to be handled. - batch_size: The size of each batch. - num_workers: The number of workers for data loading. - pin_memory: Whether to use pinned memory for data loading. - drop_last: Whether to drop the last incomplete batch. - collate_fn: The function to use for collating data into batches. - """ - - def __init__( - self, - dataset: Dataset, - batch_size: int = 32, - num_workers: int = 4, - pin_memory: bool = False, - drop_last: bool = False, - # sampler=None, - sampler=StratifiedSampler, - ): - """ - Initializes the DataModule with the given dataset and parameters. - - Args: - dataset: The dataset to be handled. - batch_size: The size of each batch. Default is 32. - num_workers: The number of workers for data loading. Default is 4. - pin_memory: Whether to use pinned memory for data loading. Default is False. - drop_last: Whether to drop the last incomplete batch. Default is False. - collate_fn: The function to use for collating data into batches. Default is None. - """ - super().__init__() - self.dataset = dataset - self.collator = Collator() - self.sampler = sampler - self.dataloader = partial( - DataLoader, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - drop_last=drop_last, - collate_fn=self.collator, - ) - - self.train_dataset = None - self.val_dataset = None - self.test_dataset = None - self.setup() - - def setup(self, stage=None): - """ - Sets up the datasets by splitting the main dataset into train, validation, and test sets. - - Args: - stage: The stage of the setup. Default is None. - """ - ( - self.train_dataset, - self.val_dataset, - self.test_dataset, - ) = self.splitting(self.dataset) - - def splitting( - self, dataset: Dataset, split_train=0.8, split_val=0.1, seed=42 - ) -> Tuple[Dataset, Dataset, Dataset]: - """ - Splits the dataset into train, validation, and test sets. - - Args: - dataset: The dataset to be split. - split_train: The proportion of the dataset to be used for training. Default is 0.8. - split_val: The proportion of the dataset to be used for validation. Default is 0.1. - seed: The random seed for splitting the dataset. Default is 42. - - Returns: - A tuple containing the train, validation, and test datasets. - """ - dataset_size = len(dataset) - indices = list(range(dataset_size)) - train_size = int(split_train * dataset_size) - val_size = int(split_val * dataset_size) - test_size = dataset_size - train_size - val_size - - if train_size + val_size + test_size != dataset_size: - raise ValueError( - "The splitting ratios do not add up to the length of the dataset" - ) - - torch.manual_seed(seed) - train_indices, val_indices, test_indices = random_split( - indices, [train_size, val_size, test_size] - ) - - train_dataset = torch.utils.data.Subset(dataset, train_indices) - val_dataset = torch.utils.data.Subset(dataset, val_indices) - test_dataset = torch.utils.data.Subset(dataset, test_indices) - - return train_dataset, val_dataset, test_dataset - - def get_dataset(self): - return self.dataset - - def train_dataloader(self): - return self.init_dataloader( - self.train_dataset, - shuffle=False, - sampler=self.sampler(self.train_dataset), - ) - - def val_dataloader(self): - return self.init_dataloader(self.val_dataset, shuffle=False) - - def test_dataloader(self): - return self.init_dataloader(self.test_dataset, shuffle=False) - - def predict_dataloader(self): - return self.init_dataloader(self.dataset, shuffle=False) - - def init_dataloader(self, dataset, shuffle=False, sampler=None): - """ - Initializes a dataloader for the given dataset. - - Args: - dataset: The dataset to be loaded. - shuffle: Whether to shuffle the dataset. Default is False. - - Returns: - The dataloader for the dataset. - """ - return ( - self.dataloader( - dataset, - shuffle=shuffle, - sampler=sampler, - ) - if dataset - else None - ) - - -def valid_indices(dataset): - valid_indices = [] - # Iterate through the dataset and apply the transform to each image - for idx in range(len(dataset)): - try: - image, label = dataset[idx] - # If the transform works without errors, add the index to the list of valid indices - valid_indices.append(idx) - except Exception as e: - # A better way to do with would be with batch collation - print(f"Error occurred for image {idx}: {e}") - - # Create a Subset using the valid indices - subset = torch.utils.data.Subset(dataset, valid_indices) - return subset diff --git a/bioimage_embed/lightning/tests/test_contrastive_learning.py b/bioimage_embed/lightning/tests/test_contrastive_learning.py deleted file mode 100644 index cb805b58..00000000 --- a/bioimage_embed/lightning/tests/test_contrastive_learning.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import pytest -from ..torch import create_label_based_pairs, compute_contrastive_loss - -torch.manual_seed(42) - - -@pytest.fixture -def batch_size(): - return 32 - - -@pytest.fixture -def latent_dim(): - return 16 - - -@pytest.fixture(params=[1, 2, 4]) -def classes(request): - return request.param - - -@pytest.fixture() -def labels(batch_size, classes): - return torch.randint(0, classes + 1, (batch_size, 1)) - - -@pytest.fixture -def features(batch_size, latent_dim): - return torch.rand(batch_size, latent_dim) - - -def test_create_label_based_pairs_single_label(features, batch_size): - # All labels are the same - labels = torch.zeros(batch_size, 1) - input_pairs, target_pairs = create_label_based_pairs(features, labels) - assert ( - input_pairs.numel() == 0 and target_pairs.numel() == 0 - ), "Should return empty tensors for single label" - - -def test_create_label_based_pairs_no_pair(features, labels): - # Case when each label only has one sample - input_pairs, target_pairs = create_label_based_pairs(features, labels) - if torch.unique(labels).size(0) == labels.size(0): - assert ( - input_pairs.numel() == 0 and target_pairs.numel() == 0 - ), "Should return empty tensors when no pairs are available" - - -def test_create_label_based_pairs_multiple_pairs(features, labels): - # Check if multiple pairs can be created - input_pairs, target_pairs = create_label_based_pairs(features, labels) - if torch.unique(labels).size(0) > 1: - assert input_pairs.size(0) == target_pairs.size( - 0 - ), "Input and target pairs should have the same number of pairs" - assert input_pairs.size(1) == features.size( - 1 - ), "Pair feature size should match the input feature size" - - -def test_compute_contrastive_loss_single_label(features, batch_size): - # All labels are the same - labels = torch.zeros(batch_size, 1) - loss = compute_contrastive_loss(features, labels) - assert ( - loss.item() == 0.0 - ), "Loss should be zero when all labels are the same (no valid pairs)" - - -def test_compute_contrastive_loss_no_valid_pairs(features, labels): - # Case when each label only has one sample - if torch.unique(labels).size(0) == labels.size(0): # All labels are unique - loss = compute_contrastive_loss(features, labels) - assert ( - loss.item() == 0.0 - ), "Loss should be zero when no valid pairs are available" - - -def test_compute_contrastive_loss_valid_pairs(features, labels): - # Case where valid pairs exist - if torch.unique(labels).size(0) < labels.size(0): # There are valid pairs - loss = compute_contrastive_loss(features, labels) - assert loss.item() > 0.0, "Loss should be non-zero when valid pairs exist" diff --git a/bioimage_embed/lightning/torch.py b/bioimage_embed/lightning/torch.py deleted file mode 100644 index 44ee2d56..00000000 --- a/bioimage_embed/lightning/torch.py +++ /dev/null @@ -1,365 +0,0 @@ -import torchvision -import torch -import pytorch_lightning as pl -from timm import optim, scheduler -from types import SimpleNamespace -import argparse -from transformers.utils import ModelOutput -import torch.nn.functional as F -from monai import losses - -""" -x_recon -> output of the model -z -> latent space -data -> input to the model -target -> target for supervised learning -recon_loss -> reconstruction loss -loss -> total loss -variational_loss -> loss - recon_loss -""" - - -class AutoEncoder(pl.LightningModule): - args = argparse.Namespace( - opt="adamw", - weight_decay=0.001, - momentum=0.9, - sched="cosine", - epochs=50, - lr=1e-4, - min_lr=1e-6, - t_initial=10, - t_mul=2, - lr_min=None, - decay_rate=0.1, - warmup_lr=1e-6, - warmup_lr_init=1e-6, - warmup_epochs=5, - cycle_limit=None, - t_in_epochs=False, - noisy=False, - noise_std=0.1, - noise_pct=0.67, - noise_seed=None, - cooldown_epochs=5, - warmup_t=0, - ) - - def __init__(self, model, args=SimpleNamespace()): - super().__init__() - self.model = model - self.model = self.model.to(self.device) - # Flatten hparams - self.encoder = self.model.encoder - self.decoder = self.model.decoder - if args: - self.args = SimpleNamespace(**{**vars(args), **vars(self.args)}) - self.save_hyperparameters(vars(self.args)) - # TODO update all models to use this for export to onxx - # self.example_input_array = torch.randn(1, *self.model.input_dim) - # self.model.train() - - def forward(self, x: torch.Tensor) -> ModelOutput: - """ - Forward pass of the model - Pythae models take in ModelOutput objects, and return ModelOutput objects so that we can pass in and return multiple tensors - """ - return self.model(ModelOutput(data=x.float())) - - def predict_step( - self, batch: tuple, batch_idx: int, dataloader_idx=0 - ) -> ModelOutput: - return self.batch_to_tensor(batch) - - def batch_to_tensor(self, batch) -> ModelOutput: - """ - This takes in a batch and returns a ModelOutput object. - Lightning batches are x,y pairs of tensors, but we only need the x tensor for the model. - x is fed into the self.forward method - """ - x, y = self.batch_to_xy(batch) - model_output = self.forward(x) - model_output.data = x - model_output.target = y - return model_output - - def embedding(self, model_output: ModelOutput) -> torch.Tensor: - return model_output.z.view(model_output.z.shape[0], -1) - - def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor: - self.model.train() - model_output = self.eval_step(batch, batch_idx) - self.log_dict( - { - "loss/train": model_output.loss, - "mse/train": F.mse_loss(model_output.recon_x, model_output.data), - "recon_loss/train": model_output.recon_loss, - "variational_loss/train": model_output.loss - model_output.recon_loss, - }, - # on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - if isinstance(self.logger, pl.loggers.TensorBoardLogger): - self.log_tensorboard(model_output, model_output.data) - return model_output.loss - - def validation_step(self, batch, batch_idx): - model_output = self.eval_step(batch, batch_idx) - self.log_dict( - { - "loss/val": model_output.loss, - "mse/val": F.mse_loss(model_output.recon_x, model_output.data), - "recon_loss/val": model_output.recon_loss, - "variational_loss/val": model_output.loss - model_output.recon_loss, - } - ) - return model_output.loss - - def test_step(self, batch, batch_idx): - # x, y = batch - model_output = self.eval_step(batch, batch_idx) - self.log_dict( - { - "loss/test": model_output.loss, - "mse/test": F.mse_loss(model_output.recon_x, model_output.data), - "recon_loss/test": model_output.recon_loss, - "variational_loss/test": model_output.loss - model_output.recon_loss, - } - ) - return model_output.loss - - def batch_to_xy(self, batch): - """ - Fangless function to be overloaded later - """ - x, y = batch - return x, y - - def eval_step(self, batch, batch_idx): - """ - This function should be overloaded in the child class to implement the evaluation logic. - """ - return self.predict_step(batch, batch_idx) - - # def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): - # # Implement your own logic for updating the lr scheduler - # # This method will be called at each training step - # # Update the lr scheduler based on the provided arguments - # # You can access the lr scheduler using `self.lr_schedulers()` - - # # Example: - # for lr_scheduler in self.lr_schedulers(): - # lr_scheduler.step() - - def timm_optimizers(self, model): - optimizer = optim.create_optimizer(self.args, model.parameters()) - lr_scheduler = scheduler.create_scheduler(self.args, optimizer)[0] - return optimizer, lr_scheduler - - def timm_to_lightning(self, optimizer, lr_scheduler): - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": "step", # or 'epoch' for step vs epoch training, respectively - }, - } - - def configure_optimizers(self): - # optimizer = optim.create_optimizer(self.args, self.model.parameters()) - # lr_scheduler = scheduler.create_scheduler(self.args, optimizer)[0] - optimizer, lr_scheduler = self.timm_optimizers(self.model) - return self.timm_to_lightning(optimizer, lr_scheduler) - - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): - scheduler.step(epoch=self.current_epoch, metric=metric) - - def log_wandb(self): - pass - - def log_tensorboard(self, model_output, x): - # Optionally you can add more logging, for example, visualizations: - self.logger.experiment.add_image( - "test_input", - torchvision.utils.make_grid(model_output.data), - self.global_step, - ) - self.logger.experiment.add_image( - "test_output", - torchvision.utils.make_grid(model_output.recon_x), - self.global_step, - ) - - -class AE(AutoEncoder): - pass - - -class AutoEncoderUnsupervised(AutoEncoder): - pass - - -class AEUnsupervised(AutoEncoder): - pass - - -""" -This function generates positive pairs of feature vectors (`input_pairs` and `target_pairs`) -based on the class labels provided. - -For each unique class in the labels: -- It selects all samples from the same class and creates pairs of feature vectors. -- Only pairs within the same class are generated, no cross-class pairs. -- If there is only one sample in a class, no pairs are created for that class. - -The resulting `input_pairs` and `target_pairs`: -- `input_pairs`: Feature vectors of the first sample in each pair. -- `target_pairs`: Feature vectors of the second sample in each pair. - -### Example 1: Two Classes -Suppose `X` (features) and `y` (labels) are: -X = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]] -y = [[0], [0], [1], [1]] - -For class 0: -- Input pair: [1.0, 2.0, 3.0] -- Target pair: [4.0, 5.0, 6.0] - -For class 1: -- Input pair: [7.0, 8.0, 9.0] -- Target pair: [10.0, 11.0, 12.0] - -Final pairs: -input_pairs = [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]] -target_pairs = [[4.0, 5.0, 6.0], [10.0, 11.0, 12.0]] - -### Example 2: Multiple Classes -Suppose `X` and `y` have three classes: -X = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0], - [13.0, 14.0, 15.0], [16.0, 17.0, 18.0]] -y = [[0], [0], [1], [1], [2], [2]] - -For class 0: -- Input pair: [1.0, 2.0, 3.0] -- Target pair: [4.0, 5.0, 6.0] - -For class 1: -- Input pair: [7.0, 8.0, 9.0] -- Target pair: [10.0, 11.0, 12.0] - -For class 2: -- Input pair: [13.0, 14.0, 15.0] -- Target pair: [16.0, 17.0, 18.0] - -Final pairs: -input_pairs = [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0], [13.0, 14.0, 15.0]] -target_pairs = [[4.0, 5.0, 6.0], [10.0, 11.0, 12.0], [16.0, 17.0, 18.0]] - -This is used in contrastive learning settings like SimCLR, where pairs from the same class -are treated as positive examples to learn class-consistent embeddings. -""" - - -def create_label_based_pairs( - features: torch.Tensor, labels: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Create positive pairs based on labels for contrastive learning (SimCLR/MoCo). - - Args: - features: Tensor of shape (b, latent_dim) - labels: Tensor of shape (b, 1) - - Returns: - tuple of two tensors, each of shape (n, latent_dim), where n is the number of pairs - """ - labels = labels.squeeze() # Convert (b, 1) to (b,) - unique_labels = torch.unique(labels) - - # If there's only one unique label or no samples, return empty tensors - if len(unique_labels) == 1 or features.size(0) == 1: - return torch.empty(0, features.size(1)), torch.empty(0, features.size(1)) - - positive_pairs = [] - - for label in unique_labels: - mask = labels == label - class_samples = features[mask] - if class_samples.size(0) > 1: # Need at least 2 samples to form pairs - # Generate all possible pairs of samples within this class - num_samples = class_samples.size(0) - pairs = torch.combinations(torch.arange(num_samples), r=2) - positive_pairs.append( - (class_samples[pairs[:, 0]], class_samples[pairs[:, 1]]) - ) - - # If no valid pairs were found, return empty tensors - if not positive_pairs: - return torch.empty(0, features.size(1)), torch.empty(0, features.size(1)) - - # Concatenate all positive pairs across classes - input_pairs = torch.cat([pair[0] for pair in positive_pairs]) - target_pairs = torch.cat([pair[1] for pair in positive_pairs]) - - return input_pairs, target_pairs - - -def compute_contrastive_loss( - X: torch.Tensor, y: torch.Tensor, criterion=losses.ContrastiveLoss() -): - """ - Wrapper function that computes contrastive loss using the MONAI ContrastiveLoss function. - - Args: - - X (torch.Tensor): The feature tensor of shape (batch_size, latent_dim). - - y (torch.Tensor): The label tensor of shape (batch_size, 1). - - contrastive_criterion: The criterion to compute contrastive loss. If None, defaults to monai.losses.ContrastiveLoss. - - Returns: - - loss (torch.Tensor): The computed contrastive loss. - """ - - # Create positive pairs from X and y - input_pairs, target_pairs = create_label_based_pairs(X, y) - - # If no pairs are created, return zero loss - if input_pairs.numel() == 0 or target_pairs.numel() == 0: - return torch.tensor(0.0, device=X.device) - - # Compute the contrastive loss - contrastive_loss = criterion(input_pairs, target_pairs) - - return contrastive_loss - - -class AutoEncoderSupervised(AutoEncoder): - criterion = losses.ContrastiveLoss() - - def eval_step(self, batch, batch_idx): - # x, y = batch - # TODO check this - # Scale is used as the rest of the loss functions are sums rather than means, which may mean we need to scale up the contrastive loss - model_output = self.predict_step(batch, batch_idx) - scale = torch.prod(torch.tensor(model_output.z.shape[1:])) - if model_output.target.unique().size(0) == 1: - return model_output - contrastive_loss = compute_contrastive_loss( - # Belt and braces on this view - model_output.z.view(-1, self.model.latent_dim), - model_output.target, - criterion=self.criterion, - ) - model_output.contrastive_loss = scale * contrastive_loss - model_output.loss += model_output.contrastive_loss - return model_output - - -class AESupervised(AutoEncoderSupervised): - pass - - -class NDAutoEncoder(AESupervised): - def batch_to_xy(self, batch): - x, y = super().batch_to_xy(batch) diff --git a/bioimage_embed/models/bolts/vae.py b/bioimage_embed/models/bolts/vae.py deleted file mode 100644 index 5107c147..00000000 --- a/bioimage_embed/models/bolts/vae.py +++ /dev/null @@ -1,137 +0,0 @@ -from torch import nn -from transformers.utils import ModelOutput -from pythae.models.nn import BaseDecoder, BaseEncoder - - -from pythae import models -from pythae.models import VAEConfig -from pl_bolts.models import autoencoders as ae - - -def count_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -class ResNet50VAEEncoder(BaseEncoder): - enc_out_dim = 2048 - - def __init__( - self, model_config: VAEConfig, first_conv=False, maxpool1=False, **kwargs - ): - super(ResNet50VAEEncoder, self).__init__() - - # input_height = model_config.input_dim[-2] - latent_dim = model_config.latent_dim - - self.encoder = ae.resnet50_encoder(first_conv, maxpool1) - self.embedding = nn.Linear(self.enc_out_dim, latent_dim) - self.log_var = nn.Linear(self.enc_out_dim, latent_dim) - # self.fc1 = nn.Linear(512, latent_dim) - # self._adaptive_pool = nn.AdaptiveAvgPool2d((embedding_dim, embedding_dim)) - - def forward(self, x): - x = self.encoder(x) - # x = self.fc1(x) - return ModelOutput(embedding=self.embedding(x), log_covariance=self.log_var(x)) - - -class ResNet50VAEDecoder(BaseDecoder): - enc_out_dim = 512 - - def __init__( - self, model_config: VAEConfig, first_conv=False, maxpool1=False, **kwargs - ): - super(ResNet50VAEDecoder, self).__init__() - latent_dim = model_config.latent_dim - input_height = model_config.input_dim[-2] - self.embedding = nn.Linear(latent_dim, self.enc_out_dim) - self.decoder = ae.resnet50_decoder( - self.enc_out_dim, input_height, first_conv, maxpool1 - ) - - def forward(self, x): - x = self.embedding(x) - x = self.decoder(x) - return ModelOutput(reconstruction=x) - - -class ResNet18VAEEncoder(BaseEncoder): - enc_out_dim = 512 - - def __init__( - self, model_config: VAEConfig, first_conv=False, maxpool1=False, **kwargs - ): - super(ResNet18VAEEncoder, self).__init__() - - # input_height = model_config.input_dim[-2] - latent_dim = model_config.latent_dim - - self.encoder = ae.resnet18_encoder(first_conv, maxpool1) - self.embedding = nn.Linear(self.enc_out_dim, latent_dim) - self.log_var = nn.Linear(self.enc_out_dim, latent_dim) - - def forward(self, x): - x = self.encoder(x) - # x = self.fc1(x) - return ModelOutput(embedding=self.embedding(x), log_covariance=self.log_var(x)) - - -class ResNet18VAEDecoder(BaseDecoder): - enc_out_dim = 512 - - def __init__( - self, model_config: VAEConfig, first_conv=False, maxpool1=False, **kwargs - ): - super(ResNet18VAEDecoder, self).__init__() - latent_dim = model_config.latent_dim - input_height = model_config.input_dim[-2] - self.decoder = ae.resnet18_decoder( - self.enc_out_dim, input_height, first_conv, maxpool1 - ) - self.embedding = nn.Linear(latent_dim, self.enc_out_dim) - - def forward(self, x): - x = self.embedding(x) - x = self.decoder(x) - return ModelOutput(reconstruction=x) - - -class VAEPythaeWrapper(models.VAE): - def __init__( - self, - model_config, - input_height, - enc_type="resnet18", - enc_out_dim=512, - first_conv=False, - maxpool1=False, - kl_coeff=0.1, - encoder=None, - decoder=None, - ): - super(models.BaseAE, self).__init__() - self.model_name = "VAE_bolt" - self.model_config = model_config - self.model = ae.VAE( - input_height=input_height, - enc_type=enc_type, - enc_out_dim=enc_out_dim, - first_conv=first_conv, - maxpool1=maxpool1, - kl_coeff=kl_coeff, - latent_dim=model_config.latent_dim, - ) - self.encoder = self.model.encoder - self.decoder = self.model.decoder - self.input_dim = self.model_config.input_dim - self.latent_dim = self.model_config.latent_dim - - def forward(self, x, epoch=None): - # return ModelOutput(x=x,recon_x=x,z=x,loss=1) - # # Forward pass logic - x = x["data"] - # x_recon = self.model(x) - z, recon_x, p, q = self.model._run_step(x) - loss, logs = self.model.step((x, x), batch_idx=epoch) - # recon_loss = self.model.reconstruction_loss(x, recon_x) - return ModelOutput(recon_x=recon_x, z=z, logs=logs, loss=loss, recon_loss=loss) diff --git a/bioimage_embed/models/bolts/vqvae.py b/bioimage_embed/models/bolts/vqvae.py deleted file mode 100644 index 05ce664f..00000000 --- a/bioimage_embed/models/bolts/vqvae.py +++ /dev/null @@ -1,136 +0,0 @@ -from torch import nn -from transformers.utils import ModelOutput -from pythae.models.nn import BaseDecoder, BaseEncoder -from pythae.models import VAEConfig - -from pl_bolts.models import autoencoders as ae - - -class BaseResNetVQVAEEncoder(BaseEncoder): - def __init__( - self, - model_config: VAEConfig, - resnet_encoder, - enc_out_dim, - first_conv=False, - maxpool1=False, - **kwargs, - ): - super(BaseResNetVQVAEEncoder, self).__init__() - self.input_dim = model_config.input_dim - self.input_height = model_config.input_dim[-2] - self.latent_dim = model_config.latent_dim - self.enc_out_dim = enc_out_dim - - self.encoder = resnet_encoder(first_conv, maxpool1) - # self.embedding = nn.Linear(self.enc_out_dim, self.latent_dim) - # self.log_var = nn.Linear(self.enc_out_dim, self.latent_dim) - self.prequantized = nn.Conv2d(self.enc_out_dim, self.latent_dim, 1, 1) - - def forward(self, inputs): - x = self.encoder(inputs) - # log_covariance = self.log_var(x) - x = x.view(-1, self.enc_out_dim, 1, 1) - embedding = self.prequantized(x) - embedding = embedding.view(-1, self.latent_dim) - return ModelOutput(embedding=embedding) - # return ModelOutput(embedding=embedding, log_covariance=log_covariance) - - -class ResNet50VQVAEEncoder(BaseResNetVQVAEEncoder): - enc_out_dim = 2048 - - def __init__( - self, - model_config: VAEConfig, - first_conv=False, - maxpool1=False, - **kwargs, - ): - super(ResNet50VQVAEEncoder, self).__init__( - model_config, - ae.resnet50_encoder, - self.enc_out_dim, - first_conv, - maxpool1, - **kwargs, - ) - - -class ResNet18VQVAEEncoder(BaseResNetVQVAEEncoder): - enc_out_dim = 512 - - def __init__( - self, - model_config: VAEConfig, - first_conv=False, - maxpool1=False, - **kwargs, - ): - super(ResNet18VQVAEEncoder, self).__init__( - model_config, - ae.resnet18_encoder, - self.enc_out_dim, - first_conv, - maxpool1, - **kwargs, - ) - - -class BaseResNetVQVAEDecoder(BaseDecoder): - def __init__( - self, - model_config: VAEConfig, - resnet_decoder, - first_conv=False, - maxpool1=False, - **kwargs, - ): - super(BaseResNetVQVAEDecoder, self).__init__() - self.model_config = model_config - self.latent_dim = model_config.latent_dim - self.input_height = model_config.input_dim[-2] - # self.postquantized = nn.Conv2d(self.enc_out_dim, self.latent_dim, 1, 1) - self.postquantized = nn.Conv2d(self.latent_dim, self.enc_out_dim, 1, 1) - self.decoder = resnet_decoder( - self.enc_out_dim, self.input_height, first_conv, maxpool1 - ) - # Activation layer might be useful here - # https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vq_vae.py#L166 - - def forward(self, x): - x = x.view(-1, self.latent_dim, 1, 1) - x = self.postquantized(x) - x = x.view(-1, self.enc_out_dim) - x = self.decoder(x) - return ModelOutput(reconstruction=x) - - -class ResNet50VQVAEDecoder(BaseResNetVQVAEDecoder): - enc_out_dim = 512 - - def __init__( - self, - model_config: VAEConfig, - first_conv=False, - maxpool1=False, - **kwargs, - ): - super(ResNet50VQVAEDecoder, self).__init__( - model_config, ae.resnet50_decoder, first_conv, maxpool1, **kwargs - ) - - -class ResNet18VQVAEDecoder(BaseResNetVQVAEDecoder): - enc_out_dim = 512 - - def __init__( - self, - model_config: VAEConfig, - first_conv=False, - maxpool1=False, - **kwargs, - ): - super(ResNet18VQVAEDecoder, self).__init__( - model_config, ae.resnet18_decoder, first_conv, maxpool1, **kwargs - ) diff --git a/bioimage_embed/models/pythae/legacy/vq_vae.py b/bioimage_embed/models/pythae/legacy/vq_vae.py deleted file mode 100644 index eba2b526..00000000 --- a/bioimage_embed/models/pythae/legacy/vq_vae.py +++ /dev/null @@ -1,234 +0,0 @@ -# TODO make this a relative import - -import torch -from torch import nn -from torch.nn import functional as F - -from pythae import models - - -# from pythae.models import VQVAE, Encoder, Decoder -from transformers.utils import ModelOutput - -# import VQVAE, Encoder, Decoder -from pythae.models.nn import BaseDecoder, BaseEncoder -from ...nets.resnet import ResnetDecoder, ResnetEncoder -from ....models.legacy import vq_vae - - -from pythae.models import VQVAEConfig, VAEConfig - - -class Encoder(BaseEncoder): - def __init__( - self, - model_config, - num_hiddens, - num_residual_hiddens, - num_residual_layers, - ): - super(Encoder, self).__init__() - - self.model = ResnetEncoder( - in_channels=model_config.input_dim[0], - num_hiddens=num_hiddens, - num_residual_hiddens=num_residual_hiddens, - num_residual_layers=num_residual_layers, - ) - - -class VAEEncoder(Encoder): - def forward(self, x): - return ModelOutput(embedding=self.model(x["data"])) - - -class VQVAEEncoder(Encoder): - def forward(self, x): - return ModelOutput(pre_quantized=self.model(x["data"])) - - -class VAEDecoder(BaseDecoder): - def __init__( - self, - model_config, - num_hiddens, - num_residual_hiddens, - num_residual_layers, - ): - super(VAEDecoder, self).__init__() - self.model = ResnetDecoder( - in_channels=model_config.latent_dim, - out_channels=model_config.input_dim[0], - num_hiddens=num_hiddens, - num_residual_layers=num_residual_layers, - num_residual_hiddens=num_residual_hiddens, - ) - - def forward(self, x): - reconstruction = self.model(x["embedding"]) - return ModelOutput(reconstruction=reconstruction) - - -def count_params(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -class VQVAE(models.VQVAE): - def __init__( - self, - model_config: VQVAEConfig, - depth, - encoder=None, - decoder=None, - strict_latent_size=True, - ): - super(models.BaseAE, self).__init__() - # super(nn.Module) - # input_dim (tuple) – The input_data dimension. - - self.model_name = "VQVAE" - self.model_config = model_config - - if self.model_config.decay > 0.0: - self.model_config.use_ema = True - - self.strict_latent_size = strict_latent_size - self.model = vq_vae.VQ_VAE( - channels=model_config.input_dim[0], - embedding_dim=model_config.latent_dim, - num_hiddens=model_config.latent_dim, - num_residual_layers=depth, - ) - self.encoder = self.model._encoder - self.decoder = self.model._decoder - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.latent_dim = model_config.latent_dim - self.input_dim = model_config.input_dim - - # This isn't completely necessary for training I don't think - # self._set_quantizer(model_config) - - def forward(self, x, epoch=None): - # loss, x_recon, perplexity = self.model.forward(x["data"]) - z = self.model.encoder(x["data"]) - z = self.model._pre_vq_conv(z) - proper_shape = z.shape - - if self.strict_latent_size: - z = self.avgpool(z) - # Features need to be in the right order for the quantizer - z = z.permute(0, 2, 3, 1) - - loss, quantized, perplexity, encodings = self.model._vq_vae(z) - z = quantized.flatten(1) - if self.strict_latent_size: - quantized = quantized.permute(0, 3, 1, 2) - quantized = quantized.expand(-1, *proper_shape[-3:]) - - x_recon = self.model._decoder(quantized) - # return loss, x_recon, perplexity - - legacy_loss_dict = self.model.loss_function( - loss, - x_recon, - perplexity, - vq_loss=loss, - perplexity=perplexity, - recons=x_recon, - input=x["data"], - ) - # This matches how pythae returns the loss - - indices = (encodings == 1).nonzero(as_tuple=True) - - recon_loss = F.mse_loss(x_recon, x["data"], reduction="sum") - mse_loss = F.mse_loss(x_recon, x["data"], reduction="mean") - - variational_loss = loss - mse_loss - - pythae_loss_dict = { - "recon_loss": mse_loss, - "vq_loss": variational_loss, - # TODO check this proppperppply - "loss": recon_loss + variational_loss, - "recon_x": x_recon, - "z": z, - "quantized_indices": indices[0], - "indices": indices, - } - return ModelOutput(**{**legacy_loss_dict, **pythae_loss_dict}) - - -class VAE(models.VAE): - def __init__( - self, - model_config: VAEConfig, - num_hiddens=64, - num_residual_hiddens=18, - num_residual_layers=2, - encoder=None, - decoder=None, - ): - super(models.BaseAE, self).__init__() - # super(nn.Module) - # input_dim (tuple) – The input_data dimension. - - self.model_name = "VAE" - self.model_config = model_config - self.encoder = VAEEncoder( - model_config, - num_hiddens=num_hiddens, - num_residual_hiddens=num_residual_hiddens, - num_residual_layers=num_residual_layers, - ) - self.decoder = VAEDecoder( - model_config, - num_hiddens=num_hiddens, - num_residual_hiddens=num_residual_hiddens, - num_residual_layers=num_residual_layers, - ) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(num_hiddens, model_config.latent_dim * 2) - self.latent_dim = model_config.latent_dim - self.input_dim = model_config.input_dim - - # shape is (batch_size, model_config.num_hiddens, 1, 1) - - def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return mu + eps * std - - def forward(self, x, epoch=None): - h = self.encoder(x)["embedding"] - # pre_encode_size = torch.tensor(x["data"].shape[-2:]) - # scale = torch.floor_divide(torch.tensor(x["data"].shape[-2:]),torch.tensor(h.shape[-2:])) - pre_encode_size = torch.tensor(h.shape[-2:]) - h = self.avgpool(h) - post_encode_size = torch.tensor(h.shape[-2:]) - scale = torch.div(pre_encode_size, post_encode_size, rounding_mode="trunc") - h = torch.flatten(h, 1) - h = self.fc(h) - mu, log_var = torch.split(h, h.size(1) // 2, dim=1) - z = self.reparameterize(mu, log_var) - # x_recon = self.decoder(z.view(z.size(0), z.size(1), 1, 1)) - embedding = z.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, *scale.tolist()) - x_recon = self.decoder({"embedding": embedding})["reconstruction"] - # return x_recon, mu, log_var - - loss_dict = self.loss_function(x_recon, x["data"], mu, log_var) - # recon_loss = F.mse_loss(x_recon, x["data"], reduction="sum") - return ModelOutput(recon_x=x_recon, z=z, **loss_dict) - - def loss_function(self, recons, input, mu, log_var): - recons_loss = F.mse_loss(recons, input) - kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) - loss = recons_loss + kld_loss - return { - "loss": loss, - "recon_loss": recons_loss, - "variational_loss": kld_loss, - } - - -# Resnet50_VQVAE = partial(VQVAE,num_hidden_residuals=50) diff --git a/bioimage_embed/tests/conftest.py b/bioimage_embed/tests/conftest.py deleted file mode 100644 index ea8f482a..00000000 --- a/bioimage_embed/tests/conftest.py +++ /dev/null @@ -1,69 +0,0 @@ -# from bioimage_embed import config -from torchvision.datasets import FakeData -import pytest -from torchvision import transforms - -# from bioimage_embed.bie import BioImageEmbed -from .. import config -from ..bie import BioImageEmbed - -from hydra import initialize, compose - - -@pytest.fixture -def input_dim(): - return [3, 224, 224] - - -@pytest.fixture -def dataset(input_dim): - transform = transforms.ToTensor() - return FakeData(size=64, image_size=input_dim, num_classes=2, transform=transform) - - -@pytest.fixture -def lite_model(): - return "dummy_model" - - -@pytest.fixture -def bie(cfg): - return BioImageEmbed(cfg) - - -@pytest.fixture -def hydra_cfg(): - with initialize(config_path="."): - cfg = compose(config_name="config") - return cfg - - -@pytest.fixture -def cfg_recipe(lite_model): - return config.Recipe(model=lite_model) - - -@pytest.fixture -def cfg_trainer(): - return config.Trainer(max_epochs=1, max_steps=1, fast_dev_run=True) - - -@pytest.fixture -def cfg_dataloader(dataset): - return config.DataLoader(dataset=dataset, num_workers=0) - - -# TODO double check this is sensible -@pytest.fixture -def cfg(cfg_recipe, cfg_trainer, cfg_dataloader): - cfg = config.Config( - recipe=cfg_recipe, trainer=cfg_trainer, dataloader=cfg_dataloader - ) - return cfg - # This is an alternative way to create a config object but it is less flexible and if the config object is changed in the future, this will break, i.e validation is not guaranteed - - # cfg.dataloader.num_workers = 0 # This avoids processes being forked - # cfg.trainer.max_epochs = 1 - # cfg.trainer.max_steps = 1 - # cfg.trainer.fast_dev_run = True - # cfg.recipe.model = model diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py deleted file mode 100644 index b38fd109..00000000 --- a/bioimage_embed/tests/test_cli.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -from .. import cli -from pathlib import Path -from typer.testing import CliRunner - -runner = CliRunner() - - -@pytest.fixture -def config_dir(): - return "test_conf" - - -@pytest.fixture -def config_file(): - return "config.yaml" - - -@pytest.fixture -def config_path(config_dir, config_file): - return Path(config_dir).joinpath(config_file) - - -@pytest.fixture -def config_directory_setup(config_dir, config_file, config_path): - if config_path.is_file(): - config_path.unlink() - config_dir = config_path.parent - config_dir.mkdir(parents=True, exist_ok=True) - if config_path.is_file(): - config_path.unlink() - if config_dir.is_dir(): - config_dir.rmdir() - - yield config_dir, config_file, config_path - - -def test_write_default_config_file( - config_path, config_dir, config_file, config_directory_setup -): - # config_path, config_file = config_directory_setup - cli.write_default_config_file(config_path) - assert config_path.is_file(), "Default config file was not created" - - -def test_get_default_config(cfg): - assert cfg is not None, "Default config should not be None" - # Further assertions can be added to check specific config properties - - -# def test_main_with_default_config( -# cfg, config_path, config_dir, config_file, config_directory_setup -# ): -# test_get_default_config - -# # cli.main(config_dir=config_dir, config_file=config_file, job_name="test_app") - - -# @pytest.mark.skip("Computationally heavy") -# def test_hydra(): -# # bie_train model.model="resnet50_vqvae" dataset._target_="bioimage_embed.datasets.FakeImageFolder" -# input_dim = [3, 224, 224] -# cfg = config.Config() -# # cfg.dataloader.dataset._target_ = "bioimage_embed.datasets.FakeImageFolder" -# # cfg.dataloader.dataset.image_size = input_dim -# cfg.recipe.model = "dummy_model" -# cfg.recipe.max_epochs = 1 - - -# def test_cli(): -# # This test checks if the CLI correctly handles the dataset target input -# result = runner.invoke(app, ["bie_train", "--dataset-target", "bioimage_embed.datasets.FakeImageFolder"]) -# assert result.exit_code == 0 -# assert "Dataset target set to: bioimage_embed.datasets.FakeImageFolder" in result.stdout - -# result = runner.invoke(app, ["main", "+dataset.root=data", "--config_dir", "tests/sample_conf", "--config_file", "sample_config.yaml"]) -# def test_init_hydra_with_default_values(): -# config = init_hydra() -# assert config is not None, "Config should not be None" - -# def test_init_hydra_with_custom_values(): -# config_dir = "custom_conf" -# config_file = "custom_config.yaml" -# job_name = "custom_job" -# config = init_hydra(config_dir, config_file, job_name) -# assert config is not None, "Config should not be None" -# assert config["config_dir"] == config_dir, "Config directory should match" -# assert config["config_file"] == config_file, "Config file should match" -# assert config["job_name"] == job_name, "Job name should match" - - -def test_init_hydra_with_invalid_config_dir(): - with pytest.raises(Exception): - cli.init_hydra(config_dir="invalid_dir") - - -def test_init_hydra_with_invalid_config_file(): - with pytest.raises(Exception): - cli.init_hydra(config_file="invalid_config.yaml") - - -def test_train(cfg): - cli.train(cfg) - - -def test_check(cfg): - cli.check(cfg) diff --git a/bioimage_embed/tests/test_config.py b/bioimage_embed/tests/test_config.py deleted file mode 100644 index 79699c00..00000000 --- a/bioimage_embed/tests/test_config.py +++ /dev/null @@ -1,42 +0,0 @@ -from .. import config -import pytest -from hydra.utils import instantiate - -schema_map = config.__schemas__ -schemas = list(schema_map.values()) - - -@pytest.mark.parametrize("Schema", schemas) -def test_schema(Schema): - Schema() - - -@pytest.mark.parametrize("Schema", schemas) -def test_instantiate(Schema): - schema = config.resolve_schema(Schema()) - obj = instantiate(schema) - assert obj is not None, "obj should not be None" - - -def test_config(cfg): - assert cfg is not None, "Config should not be None" - - -def test_config_instantiate(cfg): - assert instantiate(cfg) is not None, "Config should not be None" - - -def test_resolve(bie): - assert bie.resolve() - - -def test_model_check(bie): - bie.model_check() - - -def test_train_check(bie): - bie.trainer_check() - - -def test_bie_train(bie): - bie.train() diff --git a/bioimage_embed/tests/test_lightning.py b/bioimage_embed/tests/test_lightning.py deleted file mode 100644 index 2fbf751d..00000000 --- a/bioimage_embed/tests/test_lightning.py +++ /dev/null @@ -1,326 +0,0 @@ -import pytest -import torch -import pytorch_lightning as pl -from bioimage_embed.models import __all_models__ -from bioimage_embed.lightning import ( - DataModule, - AESupervised, - AEUnsupervised, -) -from bioimage_embed.models import create_model -from bioimage_embed import config -from hydra.utils import instantiate -from torchvision.datasets import FakeData -import numpy as np -from torch.utils.data import DataLoader, TensorDataset -from bioimage_embed.lightning.dataloader import StratifiedSampler - - -torch.manual_seed(42) - - -@pytest.fixture() -def transform(): - return instantiate(config.Transform()) - - -@pytest.fixture(params=[1, 2, 16]) -def classes(request): - return request.param - - -@pytest.fixture(params=__all_models__) -def model_name(request): - return request.param - - -@pytest.fixture() -def image_dim(): - return (224, 224) - - -@pytest.fixture() -def channel_dim(): - return 3 - - -@pytest.fixture() -def samples(): - return 32 - - -@pytest.fixture(params=[16]) -def latent_dim(request): - return request.param - - -@pytest.fixture( - params=[ - 4, - ] -) -def batch_size(request): - return request.param - - -@pytest.fixture() -def pretrained(): - return True - - -@pytest.fixture() -def progress(): - return True - - -# TODO put this in a conftest.py file -@pytest.fixture -def model(model_name, image_dim, channel_dim, latent_dim, pretrained, progress): - input_dim = (channel_dim, *image_dim) - return create_model( - model_name, - input_dim, - latent_dim, - pretrained, - progress, - ) - - -@pytest.fixture() -def dummy_model(channel_dim, image_dim, latent_dim): - return create_model( - "dummy_model", - input_dim=(channel_dim, *image_dim), - latent_dim=latent_dim, - pretrained=False, - progress=False, - ) - - -@pytest.fixture() -def input_dim(image_dim, channel_dim): - return (channel_dim, *image_dim) - - -@pytest.fixture() -def data(input_dim): - return torch.rand(*input_dim) - - -@pytest.fixture() -def dataset(samples, input_dim, transform, classes=2): - # x = torch.rand(samples, *input_dim) - # y = torch.torch.randint(classes - 1, (samples,)) - # return TensorDataset(x, y) - return FakeData( - size=samples, - image_size=input_dim, - num_classes=classes, - transform=transform, - ) - - -@pytest.fixture(params=[AESupervised, AEUnsupervised]) -def lit_model_wrapper(request): - return request.param - - -# @pytest.mark.skip(reason="Dictionaries not allowed") -# def test_export_onxx(data, lit_model): -# return lit_model.to_onnx("model.onnx", data) - - -@pytest.fixture() -def datamodule(dataset, batch_size): - return DataModule( - dataset, - batch_size=batch_size, - # shuffle=True, - num_workers=0, # This avoids processes being forked - pin_memory=False, - ) - - -@pytest.fixture() -def trainer(): - return pl.Trainer( - # max_steps=1, - max_epochs=2, - ) - - -@pytest.fixture() -def model_torchscript(lit_model): - return lit_model.to_torchscript() - - -@pytest.fixture() -def lit_dummy_model(lit_model_wrapper, dummy_model): - return lit_model_wrapper(dummy_model) - - -@pytest.fixture() -def lit_model(lit_model_wrapper, model): - return lit_model_wrapper(model) - - -def test_trainer_test(trainer, lit_model, datamodule): - return trainer.test(lit_model, datamodule) - - -def test_trainer_dummy_model_fit(trainer, lit_dummy_model, datamodule): - return trainer.fit(lit_dummy_model, datamodule) - - -@pytest.mark.skip(reason="Expensive") -def test_trainer_fit(trainer, lit_model, datamodule): - return trainer.fit(lit_model, datamodule) - - -@pytest.mark.skip(reason="needs batched data") -def test_dataset_trainer(trainer, lit_model, dataset): - return trainer.test(lit_model, dataset) - - -def test_model_properties(model): - assert model.encoder is not None - assert model.decoder is not None - assert model.latent_dim is not None - assert model.input_dim is not None - assert model.model_name is not None - assert model.model_config is not None - - -def test_trainer_predict(trainer, lit_model, datamodule): - batch_size = datamodule.predict_dataloader().batch_size - latent_dim = lit_model.model.latent_dim - predictions = trainer.predict(lit_model, datamodule) - assert predictions is not None - assert len(predictions[0].z.flatten()) == batch_size * latent_dim - # TODO prefer - # assert list(predictions[0].z.shape) == [batch_size,latent_dim] - # assert len(list(predictions[0].z.shape)) == 2 - - -# Has to be a list not a tuple -def test_export_onnx(lit_model, data): - example_input = data.unsqueeze(0) - return lit_model.to_onnx("model.onnx", example_input, export_params=True) - - -@pytest.mark.skip(reason="models cant take in variable length args and kwargs") -def test_export_jit(model_torchscript): - return model_torchscript.save("model.pt") - - -@pytest.mark.skip(reason="models cant take in variable length args and kwargs") -def test_jit_save(model_torchscript): - return torch.jit.save(model_torchscript, "model.pt", method="script") - - -@pytest.fixture(params=[128]) -def large_batch(request): - return request.param - - -@pytest.fixture -def large_data(input_dim, large_batch): - return torch.empty(large_batch**2, *input_dim) - - -@pytest.fixture(params=[1, 8]) -def many_classes(request): - return request.param - - -@pytest.fixture -def imbalanced_dataset(large_data, many_classes): - """ - Return 'classes' and an imbalanced distribution 'p'. 'classes' can be any length. - The distribution 'p' must sum to 1. - """ - data, classes = large_data, many_classes - samples = len(data) - # if classes == 0: - # return TensorDataset(data,[None] * samples) - p = 2 ** np.arange(1, classes + 1) - - p = p / p.sum() # Normalize to sum to 1 - - labels = np.random.choice(a=np.arange(classes), size=(samples,), p=p) - # Set the dataset's targets - # dataset.targets = torch.tensor(labels) - dataset = TensorDataset(data, torch.tensor(labels)) - return dataset - - -@pytest.fixture(params=[16]) -def batch_split(request): - return request.param - - -@pytest.fixture() -def stratified_dataloader(imbalanced_dataset): - dataset = imbalanced_dataset - samples = len(dataset) - - return DataLoader( - dataset, - batch_size=int(np.sqrt(samples)), - sampler=StratifiedSampler(dataset), - num_workers=0, - drop_last=True, - ) - - -def test_stratified_sampler(stratified_dataloader): - # Unpack the dataloader - dataloader = stratified_dataloader - - # Collect all sampled labels - all_labels = [] - for inputs, labels in dataloader: - all_labels.extend(labels.numpy()) - - # Convert to NumPy array - all_labels = np.array(all_labels) - - # Calculate the number of classes - num_classes = len(np.unique(all_labels)) - - # Calculate the sampled label distribution - sampled_counts = np.bincount(all_labels, minlength=num_classes) - sampled_distribution = sampled_counts / len(all_labels) - - # Expected proportion for each class (uniform distribution) - expected_proportion = num_classes * [1.0 / num_classes] - # Assert that the sampled distribution is close to the expected proportions - assert np.allclose( - sampled_distribution, expected_proportion, atol=0.05 - ), f"Sampled distribution {sampled_distribution} does not match expected {expected_proportion}" - - -def test_sanity_check_stratified(): - # Create an imbalanced dataset (e.g., class 0 has more samples than class 1) - labels = [0] * 80 + [1] * 20 - dataset = [(data, label) for data, label in zip(range(100), labels)] - - # Initialize the sampler - sampler = StratifiedSampler(dataset) - - # Create a DataLoader - dataloader = DataLoader(dataset, batch_size=10, sampler=sampler) - - # Collect sampled labels - sampled_labels = [] - for _, label_batch in dataloader: - sampled_labels.extend(label_batch.numpy()) - - # Analyze the distribution - sampled_labels = np.array(sampled_labels) - sampled_counts = np.bincount(sampled_labels) - sampled_distribution = sampled_counts / len(sampled_labels) - - print("Sampled Label Distribution:") - for i, proportion in enumerate(sampled_distribution): - print(f"Class {i}: {proportion*100:.2f}%") diff --git a/notebooks/_shape_embed.ipynb b/notebooks/_shape_embed.ipynb new file mode 100644 index 00000000..bbda3773 --- /dev/null +++ b/notebooks/_shape_embed.ipynb @@ -0,0 +1,570 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a6f5a045", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import pyefd\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import cross_validate, KFold, train_test_split\n", + "from sklearn.metrics import make_scorer\n", + "import pandas as pd\n", + "from sklearn import metrics\n", + "from pathlib import Path\n", + "import umap\n", + "from torch.autograd import Variable\n", + "from types import SimpleNamespace\n", + "import numpy as np\n", + "import logging\n", + "from skimage import measure\n", + "import umap.plot\n", + "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "\n", + "# Deal with the filesystem\n", + "import torch.multiprocessing\n", + "\n", + "torch.multiprocessing.set_sharing_strategy(\"file_system\")\n", + "\n", + "from shape_embed import shapes\n", + "import bioimage_embed\n", + "\n", + "# Note - you must have torchvision installed for this example\n", + "\n", + "from pytorch_lightning import loggers as pl_loggers\n", + "from torchvision import transforms\n", + "from bioimage_embed.lightning import DataModule\n", + "\n", + "from torchvision import datasets\n", + "from bioimage_embed.shapes.transforms import (\n", + " ImageToCoords,\n", + " CropCentroidPipeline,\n", + " DistogramToCoords,\n", + " MaskToDistogramPipeline,\n", + ")\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "def scoring_df(X, y):\n", + " # Split the data into training and test sets\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y\n", + " )\n", + " # Define a dictionary of metrics\n", + " scoring = {\n", + " \"accuracy\": make_scorer(metrics.accuracy_score),\n", + " \"precision\": make_scorer(metrics.precision_score, average=\"macro\"),\n", + " \"recall\": make_scorer(metrics.recall_score, average=\"macro\"),\n", + " \"f1\": make_scorer(metrics.f1_score, average=\"macro\"),\n", + " }\n", + "\n", + " # Create a random forest classifier\n", + " clf = RandomForestClassifier()\n", + "\n", + " # Specify the number of folds\n", + " k_folds = 10\n", + "\n", + " # Perform k-fold cross-validation\n", + " cv_results = cross_validate(\n", + " estimator=clf,\n", + " X=X,\n", + " y=y,\n", + " cv=KFold(n_splits=k_folds),\n", + " scoring=scoring,\n", + " n_jobs=-1,\n", + " return_train_score=False,\n", + " )\n", + "\n", + " # Put the results into a DataFrame\n", + " return pd.DataFrame(cv_results)\n", + "\n", + "\n", + "def shape_embed_process():\n", + " # Setting the font size\n", + "\n", + " # rc(\"text\", usetex=True)\n", + " width = 3.45\n", + " height = width / 1.618\n", + " plt.rcParams[\"figure.figsize\"] = [width, height]\n", + "\n", + " sns.set(\n", + " style=\"white\",\n", + " context=\"notebook\",\n", + " rc={\"figure.figsize\": (width, height)},\n", + " )\n", + "\n", + " # matplotlib.use(\"TkAgg\")\n", + " interp_size = 128 * 2\n", + " max_epochs = 100\n", + " window_size = 128 * 2\n", + "\n", + " params = {\n", + " \"model\": \"resnet18_vqvae_legacy\",\n", + " \"epochs\": 75,\n", + " \"batch_size\": 3,\n", + " \"num_workers\": 2**4,\n", + " \"input_dim\": (3, interp_size, interp_size),\n", + " \"latent_dim\": (interp_size) * 8,\n", + " \"num_embeddings\": 16,\n", + " \"num_hiddens\": 16,\n", + " \"pretrained\": True,\n", + " \"commitment_cost\": 0.25,\n", + " \"decay\": 0.99,\n", + " \"loss_weights\": [1, 1, 1, 1],\n", + " }\n", + "\n", + " optimizer_params = {\n", + " \"opt\": \"LAMB\",\n", + " \"lr\": 0.001,\n", + " \"weight_decay\": 0.0001,\n", + " \"momentum\": 0.9,\n", + " }\n", + "\n", + " lr_scheduler_params = {\n", + " \"sched\": \"cosine\",\n", + " \"min_lr\": 1e-4,\n", + " \"warmup_epochs\": 5,\n", + " \"warmup_lr\": 1e-6,\n", + " \"cooldown_epochs\": 10,\n", + " \"t_max\": 50,\n", + " \"cycle_momentum\": False,\n", + " }\n", + "\n", + " args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)\n", + "\n", + " dataset_path = \"bbbc010/BBBC010_v1_foreground_eachworm\"\n", + " # dataset_path = \"vampire/mefs/data/processed/Control\"\n", + " # dataset_path = \"shape_embed_data/data/vampire/torchvision/Control/\"\n", + " # dataset_path = \"vampire/torchvision/Control\"\n", + " # dataset = \"bbbc010\"\n", + "\n", + " # train_data_path = f\"scripts/shapes/data/{dataset_path}\"\n", + " train_data_path = f\"data/{dataset_path}\"\n", + " metadata = lambda x: f\"results/{dataset_path}_{args.model}/{x}\"\n", + "\n", + " path = Path(metadata(\"\"))\n", + " path.mkdir(parents=True, exist_ok=True)\n", + " model_dir = f\"models/{dataset_path}_{args.model}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf031300", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "\n", + " transform_crop = CropCentroidPipeline(window_size)\n", + " transform_dist = MaskToDistogramPipeline(\n", + " window_size, interp_size, matrix_normalised=False\n", + " )\n", + " transform_mdscoords = DistogramToCoords(window_size)\n", + " transform_coords = ImageToCoords(window_size)\n", + "\n", + " transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)])\n", + "\n", + " transform_mask_to_crop = transforms.Compose(\n", + " [\n", + " # transforms.ToTensor(),\n", + " transform_mask_to_gray,\n", + " transform_crop,\n", + " ]\n", + " )\n", + "\n", + " transform_mask_to_dist = transforms.Compose(\n", + " [\n", + " transform_mask_to_crop,\n", + " transform_dist,\n", + " ]\n", + " )\n", + " transform_mask_to_coords = transforms.Compose(\n", + " [\n", + " transform_mask_to_crop,\n", + " transform_coords,\n", + " ]\n", + " )\n", + "\n", + " transforms_dict = {\n", + " \"none\": transform_mask_to_gray,\n", + " \"transform_crop\": transform_mask_to_crop,\n", + " \"transform_dist\": transform_mask_to_dist,\n", + " \"transform_coords\": transform_mask_to_coords,\n", + " }\n", + "\n", + " train_data = {\n", + " key: datasets.ImageFolder(train_data_path, transform=value)\n", + " for key, value in transforms_dict.items()\n", + " }\n", + "\n", + " for key, value in train_data.items():\n", + " print(key, len(value))\n", + " plt.imshow(train_data[key][0][0], cmap=\"gray\")\n", + " plt.imsave(metadata(f\"{key}.png\"), train_data[key][0][0], cmap=\"gray\")\n", + " # plt.show()\n", + " plt.close()\n", + "\n", + " # plt.scatter(*train_data[\"transform_coords\"][0][0])\n", + " # plt.savefig(metadata(f\"transform_coords.png\"))\n", + " # plt.show()\n", + "\n", + " # plt.imshow(train_data[\"transform_crop\"][0][0], cmap=\"gray\")\n", + " # plt.scatter(*train_data[\"transform_coords\"][0][0],c=np.arange(interp_size), cmap='rainbow', s=1)\n", + " # plt.show()\n", + " # plt.savefig(metadata(f\"transform_coords.png\"))\n", + "\n", + " # Retrieve the coordinates and cropped image\n", + " coords = train_data[\"transform_coords\"][0][0]\n", + " crop_image = train_data[\"transform_crop\"][0][0]\n", + "\n", + " fig = plt.figure(frameon=True)\n", + " ax = plt.Axes(fig, [0, 0, 1, 1])\n", + " ax.set_axis_off()\n", + " fig.add_axes(ax)\n", + "\n", + " # Display the cropped image using grayscale colormap\n", + " plt.imshow(crop_image, cmap=\"gray_r\")\n", + "\n", + " # Scatter plot with smaller point size\n", + " plt.scatter(*coords, c=np.arange(interp_size), cmap=\"rainbow\", s=2)\n", + "\n", + " # Save the plot as an image without border and coordinate axes\n", + " plt.savefig(metadata(\"transform_coords.png\"), bbox_inches=\"tight\", pad_inches=0)\n", + "\n", + " # Close the plot\n", + " plt.close()\n", + " # import albumentations as A" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5650bea0", + "metadata": {}, + "outputs": [], + "source": [ + " gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1))\n", + " transform = transforms.Compose(\n", + " [transform_mask_to_dist, transforms.ToTensor(), gray2rgb]\n", + " )\n", + "\n", + " dataset = datasets.ImageFolder(train_data_path, transform=transform)\n", + "\n", + " valid_indices = []\n", + " # Iterate through the dataset and apply the transform to each image\n", + " for idx in range(len(dataset)):\n", + " try:\n", + " image, label = dataset[idx]\n", + " # If the transform works without errors, add the index to the list of valid indices\n", + " valid_indices.append(idx)\n", + " except Exception as e:\n", + " # A better way to do with would be with batch collation\n", + " print(f\"Error occurred for image {idx}: {e}\")\n", + "\n", + " # Create a Subset using the valid indices\n", + " dataset = torch.utils.data.Subset(dataset, valid_indices)\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=args.batch_size,\n", + " shuffle=True,\n", + " num_workers=args.num_workers,\n", + " )\n", + "\n", + " model = bioimage_embed.models.create_model(**vars(args))\n", + " logger.info(model)\n", + "\n", + " # lit_model = shapes.MaskEmbedLatentAugment(model, args)\n", + " lit_model = shapes.MaskEmbed(model, args)\n", + " test_data = dataset[0][0].unsqueeze(0)\n", + " # test_lit_data = 2*(dataset[0][0].unsqueeze(0).repeat_interleave(3, dim=1),)\n", + " test_output = lit_model.forward((test_data,))\n", + "\n", + " dataloader.setup()\n", + " model.eval()\n", + " # Model\n", + " lit_model.eval()\n", + "\n", + " logger.info(f\"Saving model to {model_dir}\")\n", + "\n", + " model_dir = f\"my_models/{dataset_path}_{model._get_name()}_{lit_model._get_name()}\"\n", + " Path(f\"{model_dir}/\").mkdir(parents=True, exist_ok=True)\n", + "\n", + " tb_logger = pl_loggers.TensorBoardLogger(\n", + " \"logs/\",\n", + " name=f\"{dataset_path}_{args.model}_{model._get_name()}_{lit_model._get_name()}\",\n", + " )\n", + "\n", + " checkpoint_callback = ModelCheckpoint(dirpath=f\"{model_dir}/\", save_last=True)\n", + "\n", + " trainer = pl.Trainer(\n", + " logger=tb_logger,\n", + " gradient_clip_val=0.5,\n", + " enable_checkpointing=True,\n", + " devices=1,\n", + " accelerator=\"gpu\",\n", + " precision=16, # Use mixed precision\n", + " accumulate_grad_batches=4,\n", + " callbacks=[checkpoint_callback],\n", + " min_epochs=50,\n", + " max_epochs=args.epochs,\n", + " )\n", + " # # %%\n", + "\n", + " testing = trainer.test(lit_model, datamodule=dataloader)\n", + "\n", + " try:\n", + " trainer.fit(\n", + " lit_model, datamodule=dataloader, ckpt_path=f\"{model_dir}/last.ckpt\"\n", + " )\n", + " except:\n", + " trainer.fit(lit_model, datamodule=dataloader)\n", + "\n", + " logger.info(f\"Saving model to {model_dir}\")\n", + " try:\n", + " example_input = Variable(torch.rand(2, 1, *args.input_dim))\n", + " torch.onnx.export(lit_model, example_input, f\"{model_dir}/model.onnx\")\n", + " torch.jit.save(lit_model.to_torchscript(), f\"{model_dir}/model.pt\")\n", + "\n", + " except:\n", + " logger.info(\"Model \")\n", + "\n", + " validation = trainer.validate(lit_model, datamodule=dataloader)\n", + " # testing = trainer.test(lit_model, datamodule=dataloader)\n", + "\n", + " example_input = Variable(torch.rand(1, *args.input_dim))\n", + " logger.info(f\"Saving model to {model_dir}\")\n", + " torch.jit.save(lit_model.to_torchscript(), f\"{model_dir}/model.pt\")\n", + " torch.onnx.export(lit_model, example_input, f\"{model_dir}/model.onnx\")\n", + "\n", + " # Inference\n", + "\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=1,\n", + " shuffle=False,\n", + " num_workers=args.num_workers,\n", + " # Transform is commented here to avoid augmentations in real data\n", + " # HOWEVER, applying a the transform multiple times and averaging the results might produce better latent embeddings\n", + " # transform=transform,\n", + " )\n", + " dataloader.setup()\n", + "\n", + " predictions = trainer.predict(lit_model, datamodule=dataloader)\n", + " latent_space = torch.stack([d[\"z\"].flatten() for d in predictions])\n", + " scalings = torch.stack([d[\"scalings\"].flatten() for d in predictions])\n", + "\n", + " idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}\n", + "\n", + " y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])\n", + "\n", + " y_partial = y.copy()\n", + " indices = np.random.choice(y.size, int(0.3 * y.size), replace=False)\n", + " y_partial[indices] = -1\n", + " y_blind = -1 * np.ones_like(y)\n", + " umap_labels = y_blind\n", + " classes = np.array([idx_to_class[i] for i in y])\n", + "\n", + " n_components = 64 # Number of UMAP components\n", + " component_names = [f\"umap{i}\" for i in range(n_components)] # List of column names\n", + "\n", + " logger.info(\"UMAP fitting\")\n", + " mapper = umap.UMAP(n_components=64, random_state=42).fit(\n", + " latent_space.numpy(), y=umap_labels\n", + " )\n", + "\n", + " logger.info(\"UMAP transforming\")\n", + " semi_supervised_latent = mapper.transform(latent_space.numpy())\n", + "\n", + " df = pd.DataFrame(semi_supervised_latent, columns=component_names)\n", + " df[\"Class\"] = y\n", + " # Map numeric classes to their labels\n", + " idx_to_class = {0: \"alive\", 1: \"dead\"}\n", + " df[\"Class\"] = df[\"Class\"].map(idx_to_class)\n", + " df[\"Scale\"] = scalings\n", + " df = df.set_index(\"Class\")\n", + " df_shape_embed = df.copy()\n", + "\n", + " ax = sns.relplot(\n", + " data=df,\n", + " x=\"umap0\",\n", + " y=\"umap1\",\n", + " hue=\"Class\",\n", + " palette=\"deep\",\n", + " alpha=0.5,\n", + " edgecolor=None,\n", + " s=5,\n", + " height=height,\n", + " aspect=0.5 * width / height,\n", + " )\n", + "\n", + " sns.move_legend(\n", + " ax,\n", + " \"upper center\",\n", + " )\n", + " ax.set(xlabel=None, ylabel=None)\n", + " sns.despine(left=True, bottom=True)\n", + " plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)\n", + " plt.tight_layout()\n", + " plt.savefig(metadata(\"umap_no_axes.pdf\"))\n", + " # plt.show()\n", + " plt.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55fb1f0c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " X = df_shape_embed.to_numpy()\n", + " y = df_shape_embed.index.values\n", + "\n", + " properties = [\n", + " \"area\",\n", + " \"perimeter\",\n", + " \"centroid\",\n", + " \"major_axis_length\",\n", + " \"minor_axis_length\",\n", + " \"orientation\",\n", + " ]\n", + " dfs = []\n", + " for i, data in enumerate(train_data[\"transform_crop\"]):\n", + " X, y = data\n", + " # Do regionprops here\n", + " # Calculate shape summary statistics using regionprops\n", + " # We're considering that the mask has only one object, thus we take the first element [0]\n", + " # props = regionprops(np.array(X).astype(int))[0]\n", + " props_table = measure.regionprops_table(\n", + " np.array(X).astype(int), properties=properties\n", + " )\n", + "\n", + " # Store shape properties in a dataframe\n", + " df = pd.DataFrame(props_table)\n", + "\n", + " # Assuming the class or label is contained in 'y' variable\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True)\n", + " dfs.append(df)\n", + "\n", + " df_regionprops = pd.concat(dfs)\n", + "\n", + " # Assuming 'dataset_contour' is your DataLoader for the dataset\n", + " dfs = []\n", + " for i, data in enumerate(train_data[\"transform_coords\"]):\n", + " # Convert the tensor to a numpy array\n", + " X, y = data\n", + "\n", + " # Feed it to PyEFD's calculate_efd function\n", + " coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False)\n", + " # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]})\n", + "\n", + " norm_coeffs = pyefd.normalize_efd(coeffs)\n", + " df = pd.DataFrame(\n", + " {\n", + " \"norm_coeffs\": norm_coeffs.flatten().tolist(),\n", + " \"coeffs\": coeffs.flatten().tolist(),\n", + " }\n", + " ).T.rename_axis(\"coeffs\")\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True, append=True)\n", + " dfs.append(df)\n", + "\n", + " df_pyefd = pd.concat(dfs)\n", + "\n", + " trials = [\n", + " {\n", + " \"name\": \"mask_embed\",\n", + " \"features\": df_shape_embed.to_numpy(),\n", + " \"labels\": df_shape_embed.index,\n", + " },\n", + " {\n", + " \"name\": \"fourier_coeffs\",\n", + " \"features\": df_pyefd.xs(\"coeffs\", level=\"coeffs\"),\n", + " \"labels\": df_pyefd.xs(\"coeffs\", level=\"coeffs\").index,\n", + " },\n", + " # {\"name\": \"fourier_norm_coeffs\",\n", + " # \"features\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\"),\n", + " # \"labels\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\").index\n", + " # }\n", + " {\n", + " \"name\": \"regionprops\",\n", + " \"features\": df_regionprops,\n", + " \"labels\": df_regionprops.index,\n", + " },\n", + " ]\n", + "\n", + " trial_df = pd.DataFrame()\n", + " for trial in trials:\n", + " X = trial[\"features\"]\n", + " y = trial[\"labels\"]\n", + " trial[\"score_df\"] = scoring_df(X, y)\n", + " trial[\"score_df\"][\"trial\"] = trial[\"name\"]\n", + " print(trial[\"score_df\"])\n", + " trial[\"score_df\"].to_csv(metadata(f\"{trial['name']}_score_df.csv\"))\n", + " trial_df = pd.concat([trial_df, trial[\"score_df\"]])\n", + " trial_df = trial_df.drop([\"fit_time\", \"score_time\"], axis=1)\n", + "\n", + " trial_df.to_csv(metadata(\"trial_df.csv\"))\n", + " trial_df.groupby(\"trial\").mean().to_csv(metadata(\"trial_df_mean.csv\"))\n", + " trial_df.plot(kind=\"bar\")\n", + "\n", + " melted_df = trial_df.melt(id_vars=\"trial\", var_name=\"Metric\", value_name=\"Score\")\n", + " # fig, ax = plt.subplots(figsize=(width, height))\n", + " ax = sns.catplot(\n", + " data=melted_df,\n", + " kind=\"bar\",\n", + " x=\"trial\",\n", + " hue=\"Metric\",\n", + " y=\"Score\",\n", + " errorbar=\"se\",\n", + " height=height,\n", + " aspect=width * 2**0.5 / height,\n", + " )\n", + " # ax.xtick_params(labelrotation=45)\n", + " # plt.legend(loc='lower center', bbox_to_anchor=(1, 1))\n", + " # sns.move_legend(ax, \"lower center\", bbox_to_anchor=(1, 1))\n", + " # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + " # plt.tight_layout()\n", + " plt.savefig(metadata(\"trials_barplot.pdf\"))\n", + " plt.close()\n", + "\n", + " avs = (\n", + " melted_df.set_index([\"trial\", \"Metric\"])\n", + " .xs(\"test_f1\", level=\"Metric\", drop_level=False)\n", + " .groupby(\"trial\")\n", + " .mean()\n", + " )\n", + " print(avs)\n", + " # tikzplotlib.save(metadata(f\"trials_barplot.tikz\"))\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " shape_embed_process()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/shape_embed.ipynb b/notebooks/shape_embed.ipynb new file mode 100644 index 00000000..4a352198 --- /dev/null +++ b/notebooks/shape_embed.ipynb @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6c47a14", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import pyefd\n", + "from sklearn.discriminant_analysis import StandardScaler\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import (\n", + " cross_validate,\n", + " KFold,\n", + " train_test_split,\n", + ")\n", + "from sklearn.metrics import make_scorer\n", + "import pandas as pd\n", + "from sklearn import metrics\n", + "import matplotlib as mpl\n", + "from pathlib import Path\n", + "from sklearn.pipeline import Pipeline\n", + "from torch.autograd import Variable\n", + "from types import SimpleNamespace\n", + "import numpy as np\n", + "from skimage import measure\n", + "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "from types import SimpleNamespace\n", + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "from umap import UMAP\n", + "# Deal with the filesystem\n", + "import torch.multiprocessing\n", + "import logging\n", + "from tqdm import tqdm\n", + "\n", + "logging.basicConfig(level=logging.INFO)\n", + "\n", + "torch.multiprocessing.set_sharing_strategy(\"file_system\")\n", + "\n", + "from shape_embed import shapes\n", + "import bioimage_embed\n", + "from pytorch_lightning import loggers as pl_loggers\n", + "from torchvision import transforms\n", + "from bioimage_embed.lightning import DataModule\n", + "\n", + "from torchvision import datasets\n", + "from shape_embed.shapes.transforms import (\n", + " ImageToCoords,\n", + " CropCentroidPipeline,\n", + " DistogramToCoords,\n", + " RotateIndexingClockwise,\n", + " CoordsToDistogram,\n", + " AsymmetricDistogramToCoordsPipeline,\n", + ")\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from matplotlib import rc\n", + "\n", + "import pickle\n", + "import base64\n", + "import hashlib\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "# Seed everything\n", + "np.random.seed(42)\n", + "pl.seed_everything(42)\n", + "\n", + "\n", + "def hashing_fn(args):\n", + " serialized_args = pickle.dumps(vars(args))\n", + " hash_object = hashlib.sha256(serialized_args)\n", + " hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode()\n", + " return hashed_string\n", + "\n", + "\n", + "def umap_plot(df, metadata, width=3.45, height=3.45 / 1.618):\n", + " umap_reducer = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)\n", + " mask = np.random.rand(len(df)) < 0.7\n", + "\n", + " semi_labels = df[\"Class\"].copy()\n", + " semi_labels[~mask] = -1 # Assuming -1 indicates unknown label for semi-supervision\n", + "\n", + " umap_embedding = umap_reducer.fit_transform(df, y=semi_labels)\n", + "\n", + " ax = sns.relplot(\n", + " data=pd.DataFrame(umap_embedding, columns=[\"umap0\", \"umap1\"]),\n", + " x=\"umap0\",\n", + " y=\"umap1\",\n", + " hue=\"Class\",\n", + " palette=\"deep\",\n", + " alpha=0.5,\n", + " edgecolor=None,\n", + " s=5,\n", + " height=height,\n", + " aspect=0.5 * width / height,\n", + " )\n", + "\n", + " sns.move_legend(\n", + " ax,\n", + " \"upper center\",\n", + " )\n", + " ax.set(xlabel=None, ylabel=None)\n", + " sns.despine(left=True, bottom=True)\n", + " plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)\n", + " plt.tight_layout()\n", + " plt.savefig(metadata(\"umap_no_axes.pdf\"))\n", + " # plt.show()\n", + " plt.close()\n", + "\n", + "\n", + "def scoring_df(X, y):\n", + " # Split the data into training and test sets\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y\n", + " )\n", + " # Define a dictionary of metrics\n", + " scoring = {\n", + " \"accuracy\": make_scorer(metrics.balanced_accuracy_score),\n", + " \"precision\": make_scorer(metrics.precision_score, average=\"macro\"),\n", + " \"recall\": make_scorer(metrics.recall_score, average=\"macro\"),\n", + " \"f1\": make_scorer(metrics.f1_score, average=\"macro\"),\n", + " \"roc_auc\": make_scorer(metrics.roc_auc_score, average=\"macro\"),\n", + " }\n", + "\n", + " # Create a random forest classifier\n", + " pipeline = Pipeline(\n", + " [\n", + " (\"scaler\", StandardScaler()),\n", + " # (\"pca\", PCA(n_components=0.95, whiten=True, random_state=42)),\n", + " (\"clf\", RandomForestClassifier()),\n", + " # (\"clf\", DummyClassifier()),\n", + " ]\n", + " )\n", + "\n", + " # Specify the number of folds\n", + " k_folds = 5\n", + "\n", + " # Perform k-fold cross-validation\n", + " cv_results = cross_validate(\n", + " estimator=pipeline,\n", + " X=X,\n", + " y=y,\n", + " cv=StratifiedKFold(n_splits=k_folds),\n", + " scoring=scoring,\n", + " n_jobs=-1,\n", + " return_train_score=False,\n", + " )\n", + "\n", + " # Put the results into a DataFrame\n", + " return pd.DataFrame(cv_results)\n", + "\n", + "\n", + "def shape_embed_process():\n", + " # Setting the font size\n", + " mpl.rcParams[\"font.size\"] = 10\n", + "\n", + " # rc(\"text\", usetex=True)\n", + " rc(\"font\", **{\"family\": \"sans-serif\", \"sans-serif\": [\"Arial\"]})\n", + " width = 3.45\n", + " height = width / 1.618\n", + " plt.rcParams[\"figure.figsize\"] = [width, height]\n", + "\n", + " sns.set(\n", + " style=\"white\",\n", + " context=\"notebook\",\n", + " rc={\"figure.figsize\": (width, height)},\n", + " )\n", + "\n", + " # matplotlib.use(\"TkAgg\")\n", + " interp_size = 128 * 2\n", + " max_epochs = 100\n", + " window_size = 128 * 2\n", + "\n", + " params = {\n", + " \"model\": \"resnet50_vqvae\",\n", + " \"epochs\": 250,\n", + " \"batch_size\": 4,\n", + " \"num_workers\": 2**4,\n", + " \"input_dim\": (3, interp_size, interp_size),\n", + " \"latent_dim\": interp_size,\n", + " \"num_embeddings\": interp_size,\n", + " \"num_hiddens\": interp_size,\n", + " \"pretrained\": True,\n", + " \"commitment_cost\": 0.25,\n", + " \"decay\": 0.99,\n", + " \"frobenius_norm\": False,\n", + " # dataset = \"bbbc010/BBBC010_v1_foreground_eachworm\"\n", + " # dataset = \"vampire/mefs/data/processed/Control\"\n", + " \"dataset\": \"synthcellshapes_dataset\",\n", + " }\n", + "\n", + " optimizer_params = {\n", + " \"opt\": \"AdamW\",\n", + " \"lr\": 0.001,\n", + " \"weight_decay\": 0.0001,\n", + " \"momentum\": 0.9,\n", + " }\n", + "\n", + " lr_scheduler_params = {\n", + " \"sched\": \"cosine\",\n", + " \"min_lr\": 1e-4,\n", + " \"warmup_epochs\": 5,\n", + " \"warmup_lr\": 1e-6,\n", + " \"cooldown_epochs\": 10,\n", + " \"t_max\": 50,\n", + " \"cycle_momentum\": False,\n", + " }\n", + "\n", + " args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)\n", + "\n", + " dataset_path = args.dataset\n", + "\n", + " # train_data_path = f\"scripts/shapes/data/{dataset_path}\"\n", + " train_data_path = f\"data/{dataset_path}\"\n", + " metadata = lambda x: f\"results/{dataset_path}_{args.model}/{x}\"\n", + "\n", + " path = Path(metadata(\"\"))\n", + " path.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29c6dc6b", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "\n", + " transform_crop = CropCentroidPipeline(window_size)\n", + " # transform_dist = MaskToDistogramPipeline(\n", + " # window_size, interp_size, matrix_normalised=False\n", + " # )\n", + " transform_coord_to_dist = CoordsToDistogram(interp_size, matrix_normalised=False)\n", + " transform_mdscoords = DistogramToCoords(window_size)\n", + " transform_coords = ImageToCoords(window_size)\n", + "\n", + " transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)])\n", + "\n", + " transform_mask_to_crop = transforms.Compose(\n", + " [\n", + " # transforms.ToTensor(),\n", + " transform_mask_to_gray,\n", + " transform_crop,\n", + " ]\n", + " )\n", + "\n", + " transform_mask_to_coords = transforms.Compose(\n", + " [\n", + " transform_mask_to_crop,\n", + " transform_coords,\n", + " ]\n", + " )\n", + "\n", + " transform_mask_to_dist = transforms.Compose(\n", + " [\n", + " transform_mask_to_coords,\n", + " transform_coord_to_dist,\n", + " ]\n", + " )\n", + "\n", + " gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1))\n", + " transform = transforms.Compose(\n", + " [\n", + " transform_mask_to_dist,\n", + " transforms.ToTensor(),\n", + " RotateIndexingClockwise(p=1),\n", + " gray2rgb,\n", + " ]\n", + " )\n", + "\n", + " transforms_dict = {\n", + " \"none\": transform_mask_to_gray,\n", + " \"transform_crop\": transform_mask_to_crop,\n", + " \"transform_dist\": transform_mask_to_dist,\n", + " \"transform_coords\": transform_mask_to_coords,\n", + " }\n", + "\n", + " # Apply transform to find which images don't work\n", + " dataset = datasets.ImageFolder(train_data_path, transform=transform)\n", + "\n", + " valid_indices = []\n", + " # Iterate through the dataset and apply the transform to each image\n", + " for idx in range(len(dataset)):\n", + " try:\n", + " image, label = dataset[idx]\n", + " # If the transform works without errors, add the index to the list of valid indices\n", + " valid_indices.append(idx)\n", + " except Exception as e:\n", + " # A better way to do with would be with batch collation\n", + " logger.warning(f\"Error occurred for image {idx}: {e}\")\n", + "\n", + " train_data = {\n", + " key: torch.utils.data.Subset(\n", + " datasets.ImageFolder(train_data_path, transform=value),\n", + " valid_indices,\n", + " )\n", + " for key, value in transforms_dict.items()\n", + " }\n", + "\n", + " dataset = torch.utils.data.Subset(\n", + " datasets.ImageFolder(train_data_path, transform=transform),\n", + " valid_indices,\n", + " )\n", + "\n", + " for key, value in train_data.items():\n", + " logger.info(key, len(value))\n", + " plt.imshow(np.array(train_data[key][0][0]), cmap=\"gray\")\n", + " plt.imsave(metadata(f\"{key}.png\"), train_data[key][0][0], cmap=\"gray\")\n", + " # plt.show()\n", + " plt.close()\n", + "\n", + " # Retrieve the coordinates and cropped image\n", + " coords = train_data[\"transform_coords\"][0][0]\n", + " crop_image = train_data[\"transform_crop\"][0][0]\n", + "\n", + " fig = plt.figure(frameon=True)\n", + " ax = plt.Axes(fig, [0, 0, 1, 1])\n", + " ax.set_axis_off()\n", + " fig.add_axes(ax)\n", + "\n", + " # Display the cropped image using grayscale colormap\n", + " plt.imshow(crop_image, cmap=\"gray_r\")\n", + "\n", + " # Scatter plot with smaller point size\n", + " plt.scatter(*coords, c=np.arange(interp_size), cmap=\"rainbow\", s=2)\n", + "\n", + " # Save the plot as an image without border and coordinate axes\n", + " plt.savefig(metadata(\"transform_coords.png\"), bbox_inches=\"tight\", pad_inches=0)\n", + "\n", + " # Close the plot\n", + " plt.close()\n", + "\n", + " # Create a Subset using the valid indices\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=args.batch_size,\n", + " shuffle=True,\n", + " num_workers=args.num_workers,\n", + " )\n", + "\n", + " model = bioimage_embed.models.create_model(\n", + " model=args.model,\n", + " input_dim=args.input_dim,\n", + " latent_dim=args.latent_dim,\n", + " pretrained=args.pretrained,\n", + " )\n", + "\n", + " # lit_model = shapes.MaskEmbedLatentAugment(model, args)\n", + " lit_model = shapes.MaskEmbed(model, args)\n", + " test_data = dataset[0][0].unsqueeze(0)\n", + " # test_lit_data = 2*(dataset[0][0].unsqueeze(0).repeat_interleave(3, dim=1),)\n", + " test_output = lit_model.forward((test_data,))\n", + "\n", + " dataloader.setup()\n", + " model.eval()\n", + "\n", + " model_dir = f\"checkpoints/{hashing_fn(args)}\"\n", + "\n", + " tb_logger = pl_loggers.TensorBoardLogger(\"logs/\")\n", + " wandb = pl_loggers.WandbLogger(project=\"bioimage-embed\", name=\"shapes\")\n", + "\n", + " Path(f\"{model_dir}/\").mkdir(parents=True, exist_ok=True)\n", + "\n", + " checkpoint_callback = ModelCheckpoint(\n", + " dirpath=f\"{model_dir}/\",\n", + " save_last=True,\n", + " save_top_k=1,\n", + " monitor=\"loss/val\",\n", + " mode=\"min\",\n", + " )\n", + " wandb.watch(lit_model, log=\"all\")\n", + "\n", + " trainer = pl.Trainer(\n", + " logger=[wandb, tb_logger],\n", + " gradient_clip_val=0.5,\n", + " enable_checkpointing=True,\n", + " devices=1,\n", + " accelerator=\"gpu\",\n", + " accumulate_grad_batches=4,\n", + " callbacks=[checkpoint_callback],\n", + " min_epochs=50,\n", + " max_epochs=args.epochs,\n", + " # callbacks=[EarlyStopping(monitor=\"loss/val\", mode=\"min\")],\n", + " log_every_n_steps=1,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ba8a17d", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " # Determine the checkpoint path for resuming\n", + " last_checkpoint_path = f\"{model_dir}/last.ckpt\"\n", + " best_checkpoint_path = checkpoint_callback.best_model_path\n", + "\n", + " # Check if a last checkpoint exists to resume from\n", + " if os.path.isfile(last_checkpoint_path):\n", + " resume_checkpoint = last_checkpoint_path\n", + " elif best_checkpoint_path and os.path.isfile(best_checkpoint_path):\n", + " resume_checkpoint = best_checkpoint_path\n", + " else:\n", + " resume_checkpoint = None\n", + "\n", + " trainer.fit(lit_model, datamodule=dataloader, ckpt_path=resume_checkpoint)\n", + "\n", + " lit_model.eval()\n", + "\n", + " validation = trainer.validate(lit_model, datamodule=dataloader)\n", + " # testing = trainer.test(lit_model, datamodule=dataloader)\n", + " example_input = Variable(torch.rand(1, *args.input_dim))\n", + "\n", + " # torch.jit.save(lit_model.to_torchscript(), f\"{model_dir}/model.pt\")\n", + " # torch.onnx.export(lit_model, example_input, f\"{model_dir}/model.onnx\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a480dd1", + "metadata": {}, + "outputs": [], + "source": [ + " # Inference on full dataset\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=1,\n", + " shuffle=False,\n", + " num_workers=args.num_workers,\n", + " # Transform is commented here to avoid augmentations in real data\n", + " # HOWEVER, applying the transform multiple times and averaging the results might produce better latent embeddings\n", + " # transform=transform,\n", + " )\n", + " dataloader.setup()\n", + "\n", + " predictions = trainer.predict(lit_model, datamodule=dataloader)\n", + "\n", + " test_dist_pred = predictions[0].out.recon_x\n", + " plt.imsave(metadata(\"test_dist_pred.png\"), test_dist_pred.mean(axis=(0, 1)))\n", + " plt.close()\n", + "\n", + " test_dist_in = predictions[0].x.data\n", + " plt.imsave(metadata(\"test_dist_in.png\"), test_dist_in.mean(axis=(0, 1)))\n", + " plt.close()\n", + "\n", + " test_pred_coords = AsymmetricDistogramToCoordsPipeline(window_size=window_size)(\n", + " np.array(test_dist_pred[:, 0, :, :].unsqueeze(dim=0))\n", + " )\n", + "\n", + " plt.scatter(*test_pred_coords[0, 0].T)\n", + " # Save the plot as an image without border and coordinate axes\n", + " plt.savefig(metadata(\"test_pred_coords.png\"), bbox_inches=\"tight\", pad_inches=0)\n", + " plt.close()\n", + "\n", + " # Use the namespace variables\n", + " latent_space = torch.stack([d.out.z.flatten() for d in predictions])\n", + " scalings = torch.stack([d.x.scalings.flatten() for d in predictions])\n", + " idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}\n", + " y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])\n", + "\n", + " y_partial = y.copy()\n", + " indices = np.random.choice(y.size, int(0.3 * y.size), replace=False)\n", + " y_partial[indices] = -1\n", + " y_blind = -1 * np.ones_like(y)\n", + "\n", + " df = pd.DataFrame(latent_space.numpy())\n", + " df[\"Class\"] = pd.Series(y).map(idx_to_class).astype(\"category\")\n", + " df[\"Scale\"] = scalings[:, 0].squeeze()\n", + " df = df.set_index(\"Class\")\n", + " df_shape_embed = df.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3892a443", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " X = df_shape_embed.to_numpy()\n", + " y = df_shape_embed.index\n", + "\n", + " properties = [\n", + " \"area\",\n", + " \"perimeter\",\n", + " \"centroid\",\n", + " \"major_axis_length\",\n", + " \"minor_axis_length\",\n", + " \"orientation\",\n", + " ]\n", + " dfs = []\n", + " # Distance matrix data\n", + " for i, data in enumerate(tqdm(train_data[\"transform_crop\"])):\n", + " X, y = data\n", + " # Do regionprops here\n", + " # Calculate shape summary statistics using regionprops\n", + " # We're considering that the mask has only one object, so we take the first element [0]\n", + " # props = regionprops(np.array(X).astype(int))[0]\n", + " props_table = measure.regionprops_table(\n", + " np.array(X).astype(int), properties=properties\n", + " )\n", + "\n", + " # Store shape properties in a dataframe\n", + " df = pd.DataFrame(props_table)\n", + "\n", + " # Assuming the class or label is contained in 'y' variable\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True)\n", + " dfs.append(df)\n", + "\n", + " df_regionprops = pd.concat(dfs)\n", + "\n", + " dfs = []\n", + " for i, data in enumerate(tqdm(train_data[\"transform_coords\"])):\n", + " # Convert the tensor to a numpy array\n", + " X, y = data\n", + "\n", + " # Feed it to PyEFD's calculate_efd function\n", + " coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False)\n", + " # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]})\n", + "\n", + " norm_coeffs = pyefd.normalize_efd(coeffs)\n", + " df = pd.DataFrame(\n", + " {\n", + " \"norm_coeffs\": norm_coeffs.flatten().tolist(),\n", + " \"coeffs\": coeffs.flatten().tolist(),\n", + " }\n", + " ).T.rename_axis(\"coeffs\")\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True, append=True)\n", + " dfs.append(df)\n", + "\n", + " df_pyefd = pd.concat(dfs)\n", + "\n", + " trials = [\n", + " {\n", + " \"name\": \"mask_embed\",\n", + " \"features\": df_shape_embed.to_numpy(),\n", + " \"labels\": df_shape_embed.index,\n", + " },\n", + " {\n", + " \"name\": \"fourier_coeffs\",\n", + " \"features\": df_pyefd.xs(\"coeffs\", level=\"coeffs\"),\n", + " \"labels\": df_pyefd.xs(\"coeffs\", level=\"coeffs\").index,\n", + " },\n", + " # {\"name\": \"fourier_norm_coeffs\",\n", + " # \"features\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\"),\n", + " # \"labels\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\").index\n", + " # }\n", + " {\n", + " \"name\": \"regionprops\",\n", + " \"features\": df_regionprops,\n", + " \"labels\": df_regionprops.index,\n", + " },\n", + " ]\n", + "\n", + " trial_df = pd.DataFrame()\n", + " for trial in trials:\n", + " X = trial[\"features\"]\n", + " y = trial[\"labels\"]\n", + " trial[\"score_df\"] = scoring_df(X, y)\n", + " trial[\"score_df\"][\"trial\"] = trial[\"name\"]\n", + " logger.info(trial[\"score_df\"])\n", + " trial[\"score_df\"].to_csv(metadata(f\"{trial['name']}_score_df.csv\"))\n", + " trial_df = pd.concat([trial_df, trial[\"score_df\"]])\n", + " trial_df = trial_df.drop([\"fit_time\", \"score_time\"], axis=1)\n", + "\n", + " trial_df.to_csv(metadata(\"trial_df.csv\"))\n", + " trial_df.groupby(\"trial\").mean().to_csv(metadata(\"trial_df_mean.csv\"))\n", + " trial_df.plot(kind=\"bar\")\n", + "\n", + " avg = trial_df.groupby(\"trial\").mean()\n", + " logger.info(avg)\n", + " avg.to_latex(metadata(\"trial_df.tex\"))\n", + "\n", + " melted_df = trial_df.melt(id_vars=\"trial\", var_name=\"Metric\", value_name=\"Score\")\n", + " # fig, ax = plt.subplots(figsize=(width, height))\n", + " ax = sns.catplot(\n", + " data=melted_df,\n", + " kind=\"bar\",\n", + " x=\"trial\",\n", + " hue=\"Metric\",\n", + " y=\"Score\",\n", + " errorbar=\"se\",\n", + " height=height,\n", + " aspect=width * 2**0.5 / height,\n", + " )\n", + " # ax.xtick_params(labelrotation=45)\n", + " # plt.legend(loc='lower center', bbox_to_anchor=(1, 1))\n", + " # sns.move_legend(ax, \"lower center\", bbox_to_anchor=(1, 1))\n", + " # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + " # plt.tight_layout()\n", + " plt.savefig(metadata(\"trials_barplot.pdf\"))\n", + " plt.close()\n", + "\n", + " avs = (\n", + " melted_df.set_index([\"trial\", \"Metric\"])\n", + " .xs(\"test_f1\", level=\"Metric\", drop_level=False)\n", + " .groupby(\"trial\")\n", + " .mean()\n", + " )\n", + " logger.info(avs)\n", + " # tikzplotlib.save(metadata(f\"trials_barplot.tikz\"))\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " shape_embed_process()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/_shape_embed.py b/scripts/_shape_embed.py index 5e8b11b9..755492e3 100644 --- a/scripts/_shape_embed.py +++ b/scripts/_shape_embed.py @@ -146,7 +146,7 @@ def shape_embed_process(): path = Path(metadata("")) path.mkdir(parents=True, exist_ok=True) model_dir = f"models/{dataset_path}_{args.model}" - # %% +# %% transform_crop = CropCentroidPipeline(window_size) transform_dist = MaskToDistogramPipeline( @@ -394,7 +394,7 @@ def shape_embed_process(): # plt.show() plt.close() - # %% +# %% X = df_shape_embed.to_numpy() y = df_shape_embed.index.values diff --git a/scripts/shape_embed.py b/scripts/shape_embed.py index f1064f7c..dad19802 100644 --- a/scripts/shape_embed.py +++ b/scripts/shape_embed.py @@ -213,7 +213,7 @@ def shape_embed_process(): path = Path(metadata("")) path.mkdir(parents=True, exist_ok=True) - # %% +# %% transform_crop = CropCentroidPipeline(window_size) # transform_dist = MaskToDistogramPipeline( @@ -372,7 +372,7 @@ def shape_embed_process(): # callbacks=[EarlyStopping(monitor="loss/val", mode="min")], log_every_n_steps=1, ) - # %% +# %% # Determine the checkpoint path for resuming last_checkpoint_path = f"{model_dir}/last.ckpt" @@ -446,7 +446,7 @@ def shape_embed_process(): df = df.set_index("Class") df_shape_embed = df.copy() - # %% +# %% X = df_shape_embed.to_numpy() y = df_shape_embed.index