diff --git a/.gitignore b/.gitignore index 894a44cc..ace2a4e1 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,7 @@ venv.bak/ # mypy .mypy_cache/ + +checkpoints +dataset +result \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..c4640da0 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "MBIPED"] + path = MBIPED + url = git@github.com:xavysp/MBIPED.git diff --git a/MBIPED b/MBIPED new file mode 160000 index 00000000..91473d4a --- /dev/null +++ b/MBIPED @@ -0,0 +1 @@ +Subproject commit 91473d4ac10b36ae784e194d3e73b1497b7a0457 diff --git a/README.md b/README.md index e40e337e..352fc3c4 100644 --- a/README.md +++ b/README.md @@ -52,10 +52,11 @@ Dexined version on TF 2 is not ready * [Kornia](https://kornia.github.io/) * Other package like Numpy, h5py, PIL, json. -Once the packages are installed, clone this repo as follow: +Once the packages are installed, clone this repo as follows and initiate the MBIPED submodule: git clone https://github.com/xavysp/DexiNed.git cd DexiNed + git submodule update --init ## Project Architecture @@ -76,19 +77,19 @@ Once the packages are installed, clone this repo as follow: ├── model.py # DexiNed class in pythorch ``` -Before to start please check dataset.py, from the first line of code you can see the datasets used for training/testing. The main.py, line 194, call the data for the training or testing, see the example of the code below: +Before starting, please check `datasets.py`. From the first line of code you can see the datasets used for training/testing. In `main.py`, line 194, we call the data for the training or testing. See the example of the code below: ``` parser = argparse.ArgumentParser(description='DexiNed trainer.') parser.add_argument('--choose_test_data', type=int, default=1, help='Already set the dataset for testing choice: 0 - 8') - # ----------- test -------0-- + + ... TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8 test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX) test_dir = test_inf['data_dir'] - is_testing = True# current test -352-SM-NewGT-2AugmenPublish # Training settings TRAIN_DATA = DATASET_NAMES[0] # BIPED=0 @@ -96,8 +97,24 @@ Before to start please check dataset.py, from the first line of code you can see train_dir = train_inf['data_dir'] ``` +The datasets listed below must be downloaded in order for training to be performed. In the current configuration of the `datasets.py` file, for standard datasets, the datasets must be stored in a directory called `dataset` underneath the root directory of this repository. For custom datasets, the data must be stored in the `data` directory underneath the root directory of this repository. + +## Train + +In order to train the model, call the `main.py` script with the flag `--is_training` along with any other additional flags. (Note: The other flags can all be found in the `parse_args()` function in the `main.py` script.) The program is assumed to be in testing mode unless this is done. Below is what should be entered into the command line, assuming no other options are selected: + +`python main.py --is_training` + +Training with the BIPED dataset is configured to work only with the augmented version of the dataset that is generated from the MBIPED project/submodule. In order to generate this augmented dataset, edit file `MBIPED/main.py` to have `BIPED_main_dir` in the `main()` function be equal to the directory at which you are storing the BIPED dataset. This should be changed to `dataset` for the standard configuration. After this is completed, from the DexiNed root directory, call the MBIPED main script using: + +`python MBIPED/main.py` + +This will generate the augmented dataset. + +Note: This is a long process that is not currently able to be paused/restarted. If the process fails for any reason, the user must delete the augmented image files and restart in order to retry. + ## Test -As previously mentioned, the datasets.py has, among other things, the whole datasets configurations used in DexiNed for testing and training: +As previously mentioned, the `datasets.py` file has, among other things, the whole datasets configurations used in DexiNed for testing and training: ``` DATASET_NAMES = [ 'BIPED', @@ -111,19 +128,13 @@ DATASET_NAMES = [ 'CLASSIC' ] ``` -For example, if want to test your own dataset or image choose "CLASSIC" and save your test data in "data" dir. -Before test the DexiNed model, it is necesarry to download the checkpoint here [Checkpoint Pytorch](https://drive.google.com/file/d/1V56vGTsu7GYiQouCIKvTWl5UKCZ6yCNu/view?usp=sharing) and save this file into the DexiNed folder like: checkpoints/BIPED/10/(here the checkpoints from Drive), then run as follow: +For example, if want to test your own dataset or image choose `CLASSIC` and save your test data in the `data` dir. +Before testing a pretrained version of the DexiNed model, it is necesarry to download the checkpoint here [Checkpoint Pytorch](https://drive.google.com/file/d/1V56vGTsu7GYiQouCIKvTWl5UKCZ6yCNu/view?usp=sharing) and save this file into the DexiNed folder like: checkpoints/BIPED/10/(here the checkpoints from Drive), then run as follow: ```python main.py --choose_test_data=-1 ``` -Make sure that in main.py the test setting be as: -```parser.add_argument('--is_testing', default=True, help='Script in testing mode.')``` -DexiNed downsample the input image till 16 scales, please make sure that, in dataset_info fucn (datasets.py), the image width and height be multiple of 16, like 512, 960, and etc. **In the Checkpoint from Drive you will find the last trained checkpoint, which has been trained in the last version of BIPED dataset that will be updated soon in Kaggle ** - -## Train +Be sure not to set the `--is_training` flag when calling the main script. - python main.py -Make sure that in main.py the train setting be as: -```parser.add_argument('--is_testing', default=False, help='Script in testing mode.')``` +DexiNed downsample the input images by factors of 16. Please make sure that, in the `dataset_info()` function (`datasets.py`), the image width and height be multiple of 16, like 512, 960, and etc. **In the Checkpoint from Drive you will find the last trained checkpoint, which has been trained in the last version of the BIPED dataset that will be updated soon in Kaggle ** # Datasets @@ -164,7 +175,7 @@ After WACV20, the BIPED images have been checked again and added more annotation # Citation -If you like DexiNed, why not starring the project on GitHub! +If you like DexiNed, why not star the project on GitHub! [![GitHub stars](https://img.shields.io/github/stars/xavysp/DexiNed.svg?style=social&label=Star&maxAge=3600)](https://GitHub.com/xavysp/DexiNed/stargazers/) diff --git a/datasets.py b/datasets.py index 43986eb9..5d427339 100644 --- a/datasets.py +++ b/datasets.py @@ -30,7 +30,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 512, #481 'train_list': 'train_pair.lst', 'test_list': 'test_pair.lst', - 'data_dir': '/opt/dataset/BSDS', # mean_rgb + 'data_dir': 'dataset/BSDS', # mean_rgb 'yita': 0.5 }, 'BRIND': { @@ -38,7 +38,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 512, # 481 'train_list': 'train_pair2.lst', 'test_list': 'test_pair.lst', - 'data_dir': '/opt/dataset/BRIND', # mean_rgb + 'data_dir': 'dataset/BRIND', # mean_rgb 'yita': 0.5 }, 'BSDS300': { @@ -46,7 +46,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 512, #481 'test_list': 'test_pair.lst', 'train_list': None, - 'data_dir': '/opt/dataset/BSDS300', # NIR + 'data_dir': 'dataset/BSDS300', # NIR 'yita': 0.5 }, 'PASCAL': { @@ -54,7 +54,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 512, #500 'test_list': 'test_pair.lst', 'train_list': None, - 'data_dir': '/opt/dataset/PASCAL', # mean_rgb + 'data_dir': 'dataset/PASCAL', # mean_rgb 'yita': 0.3 }, 'CID': { @@ -62,7 +62,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 512, 'test_list': 'test_pair.lst', 'train_list': None, - 'data_dir': '/opt/dataset/CID', # mean_rgb + 'data_dir': 'dataset/CID', # mean_rgb 'yita': 0.3 }, 'NYUD': { @@ -70,7 +70,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 560,#560 'test_list': 'test_pair.lst', 'train_list': None, - 'data_dir': '/opt/dataset/NYUD', # mean_rgb + 'data_dir': 'dataset/NYUD', # mean_rgb 'yita': 0.5 }, 'MDBD': { @@ -78,7 +78,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 1280, 'test_list': 'test_pair.lst', 'train_list': 'train_pair.lst', - 'data_dir': '/opt/dataset/MDBD', # mean_rgb + 'data_dir': 'dataset/MDBD', # mean_rgb 'yita': 0.3 }, 'BIPED': { @@ -86,7 +86,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 1280, # 1280 5 1920 'test_list': 'test_pair.lst', 'train_list': 'train_rgb.lst', - 'data_dir': '/opt/dataset/BIPED', # mean_rgb + 'data_dir': 'dataset/BIPED', # mean_rgb 'yita': 0.5 }, 'CLASSIC': { @@ -102,7 +102,7 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 480,# 360 'test_list': 'test_pair.lst', 'train_list': None, - 'data_dir': '/opt/dataset/DCD', # mean_rgb + 'data_dir': 'dataset/DCD', # mean_rgb 'yita': 0.2 } } @@ -112,39 +112,39 @@ def dataset_info(dataset_name, is_linux=True): 'img_width': 512, # 481 'test_list': 'test_pair.lst', 'train_list': 'train_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/BSDS', # mean_rgb + 'data_dir': 'dataset/BSDS', # mean_rgb 'yita': 0.5}, 'BSDS300': {'img_height': 512, # 321 'img_width': 512, # 481 'test_list': 'test_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/BSDS300', # NIR + 'data_dir': 'dataset/BSDS300', # NIR 'yita': 0.5}, 'PASCAL': {'img_height': 375, 'img_width': 500, 'test_list': 'test_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/PASCAL', # mean_rgb + 'data_dir': 'dataset/PASCAL', # mean_rgb 'yita': 0.3}, 'CID': {'img_height': 512, 'img_width': 512, 'test_list': 'test_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/CID', # mean_rgb + 'data_dir': 'dataset/CID', # mean_rgb 'yita': 0.3}, 'NYUD': {'img_height': 425, 'img_width': 560, 'test_list': 'test_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/NYUD', # mean_rgb + 'data_dir': 'dataset/NYUD', # mean_rgb 'yita': 0.5}, 'MDBD': {'img_height': 720, 'img_width': 1280, 'test_list': 'test_pair.lst', 'train_list': 'train_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/MDBD', # mean_rgb + 'data_dir': 'dataset/MDBD', # mean_rgb 'yita': 0.3}, 'BIPED': {'img_height': 720, # 720 'img_width': 1280, # 1280 'test_list': 'test_pair.lst', 'train_list': 'train_rgb.lst', - 'data_dir': 'C:/Users/xavysp/dataset/BIPED', # WIN: '../.../dataset/BIPED/edges' + 'data_dir': 'dataset/BIPED', # WIN: '../.../dataset/BIPED/edges' 'yita': 0.5}, 'CLASSIC': {'img_height': 512, 'img_width': 512, @@ -155,7 +155,7 @@ def dataset_info(dataset_name, is_linux=True): 'DCD': {'img_height': 240, 'img_width': 360, 'test_list': 'test_pair.lst', - 'data_dir': 'C:/Users/xavysp/dataset/DCD', # mean_rgb + 'data_dir': 'dataset/DCD', # mean_rgb 'yita': 0.2} } return config[dataset_name] diff --git a/main.py b/main.py index 732f6bd9..f84d91dc 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ import argparse import os import time, platform +from dataclasses import dataclass import cv2 import torch.optim as optim @@ -15,6 +16,38 @@ from utils import (image_normalization, save_image_batch_to_disk, visualize_result,count_parameters) +@dataclass +class Args: + choose_test_data: int + input_dir: str + input_val_dir: str + output_dir: str + train_data: str + test_data: str + test_list: str + train_list: str + is_testing: bool + double_img: bool + resume: bool + checkpoint_data: str + test_img_width: int + test_img_height: int + res_dir: str + log_interval_vis: int + epochs: int + lr: float + wd: float + adjust_lr: list[int] + batch_size: int + workers: int + tensorboard: bool + img_width: int + img_height: int + channel_swap: list[int] + crop_img: bool + mean_pixel_values: list[float] + + IS_LINUX = True if platform.system()=="Linux" else False def train_one_epoch(epoch, dataloader, model, criterion, optimizer, device, log_interval_vis, tb_writer, args=None): @@ -199,28 +232,12 @@ def parse_args(): type=int, default=-1, help='Already set the dataset for testing choice: 0 - 8') - # ----------- test -------0-- - - - TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8 - test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX) - test_dir = test_inf['data_dir'] - is_testing =True# current test -352-SM-NewGT-2AugmenPublish - - # Training settings - TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, MDBD=6 - train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX) - train_dir = train_inf['data_dir'] - - # Data parameters parser.add_argument('--input_dir', type=str, - default=train_dir, help='the path to the directory with the input data.') parser.add_argument('--input_val_dir', type=str, - default=test_inf['data_dir'], help='the path to the directory with the input data for validation.') parser.add_argument('--output_dir', type=str, @@ -229,31 +246,25 @@ def parse_args(): parser.add_argument('--train_data', type=str, choices=DATASET_NAMES, - default=TRAIN_DATA, help='Name of the dataset.') parser.add_argument('--test_data', type=str, choices=DATASET_NAMES, - default=TEST_DATA, help='Name of the dataset.') parser.add_argument('--test_list', type=str, - default=test_inf['test_list'], help='Dataset sample indices list.') parser.add_argument('--train_list', type=str, - default=train_inf['train_list'], help='Dataset sample indices list.') - parser.add_argument('--is_testing',type=bool, - default=is_testing, - help='Script in testing mode.') + parser.add_argument('--is_training', + action='store_true', + help='Script in training mode.') parser.add_argument('--double_img', - type=bool, - default=False, + action='store_true', help='True: use same 2 imgs changing channels') # Just for test parser.add_argument('--resume', - type=bool, - default=False, + action='store_true', help='use previous trained data') # Just for test parser.add_argument('--checkpoint_data', type=str, @@ -261,11 +272,9 @@ def parse_args(): help='Checkpoint path from which to restore model weights from.') parser.add_argument('--test_img_width', type=int, - default=test_inf['img_width'], help='Image width for testing.') parser.add_argument('--test_img_height', type=int, - default=test_inf['img_height'], help='Image height for testing.') parser.add_argument('--res_dir', type=str, @@ -303,8 +312,8 @@ def parse_args(): default=16, type=int, help='The number of workers for the dataloaders.') - parser.add_argument('--tensorboard',type=bool, - default=True, + parser.add_argument('--no_tensorboard', + action='store_true', help='Use Tensorboard for logging.'), parser.add_argument('--img_width', type=int, @@ -317,15 +326,94 @@ def parse_args(): parser.add_argument('--channel_swap', default=[2, 1, 0], type=int) - parser.add_argument('--crop_img', - default=True, - type=bool, + parser.add_argument('--no_crop_img', + action='store_true', help='If true crop training images, else resize images to match image width and height.') parser.add_argument('--mean_pixel_values', default=[103.939,116.779,123.68, 137.86], type=float) # [103.939,116.779,123.68] [104.00699, 116.66877, 122.67892] + + args = parser.parse_args() - return args + + # Grab all information from args with literal default values + choose_test_data = args.choose_test_data + output_dir = args.output_dir + is_testing = not args.is_training + double_img = args.double_img + resume = args.resume + checkpoint_data = args.checkpoint_data + res_dir = args.res_dir + log_interval_vis = args.log_interval_vis + epochs = args.epochs + lr = args.lr + wd = args.wd + adjust_lr = args.adjust_lr + batch_size = args.batch_size + workers = args.workers + tensorboard = not args.no_tensorboard + img_width = args.img_width + img_height = args.img_height + channel_swap = args.channel_swap + crop_img = not args.no_crop_img + mean_pixel_values = args.mean_pixel_values + + # Grab non-literal/dependent default values + + TEST_DATA = DATASET_NAMES[choose_test_data] # max 8 + test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX) + + # Training settings + TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, MDBD=6 + train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX) + + if args.input_dir is None: + input_dir = train_inf['data_dir'] + else: + input_dir = args.input_dir + + if args.input_val_dir is None: + input_val_dir = test_inf['data_dir'] + else: + input_val_dir = args.input_val_dir + + if args.train_data is None: + train_data = TRAIN_DATA + else: + train_data = args.train_data + + if args.test_data is None: + test_data = TEST_DATA + else: + test_data = args.test_data + + if args.test_list is None: + test_list = test_inf['test_list'] + else: + test_list = args.test_list + + if args.train_list is None: + train_list = train_inf['train_list'] + else: + train_list = args.train_list + + if args.test_img_width is None: + test_img_width = test_inf['img_width'] + else: + test_img_width = args.test_img_width + + if args.test_img_height is None: + test_img_height = test_inf['img_height'] + else: + test_img_height = args.test_img_height + + arg_struct = Args(choose_test_data, input_dir, input_val_dir, output_dir, train_data, test_data, \ + test_list, train_list, is_testing, double_img, resume, checkpoint_data, test_img_width, \ + test_img_height, res_dir, log_interval_vis, epochs, lr, wd, adjust_lr, batch_size, \ + workers, tensorboard, img_width, img_height, channel_swap, crop_img, mean_pixel_values) + + print(arg_struct) + return arg_struct def main(args):