Skip to content

Commit

Permalink
Fix SeCo transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Mar 18, 2023
1 parent 9b5f617 commit 397265a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 50 deletions.
65 changes: 23 additions & 42 deletions torchgeo/datamodules/seco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from typing import Any

import kornia.augmentation as K
import torch
from einops import repeat

from ..datasets import SeasonalContrastS2
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand All @@ -18,40 +20,6 @@ class SeasonalContrastS2DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""

# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501
mean = torch.tensor(
[
340.76769064,
429.9430203,
614.21682446,
590.23569706,
950.68368468,
1792.46290469,
2075.46795189,
2218.94553375,
2266.46036911,
2246.0605464,
1594.42694882,
1009.32729131,
]
)
std = 2 * torch.tensor(
[
554.81258967,
572.41639287,
582.87945694,
675.88746967,
729.89827633,
1096.01480586,
1273.45393088,
1365.45589904,
1356.13789355,
1302.3292881,
1079.19066363,
818.86747235,
]
)

def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
Expand All @@ -63,17 +31,30 @@ def __init__(
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SeasonalContrastS2`.
"""
bands = kwargs.get("bands", SeasonalContrastS2.rgb_bands)
all_bands = SeasonalContrastS2.all_bands
indices = [all_bands.index(band) for band in bands]
self.mean = self.mean[indices]
self.std = self.std[indices]
super().__init__(SeasonalContrastS2, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", SeasonalContrastS2.rgb_bands)
seasons = kwargs.get("seasons", 1)
self.mean = repeat(self.mean, "c -> (t c)", t=seasons)
self.std = repeat(self.std, "c -> (t c)", t=seasons)

super().__init__(SeasonalContrastS2, batch_size, num_workers, **kwargs)
# Normalization only available for RGB dataset
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
if bands == SeasonalContrastS2.rgb_bands:
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])

_min = repeat(_min, "c -> (t c)", t=seasons)
_max = repeat(_max, "c -> (t c)", t=seasons)
_mean = repeat(_mean, "c -> (t c)", t=seasons)
_std = repeat(_std, "c -> (t c)", t=seasons)

self.aug = AugmentationSequential(
K.Normalize(mean=_min, std=_max - _min),
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
)

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down
20 changes: 12 additions & 8 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
__all__ = ["ResNet50_Weights", "ResNet18_Weights"]


# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = AugmentationSequential(
K.Resize(256),
Expand All @@ -26,13 +26,17 @@
data_keys=["image"],
)

# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/bigearthnet_dataset.py#L13 # noqa: E501
# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501
_min = torch.tensor([3, 2, 0])
_max = torch.tensor([88, 103, 129])
_mean = torch.tensor([0.485, 0.456, 0.406])
_std = torch.tensor([0.229, 0.224, 0.225])
_seco_transforms = AugmentationSequential(
K.Resize(128),
K.Normalize(
mean=torch.tensor([590.23569706, 614.21682446, 429.9430203]),
std=2 * torch.tensor([675.88746967, 582.87945694, 572.41639287]),
),
K.Resize(256),
K.CenterCrop(224),
K.Normalize(mean=_min, std=_max - _min),
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
)

Expand Down

0 comments on commit 397265a

Please sign in to comment.