Skip to content

This repository is a PyTorch version of the paper "EWT: Efficient Wavelet-Transformer for Single Image Denoising"

Notifications You must be signed in to change notification settings

MIVRC/EWT-PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

EWT-PyTorch

This repository is an official PyTorch implementation of the paper "EWT: Efficient Wavelet-Transformer for Single Image Denoising" from NN 2024.

Transformer-based image denoising methods have achieved encouraging results in the past year. However, Transformers must use linear operations to model long-range dependencies, which greatly increases model inference time and consumes GPU storage space. Compared with convolutional neural network-based methods, current Transformer-based image denoising methods cannot achieve a balance between performance improvement and resource consumption. In this paper, we propose an Efficient Wavelet Transformer (EWT) for image denoising. Specifically, we use Discrete Wavelet Transform (DWT) and Inverse Wavelet Transform (IWT) for downsampling and upsampling, respectively. This method can fully preserve the image features while reducing the image resolution, thereby greatly reducing the device resource consumption of the Transformer model. Furthermore, we propose a novel Two-Stream Feature Extraction Block (DFEB) to extract image features at different levels, which can further reduce model inference time and GPU memory usage. Experiments show that our method speeds up the original Transformer by more than 80%, reduces GPU memory usage by more than 60%, and achieves excellent denoising results. All code will be public.

We provide scripts for reproducing all the results from our paper. You can train your model from scratch, or use a pre-trained model to enlarge your images.

Dependencies

  • Python 3.8
  • PyTorch >= 1.7.1
  • numpy
  • skimage
  • imageio
  • matplotlib
  • tqdm

Dataset

We use DIV2K dataset as clear images to train our model. Please download it from here or SNU_CVLab. Put all clear images into the dataset/DIV2K/DIV2K_train_HR. As for noisy images, we use Matlab/generate_noise.m to generate noisy images and put them into the dataset/DIV2K/DIV2K_train_LR_bicubic/x1.

When testing, you can put the clear images and noisy images of the test set into dataset/DIV2K/DIV2K_train_HR and dataset/DIV2K/DIV2K_train_LR_bicubic/x1 respectively

##Training

Using --ext sep_reset argument on your first running.

You can skip the decoding part and use saved binaries with --ext sep argument in second time.

## train
python main.py --scale 1 --patch_size 176 --save ewt --ext sep_reset

##Testing All pre-trained model should be put into experiment/ first.

## test
python main.py --data_test DIV2K --data_range 1-24 --scale 1 --pre_train your_path/EWT/experiment/model_name/model/model_best.pt --test_only --save_results --ext sep_reset

After the above command is run, a file named test will be generated in experiment/, where you can view the noise-removed image.

Performance

About

This repository is a PyTorch version of the paper "EWT: Efficient Wavelet-Transformer for Single Image Denoising"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published