Skip to content

Commit

Permalink
Fixes for when no background is present. (#152)
Browse files Browse the repository at this point in the history
* Fixes for when no background is present.

* Update setup.
  • Loading branch information
ebezzam authored Oct 2, 2024
1 parent 3ea395f commit 99d6341
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 12 additions & 6 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,20 +1565,26 @@ def _get_images_pair(self, idx):

if len(background_np.shape) == 2:
warnings.warn(f"Converting background[{idx}] to RGB")
background_np = np.stack([background_np] * 3, axis=2) if not None else None
background_np = (
np.stack([background_np] * 3, axis=2) if background_np is not None else None
)
elif len(background_np.shape) == 3:
pass

# convert to float
if lensless_np.dtype == np.uint8:
lensless_np = lensless_np.astype(np.float32) / 255
lensed_np = lensed_np.astype(np.float32) / 255
background_np = background_np.astype(np.float32) / 255 if not None else None
background_np = (
background_np.astype(np.float32) / 255 if background_np is not None else None
)
else:
# 16 bit
lensless_np = lensless_np.astype(np.float32) / 65535
lensed_np = lensed_np.astype(np.float32) / 65535
background_np = background_np.astype(np.float32) / 65535 if not None else None
background_np = (
background_np.astype(np.float32) / 65535 if background_np is not None else None
)

# downsample if necessary
if self.downsample_lensless != 1.0:
Expand All @@ -1591,13 +1597,13 @@ def _get_images_pair(self, idx):
factor=1 / self.downsample_lensless,
interpolation=cv2.INTER_NEAREST,
)
if not None
if background_np is not None
else None
)

lensless = lensless_np
lensed = lensed_np
background = background_np if not None else None
background = background_np if background_np is not None else None

if self.simulator is not None:
# convert to torch
Expand Down Expand Up @@ -1640,7 +1646,7 @@ def __getitem__(self, idx):
# to torch
lensless = torch.from_numpy(lensless)
lensed = torch.from_numpy(lensed)
background = torch.from_numpy(background) if not None else None
background = torch.from_numpy(background) if background is not None else None
# If [H, W, C] -> [D, H, W, C]
if len(lensless.shape) == 3:
lensless = lensless.unsqueeze(0)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
],
python_requires=">=3.8.1, <=3.11.9",
python_requires=">=3.8.1, <3.12",
install_requires=[
"opencv-python>=4.5.1.48",
"numpy==1.26.4; python_version=='3.11'",
Expand Down

0 comments on commit 99d6341

Please sign in to comment.