Skip to content

Commit

Permalink
Numpy 2.0 compatibility
Browse files Browse the repository at this point in the history
Numpy 2.0 changes scalar repr as per:

     https://numpy.org/neps/nep-0051-scalar-representation.html

Make matplotlib_export convert to python scalars where appropriate
to avoid 'np.*' in generated code.
  • Loading branch information
ales-erjavec committed Jul 11, 2024
1 parent 1d621d0 commit 470635c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
24 changes: 22 additions & 2 deletions orangewidget/tests/test_matplotlib_export.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import numpy as np
import pyqtgraph as pg

from orangewidget.tests.base import GuiTest
from orangewidget.utils.matplotlib_export import scatterplot_code
from orangewidget.utils.matplotlib_export import (
scatterplot_code, numpy_repr, compress_if_all_same, numpy_repr_int
)


def add_intro(a):
r = "import matplotlib.pyplot as plt\n" + \
"import numpy as np\n" + \
"from numpy import array\n" + \
"plt.clf()"
return r + a
Expand All @@ -15,8 +19,24 @@ class TestScatterPlot(GuiTest):
def test_scatterplot_simple(self):
plotWidget = pg.PlotWidget(background="w")
scatterplot = pg.ScatterPlotItem()
scatterplot.setData(x=[1, 2, 3], y=[3, 2, 1])
scatterplot.setData(
x=np.array([1., 2, 3]),
y=np.array([3., 2, 1]),
size=np.array([1., 1, 1])
)
plotWidget.addItem(scatterplot)
code = scatterplot_code(scatterplot)
self.assertIn("plt.scatter", code)
exec(add_intro(code), {})

def test_utils(self):
a = np.array([1.5, 2.5])
self.assertIn("1.5, 2.5", numpy_repr(a))
a = np.array([1, 1])
v = compress_if_all_same(a)
self.assertEqual(v, 1)
self.assertEqual(repr(v), "1")
self.assertIs(type(v), int)
a = np.array([1, 2], dtype=int)
v = numpy_repr_int(a)
self.assertIn("1, 2", v)
10 changes: 7 additions & 3 deletions orangewidget/utils/matplotlib_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def numpy_repr(a):
# avoid numpy repr as it changes between versions
# TODO handle numpy repr differences
if isinstance(a, np.ndarray):
return "array(" + repr(list(a)) + ")"
return "array(" + repr(a.tolist()) + ")"
try:
np.set_printoptions(threshold=10**10)
return repr(a)
Expand All @@ -25,12 +25,16 @@ def numpy_repr(a):
def numpy_repr_int(a):
# avoid numpy repr as it changes between versions
# TODO handle numpy repr differences
return "array(" + repr(list(a)) + ", dtype='int')"
return "array(" + repr(a.tolist()) + ", dtype='int')"


def compress_if_all_same(l):
s = set(l)
return s.pop() if len(s) == 1 else l
if len(s) == 1:
v = s.pop()
return v.item() if isinstance(v, np.generic) else v
else:
return l


def is_sequence_not_string(a):
Expand Down

0 comments on commit 470635c

Please sign in to comment.