Skip to content

Commit

Permalink
Merge pull request astropy#15929 from neutrinoceros/tst/numpy2/ma_sor…
Browse files Browse the repository at this point in the history
…t_stable

BUG: fix compatibility with numpy 2.0 for ndarray suclasses overriding ndarray.sort and ndarray.argsort
  • Loading branch information
pllim authored Jan 25, 2024
2 parents f85f796 + 50470df commit 74bdf27
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
13 changes: 11 additions & 2 deletions astropy/units/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,8 +1784,17 @@ def choose(self, choices, out=None, mode="raise"):
)

# ensure we do not return indices as quantities
def argsort(self, axis=-1, kind="quicksort", order=None):
return self.view(np.ndarray).argsort(axis=axis, kind=kind, order=order)
if NUMPY_LT_2_0:

def argsort(self, axis=-1, kind=None, order=None):
return self.view(np.ndarray).argsort(axis=axis, kind=kind, order=order)

else:

def argsort(self, axis=-1, kind=None, order=None, *, stable=None):
return self.view(np.ndarray).argsort(
axis=axis, kind=kind, order=order, stable=stable
)

def searchsorted(self, v, *args, **kwargs):
return np.searchsorted(
Expand Down
20 changes: 16 additions & 4 deletions astropy/utils/masked/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ def argmax(self, axis=None, out=None, *, keepdims=False):
at_max = self == self.max(axis=axis, keepdims=True)
return at_max.filled(False).argmax(axis=axis, out=out, keepdims=keepdims)

def argsort(self, axis=-1, kind=None, order=None):
def argsort(self, axis=-1, kind=None, order=None, *, stable=None):
"""Returns the indices that would sort an array.
Perform an indirect sort along the given axis on both the array
Expand All @@ -1044,6 +1044,8 @@ def argsort(self, axis=-1, kind=None, order=None):
second, etc. A single field can be specified as a string, and not
all fields need be specified, but unspecified fields will still be
used, in dtype order, to break ties.
stable: bool, keyword-only, ignored
Sort stability. Present only to allow subclasses to work.
Returns
-------
Expand Down Expand Up @@ -1078,10 +1080,20 @@ def argsort(self, axis=-1, kind=None, order=None):

return np.lexsort(keys, axis=axis)

def sort(self, axis=-1, kind=None, order=None):
"""Sort an array in-place. Refer to `numpy.sort` for full documentation."""
def sort(self, axis=-1, kind=None, order=None, *, stable=False):
"""Sort an array in-place. Refer to `numpy.sort` for full documentation.
Notes
-----
Masked items will be sorted to the end. The implementation
is via `numpy.lexsort` and thus ignores the ``kind`` and ``stable`` arguments;
they are present only so that subclasses can pass them on.
"""
# TODO: probably possible to do this faster than going through argsort!
indices = self.argsort(axis, kind=kind, order=order)
argsort_kwargs = dict(kind=kind, order=order)
if not NUMPY_LT_2_0:
argsort_kwargs["stable"] = stable
indices = self.argsort(axis, **argsort_kwargs)
self[:] = np.take_along_axis(self, indices, axis=axis)

def argpartition(self, kth, axis=-1, kind="introselect", order=None):
Expand Down

0 comments on commit 74bdf27

Please sign in to comment.