Skip to content

Commit

Permalink
Work around numpy2's changed repr for scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Apr 16, 2024
1 parent 52abd21 commit 38deb28
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
22 changes: 17 additions & 5 deletions loopy/target/c/codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,10 @@ def map_constant(self, expr, type_context):

# FIXME: This assumes a 32-bit architecture.
if isinstance(expr, np.float32):
return Literal(repr(expr)+"f")
return Literal(repr(float(expr))+"f")

elif isinstance(expr, np.float64):
return Literal(repr(expr))
return Literal(repr(float(expr)))

# Disabled for now, possibly should be a subtarget.
# elif isinstance(expr, np.float128):
Expand All @@ -464,7 +464,7 @@ def map_constant(self, expr, type_context):
suffix += "u"
if iinfo.max > (2**31-1):
suffix += "l"
return Literal(repr(expr)+suffix)
return Literal(repr(int(expr))+suffix)
elif isinstance(expr, np.bool_):
return Literal("true") if expr else Literal("false")
else:
Expand All @@ -473,7 +473,7 @@ def map_constant(self, expr, type_context):

elif np.isfinite(expr):
if type_context == "f":
return Literal(repr(np.float32(expr))+"f")
return Literal(repr(float((expr)))+"f")
elif type_context == "d":
return Literal(repr(float(expr)))
elif type_context in ["i", "b"]:
Expand Down Expand Up @@ -633,7 +633,19 @@ def join(self, joiner, iterable):
# }}}

def map_constant(self, expr, prec):
return repr(expr)
if isinstance(expr, np.generic):
if isinstance(expr, np.integer):
# FIXME: Add type suffixes?
return repr(int(expr))
elif isinstance(expr, np.float32):
return f"{repr(float(expr))}f"
elif isinstance(expr, np.float64):
return repr(float(expr))
else:
raise NotImplementedError(
f"unimplemented numpy-to-C conversion: {type(expr)}")
else:
return repr(expr)

def map_call(self, expr, enclosing_prec):
from pymbolic.primitives import Variable
Expand Down
4 changes: 3 additions & 1 deletion test/test_fortran.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def test_assign_single_precision_scalar(ctx_factory):
"""

t_unit = lp.parse_fortran(fortran_src)
assert "1.1f" in lp.generate_code_v2(t_unit).device_code()

import re
assert re.search("1.1000000[0-9]*f", lp.generate_code_v2(t_unit).device_code())

a_dev = cl.array.empty(queue, 1, dtype=np.float64, order="F")
t_unit(queue, a=a_dev)
Expand Down

0 comments on commit 38deb28

Please sign in to comment.