-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Clean up training with simulated dataset.
- Loading branch information
Showing
12 changed files
with
224 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,8 @@ | |
# Eric BEZZAM [[email protected]] | ||
# ############################################################################# | ||
|
||
import numpy as np | ||
from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp | ||
import torch | ||
|
||
|
||
class FarFieldSimulator(FarFieldSimulator_wp): | ||
|
@@ -34,7 +34,7 @@ def __init__( | |
""" | ||
Parameters | ||
---------- | ||
psf : np.ndarray, optional. | ||
psf : np.ndarray or torch.Tensor, optional. | ||
Point spread function. If not provided, return image at object plane. | ||
object_height : float or (float, float) | ||
Height of object in meters. Or range of values to randomly sample from. | ||
|
@@ -58,9 +58,15 @@ def __init__( | |
Whether to quantize image, by default True. | ||
""" | ||
|
||
if psf is not None: | ||
# convert HWC to CHW | ||
psf = psf.squeeze().movedim(-1, 0) | ||
assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" | ||
|
||
if torch.is_tensor(psf): | ||
# drop depth dimension, and convert HWC to CHW | ||
psf = psf[0].movedim(-1, 0) | ||
assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" | ||
else: | ||
psf = psf[0] | ||
assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" | ||
|
||
super().__init__( | ||
object_height, | ||
|
@@ -78,6 +84,13 @@ def __init__( | |
**kwargs | ||
) | ||
|
||
if self.is_torch: | ||
assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels" | ||
else: | ||
assert ( | ||
self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 | ||
), "PSF must have 1 or 3 channels" | ||
|
||
# save all the parameters in a dict | ||
self.params = { | ||
"object_height": object_height, | ||
|
@@ -94,7 +107,15 @@ def __init__( | |
} | ||
self.params.update(kwargs) | ||
|
||
def set_psf(self, psf): | ||
def get_psf(self): | ||
if self.is_torch: | ||
# convert CHW to HWC | ||
return self.psf.movedim(0, -1).unsqueeze(0) | ||
else: | ||
return self.psf[None, ...] | ||
|
||
# needs different name from parent class | ||
def set_point_spread_function(self, psf): | ||
""" | ||
Set point spread function. | ||
|
@@ -103,19 +124,32 @@ def set_psf(self, psf): | |
psf : np.ndarray or torch.Tensor | ||
Point spread function. | ||
""" | ||
psf = psf.squeeze().movedim(-1, 0) | ||
assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" | ||
|
||
if torch.is_tensor(psf): | ||
# convert HWC to CHW | ||
psf = psf[0].movedim(-1, 0) | ||
assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" | ||
else: | ||
psf = psf[0] | ||
assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" | ||
|
||
return super().set_psf(psf) | ||
|
||
def propagate(self, obj, return_object_plane=False): | ||
def propagate_image(self, obj, return_object_plane=False): | ||
""" | ||
Parameters | ||
---------- | ||
obj : np.ndarray or torch.Tensor | ||
Single image to propagate at format HWC. | ||
Single image to propagate of format HWC. | ||
return_object_plane : bool, optional | ||
Whether to return object plane, by default False. | ||
""" | ||
|
||
assert obj.shape[-1] == 1 or obj.shape[-1] == 3, "Image must have 1 or 3 channels" | ||
|
||
if self.is_torch: | ||
# channel in first dimension as expected by waveprop for pytorch | ||
obj = obj.moveaxis(-1, 0) | ||
res = super().propagate(obj, return_object_plane) | ||
if isinstance(res, tuple): | ||
|
@@ -124,10 +158,6 @@ def propagate(self, obj, return_object_plane=False): | |
res = res.moveaxis(-3, -1) | ||
return res | ||
else: | ||
obj = np.moveaxis(obj, -1, 0) | ||
# TODO: not tested, but normally don't need to move dimensions for numpy | ||
res = super().propagate(obj, return_object_plane) | ||
if isinstance(res, tuple): | ||
res = np.moveaxis(res[0], -3, -1), np.moveaxis(res[1], -3, -1) | ||
else: | ||
res = np.moveaxis(res, -3, -1) | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
sympy>=1.11.1 | ||
perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 | ||
waveprop>=0.0.7 | ||
waveprop>=0.0.8 |
Oops, something went wrong.