Skip to content

Commit

Permalink
Merge pull request #2345 from astrofrog/fix-1d-wcs
Browse files Browse the repository at this point in the history
Fix world coordinates for 1D WCS
  • Loading branch information
astrofrog authored Jan 16, 2023
2 parents ff81f84 + 06c6c50 commit a609e60
Show file tree
Hide file tree
Showing 16 changed files with 119 additions and 57 deletions.
5 changes: 2 additions & 3 deletions glue/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pandas as pd

from glue.core.coordinate_helpers import dependent_axes, pixel2world_single_axis
from glue.utils import (shape_to_string, coerce_numeric,
broadcast_to, categorical_ndarray)
from glue.utils import shape_to_string, coerce_numeric, categorical_ndarray

try:
import dask.array as da
Expand Down Expand Up @@ -330,7 +329,7 @@ def _calculate(self, view=None):
world_coords = world_coords[tuple(final_slice)]

# We then broadcast the final array back to what it should be
world_coords = broadcast_to(world_coords, tuple(final_shape))
world_coords = np.broadcast_to(world_coords, tuple(final_shape))

# We apply the view if we weren't able to optimize before
if optimize_view:
Expand Down
8 changes: 4 additions & 4 deletions glue/core/component_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
world2pixel_single_axis)
from glue.core.subset import InequalitySubsetState
from glue.core.util import join_component_view
from glue.utils import unbroadcast, broadcast_to
from glue.utils import unbroadcast
from glue.logger import logger

__all__ = ['ComponentLink', 'BinaryComponentLink', 'CoordinateComponentLink']
Expand Down Expand Up @@ -198,7 +198,7 @@ def compute(self, data, view=None):
result.shape = args[0].shape

# Finally we broadcast the final result to desired shape
result = broadcast_to(result, original_shape)
result = np.broadcast_to(result, original_shape)

return result

Expand Down Expand Up @@ -386,7 +386,7 @@ def using(self, *args):
args2[f] = a
for i in range(self.ndim):
if args2[i] is None:
args2[i] = broadcast_to(default[self.ndim - 1 - i], args[0].shape)
args2[i] = np.broadcast_to(default[self.ndim - 1 - i], args[0].shape)
args2 = tuple(args2)

if self.pixel2world:
Expand Down Expand Up @@ -487,7 +487,7 @@ def compute(self, data, view=None):
if original_shape is None:
return result
else:
return broadcast_to(result, original_shape)
return np.broadcast_to(result, original_shape)

def __gluestate__(self, context):
left = context.id(self._left)
Expand Down
32 changes: 27 additions & 5 deletions glue/core/coordinate_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from astropy.wcs import WCS

from glue.utils import unbroadcast, broadcast_to
from glue.utils import unbroadcast
from glue.core.coordinates import LegacyCoordinates


Expand Down Expand Up @@ -53,9 +53,20 @@ def pixel2world_single_axis(wcs, *pixel, world_axis=None):
pixel_new.append(p.flat[0])
pixel = np.broadcast_arrays(*pixel_new)

result = wcs.pixel_to_world_values(*pixel)
# In the case of 1D WCS, there is an astropy issue which prevents us from
# passing arbitrary shapes - see https://github.com/astropy/astropy/issues/12154
# Therefore, we ravel the values and reshape afterwards

return broadcast_to(result[world_axis], original_shape)
if len(pixel) == 1 and pixel[0].ndim > 1:
pixel_shape = pixel[0].shape
result = wcs.pixel_to_world_values(pixel[0].ravel())
result = result.reshape(pixel_shape)
else:
result = wcs.pixel_to_world_values(*pixel)
if len(pixel) > 1:
result = result[world_axis]

return np.broadcast_to(result, original_shape)


def world2pixel_single_axis(wcs, *world, pixel_axis=None):
Expand Down Expand Up @@ -99,9 +110,20 @@ def world2pixel_single_axis(wcs, *world, pixel_axis=None):
world_new.append(w.flat[0])
world = np.broadcast_arrays(*world_new)

result = wcs.world_to_pixel_values(*world)
# In the case of 1D WCS, there is an astropy issue which prevents us from
# passing arbitrary shapes - see https://github.com/astropy/astropy/issues/12154
# Therefore, we ravel the values and reshape afterwards

if len(world) == 1 and world[0].ndim > 1:
world_shape = world[0].shape
result = wcs.world_to_pixel_values(world[0].ravel())
result = result.reshape(world_shape)
else:
result = wcs.world_to_pixel_values(*world)
if len(world) > 1:
result = result[pixel_axis]

return broadcast_to(result[pixel_axis], original_shape)
return np.broadcast_to(result, original_shape)


def world_axis(wcs, data, *, pixel_axis=None, world_axis=None):
Expand Down
20 changes: 16 additions & 4 deletions glue/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,16 @@ def __init__(self):
super().__init__(pixel_n_dim=10, world_n_dim=10)

def pixel_to_world_values(self, *pixel):
return pixel
if len(pixel) == 1:
return pixel[0]
else:
return pixel

def world_to_pixel_values(self, *world):
return world
if len(world) == 1:
return world[0]
else:
return world


class IdentityCoordinates(Coordinates):
Expand All @@ -102,10 +108,16 @@ def __init__(self, n_dim=None):
super().__init__(pixel_n_dim=n_dim, world_n_dim=n_dim)

def pixel_to_world_values(self, *pixel):
return pixel
if self.pixel_n_dim == 1:
return pixel[0]
else:
return pixel

def world_to_pixel_values(self, *world):
return world
if self.world_n_dim == 1:
return world[0]
else:
return world

@property
def axis_correlation_matrix(self):
Expand Down
8 changes: 4 additions & 4 deletions glue/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from glue.core.joins import get_mask_with_key_joins
from glue.config import settings, data_translator, subset_state_translator
from glue.utils import (compute_statistic, unbroadcast, iterate_chunks,
datetime64_to_mpl, broadcast_to, categorical_ndarray,
datetime64_to_mpl, categorical_ndarray,
format_choices, random_views_for_dask_array)
from glue.core.coordinate_helpers import axis_label

Expand Down Expand Up @@ -445,9 +445,9 @@ def get_data(self, cid, view=None):
shape = tuple(-1 if i == cid.axis else 1 for i in range(self.ndim))
pix = np.arange(self.shape[cid.axis], dtype=float).reshape(shape)
if view is None:
return broadcast_to(pix, self.shape)
return np.broadcast_to(pix, self.shape)
else:
return broadcast_to(pix, self.shape)[view]
return np.broadcast_to(pix, self.shape)[view]
elif cid in self.world_component_ids:
comp = self._world_components[cid]
elif cid in self._externally_derivable_components:
Expand Down Expand Up @@ -1822,7 +1822,7 @@ def compute_statistic(self, statistic, cid, subset_state=None, axis=None,
if isinstance(axis, int):
axis = [axis]
final_shape = [mask.shape[i] for i in range(mask.ndim) if i not in axis]
return broadcast_to(np.nan, final_shape)
return np.broadcast_to(np.nan, final_shape)
else:
data = self.get_data(cid, view=view)
mask = None
Expand Down
6 changes: 3 additions & 3 deletions glue/core/fixed_resolution_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from glue.core.exceptions import IncompatibleAttribute, IncompatibleDataException
from glue.core.component import DaskComponent
from glue.core.coordinate_helpers import dependent_axes
from glue.utils import unbroadcast, broadcast_to, broadcast_arrays_minimal
from glue.utils import unbroadcast, broadcast_arrays_minimal

# TODO: cache needs to be updated when links are removed/changed

Expand Down Expand Up @@ -73,7 +73,7 @@ def translate_pixel(data, pixel_coords, target_cid):
shape = values_all[0].shape
values_all = broadcast_arrays_minimal(*values_all)
results = link._using(*values_all)
result = broadcast_to(results, shape)
result = np.broadcast_to(results, shape)
else:
result = None
return result, sorted(set(dimensions_all))
Expand Down Expand Up @@ -222,7 +222,7 @@ def compute_fixed_resolution_buffer(data, bounds, target_data=None, target_cid=N
invalid_all |= invalid

# Broadcast back to the original shape and add to the list
translated_coords.append(broadcast_to(translated_coord, original_shape))
translated_coords.append(np.broadcast_to(translated_coord, original_shape))

# Also keep track of all the dimensions that contributed to this coordinate
dimensions_all.extend(dimensions)
Expand Down
10 changes: 5 additions & 5 deletions glue/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from glue.core.decorators import memoize
from glue.core.visual import VisualAttributes
from glue.config import settings
from glue.utils import (view_shape, broadcast_to, floodfill, combine_slices,
polygon_line_intersections, categorical_ndarray, iterate_chunks)
from glue.utils import (categorical_ndarray, combine_slices, floodfill, iterate_chunks,
polygon_line_intersections, view_shape)


__all__ = ['Subset', 'SubsetState', 'RoiSubsetStateNd', 'RoiSubsetState', 'CategoricalROISubsetState',
Expand Down Expand Up @@ -458,7 +458,7 @@ def to_mask(self, data, view=None):
Any object that returns a valid view for a Numpy array.
"""
shp = view_shape(data.shape, view)
return broadcast_to(False, shp)
return np.broadcast_to(False, shp)

@contract(returns='isinstance(SubsetState)')
def copy(self):
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def to_mask(self, data, view=None):

if order is None:
# We use broadcast_to for minimal memory usage
return broadcast_to(False, shape)
return np.broadcast_to(False, shape)
else:
# Reorder slices
slices = [self.slices[idx] for idx in order]
Expand All @@ -1350,7 +1350,7 @@ def to_mask(self, data, view=None):
elif np.isscalar(view[i]):
beg, end, stp = slices[i].indices(data.shape[i])
if view[i] < beg or view[i] >= end or (view[i] - beg) % stp != 0:
return broadcast_to(False, shape)
return np.broadcast_to(False, shape)
elif isinstance(view[i], slice):
if view[i].step is not None and view[i].step < 0:
beg, end, step = view[i].indices(data.shape[i])
Expand Down
17 changes: 17 additions & 0 deletions glue/core/tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from unittest.mock import MagicMock

from astropy.wcs import WCS

from glue import core
from glue.tests.helpers import requires_astropy

Expand Down Expand Up @@ -386,3 +388,18 @@ def test_update_cid_used_in_derived():
np.testing.assert_equal(data['b'], [4, 5, 2])
data.update_id(data.id['a'], ComponentID('x'))
np.testing.assert_equal(data['b'], [4, 5, 2])


def test_coordinate_component_1d_coord():

# Regression test for a bug that caused incorrect world coordinate values
# for 1D coordinates.

wcs = WCS(naxis=1)
wcs.wcs.ctype = ['FREQ']
wcs.wcs.crpix = [1]
wcs.wcs.crval = [1]
wcs.wcs.cdelt = [1]

data = Data(flux=np.random.random(5), coords=wcs, label='data')
np.testing.assert_equal(data['Frequency'], [1, 2, 3, 4, 5])
18 changes: 18 additions & 0 deletions glue/core/tests/test_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,24 @@ def test_pixel2world_single_axis():
assert_allclose(pixel2world_single_axis(coord, x, y, z, world_axis=2), [1.5, 1.5, 1.5])


def test_pixel2world_single_axis_1d():

# Regression test for issues that occurred for 1D WCSes

coord = WCSCoordinates(naxis=1)
coord.wcs.ctype = ['FREQ']
coord.wcs.crpix = [1]
coord.wcs.crval = [1]
coord.wcs.cdelt = [1]

x = np.array([0.2, 0.4, 0.6])
expected = np.array([1.2, 1.4, 1.6])

assert_allclose(pixel2world_single_axis(coord, x, world_axis=0), expected)
assert_allclose(pixel2world_single_axis(coord, x.reshape((1, 3)), world_axis=0), expected.reshape((1, 3)))
assert_allclose(pixel2world_single_axis(coord, x.reshape((3, 1)), world_axis=0), expected.reshape((3, 1)))


def test_affine():

matrix = np.array([[2, 3, -1], [1, 2, 2], [0, 0, 1]])
Expand Down
15 changes: 10 additions & 5 deletions glue/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from astropy.utils import NumpyRNGContext

from glue import core
from glue.utils import broadcast_to

from ..component import Component, DerivedComponent, CategoricalComponent, DateTimeComponent
from ..component_id import ComponentID
Expand Down Expand Up @@ -819,10 +818,16 @@ def world_axis_names(self):
return ['Custom {0}'.format(axis) for axis in range(3)]

def world_to_pixel_values(self, *world):
return tuple([0.4 * w for w in world])
if self.pixel_n_dim == 1:
return 0.4 * world[0]
else:
return tuple([0.4 * w for w in world])

def pixel_to_world_values(self, *pixel):
return tuple([2.5 * p for p in pixel])
if self.world_n_dim == 1:
return 2.5 * pixel[0]
else:
return tuple([2.5 * p for p in pixel])

data1.coords = CustomCoordinates()

Expand Down Expand Up @@ -930,10 +935,10 @@ def test_compute_statistic_empty_subset():
assert_equal(result, np.nan)

result = data.compute_statistic('maximum', data.id['x'], subset_state=subset_state, axis=1)
assert_equal(result, broadcast_to(np.nan, (30, 40)))
assert_equal(result, np.broadcast_to(np.nan, (30, 40)))

result = data.compute_statistic('median', data.id['x'], subset_state=subset_state, axis=(1, 2))
assert_equal(result, broadcast_to(np.nan, (30)))
assert_equal(result, np.broadcast_to(np.nan, (30)))

result = data.compute_statistic('sum', data.id['x'], subset_state=subset_state, axis=(0, 1, 2))
assert_equal(result, np.nan)
Expand Down
1 change: 1 addition & 0 deletions glue/core/tests/test_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_1d_world_link():
dc.add_link(LinkSame(d2.world_component_ids[0], d1.id['x']))

assert d2.world_component_ids[0] in d1.externally_derivable_components

np.testing.assert_array_equal(d1[d2.world_component_ids[0]], x)
np.testing.assert_array_equal(d1[d2.pixel_component_ids[0]], x)

Expand Down
14 changes: 1 addition & 13 deletions glue/utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy import nanmin, nanmax, nanmean, nanmedian, nansum # noqa

__all__ = ['unique', 'shape_to_string', 'view_shape', 'stack_view',
'coerce_numeric', 'check_sorted', 'broadcast_to', 'unbroadcast',
'coerce_numeric', 'check_sorted', 'unbroadcast',
'iterate_chunks', 'combine_slices', 'format_minimal', 'compute_statistic',
'categorical_ndarray', 'index_lookup', 'ensure_numerical',
'broadcast_arrays_minimal', 'random_views_for_dask_array']
Expand Down Expand Up @@ -201,18 +201,6 @@ def pretty_number(numbers):
return result


def broadcast_to(array, shape):
"""
Compatibility function - can be removed once we support only Numpy 1.10
and above
"""
try:
return np.broadcast_to(array, shape)
except AttributeError:
array = np.asarray(array)
return np.broadcast_arrays(array, np.ones(shape, array.dtype))[0]


def find_chunk_shape(shape, n_max=None):
"""
Given the shape of an n-dimensional array, and the maximum number of
Expand Down
4 changes: 2 additions & 2 deletions glue/utils/geometry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from glue.utils import unbroadcast, broadcast_to
from glue.utils import unbroadcast

__all__ = ['points_inside_poly', 'polygon_line_intersections', 'floodfill', 'rotation_matrix_2d']

Expand Down Expand Up @@ -82,7 +82,7 @@ def points_inside_poly(x, y, vx, vy):
inside[keep][~good] = False

inside = inside.reshape(reduced_shape)
inside = broadcast_to(inside, original_shape)
inside = np.broadcast_to(inside, original_shape)

return inside

Expand Down
Loading

0 comments on commit a609e60

Please sign in to comment.