Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API changes to marker methods #1

Open
wants to merge 3 commits into
base: move-to-protocol
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 89 additions & 73 deletions astrowidgets/ginga.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Module containing core functionality of ``astrowidgets``."""
"""The ``astrowidgets.ginga`` module contains a widget implemented with the
Ginga backend.

# STDLIB
For this to work, ``astrowidgets`` must be installed along with the optional
dependencies specified for the Ginga backend; e.g.,::

pip install 'astrowidgets[ginga]'

"""
import functools
import warnings

Expand Down Expand Up @@ -244,13 +250,6 @@ def _mouse_click_cb(self, viewer, event, data_x, data_y):
print('Centered on X={} Y={}'.format(data_x + self._pixel_offset,
data_y + self._pixel_offset))

# def _repr_html_(self):
# """
# Show widget in Jupyter notebook.
# """
# from IPython.display import display
# return display(self._widget)

def load_fits(self, fitsorfn, numhdu=None, memmap=None):
"""
Load a FITS file into the viewer.
Expand Down Expand Up @@ -431,6 +430,21 @@ def start_marking(self, marker_name=None,
"""
Start marking, with option to name this set of markers or
to specify the marker style.

This disables `click_center` and `click_drag`, but enables `scroll_pan`.

Parameters
----------
marker_name : str or `None`, optional
Marker name to use. This is useful if you want to set different
groups of markers. If given, this cannot be already defined in
``RESERVED_MARKER_SET_NAMES`` attribute. If not given, an internal
default is used.

marker : dict or `None`, optional
Set the marker properties; see `marker`. If not given, the current
setting is used.

"""
self._cached_state = dict(click_center=self.click_center,
click_drag=self.click_drag,
Expand All @@ -457,9 +471,9 @@ def stop_marking(self, clear_markers=False):
Parameters
----------
clear_markers : bool, optional
If ``clear_markers`` is `False`, existing markers are
retained until :meth:`reset_markers` is called.
Otherwise, they are erased.
If `False`, existing markers are retained until
:meth:`remove_all_markers` is called.
Otherwise, they are all erased.
"""
if self.is_marking:
self._is_marking = False
Expand All @@ -468,7 +482,7 @@ def stop_marking(self, clear_markers=False):
self.scroll_pan = self._cached_state['scroll_pan']
self._cached_state = {}
if clear_markers:
self.reset_markers()
self.remove_all_markers()

@property
def marker(self):
Expand Down Expand Up @@ -512,9 +526,19 @@ def marker(self, val):
# Only set this once we have successfully created a marker
self._marker_dict = val

def get_markers(self, x_colname='x', y_colname='y',
skycoord_colname='coord',
marker_name=None):
def get_marker_names(self):
"""Return a list of used marker names.

Returns
-------
names : list of str
Sorted list of marker names.

"""
return sorted(self._marktags)

def get_markers_by_name(self, marker_name, x_colname='x', y_colname='y',
skycoord_colname='coord'):
"""
Return the locations of existing markers.

Expand All @@ -536,44 +560,6 @@ def get_markers(self, x_colname='x', y_colname='y',
Table of markers, if any, or ``None``.

"""
if marker_name is None:
marker_name = self._default_mark_tag_name

if marker_name == 'all':
# If it wasn't for the fact that SKyCoord columns can't
# be stacked this would all fit nicely into a list
# comprehension. But they can't, so we delete the
# SkyCoord column if it is present, then add it
# back after we have stacked.
coordinates = []
tables = []
for name in self._marktags:
table = self.get_markers(x_colname=x_colname,
y_colname=y_colname,
skycoord_colname=skycoord_colname,
marker_name=name)
if table is None:
# No markers by this name, skip it
continue

try:
coordinates.extend(c for c in table[skycoord_colname])
except KeyError:
pass
else:
del table[skycoord_colname]
tables.append(table)

if len(tables) == 0:
return None

stacked = vstack(tables, join_type='exact')

if coordinates:
stacked[skycoord_colname] = SkyCoord(coordinates)

return stacked

# We should always allow the default name. The case
# where that table is empty will be handled in a moment.
if (marker_name not in self._marktags
Expand All @@ -583,9 +569,9 @@ def get_markers(self, x_colname='x', y_colname='y',
try:
c_mark = self._viewer.canvas.get_object_by_tag(marker_name)
except Exception:
# No markers in this table. Issue a warning and continue
warnings.warn(f"Marker set named '{marker_name}' is empty",
category=UserWarning)
# No markers in this table. Issue a warning and continue.
# Test wants this outside of logger, so...
warnings.warn(f"Marker set named '{marker_name}' is empty", UserWarning)
return None

image = self._viewer.get_image()
Expand All @@ -604,10 +590,9 @@ def get_markers(self, x_colname='x', y_colname='y',
xy_col.append([obj.x, obj.y])
if include_skycoord:
radec_col.append([np.nan, np.nan])
elif not include_skycoord: # marker in WCS but image has none
self.logger.warning(
'Skipping ({},{}); image has no WCS'.format(obj.x, obj.y))
else: # wcs
elif not include_skycoord: # Marker in WCS but image has none
self.logger.warning(f'Skipping ({obj.x},{obj.y}); image has no WCS')
else: # WCS
xy_col.append([np.nan, np.nan])
radec_col.append([obj.x, obj.y])

Expand All @@ -630,10 +615,6 @@ def get_markers(self, x_colname='x', y_colname='y',

sky_col = SkyCoord(radec_col[:, 0], radec_col[:, 1], unit='deg')

# Convert X,Y from 0-indexed to 1-indexed
if self._pixel_offset != 0:
xy_col += self._pixel_offset

# Build table
if include_skycoord:
markers_table = Table(
Expand All @@ -646,6 +627,44 @@ def get_markers(self, x_colname='x', y_colname='y',
markers_table['marker name'] = marker_name
return markers_table

def get_all_markers(self, x_colname='x', y_colname='y', skycoord_colname='coord'):
"""Run :meth:`get_markers_by_name` for all markers."""

# If it wasn't for the fact that SkyCoord columns can't
# be stacked this would all fit nicely into a list
# comprehension. But they can't, so we delete the
# SkyCoord column if it is present, then add it
# back after we have stacked.
coordinates = []
tables = []
for name in self._marktags:
table = self.get_markers_by_name(
name, x_colname=x_colname, y_colname=y_colname,
skycoord_colname=skycoord_colname)
if table is None:
continue # No markers by this name, skip it

if skycoord_colname in table.colnames:
coordinates.extend(c for c in table[skycoord_colname])
del table[skycoord_colname]

tables.append(table)

if len(tables) == 0:
return None

stacked = vstack(tables, join_type='exact')

if coordinates:
n_rows = len(stacked)
n_coo = len(coordinates)
if n_coo != n_rows: # This guards against Table auto-broadcast
raise ValueError(f'Expects {n_rows} coordinates but found {n_coo},'
'some markers may be corrupted')
stacked[skycoord_colname] = SkyCoord(coordinates)

return stacked

def _validate_marker_name(self, marker_name):
"""
Raise an error if the marker_name is not allowed.
Expand Down Expand Up @@ -750,7 +769,7 @@ def add_markers(self, table, x_colname='x', y_colname='y',
self._viewer.canvas.add(self.dc.CompoundObject(*objs),
tag=marker_name)

def remove_markers(self, marker_name=None):
def remove_markers_by_name(self, marker_name):
"""
Remove some but not all of the markers by name used when
adding the markers
Expand Down Expand Up @@ -786,15 +805,12 @@ def remove_markers(self, marker_name=None):
else:
self._marktags.remove(marker_name)

def reset_markers(self):
"""
Delete all markers.
"""

def remove_all_markers(self):
"""Delete all markers using :meth:`remove_markers_by_name`."""
# Grab the entire list of marker names before iterating
# otherwise what we are iterating over changes.
for marker_name in list(self._marktags):
self.remove_markers(marker_name)
for marker_name in self.get_marker_names():
self.remove_markers_by_name(marker_name)

@property
def stretch_options(self):
Expand Down
27 changes: 11 additions & 16 deletions astrowidgets/interface_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,25 @@ def add_markers(self, table, x_colname='x', y_colname='y',
marker_name=None):
raise NotImplementedError

# @abstractmethod
# def remove_all_markers(self):
# raise NotImplementedError

@abstractmethod
def reset_markers(self):
def remove_all_markers(self):
raise NotImplementedError

# @abstractmethod
# def remove_markers_by_name(self, marker_name=None):
# raise NotImplementedError

@abstractmethod
def remove_markers(self, marker_name=None):
def remove_markers_by_name(self,
marker_name):
raise NotImplementedError

# @abstractmethod
# def get_all_markers(self):
# raise NotImplementedError
@abstractmethod
def get_all_markers(self):
raise NotImplementedError

@abstractmethod
def get_markers(self, x_colname='x', y_colname='y',
skycoord_colname='coord',
marker_name=None):
def get_markers_by_name(self,
marker_name,
x_colname='x',
y_colname='y',
skycoord_colname='coord'):
raise NotImplementedError

# Methods that modify the view
Expand Down
8 changes: 4 additions & 4 deletions astrowidgets/tests/test_image_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def test_get_marker_with_names():
assert len(image._marktags) == 3

for marker in image._marktags:
out_table = image.get_markers(marker_name=marker)
out_table = image.get_markers_by_name(marker_name=marker)
# No guarantee markers will come back in the same order, so sort them.
out_table.sort('x')
assert (out_table['x'] == input_markers['x']).all()
assert (out_table['y'] == input_markers['y']).all()

# Get all of markers at once
all_marks = image.get_markers(marker_name='all')
all_marks = image.get_all_markers()

# That should have given us three copies of the input table
expected = vstack([input_markers] * 3, join_type='exact')
Expand All @@ -129,7 +129,7 @@ def test_unknown_marker_name_error():
iw = ImageWidget()
bad_name = 'not a real marker name'
with pytest.raises(ValueError) as e:
iw.get_markers(marker_name=bad_name)
iw.get_markers_by_name(marker_name=bad_name)

assert f"No markers named '{bad_name}'" in str(e.value)

Expand All @@ -155,6 +155,6 @@ def test_empty_marker_name_works_with_all():
# Start marking to create a new marker set that is empty
iw.start_marking(marker_name='empty')

marks = iw.get_markers(marker_name='all')
marks = iw.get_all_markers()
assert len(marks) == len(x)
assert 'empty' not in marks['marker name']
Loading