diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 373fbc01..ffd62d91 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -37,6 +37,7 @@ def load_image( shape=None, dtype=None, normalize=True, + bgr_input=True, ): """ Load image as numpy array. @@ -151,7 +152,7 @@ def load_image( ) else: - if len(img.shape) == 3: + if len(img.shape) == 3 and bgr_input: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) original_dtype = img.dtype @@ -223,6 +224,7 @@ def load_psf( single_psf=False, shape=None, use_3d=False, + bgr_input=True, ): """ Load and process PSF for analysis or for reconstruction. @@ -296,6 +298,7 @@ def load_psf( blue_gain=blue_gain, red_gain=red_gain, nbits_out=nbits_out, + bgr_input=bgr_input, ) original_dtype = psf.dtype @@ -391,6 +394,7 @@ def load_data( torch=False, torch_device="cpu", normalize=False, + bgr_input=True, ): """ Load data for image reconstruction. @@ -471,6 +475,7 @@ def load_data( single_psf=single_psf, shape=shape, use_3d=use_3d, + bgr_input=bgr_input, ) # load and process raw measurement @@ -485,6 +490,7 @@ def load_data( return_float=return_float, shape=shape, normalize=normalize, + bgr_input=bgr_input, ) if data.shape != psf.shape: