Skip to content

Commit

Permalink
WIP in case my local computer combusts
Browse files Browse the repository at this point in the history
I promise to revise history soon 😬
  • Loading branch information
mwcraig committed Jul 16, 2021
1 parent ab66d3f commit edf546e
Showing 1 changed file with 268 additions and 0 deletions.
268 changes: 268 additions & 0 deletions astrowidgets/bqplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
import numpy as np
from bqplot import Figure, LinearScale, Axis, ColorScale, PanZoom
from bqplot_image_gl import ImageGL
from bqplot_image_gl.interacts import (MouseInteraction,
keyboard_events, mouse_events)

import ipywidgets as ipw
import traitlets as trait

# Allowed locations for cursor display
ALLOWED_CURSOR_LOCATIONS = ['top', 'bottom', None]

# List of marker names that are for internal use only
RESERVED_MARKER_SET_NAMES = ['all']


class _AstroImage(ipw.VBox):
"""
Encapsulate an image as a bqplot figure inside a box.
bqplot is involved for its pan/zoom capabilities, and it presents as
a box to obscure the usual bqplot properties and methods.
"""
def __init__(self, image_data=None,
display_width=500,
viewer_aspect_ratio=1.0):
super().__init__()

self._viewer_aspect_ratio = viewer_aspect_ratio

self._display_width = display_width
self._display_height = self._viewer_aspect_ratio * self._display_width


layout = ipw.Layout(width=f'{self._display_width}px',
height=f'{self._display_height}px',
justify_content='center')

self._figure_layout = layout

scale_x = LinearScale(min=0, max=1, #self._image_shape[1],
allow_padding=False)
scale_y = LinearScale(min=0, max=1, #self._image_shape[0],
allow_padding=False)
self._scales = {'x': scale_x, 'y': scale_y}
axis_x = Axis(scale=scale_x, visible=False)
axis_y = Axis(scale=scale_y, orientation='vertical', visible=False)
scales_image = {'x': scale_x, 'y': scale_y,
'image': ColorScale(max=1.114, min=2902,
scheme='Greys')}

self._figure = Figure(scales=self._scales, axes=[axis_x, axis_y],
fig_margin=dict(top=0, left=0,
right=0, bottom=0),
layout=layout)

self._image = ImageGL(scales=scales_image)

self._figure.marks = (self._image, )

panzoom = PanZoom(scales={'x': [scales_image['x']],
'y': [scales_image['y']]})
interaction = MouseInteraction(x_scale=scales_image['x'],
y_scale=scales_image['y'],
move_throttle=70, next=panzoom,
events=keyboard_events + mouse_events)

self._figure.interaction = interaction

if image_data:
self.set_data(image_data, reset_view=True)

self.children = (self._figure, )

@property
def data_aspect_ratio(self):
"""
Aspect ratio of the image data, horizontal size over vertical size.
"""
return self._image_shape[0] / self._image_shape[1]

def reset_scale_to_fit_image(self):
wide = self.data_aspect_ratio < 1
tall = self.data_aspect_ratio > 1
square = self.data_aspect_ratio == 1

if wide:
self._scales['x'].min = 0
self._scales['x'].max = self._image_shape[1]
self._set_scale_aspect_ratio_to_match_viewer()
elif tall or square:
self._scales['y'].min = 0
self._scales['y'].max = self._image_shape[0]
self._set_scale_aspect_ratio_to_match_viewer(reset_scale='x')

# Great, now let's center
self.center = (self._image_shape[1]/2,
self._image_shape[0]/2)


def _set_scale_aspect_ratio_to_match_viewer(self,
reset_scale='y'):
# Set the scales so that they match the aspect ratio
# of the viewer, preserving the current image center.
width_x, width_y = self.scale_widths
frozen_width = dict(y=width_x, x=width_y)
scale_aspect = width_x / width_y
figure_x = float(self._figure.layout.width[:-2])
figure_y = float(self._figure.layout.height[:-2])
figure_aspect = figure_x / figure_y
current_center = self.center
if abs(figure_aspect - scale_aspect) > 1e-4:
# Make the scale aspect ratio match the
# figure layout aspect ratio
if reset_scale == 'y':
scale_factor = 1/ figure_aspect
else:
scale_factor = figure_aspect

self._scales[reset_scale].min = 0
self._scales[reset_scale].max = frozen_width[reset_scale] * scale_factor
self.center = current_center

def set_data(self, image_data, reset_view=True):
self._image_shape = image_data.shape

if reset_view:
self.reset_scale_to_fit_image()

# Set the image data and map it to the bqplot figure so that
# cursor location corresponds to the underlying array index.
self._image.image = image_data
self._image.x = [0, self._image_shape[1]]
self._image.y = [0, self._image_shape[0]]

@property
def center(self):
"""
Center of current view in pixels in x, y.
"""
x_center = (self._scales['x'].min + self._scales['x'].max) / 2
y_center = (self._scales['y'].min + self._scales['y'].max) / 2
return (x_center, y_center)

@property
def scale_widths(self):
width_x = self._scales['x'].max - self._scales['x'].min
width_y = self._scales['y'].max - self._scales['y'].min
return (width_x, width_y)

@center.setter
def center(self, value):
x_c, y_c = value

width_x, width_y = self.scale_widths
self._scales['x'].max = x_c + width_x / 2
self._scales['x'].min = x_c - width_x / 2
self._scales['y'].max = y_c + width_y / 2
self._scales['y'].min = y_c - width_y / 2

def set_color(self, colors):
# colors here means a list of hex colors
self._image.scales['image'].colors = colors


def bqcolors(colormap, reverse=False):
from matplotlib import cm as cmp
from matplotlib.colors import to_hex

# bqplot-image-gl has 256 levels
LEVELS = 256

# Make a matplotlib colormap object
mpl = cmp.get_cmap(colormap, LEVELS)

# Get RGBA colors
mpl_colors = mpl(np.linspace(0, 1, LEVELS))

# Convert RGBA to hex
bq_colors = [to_hex(mpl_colors[i, :]) for i in range(LEVELS)]

if reverse:
bq_colors = bq_colors[::-1]

return bq_colors

"""
next(iter(imviz.app._viewer_store.values())).figure
"""


class ImageWidget(ipw.VBox):
click_center = trait.Bool(default_value=False).tag(sync=True)
click_drag = trait.Bool(default_value=False).tag(sync=True)
scroll_pan = trait.Bool(default_value=False).tag(sync=True)
image_width = trait.Int(help="Width of the image (not viewer)").tag(sync=True)
image_height = trait.Int(help="Height of the image (not viewer)").tag(sync=True)
zoom_level = trait.Float(help="Current zoom of the view").tag(sync=True)
marker = trait.Any(help="Markers").tag(sync=True)
cuts = trait.Tuple(help="Cut levels").tag(sync=True)
stretch = trait.Unicode(help='Stretch algorithm name').tag(sync=True)

def __init__(self, *args, image_width=500, image_height=500):
super().__init__(*args)
self.image_width = image_width
self.image_height = image_height
self._astro_im = _AstroImage()

# The methods, grouped loosely by purpose

# Methods for loading data
# @abstractmethod
# def load_fits(self, file):
# raise NotImplementedError

def load_array(self, array):
raise NotImplementedError

# @abstractmethod
# def load_nddata(self, data):
# raise NotImplementedError

# Saving contents of the view and accessing the view
@abstractmethod
def save(self, filename):
raise NotImplementedError

# Marker-related methods
@abstractmethod
def start_marking(self):
raise NotImplementedError

@abstractmethod
def stop_marking(self):
raise NotImplementedError

@abstractmethod
def add_markers(self):
raise NotImplementedError

@abstractmethod
def get_markers(self):
raise NotImplementedError

@abstractmethod
def remove_markers(self):
raise NotImplementedError

# @abstractmethod
# def get_all_markers(self):
# raise NotImplementedError

# @abstractmethod
# def get_markers_by_name(self, marker_name=None):
# raise NotImplementedError

# Methods that modify the view
@abstractmethod
def center_on(self):
raise NotImplementedError

@abstractmethod
def offset_to(self):
raise NotImplementedError

@abstractmethod
def zoom(self):
raise NotImplementedError

0 comments on commit edf546e

Please sign in to comment.