diff --git a/docs/api.rst b/docs/api.rst index 12088711..3af23cb1 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -79,6 +79,6 @@ Cross-validation Plotting ------- +-------- -.. automodule:: mpol.plot \ No newline at end of file +.. automodule:: mpol.plot diff --git a/docs/ci-tutorials/initializedirtyimage.md b/docs/ci-tutorials/initializedirtyimage.md index 276da74b..d5411f25 100644 --- a/docs/ci-tutorials/initializedirtyimage.md +++ b/docs/ci-tutorials/initializedirtyimage.md @@ -205,7 +205,7 @@ rml = precomposed.SimpleNet(coords=coords) rml.state_dict() # the now uninitialized parameters of the model (the ones we started with) ``` -Here you can clearly see the ``state_dict`` is in its original state, before the training loop changed the paramters through the optimization function. Loading our saved dirty image state into the model is as simple as +Here you can clearly see the ``state_dict`` is in its original state, before the training loop changed the parameters through the optimization function. Loading our saved dirty image state into the model is as simple as ```{code-cell} rml.load_state_dict(torch.load("dirty_image_model.pt")) diff --git a/src/mpol/fourier.py b/src/mpol/fourier.py index a2243321..313b643b 100644 --- a/src/mpol/fourier.py +++ b/src/mpol/fourier.py @@ -20,9 +20,10 @@ class FourierCube(nn.Module): cell_size (float): the width of an image-plane pixel [arcseconds] npix (int): the number of pixels per image side coords (GridCoords): an object already instantiated from the GridCoords class. If providing this, cannot provide ``cell_size`` or ``npix``. + persistent_vis (Boolean): should the visibility cube be stored as part of the modules `state_dict`? If `True`, the state of the UV grid will be stored. It is recommended to use `False` for most applications, since the visibility cube will rarely be a direct parameter of the model. """ - def __init__(self, cell_size=None, npix=None, coords=None): + def __init__(self, cell_size=None, npix=None, coords=None, persistent_vis=False): super().__init__() # we don't want to bother with the nchan argument here, so @@ -41,8 +42,7 @@ def __init__(self, cell_size=None, npix=None, coords=None): self.coords = GridCoords(cell_size=cell_size, npix=npix) - self.register_buffer("vis", None) - + self.register_buffer("vis", None, persistent=persistent_vis) def forward(self, cube): """ @@ -62,7 +62,7 @@ def forward(self, cube): # since it needs to correct for the spacing of the input grid. # See MPoL documentation and/or TMS Eqn A8.18 for more information. self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2)) - + return self.vis @property diff --git a/test/images_test.py b/test/images_test.py index b92a9963..b0d988e9 100644 --- a/test/images_test.py +++ b/test/images_test.py @@ -14,7 +14,7 @@ def test_odd_npix(): images.BaseCube.from_image_properties(npix=853, nchan=30, cell_size=0.015) with pytest.raises(ValueError, match=expected_error_message): - images.ImageCube.from_image_properteis(npix=853, nchan=30, cell_size=0.015) + images.ImageCube.from_image_properties(npix=853, nchan=30, cell_size=0.015) def test_negative_cell_size(): @@ -24,11 +24,11 @@ def test_negative_cell_size(): images.BaseCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015) with pytest.raises(ValueError, match=expected_error_message): - images.ImageCube.from_image_properteis(npix=800, nchan=30, cell_size=-0.015) + images.ImageCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015) def test_single_chan(): - im = images.ImageCube.from_image_properteis(cell_size=0.015, npix=800) + im = images.ImageCube.from_image_properties(cell_size=0.015, npix=800) assert im.nchan == 1