Skip to content

Commit

Permalink
enable clone viewer for image/TPF viewer (#101)
Browse files Browse the repository at this point in the history
* enable clone viewer tool in image viewer

(currently results in traceback, as shown in regression test)

* include cube viewer in viewer creator

* generalize clone viewer logic to work for multliple viewer types

and to also adopt layer plot options

* remove cube from viewer creator when no cube data remaining
  • Loading branch information
kecnry authored Apr 4, 2024
1 parent 74e88e9 commit 2a8a2f9
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 42 deletions.
7 changes: 7 additions & 0 deletions lcviz/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def default_time_viewer(self):
raise ValueError("no time viewers exist")
return tvs[0].user_api

@property
def _has_cube_data(self):
for data in self.app.data_collection:
if data.ndim == 3:
return True
return False

@property
def _tray_tools(self):
"""
Expand Down
23 changes: 19 additions & 4 deletions lcviz/plugins/viewer_creator/viewer_creator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from glue.core.message import (DataCollectionAddMessage,
DataCollectionDeleteMessage)
from jdaviz.configs.default.plugins import ViewerCreator
from jdaviz.core.events import NewViewerMessage
from jdaviz.core.registries import tool_registry
from lcviz.events import EphemerisComponentChangedMessage
from lcviz.viewers import TimeScatterView
from lcviz.viewers import TimeScatterView, CubeView

__all__ = ['ViewerCreator']

Expand All @@ -12,8 +14,11 @@ class ViewerCreator(ViewerCreator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.hub.subscribe(self, EphemerisComponentChangedMessage,
handler=self._rebuild_available_viewers)
for msg in (EphemerisComponentChangedMessage,
DataCollectionAddMessage,
DataCollectionDeleteMessage):
self.hub.subscribe(self, msg,
handler=lambda x: self._rebuild_available_viewers())
self._rebuild_available_viewers()

def _rebuild_available_viewers(self, *args):
Expand All @@ -24,13 +29,18 @@ def _rebuild_available_viewers(self, *args):
if self.app._jdaviz_helper is not None:
phase_viewers = [{'name': f'lcviz-phase-viewer:{e}', 'label': f'flux-vs-phase:{e}'}
for e in self.app._jdaviz_helper.plugins['Ephemeris'].component.choices] # noqa
if self.app._jdaviz_helper._has_cube_data:
cube_viewers = [{'name': 'lcviz-cube-viewer', 'label': 'image'}]
else:
cube_viewers = []
else:
phase_viewers = [{'name': 'lcviz-phase-viewer:default',
'label': 'flux-vs-phase:default'}]
cube_viewers = []

self.viewer_types = [v for v in self.viewer_types if v['name'].startswith('lcviz')
and not v['label'].startswith('flux-vs-phase')
and not v['label'] == 'cube'] + phase_viewers
and not v['label'] in ('cube', 'image')] + phase_viewers + cube_viewers
self.send_state('viewer_types')

def vue_create_viewer(self, name):
Expand All @@ -45,5 +55,10 @@ def vue_create_viewer(self, name):
self.app._on_new_viewer(NewViewerMessage(TimeScatterView, data=None, sender=self.app),
vid=viewer_id, name=viewer_id)
return
if name in ('image', 'lcviz-cube-viewer'):
viewer_id = self.app._jdaviz_helper._get_clone_viewer_reference('image')
self.app._on_new_viewer(NewViewerMessage(CubeView, data=None, sender=self.app),
vid=viewer_id, name=viewer_id)
return

super().vue_create_viewer(name)
17 changes: 17 additions & 0 deletions lcviz/tests/test_tray_viewer_creator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import pytest


@pytest.mark.remote_data
def test_tray_viewer_creator(helper, light_curve_like_kepler_quarter):
# additional coverage in test_plugin_ephemeris
helper.load_data(light_curve_like_kepler_quarter)
Expand All @@ -7,3 +11,16 @@ def test_tray_viewer_creator(helper, light_curve_like_kepler_quarter):
assert len(vc.viewer_types) == 2 # time and default phase
vc.vue_create_viewer('flux-vs-time')
assert len(helper.viewers) == 2

# TODO: replace with test fixture
from lightkurve import search_targetpixelfile
tpf = search_targetpixelfile("KIC 001429092",
mission="Kepler",
cadence="long",
quarter=10).download()
helper.load_data(tpf)
assert len(helper.viewers) == 3 # image viewer added by default

assert len(vc.viewer_types) == 3 # time, default phase, cube
vc.vue_create_viewer('image')
assert len(helper.viewers) == 4
14 changes: 14 additions & 0 deletions lcviz/tests/test_viewers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest


def test_reset_limits(helper, light_curve_like_kepler_quarter):
helper.load_data(light_curve_like_kepler_quarter)
Expand All @@ -19,6 +21,7 @@ def test_reset_limits(helper, light_curve_like_kepler_quarter):
assert tv.state.y_min == orig_ylims[0]


@pytest.mark.remote_data
def test_clone(helper, light_curve_like_kepler_quarter):
helper.load_data(light_curve_like_kepler_quarter)

Expand All @@ -27,3 +30,14 @@ def test_clone(helper, light_curve_like_kepler_quarter):

new_viewer = def_viewer._obj.clone_viewer()
assert helper._get_clone_viewer_reference(new_viewer._obj.reference) == 'flux-vs-time[2]'

# TODO: replace with test fixture
from lightkurve import search_targetpixelfile
tpf = search_targetpixelfile("KIC 001429092",
mission="Kepler",
cadence="long",
quarter=10).download()
helper.load_data(tpf)
im_viewer = helper.viewers['image']
assert helper._get_clone_viewer_reference(im_viewer._obj.reference) == 'image[1]'
im_viewer._obj.clone_viewer()
55 changes: 17 additions & 38 deletions lcviz/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,34 @@ def _get_clone_viewer_reference(self):
return name

def clone_viewer(self):
name = self._get_clone_viewer_reference()
name = self.jdaviz_helper._get_clone_viewer_reference(self.reference)

self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__,
data=None,
sender=self.jdaviz_app),
vid=name, name=name)

this_viewer_item = self.jdaviz_app._get_viewer_item(self.reference)
this_state = self.state.as_dict()
for data in self.jdaviz_app.data_collection:
data_id = self.jdaviz_app._data_id_from_label(data.label)
visible = this_viewer_item['selected_data_items'].get(data_id, 'hidden')
self.jdaviz_app.set_data_visibility(name, data.label, visible == 'visible')
for data_id, visible in this_viewer_item['selected_data_items'].items():
data_label = data_label = self.jdaviz_app._get_data_item_by_id(data_id)['name']
self.jdaviz_app.set_data_visibility(name, data_label, visible == 'visible')
# TODO: don't revert color when adding same data to a new viewer
# (same happens when creating a phase-viewer from ephemeris plugin)

new_viewer = self.jdaviz_helper.viewers[name]._obj
for k, v in this_state.items():
new_viewer = self.jdaviz_app.get_viewer(name)
if hasattr(self, 'ephemeris_component'):
new_viewer._ephemeris_component = self._ephemeris_component
for k, v in self.state.as_dict().items():
if k in ('layers',):
continue
setattr(new_viewer.state, k, v)

for this_layer_state, new_layer_state in zip(self.state.layers, new_viewer.state.layers):
for k, v in this_layer_state.as_dict().items():
if k in ('layer',):
continue
setattr(new_layer_state, k, v)

return new_viewer.user_api


Expand Down Expand Up @@ -253,33 +259,6 @@ def apply_roi(self, roi, use_current=False):

super().apply_roi(roi, use_current=use_current)

def clone_viewer(self):
name = self.jdaviz_helper._get_clone_viewer_reference(self.reference)

self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__,
data=None,
sender=self.jdaviz_app),
vid=name, name=name)

this_viewer_item = self.jdaviz_app._get_viewer_item(self.reference)
this_state = self.state.as_dict()
for data in self.jdaviz_app.data_collection:
data_id = self.jdaviz_app._data_id_from_label(data.label)
visible = this_viewer_item['selected_data_items'].get(data_id, 'hidden')
self.jdaviz_app.set_data_visibility(name, data.label, visible == 'visible')
# TODO: don't revert color when adding same data to a new viewer
# (same happens when creating a phase-viewer from ephemeris plugin)

new_viewer = self.jdaviz_app.get_viewer(name)
if hasattr(self, 'ephemeris_component'):
new_viewer._ephemeris_component = self._ephemeris_component
for k, v in this_state.items():
if k in ('layers',):
continue
setattr(new_viewer.state, k, v)

return new_viewer.user_api


@viewer_registry("lcviz-phase-viewer", label="flux-vs-phase")
class PhaseScatterView(TimeScatterView):
Expand Down Expand Up @@ -320,7 +299,7 @@ class CubeView(CloneViewerMixin, CubevizImageView, WithSliceSelection):
['jdaviz:boxzoom'],
['jdaviz:panzoom'],
['bqplot:rectangle'],
['jdaviz:sidebar_plot', 'jdaviz:sidebar_export']
['lcviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export']
]
# TODO: can we vary this default_class based on Kepler vs TESS, etc?
# see https://github.com/spacetelescope/lcviz/pull/81#discussion_r1469721009
Expand All @@ -337,8 +316,8 @@ def __init__(self, *args, **kwargs):
# Hide axes by default
self.state.show_axes = False

# TODO: refactor upstream so lcviz can inherit cubeviewer methods/setup with jdaviz-specific
# logic:
# TODO: refactor upstream so lcviz can inherit cubeviewer methods/setup without
# jdaviz-specific logic:
# * _default_spectrum_viewer_reference_name
# * _default_flux_viewer_reference_name
# * _default_uncert_viewer_reference_name
Expand Down

0 comments on commit 2a8a2f9

Please sign in to comment.