Skip to content

Commit

Permalink
fix: Update anomalib, fix classification model checkpoint load (#120)
Browse files Browse the repository at this point in the history
* build: Upgrade version, add anomalib rc

* feat: Add a way to load pretrained checkpoints

* build: Update anomalib version

* build: Remove prepatch tag

* docs: Update changelog
  • Loading branch information
lorenzomammana authored Jun 5, 2024
1 parent 7c4aa31 commit 922621d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 8 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
# Changelog
All notable changes to this project will be documented in this file.

### [2.1.9]

#### Updated

- Update anomalib to v0.7.0+obx.1.3.3
- Update network builders to support loading model checkpoints from disk

### [2.1.8]

#### Added
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quadra"
version = "2.1.8"
version = "2.1.9"
description = "Deep Learning experiment orchestration library"
authors = [
"Federico Belotti <[email protected]>",
Expand Down Expand Up @@ -77,7 +77,7 @@ h5py = "~3.8"
timm = "0.9.12"
# Right now only this ref supports timm 0.9.12
segmentation_models_pytorch = { git = "https://github.com/qubvel/segmentation_models.pytorch", rev = "7b381f899ed472a477a89d381689caf535b5d0a6" }
anomalib = { git = "https://github.com/orobix/anomalib.git", tag = "v0.7.0+obx.1.3.2" }
anomalib = { git = "https://github.com/orobix/anomalib.git", tag = "v0.7.0+obx.1.3.3" }
xxhash = "~3.2"
torchinfo = "~1.8"
typing_extensions = { version = "4.11.0", python = "<3.10" }
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.1.8"
__version__ = "2.1.9"


def get_version():
Expand Down
22 changes: 21 additions & 1 deletion quadra/models/classification/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import timm
import torch
from timm.models.helpers import load_checkpoint
from torch import nn
from torchvision import models

from quadra.models.classification.base import BaseNetworkBuilder
from quadra.utils.logger import get_logger

log = get_logger(__name__)


class TorchHubNetworkBuilder(BaseNetworkBuilder):
Expand All @@ -22,6 +26,7 @@ class TorchHubNetworkBuilder(BaseNetworkBuilder):
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
**torch_hub_kwargs: Additional arguments to pass to torch.hub.load
"""

Expand All @@ -35,12 +40,17 @@ def __init__(
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
checkpoint_path: str | None = None,
**torch_hub_kwargs: Any,
):
self.pretrained = pretrained
features_extractor = torch.hub.load(
repo_or_dir=repo_or_dir, model=model_name, pretrained=self.pretrained, **torch_hub_kwargs
)
if checkpoint_path:
log.info("Loading checkpoint from %s", checkpoint_path)
load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)

super().__init__(
features_extractor=features_extractor,
pre_classifier=pre_classifier,
Expand All @@ -62,6 +72,7 @@ class TorchVisionNetworkBuilder(BaseNetworkBuilder):
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
**torchvision_kwargs: Additional arguments to pass to the model function.
"""

Expand All @@ -74,11 +85,16 @@ def __init__(
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
checkpoint_path: str | None = None,
**torchvision_kwargs: Any,
):
self.pretrained = pretrained
model_function = models.__dict__[model_name]
features_extractor = model_function(pretrained=self.pretrained, progress=True, **torchvision_kwargs)
if checkpoint_path:
log.info("Loading checkpoint from %s", checkpoint_path)
load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)

# Remove classifier
features_extractor.classifier = nn.Identity()
super().__init__(
Expand All @@ -102,6 +118,7 @@ class TimmNetworkBuilder(BaseNetworkBuilder):
freeze: Whether to freeze the feature extractor. Defaults to True.
hyperspherical: Whether to map features to an hypersphere. Defaults to False.
flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
**timm_kwargs: Additional arguments to pass to timm.create_model
"""

Expand All @@ -114,10 +131,13 @@ def __init__(
freeze: bool = True,
hyperspherical: bool = False,
flatten_features: bool = True,
checkpoint_path: str | None = None,
**timm_kwargs: Any,
):
self.pretrained = pretrained
features_extractor = timm.create_model(model_name, pretrained=self.pretrained, num_classes=0, **timm_kwargs)
features_extractor = timm.create_model(
model_name, pretrained=self.pretrained, num_classes=0, checkpoint_path=checkpoint_path, **timm_kwargs
)

super().__init__(
features_extractor=features_extractor,
Expand Down

0 comments on commit 922621d

Please sign in to comment.