Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removed persistent buffer and fixed some typos. #150

Merged
merged 3 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ Cross-validation


Plotting
------
--------

.. automodule:: mpol.plot
.. automodule:: mpol.plot
2 changes: 1 addition & 1 deletion docs/ci-tutorials/initializedirtyimage.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ rml = precomposed.SimpleNet(coords=coords)
rml.state_dict() # the now uninitialized parameters of the model (the ones we started with)
```

Here you can clearly see the ``state_dict`` is in its original state, before the training loop changed the paramters through the optimization function. Loading our saved dirty image state into the model is as simple as
Here you can clearly see the ``state_dict`` is in its original state, before the training loop changed the parameters through the optimization function. Loading our saved dirty image state into the model is as simple as

```{code-cell}
rml.load_state_dict(torch.load("dirty_image_model.pt"))
Expand Down
8 changes: 4 additions & 4 deletions src/mpol/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ class FourierCube(nn.Module):
cell_size (float): the width of an image-plane pixel [arcseconds]
npix (int): the number of pixels per image side
coords (GridCoords): an object already instantiated from the GridCoords class. If providing this, cannot provide ``cell_size`` or ``npix``.
persistent_vis (Boolean): should the visibility cube be stored as part of the modules `state_dict`? If `True`, the state of the UV grid will be stored. It is recommended to use `False` for most applications, since the visibility cube will rarely be a direct parameter of the model.
"""

def __init__(self, cell_size=None, npix=None, coords=None):
def __init__(self, cell_size=None, npix=None, coords=None, persistent_vis=False):
super().__init__()

# we don't want to bother with the nchan argument here, so
Expand All @@ -41,8 +42,7 @@ def __init__(self, cell_size=None, npix=None, coords=None):

self.coords = GridCoords(cell_size=cell_size, npix=npix)

self.register_buffer("vis", None)

self.register_buffer("vis", None, persistent=persistent_vis)

def forward(self, cube):
"""
Expand All @@ -62,7 +62,7 @@ def forward(self, cube):
# since it needs to correct for the spacing of the input grid.
# See MPoL documentation and/or TMS Eqn A8.18 for more information.
self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2))

return self.vis

@property
Expand Down
6 changes: 3 additions & 3 deletions test/images_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_odd_npix():
images.BaseCube.from_image_properties(npix=853, nchan=30, cell_size=0.015)

with pytest.raises(ValueError, match=expected_error_message):
images.ImageCube.from_image_properteis(npix=853, nchan=30, cell_size=0.015)
images.ImageCube.from_image_properties(npix=853, nchan=30, cell_size=0.015)


def test_negative_cell_size():
Expand All @@ -24,11 +24,11 @@ def test_negative_cell_size():
images.BaseCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015)

with pytest.raises(ValueError, match=expected_error_message):
images.ImageCube.from_image_properteis(npix=800, nchan=30, cell_size=-0.015)
images.ImageCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015)


def test_single_chan():
im = images.ImageCube.from_image_properteis(cell_size=0.015, npix=800)
im = images.ImageCube.from_image_properties(cell_size=0.015, npix=800)
assert im.nchan == 1


Expand Down