diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index a20d502d..f0d258ba 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -40,7 +40,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): """ super().__init__() self._mask = torch.nn.Parameter(initial_mask) - self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) + self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr) self.train_mask_vals = True self._counter = 0 diff --git a/mask_requirements.txt b/mask_requirements.txt index 9e9c28a4..1c0c0b93 100644 --- a/mask_requirements.txt +++ b/mask_requirements.txt @@ -1,3 +1,4 @@ sympy>=1.11.1 perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016 -waveprop>=0.0.8 \ No newline at end of file +waveprop>=0.0.8 +slm_controller @ git+https://github.com/ebezzam/slm-controller.git \ No newline at end of file