From 41d33359c8b538734c03fa3219e5f32c4d2af70d Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Fri, 16 Jul 2021 16:29:52 -0500 Subject: [PATCH] WIP in case my local computer combusts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I promise to revise history soon 😬 --- astrowidgets/bqplot.py | 482 ++++++++++++++++++++++++++ astrowidgets/tests/test_bqplot_api.py | 292 ++++++++++++++++ 2 files changed, 774 insertions(+) create mode 100644 astrowidgets/bqplot.py create mode 100644 astrowidgets/tests/test_bqplot_api.py diff --git a/astrowidgets/bqplot.py b/astrowidgets/bqplot.py new file mode 100644 index 0000000..5b68af0 --- /dev/null +++ b/astrowidgets/bqplot.py @@ -0,0 +1,482 @@ +import numpy as np + +from astropy.coordinates import SkyCoord +from astropy.io import fits +from astropy.nddata import CCDData +from astropy import units as u +import astropy.visualization as apviz + +from bqplot import Figure, LinearScale, Axis, ColorScale, PanZoom +from bqplot_image_gl import ImageGL +from bqplot_image_gl.interacts import (MouseInteraction, + keyboard_events, mouse_events) + +import ipywidgets as ipw + +from matplotlib import cm as cmp +from matplotlib import pyplot +from matplotlib.colors import to_hex + +import traitlets as trait + +# Allowed locations for cursor display +ALLOWED_CURSOR_LOCATIONS = ['top', 'bottom', None] + +# List of marker names that are for internal use only +RESERVED_MARKER_SET_NAMES = ['all'] + + +class _AstroImage(ipw.VBox): + """ + Encapsulate an image as a bqplot figure inside a box. + + bqplot is involved for its pan/zoom capabilities, and it presents as + a box to obscure the usual bqplot properties and methods. + """ + def __init__(self, image_data=None, + display_width=500, + viewer_aspect_ratio=1.0): + super().__init__() + + self._viewer_aspect_ratio = viewer_aspect_ratio + + self._display_width = display_width + self._display_height = self._viewer_aspect_ratio * self._display_width + + + layout = ipw.Layout(width=f'{self._display_width}px', + height=f'{self._display_height}px', + justify_content='center') + + self._figure_layout = layout + + scale_x = LinearScale(min=0, max=1, #self._image_shape[1], + allow_padding=False) + scale_y = LinearScale(min=0, max=1, #self._image_shape[0], + allow_padding=False) + self._scales = {'x': scale_x, 'y': scale_y} + axis_x = Axis(scale=scale_x, visible=False) + axis_y = Axis(scale=scale_y, orientation='vertical', visible=False) + scales_image = {'x': scale_x, 'y': scale_y, + 'image': ColorScale(max=1.114, min=2902, + scheme='Greys')} + + self._figure = Figure(scales=self._scales, axes=[axis_x, axis_y], + fig_margin=dict(top=0, left=0, + right=0, bottom=0), + layout=layout) + + self._image = ImageGL(scales=scales_image) + + self._figure.marks = (self._image, ) + + panzoom = PanZoom(scales={'x': [scales_image['x']], + 'y': [scales_image['y']]}) + interaction = MouseInteraction(x_scale=scales_image['x'], + y_scale=scales_image['y'], + move_throttle=70, next=panzoom, + events=keyboard_events + mouse_events) + + self._figure.interaction = interaction + + if image_data: + self.set_data(image_data, reset_view=True) + + self.children = (self._figure, ) + + @property + def data_aspect_ratio(self): + """ + Aspect ratio of the image data, horizontal size over vertical size. + """ + return self._image_shape[0] / self._image_shape[1] + + def reset_scale_to_fit_image(self): + wide = self.data_aspect_ratio < 1 + tall = self.data_aspect_ratio > 1 + square = self.data_aspect_ratio == 1 + + if wide: + self._scales['x'].min = 0 + self._scales['x'].max = self._image_shape[1] + self._set_scale_aspect_ratio_to_match_viewer() + elif tall or square: + self._scales['y'].min = 0 + self._scales['y'].max = self._image_shape[0] + self._set_scale_aspect_ratio_to_match_viewer(reset_scale='x') + + # Great, now let's center + self.center = (self._image_shape[1]/2, + self._image_shape[0]/2) + + + def _set_scale_aspect_ratio_to_match_viewer(self, + reset_scale='y'): + # Set the scales so that they match the aspect ratio + # of the viewer, preserving the current image center. + width_x, width_y = self.scale_widths + frozen_width = dict(y=width_x, x=width_y) + scale_aspect = width_x / width_y + figure_x = float(self._figure.layout.width[:-2]) + figure_y = float(self._figure.layout.height[:-2]) + figure_aspect = figure_x / figure_y + current_center = self.center + if abs(figure_aspect - scale_aspect) > 1e-4: + # Make the scale aspect ratio match the + # figure layout aspect ratio + if reset_scale == 'y': + scale_factor = 1/ figure_aspect + else: + scale_factor = figure_aspect + + self._scales[reset_scale].min = 0 + self._scales[reset_scale].max = frozen_width[reset_scale] * scale_factor + self.center = current_center + + def set_data(self, image_data, reset_view=True): + self._image_shape = image_data.shape + + if reset_view: + self.reset_scale_to_fit_image() + + # Set the image data and map it to the bqplot figure so that + # cursor location corresponds to the underlying array index. + self._image.image = image_data + self._image.x = [0, self._image_shape[1]] + self._image.y = [0, self._image_shape[0]] + + @property + def center(self): + """ + Center of current view in pixels in x, y. + """ + x_center = (self._scales['x'].min + self._scales['x'].max) / 2 + y_center = (self._scales['y'].min + self._scales['y'].max) / 2 + return (x_center, y_center) + + @property + def scale_widths(self): + width_x = self._scales['x'].max - self._scales['x'].min + width_y = self._scales['y'].max - self._scales['y'].min + return (width_x, width_y) + + @center.setter + def center(self, value): + x_c, y_c = value + + width_x, width_y = self.scale_widths + self._scales['x'].max = x_c + width_x / 2 + self._scales['x'].min = x_c - width_x / 2 + self._scales['y'].max = y_c + width_y / 2 + self._scales['y'].min = y_c - width_y / 2 + + def set_color(self, colors): + # colors here means a list of hex colors + self._image.scales['image'].colors = colors + + def save_png(self, filename): + self._figure.save_png(filename) + + def save_svg(self, filename): + self._figure.save_svg(filename) + + +def bqcolors(colormap, reverse=False): + # bqplot-image-gl has 256 levels + LEVELS = 256 + + # Make a matplotlib colormap object + mpl = cmp.get_cmap(colormap, LEVELS) + + # Get RGBA colors + mpl_colors = mpl(np.linspace(0, 1, LEVELS)) + + # Convert RGBA to hex + bq_colors = [to_hex(mpl_colors[i, :]) for i in range(LEVELS)] + + if reverse: + bq_colors = bq_colors[::-1] + + return bq_colors + + +class MarkerTableManager: + def __init__(self): + pass + + def add_markers(self, table, x_colname='x', y_colname='y', + skycoord_colname='coord', use_skycoord=False, + marker_name=None): + + # For now we always convert marker locations to pixels; see + # comment below. + coord_type = 'data' + + if marker_name is None: + marker_name = self._default_mark_tag_name + + self.validate_marker_name(marker_name) + self._marktags.add(marker_name) + + # Extract coordinates from table. + # They are always arrays, not scalar. + if use_skycoord: + image = self._viewer.get_image() + if image is None: + raise ValueError('Cannot get image from viewer') + if image.wcs.wcs is None: + raise ValueError( + 'Image has no valid WCS, ' + 'try again with use_skycoord=False') + coord_val = table[skycoord_colname] + # TODO: Maybe switch back to letting Ginga handle conversion + # to pixel coordinates. + # Convert to pixels here (instead of in Ginga) because conversion + # in Ginga was reportedly very slow. + coord_x, coord_y = image.wcs.wcs.all_world2pix( + coord_val.ra.deg, coord_val.dec.deg, 0) + # In the event a *single* marker has been added, coord_x and coord_y + # will be scalars. Make them arrays always. + if np.ndim(coord_x) == 0: + coord_x = np.array([coord_x]) + coord_y = np.array([coord_y]) + else: # Use X,Y + coord_x = table[x_colname].data + coord_y = table[y_colname].data + # Convert data coordinates from 1-indexed to 0-indexed + if self._pixel_offset != 0: + # Don't use the in-place operator -= here that modifies + # the input table. + coord_x = coord_x - self._pixel_offset + coord_y = coord_y - self._pixel_offset + + # Prepare canvas and retain existing marks + try: + c_mark = self._viewer.canvas.get_object_by_tag(marker_name) + except Exception: + objs = [] + else: + objs = c_mark.objects + self._viewer.canvas.delete_object_by_tag(marker_name) + + # TODO: Test to see if we can mix WCS and data on the same canvas + objs += [self._marker(x=x, y=y, coord=coord_type) + for x, y in zip(coord_x, coord_y)] + self._viewer.canvas.add(self.dc.CompoundObject(*objs), tag=marker_name) + + +""" +next(iter(imviz.app._viewer_store.values())).figure +""" +STRETCHES = dict( + linear=apviz.LinearStretch, + sqrt=apviz.SqrtStretch, + histeq=apviz.HistEqStretch, + log=apviz.LogStretch + # ... +) + + +class ImageWidget(ipw.VBox): + click_center = trait.Bool(default_value=False).tag(sync=True) + click_drag = trait.Bool(default_value=False).tag(sync=True) + scroll_pan = trait.Bool(default_value=False).tag(sync=True) + image_width = trait.Int(help="Width of the image (not viewer)").tag(sync=True) + image_height = trait.Int(help="Height of the image (not viewer)").tag(sync=True) + zoom_level = trait.Float(help="Current zoom of the view").tag(sync=True) + marker = trait.Any(help="Markers").tag(sync=True) + cuts = trait.Any(help="Cut levels", allow_none=True).tag(sync=True) + stretch = trait.Unicode(help='Stretch algorithm name', allow_none=True).tag(sync=True) + + def __init__(self, *args, image_width=500, image_height=500): + super().__init__(*args) + self.image_width = image_width + self.image_height = image_height + viewer_aspect = self.image_width / self.image_height + self._astro_im = _AstroImage(display_width=self.image_width, + viewer_aspect_ratio=viewer_aspect) + self._interval = None + self._stretch = None + self._colormap = 'Grays' + + def _interval_and_stretch(self): + """ + Stretch and normalize the data before sending to the viewer. + """ + interval = self._get_interval() + intervaled = interval(self._data) + + stretch = self._get_stretch() + if stretch: + stretched = stretch(intervaled) + else: + stretched = intervaled + + return stretched + + def _send_data(self): + self._astro_im.set_data(self._interval_and_stretch(), + reset_view=False) + + def _get_interval(self): + if self._interval is None: + return apviz.MinMaxInterval() + else: + return self._interval + + def _get_stretch(self): + return self._stretch + + @trait.validate('stretch') + def _validate_stretch(self, proposal): + proposed_stretch = proposal['value'] + if (proposed_stretch not in STRETCHES.keys() and + proposed_stretch is not None): + + raise ValueError(f'{proposed_stretch} is not a valid value. ' + 'The stretch must be None or ' + 'one of these values: ' + f'{sorted(STRETCHES.keys())}') + + return proposed_stretch + + @trait.observe('stretch') + def _observe_stretch(self, change): + self._stretch = STRETCHES[change['new']] if change['new'] else None + + @trait.validate('cuts') + def _validate_cuts(self, proposal): + # Allow these: + # - a two-item thing (tuple, list, whatever) + # - an Astropy interval + # - None + proposed_cuts = proposal['value'] + + bad_value_error = (f"{proposed_cuts} is not a valid value. " + "cuts must be either None, " + "an astropy interval, or list/tuple " + "of length 2.") + + if ((proposed_cuts is None) or + isinstance(proposed_cuts, apviz.BaseInterval)): + return proposed_cuts + else: + try: + length = len(proposed_cuts) + assert length == 2 + # Tests expect this to be a tuple... + proposed_cuts = tuple(proposed_cuts) + except (TypeError, AssertionError): + raise ValueError(bad_value_error) + + return proposed_cuts + + @trait.observe('cuts') + def _observe_cuts(self, change): + # This needs to handle only the case when the cuts is a + # tuple/list of length 2. That is interpreted as a ManualInterval. + cuts = change['new'] + if cuts is not None: + if not isinstance(cuts, apviz.BaseInterval): + self._interval = apviz.ManualInterval(*cuts) + else: + self._interval = cuts + + # The methods, grouped loosely by purpose + + # Methods for loading data + def load_fits(self, file_name_or_HDU): + if isinstance(file_name_or_HDU, str): + ccd = CCDData.read(file) + elif isinstance(file_name_or_HDU, + (fits.ImageHDU, fits.CompImageHDU, fits.PrimaryHDU)): + try: + ccd_unit = u.Unit(file_name_or_HDU.header['bunit']) + except (KeyError, ValueError): + ccd_unit = u.dimensionless_unscaled + ccd = CCDData(file_name_or_HDU.data, + header=file_name_or_HDU.header, + unit=ccd_unit) + else: + raise ValueError(f'{file_name_or_HDU} is an invalid value. It must' + ' be a string or an astropy.io.fits HDU.') + + self._ccd = ccd + self._data = ccd.data + self._send_data() + + def load_array(self, array): + self._data = array + self._send_data() + + def load_nddata(self, data): + self._ccd = data + self._data = self._ccd.data + self._send_data() + + # Saving contents of the view and accessing the view + def save(self, filename): + if filename.endswith('.png'): + self._astro_im.save_png(filename) + elif filename.endswith('.svg'): + self._astro_im.save_svg(filename) + else: + raise NotImplementedError('Saving is not implemented for that' + 'file type. Use .png or .svg') + + def set_colormap(self, cmap_name, reverse=False): + self._astro_im.set_color(bqcolors(cmap_name, reverse=reverse)) + self._colormap = cmap_name + + @property + def colormap_options(self): + return pyplot.colormaps() + + # # Marker-related methods + # @abstractmethod + # def start_marking(self): + # raise NotImplementedError + + # @abstractmethod + # def stop_marking(self): + # raise NotImplementedError + + # @abstractmethod + # def add_markers(self): + # raise NotImplementedError + + # @abstractmethod + # def get_markers(self): + # raise NotImplementedError + + # @abstractmethod + # def remove_markers(self): + # raise NotImplementedError + + # @abstractmethod + # def get_all_markers(self): + # raise NotImplementedError + + # @abstractmethod + # def get_markers_by_name(self, marker_name=None): + # raise NotImplementedError + + # Methods that modify the view + def center_on(self, point): + if isinstance(point, SkyCoord): + if self._wcs is None: + raise ValueError('The image must have a WCS to be able ' + 'to center on a coordinate.') + pixel = self._wcs.world_to_pixel(point) + else: + pixel = point + + self._astro_im.center = pixel + + # @abstractmethod + # def offset_to(self): + # raise NotImplementedError + + # @abstractmethod + # def zoom(self): + # raise NotImplementedError diff --git a/astrowidgets/tests/test_bqplot_api.py b/astrowidgets/tests/test_bqplot_api.py new file mode 100644 index 0000000..c330efd --- /dev/null +++ b/astrowidgets/tests/test_bqplot_api.py @@ -0,0 +1,292 @@ +import numpy as np + +import pytest + +from astropy.io import fits +from astropy.nddata import NDData +from astropy.table import Table +from astropy.visualization import BaseStretch, AsymmetricPercentileInterval + +from astrowidgets.bqplot import ImageWidget, ALLOWED_CURSOR_LOCATIONS +from astrowidgets.interface_definition import ImageViewerInterface + + +def test_consistent_interface(): + iw = ImageWidget() + assert isinstance(iw, ImageViewerInterface) + + +def test_load_fits(): + image = ImageWidget() + data = np.random.random([100, 100]) + hdu = fits.PrimaryHDU(data=data) + image.load_fits(hdu) + + +def test_load_nddata(): + image = ImageWidget() + data = np.random.random([100, 100]) + nddata = NDData(data) + image.load_nddata(nddata) + + +def test_load_array(): + image = ImageWidget() + data = np.random.random([100, 100]) + image.load_array(data) + + +def test_center_on(): + image = ImageWidget() + x = 10 + y = 10 + image.center_on((x, y)) + + +def test_offset_to(): + image = ImageWidget() + dx = 10 + dy = 10 + image.offset_to(dx, dy) + + +def test_zoom_level(): + image = ImageWidget() + image.zoom_level = 5 + assert image.zoom_level == 5 + + +def test_zoom(): + image = ImageWidget() + image.zoom_level = 3 + val = 2 + image.zoom(val) + assert image.zoom_level == 6 + + +@pytest.mark.xfail(reason='Not implemented yet') +def test_select_points(): + image = ImageWidget() + image.select_points() + + +def test_get_selection(): + image = ImageWidget() + marks = image.get_markers() + assert isinstance(marks, Table) or marks is None + + +def test_stop_marking(): + image = ImageWidget() + # This is not much of a test... + image.stop_marking(clear_markers=True) + assert image.get_markers() is None + assert image.is_marking is False + + +def test_is_marking(): + image = ImageWidget() + assert image.is_marking in [True, False] + with pytest.raises(AttributeError): + image.is_marking = True + + +def test_start_marking(): + image = ImageWidget() + + # Setting these to check that start_marking affects them. + image.click_center = True + assert image.click_center + image.scroll_pan = False + assert not image.scroll_pan + + marker_style = {'color': 'yellow', 'radius': 10, 'type': 'cross'} + image.start_marking(marker_name='something', + marker=marker_style) + assert image.is_marking + assert image.marker == marker_style + assert not image.click_center + assert not image.click_drag + + # scroll_pan better activate when marking otherwise there is + # no way to pan while interactively marking + assert image.scroll_pan + + # Make sure that when we stop_marking we get our old + # controls back. + image.stop_marking() + assert image.click_center + assert not image.scroll_pan + + # Make sure that click_drag is restored as expected + image.click_drag = True + image.start_marking() + assert not image.click_drag + image.stop_marking() + assert image.click_drag + + +def test_add_markers(): + image = ImageWidget() + table = Table(data=np.random.randint(0, 100, [5, 2]), + names=['x', 'y'], dtype=('int', 'int')) + image.add_markers(table, x_colname='x', y_colname='y', + skycoord_colname='coord') + + +def test_set_markers(): + image = ImageWidget() + image.marker = {'color': 'yellow', 'radius': 10, 'type': 'cross'} + assert 'cross' in str(image.marker) + assert 'yellow' in str(image.marker) + assert '10' in str(image.marker) + + +def test_reset_markers(): + image = ImageWidget() + # First test: this shouldn't raise any errors + # (it also doesn't *do* anything...) + image.reset_markers() + assert image.get_markers() is None + table = Table(data=np.random.randint(0, 100, [5, 2]), + names=['x', 'y'], dtype=('int', 'int')) + image.add_markers(table, x_colname='x', y_colname='y', + skycoord_colname='coord', marker_name='test') + image.add_markers(table, x_colname='x', y_colname='y', + skycoord_colname='coord', marker_name='test2') + image.reset_markers() + with pytest.raises(ValueError): + image.get_markers(marker_name='test') + with pytest.raises(ValueError): + image.get_markers(marker_name='test2') + + +def test_remove_markers(): + image = ImageWidget() + # Add a tag name... + image._marktags.add(image._default_mark_tag_name) + with pytest.raises(ValueError) as e: + image.remove_markers('arf') + assert 'arf' in str(e.value) + + +def test_stretch(): + image = ImageWidget() + with pytest.raises(ValueError) as e: + image.stretch = 'not a valid value' + assert 'must be one of' in str(e.value) + + image.stretch = 'log' + assert isinstance(image.stretch, (BaseStretch, str)) + + +def test_cuts(): + image = ImageWidget() + + # An invalid string should raise an error + with pytest.raises(ValueError) as e: + image.cuts = 'not a valid value' + assert 'must be one of' in str(e.value) + + # Setting cuts to something with incorrect length + # should raise an error. + with pytest.raises(ValueError) as e: + image.cuts = (1, 10, 100) + assert 'length 2' in str(e.value) + + # These ought to succeed + + # ⚠️ clarify this + # image.cuts = 'histogram' + # assert image.cuts == (0.0, 0.0) + + image.cuts = [10, 100] + assert image.cuts == (10, 100) + + # This should work without error + image.cuts = AsymmetricPercentileInterval(1, 99.5) + + +def test_colormap(): + image = ImageWidget() + cmap_desired = 'viridis' + cmap_list = image.colormap_options + assert len(cmap_list) > 0 and cmap_desired in cmap_list + + image.set_colormap(cmap_desired) + + +def test_cursor(): + image = ImageWidget() + assert image.cursor in ALLOWED_CURSOR_LOCATIONS + with pytest.raises(ValueError): + image.cursor = 'not a valid option' + image.cursor = 'bottom' + assert image.cursor == 'bottom' + + +def test_click_drag(): + image = ImageWidget() + # Set this to ensure that click_drag turns it off + image._click_center = True + + # Make sure that setting click_drag to False does not turn off + # click_center. + + image.click_drag = False + assert image.click_center + + image.click_drag = True + + assert not image.click_center + + # If is_marking is true then trying to click_drag + # should fail. + image._is_marking = True + with pytest.raises(ValueError) as e: + image.click_drag = True + assert 'Interactive marking' in str(e.value) + + +def test_click_center(): + image = ImageWidget() + assert (image.click_center is True) or (image.click_center is False) + + # Set click_drag True and check that click_center affects it appropriately + image.click_drag = True + + image.click_center = False + assert image.click_drag + + image.click_center = True + assert not image.click_drag + + image.start_marking() + # If marking is in progress then setting click center should fail + with pytest.raises(ValueError) as e: + image.click_center = True + assert 'Cannot set' in str(e.value) + + # setting to False is fine though so no error is expected here + image.click_center = False + + +def test_scroll_pan(): + image = ImageWidget() + + # Make sure scroll_pan is actually settable + for val in [True, False]: + image.scroll_pan = val + assert image.scroll_pan is val + + +def test_save(): + image = ImageWidget() + filename = 'woot.png' + image.save(filename) + + +def test_width_height(): + image = ImageWidget(image_width=250, image_height=100) + assert image.image_width == 250 + assert image.image_height == 100