Skip to content

Commit

Permalink
modified training scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
CrohnEngineer committed Dec 1, 2023
1 parent 0ad757d commit 68cd281
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 37 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,25 @@ The dataset is composed of 2 folders:

In order to train the model, you first have to divide the dataset into training, validation and test splits.
You can do this by running the [`notebook/Training dataset creation.ipynb`](notebooks/Training%20dataset%20creation.ipynb) notebook.
**Please notice** that these splits and patches are the ones used in the paper, but you can create your own by modifying the notebook.

If you want to inspect the raw products, a starting point is the [Raw satellite products processing](notebooks/Raw%20satellite%20products%20processing.ipynb) notebook.

# The whole pipeline
## Normalization strategies
All the normalization strategies used in the paper are provided as classes in the [`isplutils/data.py`](isplutils/data.py) file.
Please notice that for the `MinPMax` strategy, we used the [RobustScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html) implementation from `sklearn`.
Statistics are learned from the training set, and then applied to the validation and test sets.
We provide the scalers used in the paper, one for each satellite product, inside the folders of `pristine_images/full_res_products`.

## Model training
The `train_fe.py` takes care of training the models.
You can find the network definition in the [`isplutils/network.py`](isplutils/network.py) file.
All the hyperparameters for training are listed in the file.
To replicate the models used in the paper, follow the [train_all.sh](bash_scripts/train_all.sh) bash script.

## Model evaluation
Inside the `data/spliced_images` folder are contained the two datasets used in the paper, i.e.:
1. `Standard Generated Dataset (SGD)`: images generated by simply normalizing the dynamics between 0 and 1 using a maximum scaling;
2. `Histogram Equalized Generated Dataset (HEGD)`: images generated by equalizing the histogram of the images using a uniform distribution.
Inside each folder there is a Pandas DataFrame containing info on the images.
42 changes: 42 additions & 0 deletions bash_scripts/train_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env bash

# USER PARAMETERS (put your device configuration params here)
DEVICE=0
TRAIN_DIR=../data/pristine_images/train_patches # path to training patches (YOUR PATH MAY BE DIFFERENT IF YOU SAVED THE SPLITS IN ANOTHER FOLDER, CHECK IT)
VAL_DIR=../data/pristine_images/val_patches # path to training patches (YOUR PATH MAY BE DIFFERENT IF YOU SAVED THE SPLITS IN ANOTHER FOLDER, CHECK IT)

echo ""
echo "-------------------------------------------------"
echo "| Train with MinPMax 99th percentile threshold |"
echo "-------------------------------------------------"
python ../train.py --gpu $DEVICE --batch_size 10 --num_iteration 128 --learning_rate 0.0001 \
--epochs 500 --train_dir $TRAIN_DIR --val_dir $VAL_DIR --num_tiles_peracq 200 --batch_num_num_tiles_peracq 10 \
--batch_num_pos_pertile 6 --scaler_type 99th_percentile --mean_robust_scaling --input_fp_channels 3 \
--output_fp_channels 3

echo ""
echo "-------------------------------------------------"
echo "| Train with MinPMax 95th percentile threshold |"
echo "-------------------------------------------------"
python ../train.py --gpu $DEVICE --batch_size 10 --num_iteration 128 --learning_rate 0.0001 \
--epochs 500 --train_dir $TRAIN_DIR --val_dir $VAL_DIR --num_tiles_peracq 200 --batch_num_num_tiles_peracq 10 \
--batch_num_pos_pertile 6 --scaler_type 95th_percentile --mean_robust_scaling --input_fp_channels 3 \
--output_fp_channels 3

echo ""
echo "-------------------------------------------------"
echo "| Train with MaxAbs scaling |"
echo "-------------------------------------------------"
python ../train.py --gpu $DEVICE --batch_size 10 --num_iteration 128 --learning_rate 0.0001 \
--epochs 500 --train_dir $TRAIN_DIR --val_dir $VAL_DIR --num_tiles_peracq 200 --batch_num_num_tiles_peracq 10 \
--batch_num_pos_pertile 6 --scaler_type sat_tiles_scaler --input_norm max_scaling --input_fp_channels 3 \
--output_fp_channels 3

echo ""
echo "-------------------------------------------------"
echo "| Train with HistogramEqualization scaling |"
echo "-------------------------------------------------"
python ../train.py --gpu $DEVICE --batch_size 10 --num_iteration 128 --learning_rate 0.0001 \
--epochs 500 --train_dir $TRAIN_DIR --val_dir $VAL_DIR --num_tiles_peracq 200 --batch_num_num_tiles_peracq 10 \
--batch_num_pos_pertile 6 --scaler_type sat_tiles_scaler --input_norm uniform_scaling --input_fp_channels 3 \
--output_fp_channels 3
29 changes: 0 additions & 29 deletions isplutils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from joblib import load
import os
from abc import ABC, abstractmethod
import imgaug.augmenters as iaa
import skimage.exposure
from typing import List

Expand Down Expand Up @@ -354,24 +353,6 @@ def __init__(self, batch_size, patch_size, data_dir, num_iteration, num_pos_pert
self.input_norm = input_norm
self.he_norm = he_norm
self.gray_scale = gray_scale
self.p_aug = p_aug
if p_aug > 0:
# Let's keep contrast augmentation aside for a moment
# self.augs = iaa.Sometimes(p_aug, iaa.OneOf([iaa.SigmoidContrast(gain=(5, 20), cutoff=(0.25, 0.75)),
# iaa.LogContrast(gain=(0.6, 1.4)),
# iaa.LinearContrast((0.4, 1.6))]))
# Let's tru adding uniform equalization
# self.augs = iaa.Sometimes(p_aug,
# iaa.Lambda(lambda x, random_state, parents, hooks:
# [SatTilesScaler().normalize_product(image, 'uniform_scaling',
# self.mean_scaling_strategy)
# for image in x],
# lambda x, random_state, parents, hooks: x))
self.augs = iaa.Sometimes(p_aug,
iaa.Lambda(lambda x, random_state, parents, hooks:
[skimage.exposure.equalize_hist(image, nbins=50)
for image in x],
lambda x, random_state, parents, hooks: x))

# Load the data (first, set the random split for loading the patches)
self.split_seed = np.random.random_integers(42) if split_seed is None else split_seed
Expand Down Expand Up @@ -523,16 +504,6 @@ def __init__(self, batch_size, patch_size, data_dir, num_iteration, num_pos_pert
self.input_norm = input_norm
self.he_norm = he_norm
self.gray_scale = gray_scale
self.p_aug = p_aug
if p_aug > 0:
# Let's keep contrast augmentation aside for a moment
# self.augs = iaa.Sometimes(p_aug, iaa.OneOf([iaa.SigmoidContrast(gain=(5, 20), cutoff=(0.25, 0.75)),
# iaa.LogContrast(gain=(0.6, 1.4)),
# iaa.LinearContrast((0.4, 1.6))]))
# Let's tru adding uniform equalization
self.augs = iaa.Sometimes(p_aug, iaa.Lambda(
lambda x: SatTilesScaler().normalize_product(x, 'uniform_scaling', self.mean_scaling_strategy),
None))

# Load the data (first, set the random split for loading the patches)
self.split_seed = np.random.random_integers(42) if split_seed is None else split_seed
Expand Down
11 changes: 3 additions & 8 deletions train_fe.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def main(config: argparse.Namespace):
norm = config.input_norm
pos_const = config.pos_const
output_fp_channels = config.output_fp_channels
separable_fp = config.separable_fp
depthwise_fp = config.depthwise_fp

# --- Instantiate the DataGenerators --- #
train_data_generator = DBLDataGenerator(batch_size=train_batch_size, patch_size=patch_size,
Expand All @@ -153,8 +151,7 @@ def main(config: argparse.Namespace):
batch_num_tiles_peracq=batch_num_tiles_peracq,
scaler_type=config.scaler_type,
mean_scaling_strategy=config.mean_robust_scaling,
input_norm=norm,
p_aug=config.p_aug)
input_norm=norm)
valid_data_generator = DBLDataGenerator(batch_size=train_batch_size, patch_size=patch_size,
data_dir=config.val_dir, split_seed=config.split_seed,
num_iteration=config.num_iteration,
Expand All @@ -163,8 +160,7 @@ def main(config: argparse.Namespace):
batch_num_tiles_peracq=batch_num_tiles_peracq,
scaler_type=config.scaler_type,
mean_scaling_strategy=config.mean_robust_scaling,
input_norm=norm,
p_aug=config.p_aug)
input_norm=norm)

# --- TRAINING --- #
print('Starting training')
Expand Down Expand Up @@ -223,8 +219,7 @@ def main(config: argparse.Namespace):
parser.add_argument('--scaler_type', type=str, help='Choose the scaler for the data. Choices are: '
'99th percentile robust scaler;'
'95th percentile robust scaler;'
'Maximum scaling using each band statistics;'
'input norm scaling with histogram equalization.',
'Maximum scaling using each band statistics.',
default='99th_percentile', choices=['99th_percentile', '95th_percentile', 'sat_max',
'sat_tiles_scaler'])
parser.add_argument('--mean_robust_scaling', action='store_true',
Expand Down

0 comments on commit 68cd281

Please sign in to comment.