-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP in case my local computer combusts
I promise to revise history soon 😬
- Loading branch information
Showing
1 changed file
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
import numpy as np | ||
from bqplot import Figure, LinearScale, Axis, ColorScale, PanZoom | ||
from bqplot_image_gl import ImageGL | ||
from bqplot_image_gl.interacts import (MouseInteraction, | ||
keyboard_events, mouse_events) | ||
|
||
import ipywidgets as ipw | ||
import traitlets as trait | ||
|
||
# Allowed locations for cursor display | ||
ALLOWED_CURSOR_LOCATIONS = ['top', 'bottom', None] | ||
|
||
# List of marker names that are for internal use only | ||
RESERVED_MARKER_SET_NAMES = ['all'] | ||
|
||
|
||
class _AstroImage(ipw.VBox): | ||
""" | ||
Encapsulate an image as a bqplot figure inside a box. | ||
bqplot is involved for its pan/zoom capabilities, and it presents as | ||
a box to obscure the usual bqplot properties and methods. | ||
""" | ||
def __init__(self, image_data=None, | ||
display_width=500, | ||
viewer_aspect_ratio=1.0): | ||
super().__init__() | ||
|
||
self._viewer_aspect_ratio = viewer_aspect_ratio | ||
|
||
self._display_width = display_width | ||
self._display_height = self._viewer_aspect_ratio * self._display_width | ||
|
||
|
||
layout = ipw.Layout(width=f'{self._display_width}px', | ||
height=f'{self._display_height}px', | ||
justify_content='center') | ||
|
||
self._figure_layout = layout | ||
|
||
scale_x = LinearScale(min=0, max=1, #self._image_shape[1], | ||
allow_padding=False) | ||
scale_y = LinearScale(min=0, max=1, #self._image_shape[0], | ||
allow_padding=False) | ||
self._scales = {'x': scale_x, 'y': scale_y} | ||
axis_x = Axis(scale=scale_x, visible=False) | ||
axis_y = Axis(scale=scale_y, orientation='vertical', visible=False) | ||
scales_image = {'x': scale_x, 'y': scale_y, | ||
'image': ColorScale(max=1.114, min=2902, | ||
scheme='Greys')} | ||
|
||
self._figure = Figure(scales=self._scales, axes=[axis_x, axis_y], | ||
fig_margin=dict(top=0, left=0, | ||
right=0, bottom=0), | ||
layout=layout) | ||
|
||
self._image = ImageGL(scales=scales_image) | ||
|
||
self._figure.marks = (self._image, ) | ||
|
||
panzoom = PanZoom(scales={'x': [scales_image['x']], | ||
'y': [scales_image['y']]}) | ||
interaction = MouseInteraction(x_scale=scales_image['x'], | ||
y_scale=scales_image['y'], | ||
move_throttle=70, next=panzoom, | ||
events=keyboard_events + mouse_events) | ||
|
||
self._figure.interaction = interaction | ||
|
||
if image_data: | ||
self.set_data(image_data, reset_view=True) | ||
|
||
self.children = (self._figure, ) | ||
|
||
@property | ||
def data_aspect_ratio(self): | ||
""" | ||
Aspect ratio of the image data, horizontal size over vertical size. | ||
""" | ||
return self._image_shape[0] / self._image_shape[1] | ||
|
||
def reset_scale_to_fit_image(self): | ||
wide = self.data_aspect_ratio < 1 | ||
tall = self.data_aspect_ratio > 1 | ||
square = self.data_aspect_ratio == 1 | ||
|
||
if wide: | ||
self._scales['x'].min = 0 | ||
self._scales['x'].max = self._image_shape[1] | ||
self._set_scale_aspect_ratio_to_match_viewer() | ||
elif tall or square: | ||
self._scales['y'].min = 0 | ||
self._scales['y'].max = self._image_shape[0] | ||
self._set_scale_aspect_ratio_to_match_viewer(reset_scale='x') | ||
|
||
# Great, now let's center | ||
self.center = (self._image_shape[1]/2, | ||
self._image_shape[0]/2) | ||
|
||
|
||
def _set_scale_aspect_ratio_to_match_viewer(self, | ||
reset_scale='y'): | ||
# Set the scales so that they match the aspect ratio | ||
# of the viewer, preserving the current image center. | ||
width_x, width_y = self.scale_widths | ||
frozen_width = dict(y=width_x, x=width_y) | ||
scale_aspect = width_x / width_y | ||
figure_x = float(self._figure.layout.width[:-2]) | ||
figure_y = float(self._figure.layout.height[:-2]) | ||
figure_aspect = figure_x / figure_y | ||
current_center = self.center | ||
if abs(figure_aspect - scale_aspect) > 1e-4: | ||
# Make the scale aspect ratio match the | ||
# figure layout aspect ratio | ||
if reset_scale == 'y': | ||
scale_factor = 1/ figure_aspect | ||
else: | ||
scale_factor = figure_aspect | ||
|
||
self._scales[reset_scale].min = 0 | ||
self._scales[reset_scale].max = frozen_width[reset_scale] * scale_factor | ||
self.center = current_center | ||
|
||
def set_data(self, image_data, reset_view=True): | ||
self._image_shape = image_data.shape | ||
|
||
if reset_view: | ||
self.reset_scale_to_fit_image() | ||
|
||
# Set the image data and map it to the bqplot figure so that | ||
# cursor location corresponds to the underlying array index. | ||
self._image.image = image_data | ||
self._image.x = [0, self._image_shape[1]] | ||
self._image.y = [0, self._image_shape[0]] | ||
|
||
@property | ||
def center(self): | ||
""" | ||
Center of current view in pixels in x, y. | ||
""" | ||
x_center = (self._scales['x'].min + self._scales['x'].max) / 2 | ||
y_center = (self._scales['y'].min + self._scales['y'].max) / 2 | ||
return (x_center, y_center) | ||
|
||
@property | ||
def scale_widths(self): | ||
width_x = self._scales['x'].max - self._scales['x'].min | ||
width_y = self._scales['y'].max - self._scales['y'].min | ||
return (width_x, width_y) | ||
|
||
@center.setter | ||
def center(self, value): | ||
x_c, y_c = value | ||
|
||
width_x, width_y = self.scale_widths | ||
self._scales['x'].max = x_c + width_x / 2 | ||
self._scales['x'].min = x_c - width_x / 2 | ||
self._scales['y'].max = y_c + width_y / 2 | ||
self._scales['y'].min = y_c - width_y / 2 | ||
|
||
def set_color(self, colors): | ||
# colors here means a list of hex colors | ||
self._image.scales['image'].colors = colors | ||
|
||
|
||
def bqcolors(colormap, reverse=False): | ||
from matplotlib import cm as cmp | ||
from matplotlib.colors import to_hex | ||
|
||
# bqplot-image-gl has 256 levels | ||
LEVELS = 256 | ||
|
||
# Make a matplotlib colormap object | ||
mpl = cmp.get_cmap(colormap, LEVELS) | ||
|
||
# Get RGBA colors | ||
mpl_colors = mpl(np.linspace(0, 1, LEVELS)) | ||
|
||
# Convert RGBA to hex | ||
bq_colors = [to_hex(mpl_colors[i, :]) for i in range(LEVELS)] | ||
|
||
if reverse: | ||
bq_colors = bq_colors[::-1] | ||
|
||
return bq_colors | ||
|
||
""" | ||
next(iter(imviz.app._viewer_store.values())).figure | ||
""" | ||
|
||
|
||
class ImageWidget(ipw.VBox): | ||
click_center = trait.Bool(default_value=False).tag(sync=True) | ||
click_drag = trait.Bool(default_value=False).tag(sync=True) | ||
scroll_pan = trait.Bool(default_value=False).tag(sync=True) | ||
image_width = trait.Int(help="Width of the image (not viewer)").tag(sync=True) | ||
image_height = trait.Int(help="Height of the image (not viewer)").tag(sync=True) | ||
zoom_level = trait.Float(help="Current zoom of the view").tag(sync=True) | ||
marker = trait.Any(help="Markers").tag(sync=True) | ||
cuts = trait.Tuple(help="Cut levels").tag(sync=True) | ||
stretch = trait.Unicode(help='Stretch algorithm name').tag(sync=True) | ||
|
||
def __init__(self, *args, image_width=500, image_height=500): | ||
super().__init__(*args) | ||
self.image_width = image_width | ||
self.image_height = image_height | ||
self._astro_im = _AstroImage() | ||
|
||
# The methods, grouped loosely by purpose | ||
|
||
# Methods for loading data | ||
# @abstractmethod | ||
# def load_fits(self, file): | ||
# raise NotImplementedError | ||
|
||
def load_array(self, array): | ||
raise NotImplementedError | ||
|
||
# @abstractmethod | ||
# def load_nddata(self, data): | ||
# raise NotImplementedError | ||
|
||
# Saving contents of the view and accessing the view | ||
@abstractmethod | ||
def save(self, filename): | ||
raise NotImplementedError | ||
|
||
# Marker-related methods | ||
@abstractmethod | ||
def start_marking(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def stop_marking(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def add_markers(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_markers(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def remove_markers(self): | ||
raise NotImplementedError | ||
|
||
# @abstractmethod | ||
# def get_all_markers(self): | ||
# raise NotImplementedError | ||
|
||
# @abstractmethod | ||
# def get_markers_by_name(self, marker_name=None): | ||
# raise NotImplementedError | ||
|
||
# Methods that modify the view | ||
@abstractmethod | ||
def center_on(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def offset_to(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def zoom(self): | ||
raise NotImplementedError |