Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant Performance Gap Between MaskFormer and Mask2Former Despite Identical Training Code #35738

Open
2 of 4 tasks
olmobaldoni opened this issue Jan 16, 2025 · 0 comments
Labels

Comments

@olmobaldoni
Copy link

System Info

Description

I observed a significant performance difference between MaskFormer and Mask2Former when training both models on my dataset for instance segmentation. The training code is identical except for the model-specific configurations. Below, I outline my setup, preprocessing steps, results, and relevant code to help pinpoint any potential issues.

Dataset

  • Task: Instance segmentation.
  • Format: The dataset is designed as follows:
    • The R channel contains the semantic class labels.
    • The G channel contains the instance IDs for each object.

Preprocessing

For both models, I used the following preprocessing configuration. The only difference lies in the model type and the specific pre-trained weights used.

For both MaskFormer and Mask2Former, I set:

  • do_reduce_labels=True
  • ignore_index=255

The purpose of do_reduce_labels=True is to ensure that class indices start from 0 and are incremented sequentially. This shifts class indices by -1, as shown in the Hugging Face [documentation](

image_processor = AutoImageProcessor.from_pretrained(
). The value 255 for ignore_index ensures that pixels labeled as background are ignored during loss computation.


Results

Both models were trained for 20 epochs with the same hyperparameters:

  • Learning rate: 5e-5
  • Optimizer: Adam

Test Image

Image

Here are the results:

MaskFormer:

───────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           1.0081120729446411
        test_map           0.038004860281944275
       test_map_50          0.06367719173431396
       test_map_75         0.040859635919332504
     test_map_large         0.5004204511642456
     test_map_medium        0.04175732284784317
     test_map_small        0.007470746990293264
       test_mar_1           0.01011560671031475
       test_mar_10          0.05838150158524513
      test_mar_100          0.06329479813575745
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Test Image Result with MaskFormer:

Image


Mask2Former:

───────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           15.374979972839355
        test_map            0.44928184151649475
       test_map_50          0.6224347949028015
       test_map_75          0.5011898279190063
     test_map_large         0.8390558958053589
     test_map_medium        0.6270320415496826
     test_map_small         0.32075226306915283
       test_mar_1           0.03526011481881142
       test_mar_10          0.24104046821594238
      test_mar_100          0.5274566411972046
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Test Image Result with Mask2Former:

Image



Observations

As you can see, the performance gap between the two models is substantial, despite identical training setups and preprocessing pipelines. Mask2Former achieves significantly better performance in terms of mAP and other metrics, while MaskFormer struggles to achieve meaningful results.

Any insights or suggestions would be greatly appreciated. Thank you!


Who can help?

@amyeroberts, @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Relevant Code

train_maskformer.py
import os
import logging
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, MLFlowLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import torch
import numpy as np
from src.dataset_singlegpu import SegmentationDataModule  # Replace with src.dataset_distributed for multi-GPU training
from src.maskformer_singlegpu import MaskFormer  # Replace with src.maskformer_distributed for multi-GPU training

torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

def setup_logging():
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )

    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    file_handler = logging.FileHandler("instance_segmentation/maskformer/train.log")
    file_handler.setFormatter(formatter)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    root_logger.addHandler(stream_handler)
    root_logger.addHandler(file_handler)

    logger = logging.getLogger(__name__)
    return logger

def main():
    logger = setup_logging()

    dataset_dir = "data/instance_segmentation"  # Path to dataset directory

    dm = SegmentationDataModule(dataset_dir=dataset_dir, batch_size=2, num_workers=4)

    model = MaskFormer()

    # Create logs dir if doesn't exist
    if not os.path.isdir("logs/instance_segmentation"):
        os.makedirs("logs/instance_segmentation")

    loggers = [
        CSVLogger(save_dir="logs/csv_logs", name="maskformer"),
        TensorBoardLogger(save_dir="logs/tb_logs", name="maskformer"),
        MLFlowLogger(
            experiment_name="maskformer",
            tracking_uri="file:logs/mlflow_logs",
        ),
    ]

    checkpoint_callback = ModelCheckpoint(
        dirpath="checkpoints/maskformer",
        monitor="val_map",
        filename="maskformer-{epoch:02d}-{val_map:.2f}",
        save_top_k=1,
        mode="max",
        save_last=True,
        verbose=True,
    )

    callbacks = [
        EarlyStopping(monitor="val_map", mode="max", patience=10),
        checkpoint_callback,
    ]

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=[0],  # Use [0] for single GPU, [0, 1] for multiple GPUs
        logger=loggers,
        callbacks=callbacks,
        min_epochs=1,
        max_epochs=20,
        precision="32-true",
        num_sanity_val_steps=0,
    )

    # Start training (resume if a checkpoint exists)
    if os.path.exists("checkpoints/maskformer/last.ckpt"):
        trainer.fit(
            model,
            dm,
            ckpt_path="checkpoints/maskformer/last.ckpt",
        )
    else:
        trainer.fit(model, dm)

    best_model_path = checkpoint_callback.best_model_path
    logger.info(f"Best model saved at: {best_model_path}")
    best_model = MaskFormer.load_from_checkpoint(checkpoint_path=best_model_path)

    trainer.test(best_model, dm)

if __name__ == "__main__":
    main()
train_mask2former.py
import os
import logging
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, MLFlowLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import torch
import numpy as np
from src.dataset_singlegpu import SegmentationDataModule  # Replace with src.dataset_distributed for multi-GPU training
from src.mask2former_singlegpu import Mask2Former  # Replace with src.mask2former_distributed for multi-GPU training

torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

def setup_logging():
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )

    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    file_handler = logging.FileHandler("logs/instance_segmentation/mask2former/train.log")
    file_handler.setFormatter(formatter)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    root_logger.addHandler(stream_handler)
    root_logger.addHandler(file_handler)

    logger = logging.getLogger(__name__)
    return logger

def main():
    logger = setup_logging()

    dataset_dir = "data/instance_segmentation"  # Path to dataset directory

    dm = SegmentationDataModule(dataset_dir=dataset_dir, batch_size=2, num_workers=4)

    model = Mask2Former()

    # Create logs dir if doesn't exist
    if not os.path.isdir("logs/instance_segmentation"):
        os.makedirs("logs/instance_segmentation")

    loggers = [
        CSVLogger(save_dir="logs/csv_logs", name="mask2former"),
        TensorBoardLogger(save_dir="logs/tb_logs", name="mask2former"),
        MLFlowLogger(
            experiment_name="mask2former",
            tracking_uri="file:logs/mlflow_logs",
        ),
    ]

    checkpoint_callback = ModelCheckpoint(
        dirpath="checkpoints/mask2former",
        monitor="val_map",
        filename="mask2former-{epoch:02d}-{val_map:.2f}",
        save_top_k=1,
        mode="max",
        save_last=True,
        verbose=True,
    )

    callbacks = [
        EarlyStopping(monitor="val_map", mode="max", patience=10),
        checkpoint_callback,
    ]

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=[0],  # Use [0] for single GPU, [0, 1] for multiple GPUs
        logger=loggers,
        callbacks=callbacks,
        min_epochs=1,
        max_epochs=20,
        precision="32-true",
        num_sanity_val_steps=0,
    )

    # Start training (resume if a checkpoint exists)
    if os.path.exists("checkpoints/mask2former/last.ckpt"):
        trainer.fit(
            model,
            dm,
            ckpt_path="checkpoints/mask2former/last.ckpt",
        )
    else:
        trainer.fit(model, dm)

    best_model_path = checkpoint_callback.best_model_path
    logger.info(f"Best model saved at: {best_model_path}")
    best_model = Mask2Former.load_from_checkpoint(checkpoint_path=best_model_path)

    trainer.test(best_model, dm)

if __name__ == "__main__":
    main()
maskformer.py
import logging
import torch
import pytorch_lightning as pl
from transformers import AutoImageProcessor
from transformers import MaskFormerForInstanceSegmentation
from torchmetrics.detection import MeanAveragePrecision

logger = logging.getLogger(__name__)


class MaskFormer(pl.LightningModule):
    def __init__(self, learning_rate=5e-5):
        super().__init__()
        self.lr = learning_rate
        self.mAP = MeanAveragePrecision(iou_type="segm", class_metrics=True)
        # self.id2label = {0: "background", 1: "unhealty"}
        self.id2label = {0: "unhealty"}
        self.label2id = {v: int(k) for k, v in self.id2label.items()}
        self.processor = AutoImageProcessor.from_pretrained(
            "facebook/maskformer-swin-small-coco",
            do_reduce_labels=True,
            reduce_labels=True,
            ignore_index=255,
            do_resize=False,
            do_rescale=False,
            do_normalize=False,
        )
        self.model = self.setup_model()
        self.save_hyperparameters()

    def setup_model(self):
        model = MaskFormerForInstanceSegmentation.from_pretrained(
            "facebook/maskformer-swin-small-coco",
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )
        model.train()
        return model

    def forward(self, pixel_values, mask_labels=None, class_labels=None):
        return self.model(
            pixel_values=pixel_values,
            mask_labels=mask_labels,
            class_labels=class_labels,
        )

    def compute_metrics(self, outputs, batch):
        # For metric computatation we need to provide:
        # - targets in a form of list of dictionaries with keys "masks", "labels"
        # - predictions in a form of list of dictionaries with leys "masks", "labels", "scores"
        targets = []
        for masks, labels in zip(batch["mask_labels"], batch["class_labels"]):
            target = {
                "masks": masks.to("cuda").to(torch.bool),
                "labels": labels.to("cuda"),
            }
            targets.append(target)

        threshold = 0.5
        target_sizes = [
            (image.shape[1], image.shape[2]) for image in batch["pixel_values"]
        ]

        processed_outputs = self.processor.post_process_instance_segmentation(
            outputs=outputs,
            threshold=threshold,
            target_sizes=target_sizes,
            return_binary_maps=True,
        )

        preds = []
        for output, target_size in zip(processed_outputs, target_sizes):
            if output["segments_info"]:
                pred = {
                    "masks": output["segmentation"].to(dtype=torch.bool, device="cuda"),
                    "labels": torch.tensor(
                        [x["label_id"] for x in output["segments_info"]]
                    ).to("cuda"),
                    "scores": torch.tensor(
                        [x["score"] for x in output["segments_info"]]
                    ).to("cuda"),
                }
            else:
                # for void predictions, we need to provide empty tensors
                pred = {
                    "masks": torch.zeros([0, *target_size], dtype=torch.bool).to(
                        "cuda"
                    ),
                    "labels": torch.tensor([]).to("cuda"),
                    "scores": torch.tensor([]).to("cuda"),
                }
            preds.append(pred)

        return preds, targets

    def training_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            mask_labels=[labels for labels in batch["mask_labels"]],
            class_labels=[labels for labels in batch["class_labels"]],
        )
        loss = outputs.loss
        self.log(
            "train_loss",
            loss,
            batch_size=len(batch),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            mask_labels=[labels for labels in batch["mask_labels"]],
            class_labels=[labels for labels in batch["class_labels"]],
        )
        loss = outputs.loss
        self.log(
            "val_loss",
            loss,
            batch_size=len(batch),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )
        preds, targets = self.compute_metrics(outputs, batch)
        self.mAP.update(preds, targets)
        return loss

    def on_validation_epoch_end(self):
        result = self.mAP.compute()
        self.log("val_map", result["map"])
        self.log("val_map_50", result["map_50"])
        self.log("val_map_75", result["map_75"])
        self.log("val_map_small", result["map_small"])
        self.log("val_map_medium", result["map_medium"])
        self.log("val_map_large", result["map_large"])
        self.log("val_mar_1", result["mar_1"])
        self.log("val_mar_10", result["mar_10"])
        self.log("val_mar_100", result["mar_100"])
        self.mAP.reset()

    def test_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            mask_labels=[labels for labels in batch["mask_labels"]],
            class_labels=[labels for labels in batch["class_labels"]],
        )
        loss = outputs.loss
        self.log(
            "test_loss",
            loss,
            batch_size=len(batch),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )
        preds, targets = self.compute_metrics(outputs, batch)
        self.mAP.update(preds, targets)
        return loss

    def on_test_epoch_end(self):
        result = self.mAP.compute()
        self.log("test_map", result["map"])
        self.log("test_map_50", result["map_50"])
        self.log("test_map_75", result["map_75"])
        self.log("test_map_small", result["map_small"])
        self.log("test_map_medium", result["map_medium"])
        self.log("test_map_large", result["map_large"])
        self.log("test_mar_1", result["mar_1"])
        self.log("test_mar_10", result["mar_10"])
        self.log("test_mar_100", result["mar_100"])
        self.mAP.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.lr,
        )
        return optimizer
mask2former.py
import logging
import torch
import pytorch_lightning as pl
from transformers import AutoImageProcessor
from transformers import Mask2FormerForUniversalSegmentation
from torchmetrics.detection import MeanAveragePrecision

logger = logging.getLogger(__name__)


class Mask2Former(pl.LightningModule):
    def __init__(self, learning_rate=5e-5):
        super().__init__()
        self.lr = learning_rate
        self.mAP = MeanAveragePrecision(iou_type="segm", class_metrics=True)
        # self.id2label = {0: "background", 1: "unhealty"}
        self.id2label = {0: "unhealty"}
        self.label2id = {v: int(k) for k, v in self.id2label.items()}
        self.processor = AutoImageProcessor.from_pretrained(
            "facebook/mask2former-swin-small-coco-instance",
            do_reduce_labels=True,
            reduce_labels=True,
            ignore_index=255,
            do_resize=False,
            do_rescale=False,
            do_normalize=False,
        )
        self.model = self.setup_model()
        self.save_hyperparameters()

    def setup_model(self):
        model = Mask2FormerForUniversalSegmentation.from_pretrained(
            "facebook/mask2former-swin-small-coco-instance",
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )
        model.train()
        return model

    def forward(self, pixel_values, mask_labels=None, class_labels=None):
        return self.model(
            pixel_values=pixel_values,
            mask_labels=mask_labels,
            class_labels=class_labels,
        )

    def compute_metrics(self, outputs, batch):
        # For metric computatation we need to provide:
        # - targets in a form of list of dictionaries with keys "masks", "labels"
        # - predictions in a form of list of dictionaries with leys "masks", "labels", "scores"
        targets = []
        for masks, labels in zip(batch["mask_labels"], batch["class_labels"]):
            target = {
                "masks": masks.to("cuda").to(dtype=torch.bool, device="cuda"),
                "labels": labels.to("cuda"),
            }
            targets.append(target)

        threshold = 0.5
        target_sizes = [
            (image.shape[1], image.shape[2]) for image in batch["pixel_values"]
        ]

        processed_outputs = self.processor.post_process_instance_segmentation(
            outputs=outputs,
            threshold=threshold,
            target_sizes=target_sizes,
            return_binary_maps=True,
        )

        # TODO: remove detach
        # detached_outputs = [
        #     {
        #         "segmentation": output["segmentation"].to("cpu"),
        #         "segments_info": output["segments_info"],
        #     }
        #     for output in processed_outputs
        # ]

        preds = []
        for output, target_size in zip(processed_outputs, target_sizes):
            if output["segments_info"]:
                pred = {
                    "masks": output["segmentation"].to(dtype=torch.bool, device="cuda"),
                    "labels": torch.tensor(
                        [x["label_id"] for x in output["segments_info"]], device="cuda"
                    ),
                    "scores": torch.tensor(
                        [x["score"] for x in output["segments_info"]], device="cuda"
                    ),
                }
            else:
                # for void predictions, we need to provide empty tensors
                pred = {
                    "masks": torch.zeros(
                        [0, *target_size], dtype=torch.bool, device="cuda"
                    ),
                    "labels": torch.tensor([], device="cuda"),
                    "scores": torch.tensor([], device="cuda"),
                }
            preds.append(pred)

        return preds, targets

    def training_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            mask_labels=[labels for labels in batch["mask_labels"]],
            class_labels=[labels for labels in batch["class_labels"]],
        )
        loss = outputs.loss
        self.log(
            "train_loss",
            loss,
            batch_size=len(batch),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            mask_labels=[labels for labels in batch["mask_labels"]],
            class_labels=[labels for labels in batch["class_labels"]],
        )
        loss = outputs.loss
        self.log(
            "val_loss",
            loss,
            batch_size=len(batch),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )
        preds, targets = self.compute_metrics(outputs, batch)
        self.mAP.update(preds, targets)
        return loss

    def on_validation_epoch_end(self):
        result = self.mAP.compute()
        self.log("val_map", result["map"])
        self.log("val_map_50", result["map_50"])
        self.log("val_map_75", result["map_75"])
        self.log("val_map_small", result["map_small"])
        self.log("val_map_medium", result["map_medium"])
        self.log("val_map_large", result["map_large"])
        self.log("val_mar_1", result["mar_1"])
        self.log("val_mar_10", result["mar_10"])
        self.log("val_mar_100", result["mar_100"])
        self.mAP.reset()

    def test_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            mask_labels=[labels for labels in batch["mask_labels"]],
            class_labels=[labels for labels in batch["class_labels"]],
        )
        loss = outputs.loss
        self.log(
            "test_loss",
            loss,
            batch_size=len(batch),
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )
        preds, targets = self.compute_metrics(outputs, batch)
        self.mAP.update(preds, targets)
        return loss

    def on_test_epoch_end(self):
        result = self.mAP.compute()
        self.log("test_map", result["map"])
        self.log("test_map_50", result["map_50"])
        self.log("test_map_75", result["map_75"])
        self.log("test_map_small", result["map_small"])
        self.log("test_map_medium", result["map_medium"])
        self.log("test_map_large", result["map_large"])
        self.log("test_mar_1", result["mar_1"])
        self.log("test_mar_10", result["mar_10"])
        self.log("test_mar_100", result["mar_100"])
        self.mAP.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.lr,
        )
        return optimizer
dataset_maskformer.py
import os
import logging
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import datasets
from transformers import AutoImageProcessor
import albumentations as A

logger = logging.getLogger(__name__)


class SegmentationDataset(Dataset):
    def __init__(self, dataset: str, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.processor = AutoImageProcessor.from_pretrained(
            "facebook/maskformer-swin-small-coco",
            do_reduce_labels=True,
            reduce_labels=True,
            ignore_index=255,
            do_resize=False,
            do_rescale=False,
            do_normalize=False,
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = np.array(self.dataset[idx]["image"].convert("RGB"))
        mask = np.array(self.dataset[idx]["annotation"].convert("RGB"))

        class_id_map = mask[:, :, 0]
        instance_seg = mask[:, :, 1]

        class_labels = np.unique(class_id_map)

        inst2class = {}
        for label in class_labels:
            instance_ids = np.unique(instance_seg[class_id_map == label])
            inst2class.update({i: label for i in instance_ids})

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=instance_seg)
            image = augmentations["image"]
            mask = augmentations["mask"]

        inputs = self.processor(
            [image],
            [mask],
            instance_id_to_semantic_id=inst2class,
            return_tensors="pt",
        )

        return {
            "pixel_values": inputs.pixel_values[0],
            "mask_labels": inputs.mask_labels[0],
            "class_labels": inputs.class_labels[0],
        }


class SegmentationDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dir, batch_size, num_workers):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = None

        # ImageNet mean and std
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

        self.train_transform = A.Compose(
            [
                A.Resize(height=512, width=512),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                A.Normalize(mean=self.mean, std=self.std),
            ]
        )

        self.val_test_transform = A.Compose(
            [
                A.Resize(height=512, width=512),
                A.Normalize(mean=self.mean, std=self.std),
            ]
        )

    def prepare_data(self):
        if os.path.exists(self.dataset_dir):
            if os.path.isdir(self.dataset_dir):
                try:
                    self.dataset = datasets.load_from_disk(self.dataset_dir)
                    logger.info(f"Loaded dataset from disk: {self.dataset_dir}")
                except Exception as e:
                    logger.info(f"Failed to load dataset from disk: {e}")

    def collate_fn(self, examples):
        batch = {}
        batch["pixel_values"] = torch.stack(
            [example["pixel_values"] for example in examples]
        )
        batch["class_labels"] = [example["class_labels"] for example in examples]
        batch["mask_labels"] = [example["mask_labels"] for example in examples]
        return batch

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            logger.info("Setting up training dataset")
            self.train_dataset = SegmentationDataset(
                dataset=self.dataset["train"], transform=self.train_transform
            )
            logger.info("Setting up validation dataset")
            self.val_dataset = SegmentationDataset(
                dataset=self.dataset["validation"], transform=self.val_test_transform
            )
        if stage == "test" or stage is None:
            logger.info("Setting up test dataset")
            self.test_dataset = SegmentationDataset(
                dataset=self.dataset["test"], transform=self.val_test_transform
            )

    def train_dataloader(self):
        logger.info("Creating training DataLoader")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        logger.info("Creating val DataLoader")
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        logger.info("Creating test DataLoader")
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.collate_fn,
        )
dataset_mask2former.py
import os
import logging
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import datasets
from transformers import AutoImageProcessor
import albumentations as A

logger = logging.getLogger(__name__)


class SegmentationDataset(Dataset):
    def __init__(self, dataset: str, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.processor = AutoImageProcessor.from_pretrained(
            "facebook/mask2former-swin-small-coco-instance",
            do_reduce_labels=True,
            reduce_labels=True,
            ignore_index=255,
            do_resize=False,
            do_rescale=False,
            do_normalize=False,
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = np.array(self.dataset[idx]["image"].convert("RGB"))
        mask = np.array(self.dataset[idx]["annotation"].convert("RGB"))

        class_id_map = mask[:, :, 0]
        instance_seg = mask[:, :, 1]

        class_labels = np.unique(class_id_map)

        inst2class = {}
        for label in class_labels:
            instance_ids = np.unique(instance_seg[class_id_map == label])
            inst2class.update({i: label for i in instance_ids})

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=instance_seg)
            image = augmentations["image"]
            mask = augmentations["mask"]

        inputs = self.processor(
            [image],
            [mask],
            instance_id_to_semantic_id=inst2class,
            return_tensors="pt",
        )

        return {
            "pixel_values": inputs.pixel_values[0],
            "mask_labels": inputs.mask_labels[0],
            "class_labels": inputs.class_labels[0],
        }


class SegmentationDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dir, batch_size, num_workers):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = None

        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

        self.train_transform = A.Compose(
            [
                A.Resize(height=512, width=512),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                A.Normalize(mean=self.mean, std=self.std),
            ]
        )

        self.val_test_transform = A.Compose(
            [
                A.Resize(height=512, width=512),
                A.Normalize(mean=self.mean, std=self.std),
            ]
        )

    def prepare_data(self):
        if os.path.exists(self.dataset_dir):
            if os.path.isdir(self.dataset_dir):
                try:
                    self.dataset = datasets.load_from_disk(self.dataset_dir)
                    logger.info(f"Loaded dataset from disk: {self.dataset_dir}")
                except Exception as e:
                    logger.info(f"Failed to load dataset from disk: {e}")

    def collate_fn(self, examples):
        batch = {}
        batch["pixel_values"] = torch.stack(
            [example["pixel_values"] for example in examples]
        )
        batch["class_labels"] = [example["class_labels"] for example in examples]
        batch["mask_labels"] = [example["mask_labels"] for example in examples]
        return batch

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            logger.info("Setting up training dataset")
            self.train_dataset = SegmentationDataset(
                dataset=self.dataset["train"], transform=self.train_transform
            )
            logger.info("Setting up validation dataset")
            self.val_dataset = SegmentationDataset(
                dataset=self.dataset["validation"], transform=self.val_test_transform
            )
        if stage == "test" or stage is None:
            logger.info("Setting up test dataset")
            self.test_dataset = SegmentationDataset(
                dataset=self.dataset["test"], transform=self.val_test_transform
            )

    def train_dataloader(self):
        logger.info("Creating training DataLoader")
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        logger.info("Creating val DataLoader")
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        logger.info("Creating test DataLoader")
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.collate_fn,
        )
requirements.txt
albumentations>=2.0.0
datasets>=3.2.0
deepspeed>=0.16.2
evaluate>=0.4.3
google-cloud-storage>=2.18.2
ijson>=3.3.0
ipykernel>=6.29.5
ipywidgets>=8.1.5
label-studio-sdk>=1.0.7
lxml>=5.3.0
mlflow>=2.19.0
opencv-python>=4.10.0.84
pillow>=11.0.0
protobuf>=5.28.3
pytorch-lightning>=2.5.0.post0
pyyaml>=6.0.2
ruff>=0.8.4
tensorboard>=2.18.0
torchmetrics[detection]>=1.6.0
torchvision>=0.20.1
transformers>=4.48.0

Expected behavior

MaskFormer and Mask2Former should have similar results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant