Skip to content

Commit

Permalink
Added PVSlicedData, which is a DerivedData sub-class, and make it fun…
Browse files Browse the repository at this point in the history
…ctional
  • Loading branch information
astrofrog committed Jan 31, 2020
1 parent 74c511d commit b6b7b6e
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 101 deletions.
2 changes: 1 addition & 1 deletion glue/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def compute_histogram(self, cids, weights=None, range=None, bins=None, log=None,
raise NotImplementedError()

def compute_fixed_resolution_buffer(self, bounds, target_data=None, target_cid=None,
subset_state=None, broadcast=True):
subset_state=None, broadcast=True, cache_id=None):
"""
Get a fixed-resolution buffer.
Expand Down
140 changes: 140 additions & 0 deletions glue/plugins/tools/pv_slicer/pv_sliced_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import numpy as np

from glue.core.data_derived import DerivedData

__all__ = ['PVSlicedData']


def sample_points(x, y, spacing=1):

# Code adapted from pvextractor

# Find the distance interval between all pairs of points
dx = np.diff(x)
dy = np.diff(y)
dd = np.hypot(dx, dy)

# Find the total displacement along the broken curve
d = np.hstack([0., np.cumsum(dd)])

# Figure out the number of points to sample, and stop short of the
# last point.
n_points = int(np.floor(d[-1] / spacing))

if n_points == 0:
raise ValueError("Path is shorter than spacing")

d_sampled = np.linspace(0., n_points * spacing, n_points + 1)

x_sampled = np.interp(d_sampled, d, x)
y_sampled = np.interp(d_sampled, d, y)

return x_sampled, y_sampled


class PVSlicedData(DerivedData):
"""
A dataset where two dimensions have been replaced with one using a path.
The extra dimension is added as the last dimension
"""

def __init__(self, original_data, cid_x, x, cid_y, y, label=''):
super(DerivedData, self).__init__()
self.original_data = original_data
self.cid_x = cid_x
self.cid_y = cid_y
self.set_xy(x, y)
self.sliced_dims = (cid_x.axis, cid_y.axis)
self._label = label

def set_xy(self, x, y):
x, y = sample_points(x, y)
self.x = x
self.y = y

@property
def label(self):
return self._label

def _without_sliced(self, iterable):
return [x for ix, x in enumerate(iterable) if ix not in self.sliced_dims]

@property
def shape(self):
return self._without_sliced(self.original_data.shape) + [len(self.x)]

@property
def main_components(self):
return self.original_data.main_components

def get_kind(self, cid):
return self.original_data.get_kind(cid)

def get_data(self, cid, view=None):

if cid in self.pixel_component_ids:
return super().get_data(cid, view)

pix_coords = []

advanced_indexing = view is not None and isinstance(view[0], np.ndarray)

idim_current = -1

for idim in range(self.original_data.ndim):

if idim == self.cid_x.axis:
pix = self.x
idim_current = self.ndim - 1
elif idim == self.cid_y.axis:
pix = self.y
idim_current = self.ndim - 1
else:
pix = np.arange(self.original_data.shape[idim])
idim_current += 1

if view is not None and len(view) > idim_current:
pix = pix[view[idim_current]]
print("DONE")

print(idim, idim_current, pix.shape)

pix_coords.append(pix)

if not advanced_indexing:
pix_coords = np.meshgrid(*pix_coords, indexing='ij', copy=False)

print(pix_coords[0].shape)

shape = pix_coords[0].shape

keep = np.ones(shape, dtype=bool)
for idim in range(self.original_data.ndim):
keep &= (pix_coords[idim] >= 0) & (pix_coords[idim] < self.original_data.shape[idim])

pix_coords = [x[keep].astype(int) for x in pix_coords]

result = np.zeros(shape)

result[keep] = self.original_data.get_data(cid, view=pix_coords)

return result

def get_mask(self, subset_state, view=None):
# Optimize by getting pixel coordinates of original data in new
# frame of reference and getting the mask for these indices
if view is None:
view = Ellipsis
return self.callable(self.original_data.get_mask(subset_state))[view]

def compute_statistic(self, *args, **kwargs):
return self.original_data.compute_statistic(*args, **kwargs)

def compute_histogram(self, *args, **kwargs):
return self.original_data.compute_histogram(*args, **kwargs)

def compute_fixed_resolution_buffer(self, *args, **kwargs):
from glue.core.fixed_resolution_buffer import compute_fixed_resolution_buffer
print(args, kwargs)
return compute_fixed_resolution_buffer(self, *args, **kwargs)
117 changes: 17 additions & 100 deletions glue/plugins/tools/pv_slicer/qt/pv_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@
from glue.config import viewer_tool
from glue.viewers.matplotlib.toolbar_mode import ToolbarModeBase
from glue.viewers.image.qt import ImageViewer


class PVSliceData(Data):
parent_data = None
parent_data_x = None
parent_data_y = None
parent_viewer = None
from glue.plugins.tools.pv_slicer.pv_sliced_data import PVSlicedData


@viewer_tool
Expand Down Expand Up @@ -49,28 +43,14 @@ def _extract_callback(self, mode):

vx, vy = mode.roi().to_polygon()

pv_slice, x, y, wcs = _slice_from_path(vx, vy, self.viewer.state.reference_data,
self.viewer.state.layers[0].attribute,
self.viewer.state.wcsaxes_slice[::-1])

xlabel = "Position along path"
if wcs is None:
ylabel = "Cube slice index"
else:
ylabel = _slice_label(self.viewer.state.reference_data,
self.viewer.state.wcsaxes_slice[::-1])

wcs.wcs.ctype = [xlabel, ylabel]

data = PVSliceData(label=self.viewer.state.reference_data.label + " [slice]")
data.coords = coordinates_from_wcs(wcs)
data[self.viewer.state.layers[0].attribute] = pv_slice

selected = self.viewer.session.application.selected_layers()

if len(selected) == 1 and isinstance(selected[0], PVSliceData):
selected[0].update_values_from_data(data)
if len(selected) == 1 and isinstance(selected[0], PVSlicedData):
data = selected[0]
data.original_data = self.viewer.state.reference_data
data.x_att = self.viewer.state.x_att
data.y_att = self.viewer.state.y_att
data.set_xy(vx, vy)
open_viewer = True
for tab in self.viewer.session.application.viewers:
for viewer in tab:
Expand All @@ -80,15 +60,14 @@ def _extract_callback(self, mode):
if not open_viewer:
break
else:
data = PVSlicedData(self.viewer.state.reference_data,
self.viewer.state.x_att, vx,
self.viewer.state.y_att, vy,
label=self.viewer.state.reference_data.label + " [slice]")
data.parent_viewer = self.viewer
self.viewer.session.data_collection.append(data)
open_viewer = True

# TODO: use weak references
data.parent_data = self.viewer.state.reference_data
data.parent_data_x = x
data.parent_data_y = y
data.parent_viewer = self.viewer

if open_viewer:
viewer = self.viewer.session.application.new_data_viewer(ImageViewer, data=data)

Expand All @@ -114,22 +93,25 @@ def __init__(self, *args, **kwargs):
self.viewer.state.add_callback('reference_data', self._on_reference_data_change)

def _on_reference_data_change(self, reference_data):
self.enabled = isinstance(reference_data, PVSliceData)
self.enabled = isinstance(reference_data, PVSlicedData)
self.data = reference_data

def _on_move(self, mode):

# Find position of click in the image viewer
xdata, ydata = self._event_xdata, self._event_ydata

if xdata is None or ydata is None:
return

# TODO: Make this robust in case the axes have been swapped

# Find position slice where cursor is
ind = int(round(np.clip(xdata, 0, self.data.shape[1] - 1)))

# Find pixel coordinate in input image for this slice
x = self.data.parent_data_x[ind]
y = self.data.parent_data_y[ind]
x = self.data.x[ind]
y = self.data.y[ind]

# The 3-rd coordinate in the input WCS is simply the second
# coordinate in the PV slice.
Expand All @@ -142,71 +124,6 @@ def _on_move(self, mode):
self.data.parent_viewer.state.slices = tuple(s)


def _slice_from_path(x, y, data, attribute, slc):
"""
Extract a PV-like slice from a cube
:param x: An array of x values to extract (pixel units)
:param y: An array of y values to extract (pixel units)
:param data: :class:`~glue.core.data.Data`
:param attribute: :claass:`~glue.core.data.Component`
:param slc: orientation of the image widget that `pts` are defined on
:returns: (slice, x, y)
slice is a 2D Numpy array, corresponding to a "PV ribbon"
cutout from the cube
x and y are the resampled points along which the
ribbon is extracted
:note: For >3D cubes, the "V-axis" of the PV slice is the longest
cube axis ignoring the x/y axes of `slc`
"""
from glue.external.pvextractor import Path, extract_pv_slice
p = Path(list(zip(x, y)))

cube = data[attribute]
dims = list(range(data.ndim))
s = list(slc)
ind = _slice_index(data, slc)

cube_wcs = getattr(data.coords, 'wcs', None)

# transpose cube to (z, y, x, <whatever>)
def _swap(x, s, i, j):
x[i], x[j] = x[j], x[i]
s[i], s[j] = s[j], s[i]

_swap(dims, s, ind, 0)
_swap(dims, s, s.index('y'), 1)
_swap(dims, s, s.index('x'), 2)

cube = cube.transpose(dims)

if cube_wcs is not None:
cube_wcs = cube_wcs.sub([data.ndim - nx for nx in dims[::-1]])

# slice down from >3D to 3D if needed
s = tuple([slice(None)] * 3 + [slc[d] for d in dims[3:]])
cube = cube[s]

# sample cube
spacing = 1 # pixel
x, y = [np.round(_x).astype(int) for _x in p.sample_points(spacing)]

from astropy.wcs import WCS

try:
result = extract_pv_slice(cube, path=p, wcs=cube_wcs, order=0)
wcs = WCS(result.header)
except Exception: # sometimes pvextractor complains due to wcs. Try to recover
result = extract_pv_slice(cube, path=p, wcs=None, order=0)
wcs = None

data = result.data

return data, x, y, wcs


def _slice_index(data, slc):
"""
The axis over which to extract PV slices
Expand Down

0 comments on commit b6b7b6e

Please sign in to comment.