From 6982333f6e02da20a2d772659098bcc6276376c4 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Fri, 5 Apr 2024 15:35:32 -0400 Subject: [PATCH] Feature: TPF support (#82) * Adding TPF translator and parser (#75) * TPF viewer (#81) * make use of upstream refactor to override indices in lcviz (#83) * fix creating phase-viewer when TPF is loaded (#86) * Time Selector (adapted version of cubeviz's slice) plugin (#85) * enable clone viewer for image/TPF viewer (#101) --------- Co-authored-by: Brett M. Morris --- CHANGES.rst | 2 + docs/plugins.rst | 34 ++ docs/reference/api_plugins.rst | 3 + lcviz/helper.py | 10 +- lcviz/marks.py | 12 +- lcviz/parsers.py | 51 ++- lcviz/plugins/__init__.py | 1 + lcviz/plugins/coords_info/coords_info.py | 30 +- lcviz/plugins/ephemeris/ephemeris.py | 3 + lcviz/plugins/time_selector/__init__.py | 1 + lcviz/plugins/time_selector/time_selector.py | 82 +++++ .../plugins/viewer_creator/viewer_creator.py | 24 +- lcviz/tests/test_parser.py | 13 +- lcviz/tests/test_tray_viewer_creator.py | 17 + lcviz/tests/test_viewers.py | 14 + lcviz/utils.py | 300 +++++++++++++++++- lcviz/viewers.py | 168 ++++++++-- 17 files changed, 700 insertions(+), 65 deletions(-) create mode 100644 lcviz/plugins/time_selector/__init__.py create mode 100644 lcviz/plugins/time_selector/time_selector.py diff --git a/CHANGES.rst b/CHANGES.rst index 9a2b8fad..8b6883fc 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,8 @@ 0.4.0 - unreleased ------------------ +* Support loading, viewing, and slicing through TPF data cubes. [#82] + 0.3.0 - (04-05-2024) -------------------- diff --git a/docs/plugins.rst b/docs/plugins.rst index aae63e16..3981ec1d 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -179,6 +179,40 @@ visible when the plugin is opened. Jdaviz documentation on the Markers plugin. +.. _time-selector: + +Time Selector +============== + +The time selector plugin allows defining the time indicated in all light curve viewers +(time and phase viewers) as well as the time at which all image cubes are displayed. + + +.. admonition:: User API Example + :class: dropdown + + See the :class:`~lcviz.plugins.time_selector.time_selector.TimeSelector` user API documentation for more details. + + .. code-block:: python + + from lcviz import LCviz + lc = search_lightcurve("HAT-P-11", mission="Kepler", + cadence="long", quarter=10).download().flatten() + lcviz = LCviz() + lcviz.load_data(lc) + lcviz.show() + + ts = lcviz.plugins['Time Selector'] + ts.open_in_tray() + + +.. seealso:: + + :ref:`Jdaviz Slice Plugin ` + Jdaviz documentation on the Slice plugin. + + + .. _flatten: Flatten diff --git a/docs/reference/api_plugins.rst b/docs/reference/api_plugins.rst index 56149ce7..788ad683 100644 --- a/docs/reference/api_plugins.rst +++ b/docs/reference/api_plugins.rst @@ -29,3 +29,6 @@ Plugins API .. automodapi:: lcviz.plugins.subset_plugin.subset_plugin :no-inheritance-diagram: + +.. automodapi:: lcviz.plugins.time_selector.time_selector + :no-inheritance-diagram: diff --git a/lcviz/helper.py b/lcviz/helper.py index ca34d032..79f9c25a 100644 --- a/lcviz/helper.py +++ b/lcviz/helper.py @@ -69,7 +69,8 @@ class LCviz(ConfigHelper): 'toolbar': ['g-data-tools', 'g-subset-tools', 'g-viewer-creator', 'lcviz-coords-info'], 'tray': ['lcviz-metadata-viewer', 'flux-column', 'lcviz-plot-options', 'lcviz-subset-plugin', - 'lcviz-markers', 'flatten', 'frequency-analysis', 'ephemeris', + 'lcviz-markers', 'time-selector', + 'flatten', 'frequency-analysis', 'ephemeris', 'binning', 'lcviz-export'], 'viewer_area': [{'container': 'col', 'children': [{'container': 'row', @@ -150,6 +151,13 @@ def default_time_viewer(self): raise ValueError("no time viewers exist") return tvs[0].user_api + @property + def _has_cube_data(self): + for data in self.app.data_collection: + if data.ndim == 3: + return True + return False + @property def _tray_tools(self): """ diff --git a/lcviz/marks.py b/lcviz/marks.py index b5c89da5..28d3b736 100644 --- a/lcviz/marks.py +++ b/lcviz/marks.py @@ -1,11 +1,21 @@ +from astropy import units as u import numpy as np -from jdaviz.core.marks import PluginLine, PluginScatter +from jdaviz.core.marks import PluginLine, PluginScatter, SliceIndicatorMarks from lcviz.viewers import PhaseScatterView __all__ = ['LivePreviewTrend', 'LivePreviewFlattened', 'LivePreviewBinning'] +def _slice_indicator_get_slice_axis(self, data): + if hasattr(data, 'time'): + return data.time.value * u.d + return [] * u.dimensionless_unscaled + + +SliceIndicatorMarks._get_slice_axis = _slice_indicator_get_slice_axis + + class WithoutPhaseSupport: def update_ty(self, times, y): self.times = np.asarray(times) diff --git a/lcviz/parsers.py b/lcviz/parsers.py index 70d8673d..c29a41e4 100644 --- a/lcviz/parsers.py +++ b/lcviz/parsers.py @@ -10,6 +10,13 @@ @data_parser_registry("light_curve_parser") def light_curve_parser(app, file_obj, data_label=None, show_in_viewer=True, **kwargs): + # load a LightCurve or TargetPixelFile object: + cls_with_translator = ( + lightkurve.LightCurve, + lightkurve.targetpixelfile.KeplerTargetPixelFile, + lightkurve.targetpixelfile.TessTargetPixelFile + ) + # load local FITS file from disk by its path: if isinstance(file_obj, str) and os.path.exists(file_obj): if data_label is None: @@ -18,8 +25,7 @@ def light_curve_parser(app, file_obj, data_label=None, show_in_viewer=True, **kw # read the light curve: light_curve = lightkurve.read(file_obj) - # load a LightCurve object: - elif isinstance(file_obj, lightkurve.LightCurve): + elif isinstance(file_obj, cls_with_translator): light_curve = file_obj # make a data label: @@ -30,7 +36,12 @@ def light_curve_parser(app, file_obj, data_label=None, show_in_viewer=True, **kw # handle flux_origin default flux_origin = light_curve.meta.get('FLUX_ORIGIN', None) # i.e. PDCSAP or SAP - if flux_origin == 'flux' or (flux_origin is None and 'flux' in light_curve.columns): + if isinstance(light_curve, lightkurve.targetpixelfile.TargetPixelFile): + new_data_label += '[TPF]' + elif flux_origin is not None: + new_data_label += f'[{flux_origin}]' + + if flux_origin == 'flux' or (flux_origin is None and 'flux' in getattr(light_curve, 'columns', [])): # noqa # then make a copy of this column so it won't be lost when changing with the flux_column # plugin light_curve['flux:orig'] = light_curve['flux'] @@ -41,13 +52,33 @@ def light_curve_parser(app, file_obj, data_label=None, show_in_viewer=True, **kw data = _data_with_reftime(app, light_curve) app.add_data(data, new_data_label) - if show_in_viewer: - # add to any known time/phase viewers - for viewer_id, viewer in app._viewer_store.items(): - if isinstance(viewer, TimeScatterView): - app.add_data_to_viewer(viewer_id, new_data_label) - elif isinstance(viewer, PhaseScatterView): - app.add_data_to_viewer(viewer_id, new_data_label) + if isinstance(light_curve, lightkurve.targetpixelfile.TargetPixelFile): + # ensure an image/cube/TPF viewer exists + # TODO: move this to an event listener on add_data so that we can also remove when empty? + from jdaviz.core.events import NewViewerMessage + from lcviz.viewers import CubeView + if show_in_viewer: + found_viewer = False + for viewer_id, viewer in app._viewer_store.items(): + if isinstance(viewer, CubeView): + app.add_data_to_viewer(viewer_id, new_data_label) + found_viewer = True + if not found_viewer: + app._on_new_viewer(NewViewerMessage(CubeView, data=None, sender=app), + vid='image', name='image') + app.add_data_to_viewer('image', new_data_label) + + else: + if show_in_viewer: + for viewer_id, viewer in app._viewer_store.items(): + if isinstance(viewer, (TimeScatterView, PhaseScatterView)): + app.add_data_to_viewer(viewer_id, new_data_label) + + # add to any known phase viewers + ephem_plugin = app._jdaviz_helper.plugins.get('Ephemeris', None) + if ephem_plugin is not None: + for viewer in ephem_plugin._obj._get_phase_viewers(): + app.add_data_to_viewer(viewer.reference, new_data_label) def _data_with_reftime(app, light_curve): diff --git a/lcviz/plugins/__init__.py b/lcviz/plugins/__init__.py index 9fbeef06..08d8668a 100644 --- a/lcviz/plugins/__init__.py +++ b/lcviz/plugins/__init__.py @@ -8,6 +8,7 @@ from .flux_column.flux_column import * # noqa from .frequency_analysis.frequency_analysis import * # noqa from .markers.markers import * # noqa +from .time_selector.time_selector import * # noqa from .metadata_viewer.metadata_viewer import * # noqa from .plot_options.plot_options import * # noqa from .subset_plugin.subset_plugin import * # noqa diff --git a/lcviz/plugins/coords_info/coords_info.py b/lcviz/plugins/coords_info/coords_info.py index 156c8dc8..6e442b39 100644 --- a/lcviz/plugins/coords_info/coords_info.py +++ b/lcviz/plugins/coords_info/coords_info.py @@ -5,14 +5,14 @@ from jdaviz.core.events import ViewerRenamedMessage from jdaviz.core.registries import tool_registry -from lcviz.viewers import TimeScatterView, PhaseScatterView +from lcviz.viewers import TimeScatterView, PhaseScatterView, CubeView __all__ = ['CoordsInfo'] @tool_registry('lcviz-coords-info') class CoordsInfo(CoordsInfo): - _supported_viewer_classes = (TimeScatterView, PhaseScatterView) + _supported_viewer_classes = (TimeScatterView, PhaseScatterView, CubeView) _viewer_classes_with_marker = (TimeScatterView, PhaseScatterView) def __init__(self, *args, **kwargs): @@ -25,12 +25,19 @@ def __init__(self, *args, **kwargs): def _viewer_renamed(self, msg): self._marks[msg.new_viewer_ref] = self._marks.pop(msg.old_viewer_ref) - def update_display(self, viewer, x, y): - self._dict = {} + def _image_shape_inds(self, image): + if image.ndim == 3: + # exception to the upstream cubeviz case of (0, 1) + return (2, 1) + return super()._image_shape_inds(image) - if not len(viewer.state.layers): - return + def _get_cube_value(self, image, arr, x, y, viewer): + if image.ndim == 3: + # exception to the upstream cubeviz case of x, y, slice + return arr[viewer.state.slices[0], int(round(y)), int(round(x))] + return super()._get_cube_value(image, arr, x, y, viewer) + def _lc_viewer_update(self, viewer, x, y): is_phase = isinstance(viewer, PhaseScatterView) # TODO: update with display_unit when supported in lcviz x_unit = '' if is_phase else str(viewer.time_unit) @@ -138,3 +145,14 @@ def _cursor_fallback(): self.marks[viewer._reference_id].update_xy([closest_x], [closest_y]) # noqa self.marks[viewer._reference_id].visible = True + + def update_display(self, viewer, x, y): + self._dict = {} + + if not len(viewer.state.layers): + return + + if isinstance(viewer, (TimeScatterView, PhaseScatterView)): + self._lc_viewer_update(viewer, x, y) + elif isinstance(viewer, CubeView): + self._image_viewer_update(viewer, x, y) diff --git a/lcviz/plugins/ephemeris/ephemeris.py b/lcviz/plugins/ephemeris/ephemeris.py index 2f263ac0..8bc5dcdb 100644 --- a/lcviz/plugins/ephemeris/ephemeris.py +++ b/lcviz/plugins/ephemeris/ephemeris.py @@ -307,6 +307,9 @@ def create_phase_viewer(self, ephem_component=None): # set default data visibility time_viewer_item = self.app._get_viewer_item(self.app._jdaviz_helper.default_time_viewer._obj.reference) # noqa for data in dc: + if data.ndim > 1: + # skip image/cube entries + continue data_id = self.app._data_id_from_label(data.label) visible = time_viewer_item['selected_data_items'].get(data_id, 'hidden') self.app.set_data_visibility(phase_viewer_id, data.label, visible == 'visible') diff --git a/lcviz/plugins/time_selector/__init__.py b/lcviz/plugins/time_selector/__init__.py new file mode 100644 index 00000000..0983e6d5 --- /dev/null +++ b/lcviz/plugins/time_selector/__init__.py @@ -0,0 +1 @@ +from .time_selector import * # noqa diff --git a/lcviz/plugins/time_selector/time_selector.py b/lcviz/plugins/time_selector/time_selector.py new file mode 100644 index 00000000..da1f40d8 --- /dev/null +++ b/lcviz/plugins/time_selector/time_selector.py @@ -0,0 +1,82 @@ +from jdaviz.configs.cubeviz.plugins import Slice +from jdaviz.core.registries import tray_registry + +from lcviz.events import EphemerisChangedMessage +from lcviz.viewers import CubeView, PhaseScatterView + +__all__ = ['TimeSelector'] + + +@tray_registry('time-selector', label="Time Selector") +class TimeSelector(Slice): + """ + See the :ref:`Time Selector Plugin Documentation ` for more details. + + Only the following attributes and methods are available through the + :ref:`public plugin API `: + + * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.show` + * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.open_in_tray` + * :meth:`~jdaviz.core.template_mixin.PluginTemplateMixin.close_in_tray` + * ``value`` Time of the indicator. When setting this directly, it will + update automatically to the value corresponding to the nearest slice, if ``snap_to_slice`` is + enabled and a cube is loaded. + * ``show_indicator`` + Whether to show indicator in spectral viewer when slice tool is inactive. + * ``show_value`` + Whether to show slice value in label to right of indicator. + * ``snap_to_slice`` + Whether the indicator (and ``value``) should snap to the value of the nearest slice in the + cube (if one exists). + """ + _cube_viewer_cls = CubeView + _cube_viewer_default_label = 'image' + + def __init__(self, *args, **kwargs): + """ + + """ + super().__init__(*args, **kwargs) + self.docs_link = f"https://lcviz.readthedocs.io/en/{self.vdocs}/plugins.html#time-selector" + self.docs_description = "Select time to sync across all viewers (as an indicator in all time/phase viewers or to select the active slice in any image/cube viewers). The slice can also be changed interactively in any time viewer by activating the slice tool." # noqa + self.value_label = 'Time' + self.value_unit = 'd' + self.allow_disable_snapping = True + + self.session.hub.subscribe(self, EphemerisChangedMessage, + handler=self._on_ephemeris_changed) + + @property + def slice_axis(self): + # global display unit "axis" corresponding to the slice axis + return 'time' + + @property + def valid_slice_att_names(self): + return ["time", "dt"] + + @property + def user_api(self): + api = super().user_api + # can be removed after deprecated upstream attributes for wavelength/wavelength_value + # are removed in the lowest supported version of jdaviz + api._expose = [e for e in api._expose if e not in ('slice', 'wavelength', + 'wavelength_value', 'show_wavelength')] + return api + + def _on_select_slice_message(self, msg): + viewer = msg.sender.viewer + if isinstance(viewer, PhaseScatterView): + prev_phase = viewer.times_to_phases(self.value) + new_phase = msg.value + self.value = self.value + (new_phase - prev_phase) * viewer.ephemeris.get('period', 1.0) + else: + super()._on_select_slice_message(msg) + + def _on_ephemeris_changed(self, msg): + for viewer in self.slice_indicator_viewers: + if not isinstance(viewer, PhaseScatterView): + continue + if viewer._ephemeris_component != msg.ephemeris_label: + continue + viewer._set_slice_indicator_value(self.value) diff --git a/lcviz/plugins/viewer_creator/viewer_creator.py b/lcviz/plugins/viewer_creator/viewer_creator.py index 69b29298..2c401f1c 100644 --- a/lcviz/plugins/viewer_creator/viewer_creator.py +++ b/lcviz/plugins/viewer_creator/viewer_creator.py @@ -1,8 +1,10 @@ +from glue.core.message import (DataCollectionAddMessage, + DataCollectionDeleteMessage) from jdaviz.configs.default.plugins import ViewerCreator from jdaviz.core.events import NewViewerMessage from jdaviz.core.registries import tool_registry from lcviz.events import EphemerisComponentChangedMessage -from lcviz.viewers import TimeScatterView +from lcviz.viewers import TimeScatterView, CubeView __all__ = ['ViewerCreator'] @@ -13,8 +15,11 @@ class ViewerCreator(ViewerCreator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.hub.subscribe(self, EphemerisComponentChangedMessage, - handler=self._rebuild_available_viewers) + for msg in (EphemerisComponentChangedMessage, + DataCollectionAddMessage, + DataCollectionDeleteMessage): + self.hub.subscribe(self, msg, + handler=lambda x: self._rebuild_available_viewers()) self._rebuild_available_viewers() def _rebuild_available_viewers(self, *args): @@ -25,12 +30,18 @@ def _rebuild_available_viewers(self, *args): if self.app._jdaviz_helper is not None: phase_viewers = [{'name': f'lcviz-phase-viewer:{e}', 'label': f'flux-vs-phase:{e}'} for e in self.app._jdaviz_helper.plugins['Ephemeris'].component.choices] # noqa + if self.app._jdaviz_helper._has_cube_data: + cube_viewers = [{'name': 'lcviz-cube-viewer', 'label': 'image'}] + else: + cube_viewers = [] else: phase_viewers = [{'name': 'lcviz-phase-viewer:default', 'label': 'flux-vs-phase:default'}] + cube_viewers = [] self.viewer_types = [v for v in self.viewer_types if v['name'].startswith('lcviz') - and not v['label'].startswith('flux-vs-phase')] + phase_viewers + and not v['label'].startswith('flux-vs-phase') + and not v['label'] in ('cube', 'image')] + phase_viewers + cube_viewers self.send_state('viewer_types') def vue_create_viewer(self, name): @@ -45,5 +56,10 @@ def vue_create_viewer(self, name): self.app._on_new_viewer(NewViewerMessage(TimeScatterView, data=None, sender=self.app), vid=viewer_id, name=viewer_id) return + if name in ('image', 'lcviz-cube-viewer'): + viewer_id = self.app._jdaviz_helper._get_clone_viewer_reference('image') + self.app._on_new_viewer(NewViewerMessage(CubeView, data=None, sender=self.app), + vid=viewer_id, name=viewer_id) + return super().vue_create_viewer(name) diff --git a/lcviz/tests/test_parser.py b/lcviz/tests/test_parser.py index e3c4cd20..a37b343e 100644 --- a/lcviz/tests/test_parser.py +++ b/lcviz/tests/test_parser.py @@ -3,7 +3,7 @@ from glue.core.roi import XRangeROI, YRangeROI from astropy.time import Time from astropy.utils.data import download_file -from lightkurve import LightCurve +from lightkurve import LightCurve, KeplerTargetPixelFile, search_targetpixelfile from lightkurve.io import kepler import astropy.units as u @@ -50,6 +50,17 @@ def test_kepler_via_mast_preparsed(helper): assert flux.unit.is_equivalent(u.electron / u.s) +@pytest.mark.remote_data +def test_kepler_tpf_via_lightkurve(helper): + tpf = search_targetpixelfile("KIC 001429092", + mission="Kepler", + cadence="long", + quarter=10).download() + helper.load_data(tpf) + assert helper.get_data().shape == (4447, 4, 6) # (time, x, y) + assert helper.app.data_collection[0].get_object(cls=KeplerTargetPixelFile).shape == (4447, 4, 6) + + def test_synthetic_lc(helper): time = Time(np.linspace(2460050, 2460060), format='jd') flux = np.ones(len(time)) * u.electron / u.s diff --git a/lcviz/tests/test_tray_viewer_creator.py b/lcviz/tests/test_tray_viewer_creator.py index b48fc5b2..26d85fe7 100644 --- a/lcviz/tests/test_tray_viewer_creator.py +++ b/lcviz/tests/test_tray_viewer_creator.py @@ -1,3 +1,7 @@ +import pytest + + +@pytest.mark.remote_data def test_tray_viewer_creator(helper, light_curve_like_kepler_quarter): # additional coverage in test_plugin_ephemeris helper.load_data(light_curve_like_kepler_quarter) @@ -7,3 +11,16 @@ def test_tray_viewer_creator(helper, light_curve_like_kepler_quarter): assert len(vc.viewer_types) == 2 # time and default phase vc.vue_create_viewer('flux-vs-time') assert len(helper.viewers) == 2 + + # TODO: replace with test fixture + from lightkurve import search_targetpixelfile + tpf = search_targetpixelfile("KIC 001429092", + mission="Kepler", + cadence="long", + quarter=10).download() + helper.load_data(tpf) + assert len(helper.viewers) == 3 # image viewer added by default + + assert len(vc.viewer_types) == 3 # time, default phase, cube + vc.vue_create_viewer('image') + assert len(helper.viewers) == 4 diff --git a/lcviz/tests/test_viewers.py b/lcviz/tests/test_viewers.py index e162ebdd..2f98ac7a 100644 --- a/lcviz/tests/test_viewers.py +++ b/lcviz/tests/test_viewers.py @@ -1,3 +1,5 @@ +import pytest + def test_reset_limits(helper, light_curve_like_kepler_quarter): helper.load_data(light_curve_like_kepler_quarter) @@ -19,6 +21,7 @@ def test_reset_limits(helper, light_curve_like_kepler_quarter): assert tv.state.y_min == orig_ylims[0] +@pytest.mark.remote_data def test_clone(helper, light_curve_like_kepler_quarter): helper.load_data(light_curve_like_kepler_quarter) @@ -27,3 +30,14 @@ def test_clone(helper, light_curve_like_kepler_quarter): new_viewer = def_viewer._obj.clone_viewer() assert helper._get_clone_viewer_reference(new_viewer._obj.reference) == 'flux-vs-time[2]' + + # TODO: replace with test fixture + from lightkurve import search_targetpixelfile + tpf = search_targetpixelfile("KIC 001429092", + mission="Kepler", + cadence="long", + quarter=10).download() + helper.load_data(tpf) + im_viewer = helper.viewers['image'] + assert helper._get_clone_viewer_reference(im_viewer._obj.reference) == 'image[1]' + im_viewer._obj.clone_viewer() diff --git a/lcviz/utils.py b/lcviz/utils.py index 2020fc81..a61dc1d9 100644 --- a/lcviz/utils.py +++ b/lcviz/utils.py @@ -1,21 +1,32 @@ from glue.config import data_translator from glue.core import Data, Subset from ipyvue import watch +import warnings import os from glue.core.coordinates import Coordinates from glue.core.component_id import ComponentID import numpy as np from scipy.interpolate import interp1d -from astropy import units as u -from astropy.table import QTable -from astropy.time import Time from lightkurve import ( LightCurve, KeplerLightCurve, TessLightCurve, FoldedLightCurve ) +from lightkurve.targetpixelfile import ( + KeplerTargetPixelFile, TessTargetPixelFile, TargetPixelFileFactory +) +from lightkurve.utils import KeplerQualityFlags, TessQualityFlags + +from astropy import units as u +from astropy.table import QTable +from astropy.time import Time +from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper +from astropy.wcs.wcsapi import HighLevelWCSMixin + +__all__ = ['TimeCoordinates', 'LightCurveHandler', 'data_not_folded', 'enable_hot_reloading'] -__all__ = ['TimeCoordinates', 'LightCurveHandler', 'data_not_folded'] + +component_ids = {'dt': ComponentID('dt')} class TimeCoordinates(Coordinates): @@ -60,12 +71,105 @@ def pixel_to_world_values(self, *pixel): )(pixel[0]) -__all__ = ['LightCurveHandler', 'enable_hot_reloading'] +class PaddedTimeWCS(BaseWCSWrapper, HighLevelWCSMixin): + + # Spectrum1D can use a 1D spectral WCS even for n-dimensional + # datasets while glue always needs the dimensionality to match, + # so this class pads the WCS so that it is n-dimensional. + + # NOTE: This class could be updated to use CompoundLowLevelWCS from NDCube. + + def __init__(self, wcs, times, ndim=3, reference_time=None, unit=u.d): + self.temporal_wcs = TimeCoordinates( + times, reference_time=reference_time, unit=unit + ) + self.spatial_wcs = wcs + self.flux_ndim = ndim + self.spatial_keys = [f"spatial{i}" for i in range(0, self.flux_ndim-1)] + + @property + def time_axis(self): + return self.temporal_wcs.time_axis + + @property + def pixel_n_dim(self): + return self.flux_ndim + + @property + def world_n_dim(self): + return self.flux_ndim + + @property + def world_axis_physical_types(self): + return [self.temporal_wcs.world_axis_physical_types[0], *[None]*(self.flux_ndim-1)] + + @property + def world_axis_units(self): + return (self.temporal_wcs.world_axis_units[0], *[None]*(self.flux_ndim-1)) + + def pixel_to_world_values(self, *pixel_arrays): + # The ravel and reshape are needed because of + # https://github.com/astropy/astropy/issues/12154 + px = np.array(pixel_arrays[0]) + world_arrays = [self.temporal_wcs.pixel_to_world_values(px.ravel()).reshape(px.shape), + *pixel_arrays[1:]] + return tuple(world_arrays) + + def world_to_pixel_values(self, *world_arrays): + # The ravel and reshape are needed because of + # https://github.com/astropy/astropy/issues/12154 + wx = np.array(world_arrays[0]) + pixel_arrays = [self.temporal_wcs.world_to_pixel_values(wx.ravel()).reshape(wx.shape), + *world_arrays[1:]] + return tuple(pixel_arrays) + + @property + def world_axis_object_components(self): + return [self.temporal_wcs.world_axis_object_components[0], + *[(key, 'value', 'value') for key in self.spatial_keys]] + + @property + def world_axis_object_classes(self): + spectral_key = self.temporal_wcs.world_axis_object_components[0][0] + obj_classes = {spectral_key: self.temporal_wcs.world_axis_object_classes[spectral_key]} + for key in self.spatial_keys: + obj_classes[key] = (u.Quantity, (), {'unit': u.pixel}) + + return obj_classes + + @property + def pixel_shape(self): + return None + + @property + def pixel_bounds(self): + return None + + @property + def pixel_axis_names(self): + return tuple([self.temporal_wcs.pixel_axis_names[0], *self.spatial_keys]) + + @property + def world_axis_names(self): + if self.flux_ndim == 2: + names = ['Offset'] + else: + names = [f"Offset{i}" for i in range(0, self.flux_ndim-1)] + + return ({}.get(self.temporal_wcs.world_axis_physical_types[0], ''), + *names) + + @property + def axis_correlation_matrix(self): + return np.identity(self.flux_ndim).astype('bool') + + @property + def serialized_classes(self): + return False @data_translator(LightCurve) class LightCurveHandler: - lc_component_ids = {} def to_data(self, obj, reference_time=None): is_folded = isinstance(obj, FoldedLightCurve) @@ -80,7 +184,7 @@ def to_data(self, obj, reference_time=None): data.meta.update( {"reference_time": time_coord.reference_time} ) - data['dt'] = (obj.time - time_coord.reference_time).to(time_coord.unit) + data[component_ids['dt']] = (obj.time - time_coord.reference_time).to(time_coord.unit) data.get_component('dt').units = str(time_coord.unit) # LightCurve is a subclass of astropy TimeSeries, so @@ -94,9 +198,9 @@ def to_data(self, obj, reference_time=None): continue component_label = f'phase:{ephem_comp}' - if component_label not in self.lc_component_ids: - self.lc_component_ids[component_label] = ComponentID(component_label) - cid = self.lc_component_ids[component_label] + if component_label not in component_ids: + component_ids[component_label] = ComponentID(component_label) + cid = component_ids[component_label] data[cid] = component_data if hasattr(component_data, 'unit'): @@ -186,6 +290,172 @@ def to_object(self, data_or_subset): return LightCurve(table, **kwargs) +class TPFHandler: + quality_flag_cls = None + tpf_attrs = ['flux', 'flux_bkg', 'flux_bkg_err', 'flux_err'] + meta_attrs = [ + 'cadenceno', + 'campaign', + 'channel', + 'column', + 'dec', + 'hdu', + 'mission', + 'module', + 'nan_time_mask', + 'obsmode', + 'output', + 'pipeline_mask', + 'pos_corr1', + 'pos_corr2', + 'quality', + 'quarter', + 'ra', + 'row', + 'shape', + 'wcs' + ] + + def to_data(self, obj, reference_time=None, unit=u.d): + coords = PaddedTimeWCS(obj.wcs, obj.time, reference_time=reference_time, unit=unit) + data = Data(coords=coords) + + flux_shape = obj.flux.shape + + if hasattr(obj, 'label'): + data.label = obj.label + + data.meta.update(obj.meta) + data.meta.update( + {"reference_time": coords.temporal_wcs.reference_time} + ) + + data[component_ids['dt']] = np.broadcast_to( + ( + obj.time - coords.temporal_wcs.reference_time + ).to(coords.temporal_wcs.unit)[:, None, None], flux_shape + ) + data.get_component('dt').units = str(coords.temporal_wcs.unit) + + # LightCurve is a subclass of astropy TimeSeries, so + # collect all other columns in the TimeSeries: + for component_label in self.tpf_attrs: + + component_data = getattr(obj, component_label) + if component_label not in component_ids: + component_ids[component_label] = ComponentID(component_label) + cid = component_ids[component_label] + + data[cid] = component_data + if hasattr(component_data, 'unit'): + try: + data.get_component(cid).units = str(component_data.unit) + except KeyError: # pragma: no cover + continue + + data.meta.update({'uncertainty_type': 'std'}) + + for attr in self.meta_attrs: + value = getattr(obj, attr, None) + data.meta.update({attr: value}) + + # if the anticipated x and y axes are the first two components in the + # Data object, the viewer will load those components correctly before + # you hit the call to `viewer.set_plot_axes`: + reordered_components = {comp.label: comp for comp in data.components} + dt_comp = reordered_components.pop('dt') + flux_comp = reordered_components.pop('flux') + data.reorder_components( + [dt_comp, flux_comp] + + list(reordered_components.values()) + ) + + return data + + def to_object(self, data_or_subset): + """ + Convert a glue Data object to a lightkurve.KeplerTargetPixelFile object. + + Parameters + ---------- + data_or_subset : `glue.core.data.Data` or `glue.core.subset.Subset` + The data to convert to a KeplerTargetPixelFile object + attribute : `glue.core.component_id.ComponentID` + The attribute to use for the KeplerTargetPixelFile data + """ + + if isinstance(data_or_subset, Subset): + data = data_or_subset.data + subset_state = data_or_subset.subset_state + else: + data = data_or_subset + subset_state = None + + # Copy over metadata + + meta = data.meta.copy() + for attr in self.meta_attrs: + # these attrs don't belong in the lightkurve object's meta: + meta.pop(attr) + + # extract a Time object out of the TimeCoordinates object: + time = data.coords.time_axis + + if subset_state is None: + # pass through mask of all True's if no glue subset is chosen + glue_mask = None + else: + # get the subset mask from glue: + glue_mask = data.get_mask(subset_state=subset_state) + # apply the subset mask to the time array: + time = time[glue_mask] + + attrs_to_save = {'meta': meta, 'time': time} + + component_ids = data.main_components + + # we already handled time separately above, and `dt` is only used internally + # in LCviz, so let's skip those IDs below: + skip_components = [id for id in component_ids if id.label in ['time', 'dt']] + for skip_comp in skip_components: + component_ids.remove(skip_comp) + + for component_id in component_ids: + if component_id.label in attrs_to_save: + # avoid duplicate column + continue + component = data.get_component(component_id) + values = component.data + if glue_mask is not None: + values = values[glue_mask] + + if component_id.label not in attrs_to_save: + attrs_to_save[component_id.label] = values + + tpf_factory = TargetPixelFileFactory(*data.shape) + + for attr, values in attrs_to_save.items(): + if attr == 'time': + values = values.value + setattr(tpf_factory, attr, values) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', + message='Could not detect filetype as TESSTargetPixelFile or KeplerTargetPixelFile, returning generic TargetPixelFile instead.') # noqa + tpf = tpf_factory.get_tpf() + + for attr in self.meta_attrs: + # if this attribute exists and can be set: + if hasattr(tpf, attr) and getattr(getattr(tpf, attr), 'fset', None) is not None: + setattr(tpf, attr, data.meta[attr]) + + tpf.quality_mask = self.quality_flag_cls.create_quality_mask( + quality_array=tpf.hdu[1].data["QUALITY"], bitmask=tpf.quality_bitmask + ) + + return tpf + + def enable_hot_reloading(watch_jdaviz=True): """ Use ``watchdog`` to perform hot reloading. @@ -218,6 +488,16 @@ class TessLightCurveHandler(LightCurveHandler): pass +@data_translator(KeplerTargetPixelFile) +class KeplerTPFHandler(TPFHandler): + quality_flag_cls = KeplerQualityFlags + + +@data_translator(TessTargetPixelFile) +class TessTPFHandler(TPFHandler): + quality_flag_cls = TessQualityFlags + + # plugin component filters def data_not_folded(data): return data.meta.get('_LCVIZ_EPHEMERIS', None) is None diff --git a/lcviz/viewers.py b/lcviz/viewers.py index fa853c74..642e9201 100644 --- a/lcviz/viewers.py +++ b/lcviz/viewers.py @@ -12,6 +12,8 @@ from jdaviz.core.events import NewViewerMessage from jdaviz.core.registries import viewer_registry +from jdaviz.configs.cubeviz.plugins.viewers import (CubevizImageView, + WithSliceIndicator, WithSliceSelection) from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView @@ -19,18 +21,60 @@ from lightkurve import LightCurve +__all__ = ['TimeScatterView', 'PhaseScatterView', 'CubeView'] -__all__ = ['TimeScatterView', 'PhaseScatterView'] + +class CloneViewerMixin: + def _get_clone_viewer_reference(self): + base_name = self.reference.split("[")[0] + name = base_name + ind = 0 + while name in self.jdaviz_helper.viewers.keys(): + ind += 1 + name = f"{base_name}[{ind}]" + return name + + def clone_viewer(self): + name = self.jdaviz_helper._get_clone_viewer_reference(self.reference) + + self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__, + data=None, + sender=self.jdaviz_app), + vid=name, name=name) + + this_viewer_item = self.jdaviz_app._get_viewer_item(self.reference) + for data_id, visible in this_viewer_item['selected_data_items'].items(): + data_label = data_label = self.jdaviz_app._get_data_item_by_id(data_id)['name'] + self.jdaviz_app.set_data_visibility(name, data_label, visible == 'visible') + # TODO: don't revert color when adding same data to a new viewer + # (same happens when creating a phase-viewer from ephemeris plugin) + + new_viewer = self.jdaviz_app.get_viewer(name) + if hasattr(self, 'ephemeris_component'): + new_viewer._ephemeris_component = self._ephemeris_component + for k, v in self.state.as_dict().items(): + if k in ('layers',): + continue + setattr(new_viewer.state, k, v) + + for this_layer_state, new_layer_state in zip(self.state.layers, new_viewer.state.layers): + for k, v in this_layer_state.as_dict().items(): + if k in ('layer',): + continue + setattr(new_layer_state, k, v) + + return new_viewer.user_api @viewer_registry("lcviz-time-viewer", label="flux-vs-time") -class TimeScatterView(JdavizViewerMixin, BqplotScatterView): +class TimeScatterView(JdavizViewerMixin, CloneViewerMixin, WithSliceIndicator, BqplotScatterView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange', 'bqplot:yrange', 'bqplot:rectangle'], + ['jdaviz:selectslice'], ['lcviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] default_class = LightCurve @@ -43,8 +87,7 @@ def __init__(self, *args, **kwargs): self.display_mask = False self.time_unit = kwargs.get('time_unit', u.d) - self._subscribe_to_layers_update() - self.initialize_toolbar() + self.initialize_toolbar(default_tool_priority=['jdaviz:selectslice']) self._subscribe_to_layers_update() # hack to inherit a small subset of methods from SpecvizProfileView # TODO: refactor jdaviz so these can be included in some mixin @@ -54,6 +97,12 @@ def __init__(self, *args, **kwargs): self._clean_error = lambda: SpecvizProfileView._clean_error(self) self.density_map = kwargs.get('density_map', False) + @property + def slice_component_label(self): + # label of the component in the lightcurves corresponding to the slice axis + # calling data_collection_item.get_component(slice_component_label) must work + return 'dt' + def data(self, cls=None): data = [] @@ -108,7 +157,7 @@ def set_plot_axes(self): self._set_plot_y_axes(dc, component_labels, light_curve) def _set_plot_x_axes(self, dc, component_labels, light_curve): - self.state.x_att = dc[0].components[component_labels.index('World 0')] + self.state.x_att = dc[0].components[component_labels.index('dt')] x_unit = self.time_unit reference_time = light_curve.meta.get('reference_time', None) @@ -210,33 +259,6 @@ def apply_roi(self, roi, use_current=False): super().apply_roi(roi, use_current=use_current) - def clone_viewer(self): - name = self.jdaviz_helper._get_clone_viewer_reference(self.reference) - - self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__, - data=None, - sender=self.jdaviz_app), - vid=name, name=name) - - this_viewer_item = self.jdaviz_app._get_viewer_item(self.reference) - this_state = self.state.as_dict() - for data in self.jdaviz_app.data_collection: - data_id = self.jdaviz_app._data_id_from_label(data.label) - visible = this_viewer_item['selected_data_items'].get(data_id, 'hidden') - self.jdaviz_app.set_data_visibility(name, data.label, visible == 'visible') - # TODO: don't revert color when adding same data to a new viewer - # (same happens when creating a phase-viewer from ephemeris plugin) - - new_viewer = self.jdaviz_app.get_viewer(name) - if hasattr(self, 'ephemeris_component'): - new_viewer._ephemeris_component = self._ephemeris_component - for k, v in this_state.items(): - if k in ('layers',): - continue - setattr(new_viewer.state, k, v) - - return new_viewer.user_api - @viewer_registry("lcviz-phase-viewer", label="flux-vs-phase") class PhaseScatterView(TimeScatterView): @@ -244,6 +266,13 @@ def __init__(self, *args, **kwargs): self._ephemeris_component = 'default' super().__init__(*args, **kwargs) + @property + def ephemeris(self): + ephem = self.jdaviz_helper.plugins.get('Ephemeris', None) + if ephem is None: + raise ValueError("must have ephemeris plugin loaded to access ephemeris") + return ephem.ephemerides.get(self._ephemeris_component) + def _set_plot_x_axes(self, dc, component_labels, light_curve): # setting of y_att will be handled by ephemeris plugin self.state.x_att = dc[0].components[component_labels.index(f'phase:{self._ephemeris_component}')] # noqa @@ -256,3 +285,78 @@ def times_to_phases(self, times): raise ValueError("must have ephemeris plugin loaded to convert") return ephem.times_to_phases(times, ephem_component=self._ephemeris_component) + + def _set_slice_indicator_value(self, value): + # NOTE: on first call, this will initialize the indicator itself + self.slice_indicator.value = self.times_to_phases(value) + + +@viewer_registry("lcviz-cube-viewer", label="cube") +class CubeView(CloneViewerMixin, CubevizImageView, WithSliceSelection): + # categories: zoom resets, zoom, pan, subset, select tools, shortcuts + tools_nested = [ + ['jdaviz:homezoom', 'jdaviz:prevzoom'], + ['jdaviz:boxzoom'], + ['jdaviz:panzoom'], + ['bqplot:rectangle'], + ['lcviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] + ] + # TODO: can we vary this default_class based on Kepler vs TESS, etc? + # see https://github.com/spacetelescope/lcviz/pull/81#discussion_r1469721009 + default_class = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.display_mask = False + self.time_unit = kwargs.get('time_unit', u.d) + self.initialize_toolbar() + self._subscribe_to_layers_update() + + # Hide axes by default + self.state.show_axes = False + + # TODO: refactor upstream so lcviz can inherit cubeviewer methods/setup without + # jdaviz-specific logic: + # * _default_spectrum_viewer_reference_name + # * _default_flux_viewer_reference_name + # * _default_uncert_viewer_reference_name + + @property + def slice_component_label(self): + # label of the component in the cubes corresponding to the slice axis + # calling data_collection_item.get_component(slice_component_label) on any + # input cube-data must work + return 'dt' + + @property + def slice_index(self): + # index in viewer.slices corresponding to the slice axis + return 0 + + def _initial_x_axis(self, *args): + # Make sure that the x_att/y_att is correct on data load + # called via a callback set upstream in CubevizImageView when reference_data is changed + ref_data = self.state.reference_data + if ref_data is not None: + self.state.x_att = ref_data.id['Pixel Axis 2 [x]'] + self.state.y_att = ref_data.id['Pixel Axis 1 [y]'] + + def _on_layers_update(self, layers=None): + super()._on_layers_update(layers=layers) + ref_data = self.state.reference_data + if ref_data is None: + return + flux_comp = ref_data.id['flux'] + for layer in self.state.layers: + if hasattr(layer, 'attribute') and layer.attribute != flux_comp: + layer.attribute = flux_comp + + def data(self, cls=None): + # TODO: generalize upstream in jdaviz. + # This method is generalized from + # jdaviz/configs/cubeviz/plugins/viewers.py + return [layer_state.layer + for layer_state in self.state.layers + if hasattr(layer_state, 'layer') and + isinstance(layer_state.layer, BaseData)]