From 99d63419a34a57f79def45333e20d0e0d3ab2577 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 2 Oct 2024 05:30:38 -0700 Subject: [PATCH] Fixes for when no background is present. (#152) * Fixes for when no background is present. * Update setup. --- lensless/utils/dataset.py | 18 ++++++++++++------ setup.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 34c5aa78..725f418b 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1565,7 +1565,9 @@ def _get_images_pair(self, idx): if len(background_np.shape) == 2: warnings.warn(f"Converting background[{idx}] to RGB") - background_np = np.stack([background_np] * 3, axis=2) if not None else None + background_np = ( + np.stack([background_np] * 3, axis=2) if background_np is not None else None + ) elif len(background_np.shape) == 3: pass @@ -1573,12 +1575,16 @@ def _get_images_pair(self, idx): if lensless_np.dtype == np.uint8: lensless_np = lensless_np.astype(np.float32) / 255 lensed_np = lensed_np.astype(np.float32) / 255 - background_np = background_np.astype(np.float32) / 255 if not None else None + background_np = ( + background_np.astype(np.float32) / 255 if background_np is not None else None + ) else: # 16 bit lensless_np = lensless_np.astype(np.float32) / 65535 lensed_np = lensed_np.astype(np.float32) / 65535 - background_np = background_np.astype(np.float32) / 65535 if not None else None + background_np = ( + background_np.astype(np.float32) / 65535 if background_np is not None else None + ) # downsample if necessary if self.downsample_lensless != 1.0: @@ -1591,13 +1597,13 @@ def _get_images_pair(self, idx): factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST, ) - if not None + if background_np is not None else None ) lensless = lensless_np lensed = lensed_np - background = background_np if not None else None + background = background_np if background_np is not None else None if self.simulator is not None: # convert to torch @@ -1640,7 +1646,7 @@ def __getitem__(self, idx): # to torch lensless = torch.from_numpy(lensless) lensed = torch.from_numpy(lensed) - background = torch.from_numpy(background) if not None else None + background = torch.from_numpy(background) if background is not None else None # If [H, W, C] -> [D, H, W, C] if len(lensless.shape) == 3: lensless = lensless.unsqueeze(0) diff --git a/setup.py b/setup.py index 23bb2269..11260697 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires=">=3.8.1, <=3.11.9", + python_requires=">=3.8.1, <3.12", install_requires=[ "opencv-python>=4.5.1.48", "numpy==1.26.4; python_version=='3.11'",