Skip to content

Commit

Permalink
feat: add tests for reference image dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
royale authored and pnsuau committed Aug 25, 2023
1 parent 7749a08 commit 4f0500d
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 0 deletions.
34 changes: 34 additions & 0 deletions scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,40 @@ if [ $OUT != 0 ]; then
exit 1
fi

####### mask ref test
echo "Running mask ref training tests"
URL=https://joligen.com/datasets/viton_mask_ref_mini.zip
ZIP_FILE=$DIR/viton_mask_ref_mini.zip
TARGET_MASK_REF_DIR=$DIR/viton_mask_ref_mini
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_MASK_REF_DIR
unzip $ZIP_FILE -d $DIR
rm $ZIP_FILE

python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_mask_ref.py" --dataroot "$TARGET_MASK_REF_DIR"
OUT=$?

if [ $OUT != 0 ]; then
exit 1
fi

####### mask ref online test
echo "Running mask ref online training tests"
URL=https://joligen.com/datasets/viton_bbox_ref_mini.zip
ZIP_FILE=$DIR/viton_bbox_ref_mini.zip
TARGET_MASK_ONLINE_REF_DIR=$DIR/viton_bbox_ref_mini
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_MASK_ONLINE_REF_DIR
unzip $ZIP_FILE -d $DIR
rm $ZIP_FILE

python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_mask_online_ref.py" --dataroot "$TARGET_MASK_ONLINE_REF_DIR"
OUT=$?

if [ $OUT != 0 ]; then
exit 1
fi

echo "Deleting target dir $DIR"
rm -rf $DIR/*

Expand Down
59 changes: 59 additions & 0 deletions tests/test_run_mask_online_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import torch.multiprocessing as mp
import sys
from itertools import product

sys.path.append(sys.path[0] + "/..")
import train
from options.train_options import TrainOptions
from data import create_dataset

json_like_dict = {
"name": "joligen_utest_mask_online_ref",
"output_display_env": "joligen_utest_mask_online_ref",
"output_display_id": 0,
"gpu_ids": "0",
"data_load_size": 128,
"data_crop_size": 128,
"train_n_epochs": 1,
"train_n_epochs_decay": 0,
"data_max_dataset_size": 10,
"data_relative_paths": True,
"train_G_ema": True,
"dataaug_no_rotate": True,
"G_unet_mha_num_head_channels": 16,
"G_unet_mha_channel_mults": [1, 2],
"G_nblocks": 1,
"G_padding_type": "reflect",
"data_online_creation_rand_mask_A": True,
"f_s_semantic_nclasses": 100,
"model_type": "palette",
"G_netG": "unet_mha",
}

models_datasets = [
["palette", "self_supervised_labeled_mask_online_ref"],
["cut", "unaligned_labeled_mask_online_ref"],
]
conditionings = [
"alg_palette_conditioning",
"alg_palette_cond_image_creation",
]

product_list = product(
models_datasets,
conditionings,
)


def test_mask_online_ref(dataroot):
json_like_dict["dataroot"] = dataroot
json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1])

for (model, dataset), conditioning in product_list:
json_like_dict_c = json_like_dict.copy()
json_like_dict_c["data_dataset_mode"] = dataset
json_like_dict_c["model_type"] = model
json_like_dict_c[conditioning] = "ref"
opt = TrainOptions().parse_json(json_like_dict_c, save_config=True)
train.launch_training(opt)
59 changes: 59 additions & 0 deletions tests/test_run_mask_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import torch.multiprocessing as mp
import sys
from itertools import product

sys.path.append(sys.path[0] + "/..")
import train
from options.train_options import TrainOptions
from data import create_dataset

json_like_dict = {
"name": "joligen_utest_mask_ref",
"output_display_env": "joligen_utest_mask_ref",
"output_display_id": 0,
"gpu_ids": "0",
"data_load_size": 128,
"data_crop_size": 128,
"train_n_epochs": 1,
"train_n_epochs_decay": 0,
"data_max_dataset_size": 10,
"data_relative_paths": True,
"train_G_ema": True,
"dataaug_no_rotate": True,
"G_unet_mha_num_head_channels": 16,
"G_unet_mha_channel_mults": [1, 2],
"G_nblocks": 1,
"G_padding_type": "reflect",
"data_online_creation_rand_mask_A": True,
"f_s_semantic_nclasses": 100,
"model_type": "palette",
"G_netG": "unet_mha",
}

models_datasets = [
["palette", "self_supervised_labeled_mask_ref"],
["cut", "unaligned_labeled_mask_ref"],
]
conditionings = [
"alg_palette_conditioning",
"alg_palette_cond_image_creation",
]

product_list = product(
models_datasets,
conditionings,
)


def test_mask_ref(dataroot):
json_like_dict["dataroot"] = dataroot
json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1])

for (model, dataset), conditioning in product_list:
json_like_dict_c = json_like_dict.copy()
json_like_dict_c["data_dataset_mode"] = dataset
json_like_dict_c["model_type"] = model
json_like_dict_c[conditioning] = "ref"
opt = TrainOptions().parse_json(json_like_dict_c, save_config=True)
train.launch_training(opt)

0 comments on commit 4f0500d

Please sign in to comment.