Skip to content

Commit

Permalink
rework flatten plugin to use flux columns
Browse files Browse the repository at this point in the history
  • Loading branch information
kecnry committed Dec 29, 2023
1 parent 861134a commit bd9ba95
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 140 deletions.
117 changes: 117 additions & 0 deletions lcviz/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from astropy import units as u
from ipyvuetify import VuetifyTemplate
from glue.core import HubListener
from traitlets import List, Unicode

from jdaviz.core.template_mixin import SelectPluginComponent

from lcviz.events import FluxOriginChangedMessage

__all__ = ['FluxOriginSelect', 'FluxOriginSelectMixin']


class FluxOriginSelect(SelectPluginComponent):
def __init__(self, plugin, items, selected, dataset):
super().__init__(plugin,
items=items,
selected=selected,
dataset=dataset)

self.add_observe(selected, self._on_change_selected)
self.add_observe(self.dataset._plugin_traitlets['selected'],
self._on_change_dataset)

# sync between instances in different plugins
self.hub.subscribe(self, FluxOriginChangedMessage,
handler=self._on_flux_origin_changed_msg)

def _on_change_dataset(self, *args):
def _include_col(lk_obj, col):
if col == 'flux' and lk_obj.meta.get('FLUX_ORIGIN') != 'flux':
# this is the currently active column (and should be copied elsewhere unless)
return False
if col in ('time', 'cadn', 'cadenceno', 'quality'):
return False
if col.startswith('phase:'):
# internal jdaviz ephemeris phase columns
return False
if col.startswith('time'):
return False
if col.startswith('centroid'):
return False
if col.startswith('cbv'):
# cotrending basis vector
return False
if col.endswith('_err'):
return False
if col.endswith('quality'):
return False
# TODO: need to think about flatten losing units in the flux column
return lk_obj[col].unit != u.pix

lk_obj = self.dataset.selected_obj
if lk_obj is None:
return
self.choices = [col for col in lk_obj.columns if _include_col(lk_obj, col)]
flux_origin = lk_obj.meta.get('FLUX_ORIGIN')
if flux_origin in self.choices:
self.selected = flux_origin
else:
self.selected = ''

def _on_flux_origin_changed_msg(self, msg):
if msg.dataset != self.dataset.selected:
return

# need to clear the cache due to the change in metadata made to the data-collection entry
self.dataset._clear_cache('selected_obj', 'selected_dc_item')
self._on_change_dataset()
self.selected = msg.flux_origin

def _on_change_selected(self, *args):
if self.selected == '':
return

dc_item = self.dataset.selected_dc_item
old_flux_origin = dc_item.meta.get('FLUX_ORIGIN')
if self.selected == old_flux_origin:
# nothing to do here!
return

# instead of using lightkurve's select_flux and having to reparse the data entry, we'll
# manipulate the arrays in the data-collection directly, and modify FLUX_ORIGIN so that
# exporting back to a lightkurve object works as expected
self.app._jdaviz_helper._set_data_component(dc_item, 'flux', dc_item[self.selected])
self.app._jdaviz_helper._set_data_component(dc_item, 'flux_err', dc_item[self.selected+"_err"]) # noqa
dc_item.meta['FLUX_ORIGIN'] = self.selected

self.hub.broadcast(FluxOriginChangedMessage(dataset=self.dataset.selected,
flux_origin=self.selected,
sender=self))

def add_new_flux_column(self, flux, flux_err, label, selected=False):
dc_item = self.dataset.selected_dc_item
self.app._jdaviz_helper._set_data_component(dc_item,
label,
flux)
self.app._jdaviz_helper._set_data_component(dc_item,
f"{label}_err",
flux_err)

# broadcast so all instances update to get the new column and selection (if applicable)
self.hub.broadcast(FluxOriginChangedMessage(dataset=self.dataset.selected,
flux_origin=label if selected else self.selected,
sender=self))


class FluxOriginSelectMixin(VuetifyTemplate, HubListener):
flux_origin_items = List().tag(sync=True)
flux_origin_selected = Unicode().tag(sync=True)
# assumes DatasetSelectMixin is also used (DatasetSelectMixin must appear after in inheritance)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.flux_origin = FluxOriginSelect(self,
'flux_origin_items',
'flux_origin_selected',
dataset='dataset')
12 changes: 11 additions & 1 deletion lcviz/events.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from glue.core.message import Message

__all__ = ['EphemerisComponentChangedMessage',
'EphemerisChangedMessage']
'EphemerisChangedMessage',
'FluxOriginChangedMessage']


class EphemerisComponentChangedMessage(Message):
Expand All @@ -27,3 +28,12 @@ class EphemerisChangedMessage(Message):
in the ephemeris plugin"""
def __init__(self, ephemeris_label, *args, **kwargs):
self.ephemeris_label = ephemeris_label


class FluxOriginChangedMessage(Message):
"""Message emitted by the FluxOriginSelect component when the selection has been changed.
To subscribe to a change for a particular dataset, consider using FluxOriginSelect directly
and observing the traitlet, rather than subscribing to this message"""
def __init__(self, dataset, flux_origin, *args, **kwargs):
self.dataset = dataset
self.flux_origin = flux_origin
73 changes: 49 additions & 24 deletions lcviz/plugins/flatten/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from jdaviz.core.events import ViewerAddedMessage
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin,
DatasetSelectMixin, AddResultsMixin,
DatasetSelectMixin,
AutoTextField,
skip_if_no_updates_since_last_active,
with_spinner)
from jdaviz.core.user_api import PluginUserApi

from lcviz.components import FluxOriginSelectMixin
from lcviz.marks import LivePreviewTrend, LivePreviewFlattened
from lcviz.utils import data_not_folded
from lcviz.viewers import TimeScatterView, PhaseScatterView
Expand All @@ -21,7 +23,7 @@


@tray_registry('flatten', label="Flatten")
class Flatten(PluginTemplateMixin, DatasetSelectMixin, AddResultsMixin):
class Flatten(PluginTemplateMixin, FluxOriginSelectMixin, DatasetSelectMixin):
"""
See the :ref:`Flatten Plugin Documentation <flatten>` for more details.
Expand All @@ -32,24 +34,24 @@ class Flatten(PluginTemplateMixin, DatasetSelectMixin, AddResultsMixin):
Whether to show the live-preview of the (unnormalized) flattened light curve
* ``show_trend_preview`` : bool
Whether to show the live-preview of the trend curve used to flatten the light curve
* ``default_to_overwrite``
* ``dataset`` (:class:`~jdaviz.core.template_mixin.DatasetSelect`):
Dataset to flatten.
* ``add_results`` (:class:`~jdaviz.core.template_mixin.AddResults`)
* ``window_length``
* ``polyorder``
* ``break_tolerance``
* ``niters``
* ``sigma``
* ``unnormalize``
* ``flux_label`` (:class:`~jdaviz.core.template_mixin.AutoTextField`):
Label for the resulting flux column added to ``dataset`` and automatically selected as the new
flux origin.
* :meth:`flatten`
"""
template_file = __file__, "flatten.vue"
uses_active_status = Bool(True).tag(sync=True)

show_live_preview = Bool(True).tag(sync=True)
show_trend_preview = Bool(True).tag(sync=True)
default_to_overwrite = Bool(True).tag(sync=True)
flatten_err = Unicode().tag(sync=True)

window_length = IntHandleEmpty(101).tag(sync=True)
Expand All @@ -59,25 +61,37 @@ class Flatten(PluginTemplateMixin, DatasetSelectMixin, AddResultsMixin):
sigma = FloatHandleEmpty(3).tag(sync=True)
unnormalize = Bool(False).tag(sync=True)

flux_label_label = Unicode().tag(sync=True)
flux_label_default = Unicode().tag(sync=True)
flux_label_auto = Bool(True).tag(sync=True)
flux_label_invalid_msg = Unicode('').tag(sync=True)
flux_label_overwrite = Bool(False).tag(sync=True)

last_live_time = Float(0).tag(sync=True)
previews_temp_disable = Bool(False).tag(sync=True)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.flux_label = AutoTextField(self, 'flux_label_label',
'flux_label_default', 'flux_label_auto',
'flux_label_invalid_msg')

# do not support flattening data in phase-space
self.dataset.add_filter(data_not_folded)

# marks do not exist for the new viewer, so force another update to compute and draw
# those marks
self.hub.subscribe(self, ViewerAddedMessage, handler=lambda _: self._live_update())

self._set_default_label()

@property
def user_api(self):
expose = ['show_live_preview', 'show_trend_preview', 'default_to_overwrite',
'dataset', 'add_results',
expose = ['show_live_preview', 'show_trend_preview',
'dataset',
'window_length', 'polyorder', 'break_tolerance',
'niters', 'sigma', 'unnormalize', 'flatten']
'niters', 'sigma', 'unnormalize', 'flux_label', 'flatten']
return PluginUserApi(self, expose=expose)

@property
Expand Down Expand Up @@ -108,18 +122,26 @@ def marks(self):

return trend_marks, flattened_marks

@observe('default_to_overwrite', 'dataset_selected')
def _set_default_results_label(self, event={}):
@observe('dataset_selected', 'flux_origin_selected')
def _set_default_label(self, event={}):
'''Generate a label and set the results field to that value'''
if not hasattr(self, 'dataset'): # pragma: no cover
return

self.add_results.label_whitelist_overwrite = [self.dataset_selected]

if self.default_to_overwrite:
self.results_label_default = self.dataset_selected
# TODO: have an option to create new data entry and drop other columns?
# (or should that just go through future data cloning)
self.flux_label.default = f"{self.flux_origin_selected}_flattened"

@observe('flux_label_label', 'dataset')
def _update_label_valid(self, event={}):
if self.flux_label.value in self.flux_origin.choices:
self.flux_label.invalid_msg = ''
self.flux_label_overwrite = True
elif self.flux_label.value in getattr(self.dataset.selected_obj, 'columns', []):
self.flux_label.invalid_msg = 'name already in use'
else:
self.results_label_default = f"{self.dataset_selected} (flattened)"
self.flux_label.invalid_msg = ''
self.flux_label_overwrite = False

@with_spinner()
def flatten(self, add_data=True):
Expand All @@ -129,8 +151,8 @@ def flatten(self, add_data=True):
Parameters
----------
add_data : bool
Whether to add the resulting trace to the application, according to the options
defined in the plugin.
Whether to add the resulting light curve as a flux column and select that as the new
flux origin for that data entry.
Returns
-------
Expand All @@ -157,9 +179,13 @@ def flatten(self, add_data=True):
output_lc.meta['NORMALIZED'] = False

if add_data:
# add data to the collection/viewer
# add data as a new flux and corresponding err columns in the existing data entry
# and select as flux origin
data = _data_with_reftime(self.app, output_lc)
self.add_results.add_results_from_plugin(data)
self.flux_origin.add_new_flux_column(flux=data['flux'],
flux_err=data['flux_err'],
label=self.flux_label.value,
selected=True)

return output_lc, trend_lc

Expand All @@ -186,13 +212,16 @@ def _toggle_marks(self, event={}):
# then the marks themselves need to be updated
self._live_update(event)

@observe('dataset_selected',
@observe('dataset_selected', 'flux_origin_selected',
'window_length', 'polyorder', 'break_tolerance',
'niters', 'sigma', 'previews_temp_disable')
@skip_if_no_updates_since_last_active()
def _live_update(self, event={}):
if self.previews_temp_disable:
return
if self.dataset_selected == '' or self.flux_origin_selected == '':
self._clear_marks()
return

start = time()
try:
Expand Down Expand Up @@ -232,7 +261,3 @@ def vue_apply(self, *args, **kwargs):
self.flatten_err = str(e)
else:
self.flatten_err = ''
if self.add_results.label_overwrite:
# then this will change the input data without triggering a
# change to dataset_selected
self._live_update()
42 changes: 19 additions & 23 deletions lcviz/plugins/flatten/flatten.vue
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@
persistent-hint
></v-switch>
</v-row>
<v-row>
<v-switch
v-model="default_to_overwrite"
label="Overwrite by default"
hint="Whether the output label should default to overwriting the input data."
persistent-hint
></v-switch>
</v-row>
</v-expansion-panel-content>
</v-expansion-panel>
</v-expansion-panels>
Expand Down Expand Up @@ -138,6 +130,14 @@
<v-alert type="warning">Live preview is unnormalized, but flattening will normalize.</v-alert>
</v-row>

<plugin-auto-label
:value.sync="flux_label_label"
:default="flux_label_default"
:auto.sync="flux_label_auto"
:invalid_msg="flux_label_invalid_msg"
hint="Label for flux column."
></plugin-auto-label>

<v-alert v-if="previews_temp_disable && (show_live_preview || show_trend_preview)" type='warning' style="margin-left: -12px; margin-right: -12px">
Live-updating is temporarily disabled (last update took {{last_live_time}}s)
<v-row justify='center'>
Expand All @@ -156,21 +156,17 @@
</v-row>
</v-alert>

<plugin-add-results
:label.sync="results_label"
:label_default="results_label_default"
:label_auto.sync="results_label_auto"
:label_invalid_msg="results_label_invalid_msg"
:label_overwrite="results_label_overwrite"
label_hint="Label for the flattened data."
:add_to_viewer_items="add_to_viewer_items"
:add_to_viewer_selected.sync="add_to_viewer_selected"
action_label="Flatten"
action_tooltip="Flatten data"
:action_disabled="flatten_err.length > 0"
:action_spinner="spinner"
@click:action="apply"
></plugin-add-results>
<v-row justify="end">
<j-tooltip tooltipcontent="Flatten and select new column as flux origin">
<plugin-action-button
:spinner="spinner"
:disabled="flux_label_invalid_msg.length > 0"
:results_isolated_to_plugin="false"
@click="apply">
Flatten{{flux_label_overwrite ? ' (Overwrite)' : ''}}
</plugin-action-button>
</j-tooltip>
</v-row>

<v-row v-if="flatten_err">
<span class="v-messages v-messages__message text--secondary">
Expand Down
Loading

0 comments on commit bd9ba95

Please sign in to comment.