diff --git a/glue/core/data.py b/glue/core/data.py index 4615a3b89..57a3650b5 100644 --- a/glue/core/data.py +++ b/glue/core/data.py @@ -514,7 +514,7 @@ def compute_histogram(self, cids, weights=None, range=None, bins=None, log=None, raise NotImplementedError() def compute_fixed_resolution_buffer(self, bounds, target_data=None, target_cid=None, - subset_state=None, broadcast=True): + subset_state=None, broadcast=True, cache_id=None): """ Get a fixed-resolution buffer. diff --git a/glue/plugins/tools/pv_slicer/pv_sliced_data.py b/glue/plugins/tools/pv_slicer/pv_sliced_data.py new file mode 100644 index 000000000..7e6ff3b25 --- /dev/null +++ b/glue/plugins/tools/pv_slicer/pv_sliced_data.py @@ -0,0 +1,140 @@ +import numpy as np + +from glue.core.data_derived import DerivedData + +__all__ = ['PVSlicedData'] + + +def sample_points(x, y, spacing=1): + + # Code adapted from pvextractor + + # Find the distance interval between all pairs of points + dx = np.diff(x) + dy = np.diff(y) + dd = np.hypot(dx, dy) + + # Find the total displacement along the broken curve + d = np.hstack([0., np.cumsum(dd)]) + + # Figure out the number of points to sample, and stop short of the + # last point. + n_points = int(np.floor(d[-1] / spacing)) + + if n_points == 0: + raise ValueError("Path is shorter than spacing") + + d_sampled = np.linspace(0., n_points * spacing, n_points + 1) + + x_sampled = np.interp(d_sampled, d, x) + y_sampled = np.interp(d_sampled, d, y) + + return x_sampled, y_sampled + + +class PVSlicedData(DerivedData): + """ + A dataset where two dimensions have been replaced with one using a path. + + The extra dimension is added as the last dimension + """ + + def __init__(self, original_data, cid_x, x, cid_y, y, label=''): + super(DerivedData, self).__init__() + self.original_data = original_data + self.cid_x = cid_x + self.cid_y = cid_y + self.set_xy(x, y) + self.sliced_dims = (cid_x.axis, cid_y.axis) + self._label = label + + def set_xy(self, x, y): + x, y = sample_points(x, y) + self.x = x + self.y = y + + @property + def label(self): + return self._label + + def _without_sliced(self, iterable): + return [x for ix, x in enumerate(iterable) if ix not in self.sliced_dims] + + @property + def shape(self): + return self._without_sliced(self.original_data.shape) + [len(self.x)] + + @property + def main_components(self): + return self.original_data.main_components + + def get_kind(self, cid): + return self.original_data.get_kind(cid) + + def get_data(self, cid, view=None): + + if cid in self.pixel_component_ids: + return super().get_data(cid, view) + + pix_coords = [] + + advanced_indexing = view is not None and isinstance(view[0], np.ndarray) + + idim_current = -1 + + for idim in range(self.original_data.ndim): + + if idim == self.cid_x.axis: + pix = self.x + idim_current = self.ndim - 1 + elif idim == self.cid_y.axis: + pix = self.y + idim_current = self.ndim - 1 + else: + pix = np.arange(self.original_data.shape[idim]) + idim_current += 1 + + if view is not None and len(view) > idim_current: + pix = pix[view[idim_current]] + print("DONE") + + print(idim, idim_current, pix.shape) + + pix_coords.append(pix) + + if not advanced_indexing: + pix_coords = np.meshgrid(*pix_coords, indexing='ij', copy=False) + + print(pix_coords[0].shape) + + shape = pix_coords[0].shape + + keep = np.ones(shape, dtype=bool) + for idim in range(self.original_data.ndim): + keep &= (pix_coords[idim] >= 0) & (pix_coords[idim] < self.original_data.shape[idim]) + + pix_coords = [x[keep].astype(int) for x in pix_coords] + + result = np.zeros(shape) + + result[keep] = self.original_data.get_data(cid, view=pix_coords) + + return result + + def get_mask(self, subset_state, view=None): + # Optimize by getting pixel coordinates of original data in new + # frame of reference and getting the mask for these indices + if view is None: + view = Ellipsis + return self.callable(self.original_data.get_mask(subset_state))[view] + + def compute_statistic(self, *args, **kwargs): + return self.original_data.compute_statistic(*args, **kwargs) + + def compute_histogram(self, *args, **kwargs): + return self.original_data.compute_histogram(*args, **kwargs) + + def compute_fixed_resolution_buffer(self, *args, **kwargs): + from glue.core.fixed_resolution_buffer import compute_fixed_resolution_buffer + print(args, kwargs) + return compute_fixed_resolution_buffer(self, *args, **kwargs) diff --git a/glue/plugins/tools/pv_slicer/qt/pv_slicer.py b/glue/plugins/tools/pv_slicer/qt/pv_slicer.py index 04d2e17c6..727fcfd7b 100644 --- a/glue/plugins/tools/pv_slicer/qt/pv_slicer.py +++ b/glue/plugins/tools/pv_slicer/qt/pv_slicer.py @@ -7,13 +7,7 @@ from glue.config import viewer_tool from glue.viewers.matplotlib.toolbar_mode import ToolbarModeBase from glue.viewers.image.qt import ImageViewer - - -class PVSliceData(Data): - parent_data = None - parent_data_x = None - parent_data_y = None - parent_viewer = None +from glue.plugins.tools.pv_slicer.pv_sliced_data import PVSlicedData @viewer_tool @@ -49,28 +43,14 @@ def _extract_callback(self, mode): vx, vy = mode.roi().to_polygon() - pv_slice, x, y, wcs = _slice_from_path(vx, vy, self.viewer.state.reference_data, - self.viewer.state.layers[0].attribute, - self.viewer.state.wcsaxes_slice[::-1]) - - xlabel = "Position along path" - if wcs is None: - ylabel = "Cube slice index" - else: - ylabel = _slice_label(self.viewer.state.reference_data, - self.viewer.state.wcsaxes_slice[::-1]) - - wcs.wcs.ctype = [xlabel, ylabel] - - data = PVSliceData(label=self.viewer.state.reference_data.label + " [slice]") - data.coords = coordinates_from_wcs(wcs) - data[self.viewer.state.layers[0].attribute] = pv_slice - selected = self.viewer.session.application.selected_layers() - if len(selected) == 1 and isinstance(selected[0], PVSliceData): - selected[0].update_values_from_data(data) + if len(selected) == 1 and isinstance(selected[0], PVSlicedData): data = selected[0] + data.original_data = self.viewer.state.reference_data + data.x_att = self.viewer.state.x_att + data.y_att = self.viewer.state.y_att + data.set_xy(vx, vy) open_viewer = True for tab in self.viewer.session.application.viewers: for viewer in tab: @@ -80,15 +60,14 @@ def _extract_callback(self, mode): if not open_viewer: break else: + data = PVSlicedData(self.viewer.state.reference_data, + self.viewer.state.x_att, vx, + self.viewer.state.y_att, vy, + label=self.viewer.state.reference_data.label + " [slice]") + data.parent_viewer = self.viewer self.viewer.session.data_collection.append(data) open_viewer = True - # TODO: use weak references - data.parent_data = self.viewer.state.reference_data - data.parent_data_x = x - data.parent_data_y = y - data.parent_viewer = self.viewer - if open_viewer: viewer = self.viewer.session.application.new_data_viewer(ImageViewer, data=data) @@ -114,7 +93,7 @@ def __init__(self, *args, **kwargs): self.viewer.state.add_callback('reference_data', self._on_reference_data_change) def _on_reference_data_change(self, reference_data): - self.enabled = isinstance(reference_data, PVSliceData) + self.enabled = isinstance(reference_data, PVSlicedData) self.data = reference_data def _on_move(self, mode): @@ -122,14 +101,17 @@ def _on_move(self, mode): # Find position of click in the image viewer xdata, ydata = self._event_xdata, self._event_ydata + if xdata is None or ydata is None: + return + # TODO: Make this robust in case the axes have been swapped # Find position slice where cursor is ind = int(round(np.clip(xdata, 0, self.data.shape[1] - 1))) # Find pixel coordinate in input image for this slice - x = self.data.parent_data_x[ind] - y = self.data.parent_data_y[ind] + x = self.data.x[ind] + y = self.data.y[ind] # The 3-rd coordinate in the input WCS is simply the second # coordinate in the PV slice. @@ -142,71 +124,6 @@ def _on_move(self, mode): self.data.parent_viewer.state.slices = tuple(s) -def _slice_from_path(x, y, data, attribute, slc): - """ - Extract a PV-like slice from a cube - - :param x: An array of x values to extract (pixel units) - :param y: An array of y values to extract (pixel units) - :param data: :class:`~glue.core.data.Data` - :param attribute: :claass:`~glue.core.data.Component` - :param slc: orientation of the image widget that `pts` are defined on - - :returns: (slice, x, y) - slice is a 2D Numpy array, corresponding to a "PV ribbon" - cutout from the cube - x and y are the resampled points along which the - ribbon is extracted - - :note: For >3D cubes, the "V-axis" of the PV slice is the longest - cube axis ignoring the x/y axes of `slc` - """ - from glue.external.pvextractor import Path, extract_pv_slice - p = Path(list(zip(x, y))) - - cube = data[attribute] - dims = list(range(data.ndim)) - s = list(slc) - ind = _slice_index(data, slc) - - cube_wcs = getattr(data.coords, 'wcs', None) - - # transpose cube to (z, y, x, ) - def _swap(x, s, i, j): - x[i], x[j] = x[j], x[i] - s[i], s[j] = s[j], s[i] - - _swap(dims, s, ind, 0) - _swap(dims, s, s.index('y'), 1) - _swap(dims, s, s.index('x'), 2) - - cube = cube.transpose(dims) - - if cube_wcs is not None: - cube_wcs = cube_wcs.sub([data.ndim - nx for nx in dims[::-1]]) - - # slice down from >3D to 3D if needed - s = tuple([slice(None)] * 3 + [slc[d] for d in dims[3:]]) - cube = cube[s] - - # sample cube - spacing = 1 # pixel - x, y = [np.round(_x).astype(int) for _x in p.sample_points(spacing)] - - from astropy.wcs import WCS - - try: - result = extract_pv_slice(cube, path=p, wcs=cube_wcs, order=0) - wcs = WCS(result.header) - except Exception: # sometimes pvextractor complains due to wcs. Try to recover - result = extract_pv_slice(cube, path=p, wcs=None, order=0) - wcs = None - - data = result.data - - return data, x, y, wcs - - def _slice_index(data, slc): """ The axis over which to extract PV slices