From 429888e7940cd088edc8f592ca8da7e500480992 Mon Sep 17 00:00:00 2001 From: "Brett M. Morris" Date: Tue, 19 Dec 2023 19:09:11 -0500 Subject: [PATCH] adding TPF translator to the parser --- lcviz/parsers.py | 4 ++++ lcviz/utils.py | 37 ++++++++++++++++++++++--------------- lcviz/viewers.py | 2 +- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/lcviz/parsers.py b/lcviz/parsers.py index e9268c83..62d35586 100644 --- a/lcviz/parsers.py +++ b/lcviz/parsers.py @@ -22,6 +22,10 @@ def light_curve_parser(app, file_obj, data_label=None, show_in_viewer=True, **kw elif isinstance(file_obj, lightkurve.LightCurve): light_curve = file_obj + # load a LightCurve object: + elif isinstance(file_obj, lightkurve.targetpixelfile.KeplerTargetPixelFile): + light_curve = file_obj + # make a data label: if data_label is not None: new_data_label = f'{data_label}' diff --git a/lcviz/utils.py b/lcviz/utils.py index 79c26457..7ede32cf 100644 --- a/lcviz/utils.py +++ b/lcviz/utils.py @@ -23,6 +23,9 @@ __all__ = ['TimeCoordinates', 'LightCurveHandler', 'data_not_folded', 'enable_hot_reloading'] +component_ids = {'dt': ComponentID('dt')} + + class TimeCoordinates(Coordinates): """ This is a sub-class of Coordinates that is intended for a time axis @@ -73,12 +76,18 @@ class PaddedTimeWCS(BaseWCSWrapper, HighLevelWCSMixin): # NOTE: This class could be updated to use CompoundLowLevelWCS from NDCube. - def __init__(self, wcs, times, ndim=3, reference_time=None): - self.temporal_wcs = TimeCoordinates(times, reference_time=reference_time) + def __init__(self, wcs, times, ndim=3, reference_time=None, unit=u.d): + self.temporal_wcs = TimeCoordinates( + times, reference_time=reference_time, unit=unit + ) self.spatial_wcs = wcs self.flux_ndim = ndim self.spatial_keys = [f"spatial{i}" for i in range(0, self.flux_ndim-1)] + @property + def time_axis(self): + return self.temporal_wcs.time_axis + @property def pixel_n_dim(self): return self.flux_ndim @@ -158,7 +167,6 @@ def serialized_classes(self): @data_translator(LightCurve) class LightCurveHandler: - lc_component_ids = {} def to_data(self, obj, reference_time=None): is_folded = isinstance(obj, FoldedLightCurve) @@ -173,7 +181,7 @@ def to_data(self, obj, reference_time=None): data.meta.update( {"reference_time": time_coord.reference_time} ) - data['dt'] = (obj.time - time_coord.reference_time).to(time_coord.unit) + data[component_ids['dt']] = (obj.time - time_coord.reference_time).to(time_coord.unit) data.get_component('dt').units = str(time_coord.unit) # LightCurve is a subclass of astropy TimeSeries, so @@ -187,9 +195,9 @@ def to_data(self, obj, reference_time=None): continue component_label = f'phase:{ephem_comp}' - if component_label not in self.lc_component_ids: - self.lc_component_ids[component_label] = ComponentID(component_label) - cid = self.lc_component_ids[component_label] + if component_label not in component_ids: + component_ids[component_label] = ComponentID(component_label) + cid = component_ids[component_label] data[cid] = component_data if hasattr(component_data, 'unit'): @@ -281,7 +289,6 @@ def to_object(self, data_or_subset): @data_translator(KeplerTargetPixelFile) class KeplerTPFHandler: - lc_component_ids = {} tpf_attrs = ['flux', 'flux_bkg', 'flux_bkg_err', 'flux_err'] meta_attrs = [ 'cadenceno', @@ -306,8 +313,8 @@ class KeplerTPFHandler: 'wcs' ] - def to_data(self, obj, reference_time=None): - coords = PaddedTimeWCS(obj.wcs, obj.time, reference_time=reference_time) + def to_data(self, obj, reference_time=None, unit=u.d): + coords = PaddedTimeWCS(obj.wcs, obj.time, reference_time=reference_time, unit=unit) data = Data(coords=coords) flux_shape = obj.flux.shape @@ -320,7 +327,7 @@ def to_data(self, obj, reference_time=None): {"reference_time": coords.temporal_wcs.reference_time} ) - data['dt'] = np.broadcast_to( + data[component_ids['dt']] = np.broadcast_to( ( obj.time - coords.temporal_wcs.reference_time ).to(coords.temporal_wcs.unit)[:, None, None], flux_shape @@ -332,9 +339,9 @@ def to_data(self, obj, reference_time=None): for component_label in self.tpf_attrs: component_data = getattr(obj, component_label) - if component_label not in self.lc_component_ids: - self.lc_component_ids[component_label] = ComponentID(component_label) - cid = self.lc_component_ids[component_label] + if component_label not in component_ids: + component_ids[component_label] = ComponentID(component_label) + cid = component_ids[component_label] data[cid] = component_data if hasattr(component_data, 'unit'): @@ -389,7 +396,7 @@ def to_object(self, data_or_subset): meta.pop(attr) # extract a Time object out of the TimeCoordinates object: - time = data.coords.temporal_wcs.time_axis + time = data.coords.time_axis if subset_state is None: # pass through mask of all True's if no glue subset is chosen diff --git a/lcviz/viewers.py b/lcviz/viewers.py index ee70f3b6..76d5f6d5 100644 --- a/lcviz/viewers.py +++ b/lcviz/viewers.py @@ -108,7 +108,7 @@ def set_plot_axes(self): self._set_plot_y_axes(dc, component_labels, light_curve) def _set_plot_x_axes(self, dc, component_labels, light_curve): - self.state.x_att = dc[0].components[component_labels.index('World 0')] + self.state.x_att = dc[0].components[component_labels.index('dt')] x_unit = self.time_unit reference_time = light_curve.meta.get('reference_time', None)