Skip to content

Commit

Permalink
Models: fix preprocessing transforms (microsoft#1166)
Browse files Browse the repository at this point in the history
* Models: fix preprocessing transforms

* Fix normalization of SeCo std dev

* black

* Fix SeCo transforms

* Add comment explaining source of transforms
  • Loading branch information
adamjstewart authored Mar 23, 2023
1 parent 9f2c44e commit cd1d921
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 82 deletions.
10 changes: 10 additions & 0 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def test_resnet(self) -> None:
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet18(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta["in_chans"]
sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet18(weights=weights)
Expand Down Expand Up @@ -75,6 +80,11 @@ def test_resnet(self) -> None:
def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
resnet50(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta["in_chans"]
sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_resnet_download(self, weights: WeightsEnum) -> None:
resnet50(weights=weights)
5 changes: 5 additions & 0 deletions tests/models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def test_vit(self) -> None:
def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=mocked_weights)

def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta["in_chans"]
sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)}
mocked_weights.transforms(sample)

@pytest.mark.slow
def test_vit_download(self, weights: WeightsEnum) -> None:
vit_small_patch16_224(weights=weights)
65 changes: 23 additions & 42 deletions torchgeo/datamodules/seco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from typing import Any

import kornia.augmentation as K
import torch
from einops import repeat

from ..datasets import SeasonalContrastS2
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand All @@ -18,40 +20,6 @@ class SeasonalContrastS2DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""

# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501
mean = torch.tensor(
[
340.76769064,
429.9430203,
614.21682446,
590.23569706,
950.68368468,
1792.46290469,
2075.46795189,
2218.94553375,
2266.46036911,
2246.0605464,
1594.42694882,
1009.32729131,
]
)
std = 2 * torch.tensor(
[
554.81258967,
572.41639287,
582.87945694,
675.88746967,
729.89827633,
1096.01480586,
1273.45393088,
1365.45589904,
1356.13789355,
1302.3292881,
1079.19066363,
818.86747235,
]
)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
Expand All @@ -63,17 +31,30 @@ def __init__(
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SeasonalContrastS2`.
"""
bands = kwargs.get("bands", SeasonalContrastS2.rgb_bands)
all_bands = SeasonalContrastS2.all_bands
indices = [all_bands.index(band) for band in bands]
self.mean = self.mean[indices]
self.std = self.std[indices]
super().__init__(SeasonalContrastS2, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", SeasonalContrastS2.rgb_bands)
seasons = kwargs.get("seasons", 1)
self.mean = repeat(self.mean, "c -> (t c)", t=seasons)
self.std = repeat(self.std, "c -> (t c)", t=seasons)

super().__init__(SeasonalContrastS2, batch_size, num_workers, **kwargs)
# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
if bands == SeasonalContrastS2.rgb_bands:
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])

_min = repeat(_min, "c -> (t c)", t=seasons)
_max = repeat(_max, "c -> (t c)", t=seasons)
_mean = repeat(_mean, "c -> (t c)", t=seasons)
_std = repeat(_std, "c -> (t c)", t=seasons)

self.aug = AugmentationSequential(
K.Normalize(mean=_min, std=_max - _min),
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
)

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
53 changes: 14 additions & 39 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,28 @@
__all__ = ["ResNet50_Weights", "ResNet18_Weights"]


# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=0, std=10000),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)),
data_keys=["image"],
)

# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501
# Normalization only available for RGB dataset, defined here:
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_seco_transforms = AugmentationSequential(
K.Resize(128),
K.Normalize(
mean=torch.Tensor(
[
340.76769064,
429.9430203,
614.21682446,
590.23569706,
950.68368468,
1792.46290469,
2075.46795189,
2218.94553375,
2266.46036911,
2246.0605464,
1594.42694882,
1009.32729131,
]
),
std=torch.Tensor(
[
554.81258967,
572.41639287,
582.87945694,
675.88746967,
729.89827633,
1096.01480586,
1273.45393088,
1365.45589904,
1356.13789355,
1302.3292881,
1079.19066363,
818.86747235,
]
),
),
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=_min, std=_max - _min),
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
)

Expand Down
3 changes: 2 additions & 1 deletion torchgeo/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import kornia.augmentation as K
import timm
import torch
from timm.models.vision_transformer import VisionTransformer
from torchvision.models._api import Weights, WeightsEnum

Expand All @@ -20,7 +21,7 @@
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=0, std=10000),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)),
data_keys=["image"],
)

Expand Down

0 comments on commit cd1d921

Please sign in to comment.