Skip to content

Commit

Permalink
Fix cube fitting compatibility with unit conversion (#3190)
Browse files Browse the repository at this point in the history
* Fix NaN handling in cube fitting and initial fixes for unit conversion/model fitting interaction

* Remove debugging prints, add comment for context

* Codestyle, changelog

* Reestimate parameters when cube fitting is toggled

* Only reestimate if spectral y type isn't SB

* Codestyle

* Changelog

* Fix initializing linear component for cube fit

* Handle linear component estimation for cube case

* Only reshape here in 3D case

* Skip tests that need #3156

* Respect selected display units when initializing model components

* Use app._get_display_unit instead of relying on Unit Conversion

* Add equivalency here

* Only reestimate here if sb unit != spectral_y unit

* Don't automatically reestimate when toggling cube fit, make the user do it

* Remove print

* Check for warning in test after cube toggle

Fix test

Codestyle

* Add test for cube fitting after flux unit change

* Add a to-do about a test for unit conversion with equivalency

* Fix failing test

* Back to parallel processing post-debugging
  • Loading branch information
rosteen authored Sep 20, 2024
1 parent d49e3b5 commit 91c6ff3
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ New Features

- Added flux/surface brightness translation and surface brightness
unit conversion in Cubeviz and Specviz. [#2781, #2940, #3088, #3111, #3113, #3129,
#3139, #3149, #3155, #3178, #3185, #3187]
#3139, #3149, #3155, #3178, #3185, #3187, #3190]

- Plugin tray is now open by default. [#2892]

Expand Down
3 changes: 3 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def initialize(self, instance, x, y):
instance : `~astropy.modeling.Model`
The initialized model.
"""
if y.ndim == 3:
# For cube fitting, need to collapse before this calculation
y = np.nanmean(y, axis=(0, 1))
slope, intercept = np.polynomial.Polynomial.fit(x.value.flatten(), y.value.flatten(), 1)

instance.slope.value = slope
Expand Down
112 changes: 83 additions & 29 deletions jdaviz/configs/default/plugins/model_fitting/model_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def _default_flux_viewer_reference_name(self):
self.app._jdaviz_helper, '_default_flux_viewer_reference_name', 'flux-viewer'
)

@observe('cube_fit')
def _cube_fit_changed(self, msg={}):
if self.cube_fit:
self.dataset.add_filter('is_flux_cube')
self.dataset.remove_filter('layer_in_spectrum_viewer')
else:
self.dataset.add_filter('layer_in_spectrum_viewer')
self.dataset.remove_filter('is_flux_cube')
self.dataset._clear_cache()

@property
def user_api(self):
expose = ['dataset']
Expand Down Expand Up @@ -309,6 +299,36 @@ def _warn_if_no_equation(self):
else:
return False

def _update_viewer_filters(self, event={}):
if event.get('new', self.cube_fit):
# only want image viewers in the options
self.add_results.viewer.filters = ['is_image_viewer']
else:
# only want spectral viewers in the options
self.add_results.viewer.filters = ['is_spectrum_viewer']

@observe('cube_fit')
def _cube_fit_changed(self, event={}):
self._update_viewer_filters(event=event)

sb_unit = self.app._get_display_unit('sb')
spectral_y_unit = self.app._get_display_unit('spectral_y')
if event.get('new'):
self._units['y'] = sb_unit
self.dataset.add_filter('is_flux_cube')
self.dataset.remove_filter('layer_in_spectrum_viewer')
else:
self._units['y'] = spectral_y_unit
self.dataset.add_filter('layer_in_spectrum_viewer')
self.dataset.remove_filter('is_flux_cube')

self.dataset._clear_cache()
if sb_unit != spectral_y_unit:
# We make the user hit the reestimate button themselves
for model_index, comp_model in enumerate(self.component_models):
self.component_models[model_index]["compat_display_units"] = False
self.send_state('component_models')

@observe("dataset_selected")
def _dataset_selected_changed(self, event=None):
"""
Expand All @@ -335,11 +355,6 @@ def _dataset_selected_changed(self, event=None):
# Replace NaNs from collapsed Spectrum1D in Cubeviz
# (won't affect calculations because these locations are masked)
selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0
# TODO: can we simplify this logic?
self._units["x"] = str(
selected_spec.spectral_axis.unit)
self._units["y"] = str(
selected_spec.flux.unit)

def _default_comp_label(self, model, poly_order=None):
abbrevs = {'BlackBody': 'BB', 'PowerLaw': 'PL', 'Lorentz1D': 'Lo'}
Expand Down Expand Up @@ -454,6 +469,16 @@ def _initialize_model_component(self, model_comp, comp_label, poly_order=None):
"parameters": [], "model_kwargs": {}}
model_cls = MODELS[model_comp]

# Need to set the units the first time we initialize a model component, after this
# we listen for display unit changes
if (self._units is None or self._units == {} or 'x' not in self._units or
'y' not in self._units):
self._units['x'] = self.app._get_display_unit('spectral')
if self.cube_fit:
self._units['y'] = self.app._get_display_unit('sb')
else:
self._units['y'] = self.app._get_display_unit('spectral_y')

if model_comp == "Polynomial1D":
# self.poly_order is the value in the widget for creating
# the new model component. We need to store that with the
Expand Down Expand Up @@ -482,18 +507,43 @@ def _initialize_model_component(self, model_comp, comp_label, poly_order=None):

initial_values[param_name] = initial_val

masked_spectrum = self._apply_subset_masks(self.dataset.selected_spectrum,
self.spectral_subset)
if self.cube_fit:
# We need to input the whole cube when initializing the model so the units are correct.
if self.dataset_selected in self.app.data_collection.labels:
data = self.app.data_collection[self.dataset_selected].get_object(statistic=None)
else: # User selected some subset from spectrum viewer, just use original cube
data = self.app.data_collection[0].get_object(statistic=None)
masked_spectrum = self._apply_subset_masks(data, self.spectral_subset)
else:
masked_spectrum = self._apply_subset_masks(self.dataset.selected_spectrum,
self.spectral_subset)
mask = masked_spectrum.mask
if mask is not None:
if mask.ndim == 3:
spectral_mask = mask.all(axis=(0, 1))
else:
spectral_mask = mask
init_x = masked_spectrum.spectral_axis[~spectral_mask]
orig_flux_shape = masked_spectrum.flux.shape
init_y = masked_spectrum.flux[~mask]
if mask.ndim == 3:
init_y = init_y.reshape(orig_flux_shape[0],
orig_flux_shape[1],
len(init_x))
else:
init_x = masked_spectrum.spectral_axis
init_y = masked_spectrum.flux

init_y = init_y.to(self._units['y'], u.spectral_density(init_x))

initialized_model = initialize(
MODELS[model_comp](name=comp_label,
**initial_values,
**new_model.get("model_kwargs", {})),
masked_spectrum.spectral_axis[~mask] if mask is not None else masked_spectrum.spectral_axis, # noqa
masked_spectrum.flux[~mask] if mask is not None else masked_spectrum.flux)
init_x, init_y)

# need to loop over parameters again as the initializer may have overridden
# the original default value.
# the original default value. However, if we toggled cube_fit, we may need to override
for param_name in get_model_parameters(model_cls, new_model["model_kwargs"]):
param_quant = getattr(initialized_model, param_name)
new_model["parameters"].append({"name": param_name,
Expand Down Expand Up @@ -535,6 +585,15 @@ def _on_global_display_unit_changed(self, msg):
else:
return

if axis == 'y' and self.cube_fit:
# The units have to be in surface brightness for a cube fit.
uc = self.app._jdaviz_helper.plugins['Unit Conversion']

if msg.unit != uc._obj.sb_unit_selected:
self._units[axis] = uc._obj.sb_unit_selected
self._check_model_component_compat([axis], [u.Unit(uc._obj.sb_unit_selected)])
return

# update internal tracking of current units
self._units[axis] = str(msg.unit)

Expand Down Expand Up @@ -753,15 +812,6 @@ def _set_default_results_label(self, event={}):
def _set_residuals_label_default(self, event={}):
self.residuals_label_default = self.results_label+" residuals"

@observe("cube_fit")
def _update_viewer_filters(self, event={}):
if event.get('new', self.cube_fit):
# only want image viewers in the options
self.add_results.viewer.filters = ['is_image_viewer']
else:
# only want spectral viewers in the options
self.add_results.viewer.filters = ['is_spectrum_viewer']

@with_spinner()
def calculate_fit(self, add_data=True):
"""
Expand Down Expand Up @@ -907,6 +957,10 @@ def _fit_model_to_cube(self, add_data):
else:
spec = data.get_object(cls=Spectrum1D, statistic=None)

sb_unit = self.app._get_display_unit('sb')
if spec.flux.unit != sb_unit:
spec = spec.with_flux_unit(sb_unit)

snackbar_message = SnackbarMessage(
"Fitting model to cube...",
loading=True, sender=self)
Expand Down
47 changes: 47 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_model_ids(cubeviz_helper, spectral_cube_wcs):
plugin.vue_add_model({})


@pytest.mark.skip(reason="Needs #3156 after merging #3190")
def test_parameter_retrieval(cubeviz_helper, spectral_cube_wcs):
flux = np.ones((3, 4, 5))
flux[2, 2, :] = [1, 2, 3, 4, 5]
Expand Down Expand Up @@ -381,6 +382,7 @@ def test_incompatible_units(specviz_helper, spectrum1d):
mf.calculate_fit(add_data=True)


@pytest.mark.skip(reason="Needs #3156 after merging #3190")
def test_cube_fit_with_nans(cubeviz_helper):
flux = np.ones((7, 8, 9)) * u.nJy
flux[:, :, 0] = np.nan
Expand All @@ -396,7 +398,12 @@ def test_cube_fit_with_nans(cubeviz_helper):
result = cubeviz_helper.app.data_collection['model']
assert np.all(result.get_component("flux").data == 1)

# Switch back to non-cube fit, check that units are marked incompatible
mf.cube_fit = False
assert mf._obj.component_models[0]['compat_display_units'] is False


@pytest.mark.skip(reason="Needs #3156 after merging #3190")
def test_cube_fit_with_subset_and_nans(cubeviz_helper):
# Also test with existing mask
flux = np.ones((7, 8, 9)) * u.nJy
Expand All @@ -417,3 +424,43 @@ def test_cube_fit_with_subset_and_nans(cubeviz_helper):
mf.calculate_fit()
result = cubeviz_helper.app.data_collection['model']
assert np.all(result.get_component("flux").data == 1)


def test_cube_fit_after_unit_change(cubeviz_helper, spectrum1d_cube_fluxunit_jy_per_steradian):
cubeviz_helper.load_data(spectrum1d_cube_fluxunit_jy_per_steradian, data_label="test")

uc = cubeviz_helper.plugins['Unit Conversion']
mf = cubeviz_helper.plugins['Model Fitting']
uc.flux_unit = "MJy"
mf.cube_fit = True

mf.create_model_component("Const1D")
# Check that the parameter is using the current units when initialized
assert mf._obj.component_models[0]['parameters'][0]['unit'] == 'MJy / sr'

with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit()

expected_result_slice = np.array([[9.00e-05, 9.50e-05, 1.00e-04, 1.05e-04],
[9.10e-05, 9.60e-05, 1.01e-04, 1.06e-04],
[9.20e-05, 9.70e-05, 1.02e-04, 1.07e-04],
[9.30e-05, 9.80e-05, 1.03e-04, 1.08e-04],
[9.40e-05, 9.90e-05, 1.04e-04, 1.09e-04]])

model_flux = cubeviz_helper.app.data_collection[-1].get_component('flux')
assert model_flux.units == 'MJy / sr'
assert np.allclose(model_flux.data[:, :, 1], expected_result_slice)

# Switch back to Jy, see that the component didn't change but the output does
uc.flux_unit = 'Jy'
assert mf._obj.component_models[0]['parameters'][0]['unit'] == 'MJy / sr'
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='Model is linear in parameters.*')
mf.calculate_fit()

model_flux = cubeviz_helper.app.data_collection[-1].get_component('flux')
assert model_flux.units == 'Jy / sr'
assert np.allclose(model_flux.data[:, :, 1], expected_result_slice * 1e6)

# ToDo: Add a test for a unit change that needs an equivalency
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_register_model_uncertainty_is_none(specviz_helper, spectrum1d):
assert np.allclose(param["std"], expected_uncertainties[param["name"]], rtol=0.01)


@pytest.mark.skip(reason="Needs #3156 after merging #3190")
def test_register_cube_model(cubeviz_helper, spectrum1d_cube):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
Expand All @@ -155,6 +156,7 @@ def test_register_cube_model(cubeviz_helper, spectrum1d_cube):
assert test_label in cubeviz_helper.app.data_collection


@pytest.mark.skip(reason="Needs #3156 after merging #3190")
def test_fit_cube_no_wcs(cubeviz_helper):
# This is like when user do something to a cube outside of Jdaviz
# and then load it back into a new instance of Cubeviz for further analysis.
Expand All @@ -163,6 +165,8 @@ def test_fit_cube_no_wcs(cubeviz_helper):
mf = cubeviz_helper.plugins['Model Fitting']
mf.create_model_component('Linear1D')
mf.cube_fit = True
# Need to manually reestimate the parameters to update the units
mf.reestimate_model_parameters()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Model is linear in parameters.*")
fitted_model, output_cube = mf.calculate_fit(add_data=True)
Expand Down

0 comments on commit 91c6ff3

Please sign in to comment.