-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Fix code generation for pow and fabs when using float32 #2504
Changes from 5 commits
5e98bec
7114100
9630b13
3b6ae7e
9753118
b6a3a4a
e034211
8e3090e
ad97f5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,10 @@ | |
|
||
from mpmath.libmp import prec_to_dps, to_str | ||
from packaging.version import Version | ||
from numbers import Real | ||
|
||
from sympy.core import S | ||
from sympy.core.numbers import equal_valued, Float | ||
from sympy.logic.boolalg import BooleanFunction | ||
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence | ||
from sympy.printing.c import C99CodePrinter | ||
|
@@ -122,15 +125,21 @@ def _print_math_func(self, expr, nest=False, known=None): | |
return f'{cname}({args})' | ||
|
||
def _print_Pow(self, expr): | ||
# Need to override because of issue #1627 | ||
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x | ||
try: | ||
if expr.exp == -1 and self.single_prec(): | ||
PREC = precedence(expr) | ||
return '1.0F/%s' % self.parenthesize(expr.base, PREC) | ||
except AttributeError: | ||
pass | ||
return super()._print_Pow(expr) | ||
# Completely reimplement `_print_Pow` from sympy, since it doesn't | ||
# correctly handle precision | ||
if "Pow" in self.known_functions: | ||
return self._print_Function(expr) | ||
PREC = precedence(expr) | ||
suffix = 'f' if self.single_prec(expr) else '' | ||
if equal_valued(expr.exp, -1): | ||
return self._print_Float(Float(1.0)) + '/' + \ | ||
self.parenthesize(expr.base, PREC) | ||
elif equal_valued(expr.exp, 0.5): | ||
return f'sqrt{suffix}({self._print(expr.base)})' | ||
elif expr.exp == S.One/3 and self.standard != 'C89': | ||
return f'cbrt{suffix}({self._print(expr.base)})' | ||
else: | ||
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})' | ||
|
||
def _print_Mod(self, expr): | ||
"""Print a Mod as a C-like %-based operation.""" | ||
|
@@ -159,8 +168,17 @@ def _print_Abs(self, expr): | |
if isinstance(self.compiler, AOMPCompiler): | ||
return "fabs(%s)" % self._print(expr.args[0]) | ||
# Check if argument is an integer | ||
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs" | ||
return "%s(%s)" % (func, self._print(expr.args[0])) | ||
if has_integer_args(*expr.args[0].args): | ||
func = "abs" | ||
elif self.single_prec(expr): | ||
func = "fabsf" | ||
elif any([isinstance(a, Real) for a in expr.args[0].args]): | ||
# The previous condition isn't sufficient to detect case with | ||
# Python `float`s in that case, fall back to the "default" | ||
func = "fabsf" if self.single_prec() else "fabs" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can see two self.single_prec tests. Isn;t this redundant here? |
||
else: | ||
func = "fabs" | ||
return f"{func}({self._print(expr.args[0])})" | ||
|
||
def _print_Add(self, expr, order=None): | ||
"""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -298,7 +298,10 @@ def test_extended_sympy_arithmetic(): | |
def test_integer_abs(): | ||
i1 = Dimension(name="i1") | ||
assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)" | ||
assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1F)" | ||
assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)" | ||
assert ccode( | ||
Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5)) | ||
) == "fabs(i1 - half)" | ||
Comment on lines
+302
to
+304
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not happy with this, but don't 100% know how I'm supposed to generate a literal double precision There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just np.float64(0.5) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. >>> ccode(Abs(i1 - np.float64(0.5)))
'fabs(i1 - 5.0e-1F)' It looks like the float64 is being demoted to float32, this could point to an issue elsewhere |
||
|
||
|
||
def test_cos_vs_cosf(): | ||
|
@@ -587,6 +590,50 @@ def test_minmax_precision(dtype, expected): | |
assert np.all(f.data == 6.0) | ||
|
||
|
||
@pytest.mark.parametrize('dtype,expected', [ | ||
(np.float32, "powf"), | ||
(np.float64, "pow"), | ||
]) | ||
def test_pow_precision(dtype, expected): | ||
grid = Grid(shape=(5, 5), dtype=dtype) | ||
|
||
f = Function(name="f", grid=grid) | ||
g = Function(name="g", grid=grid) | ||
|
||
eqn = Eq(f, g**1.5) | ||
|
||
op = Operator(eqn) | ||
|
||
g.data[:] = 4.0 | ||
|
||
op.apply() | ||
|
||
assert expected in str(op) | ||
JDBetteridge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert np.all(f.data == 8.0) | ||
|
||
|
||
@pytest.mark.parametrize('dtype,expected', [ | ||
(np.float32, "absf"), | ||
(np.float64, "abs"), | ||
]) | ||
def test_abs_precision(dtype, expected): | ||
grid = Grid(shape=(5, 5), dtype=dtype) | ||
|
||
f = Function(name="f", grid=grid) | ||
g = Function(name="g", grid=grid) | ||
|
||
eqn = Eq(f, abs(g)) | ||
|
||
op = Operator(eqn) | ||
|
||
g.data[:] = -1.0 | ||
|
||
op.apply() | ||
|
||
assert expected in str(op) | ||
JDBetteridge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert np.all(f.data == 1.0) | ||
|
||
|
||
class TestRelationsWithAssumptions: | ||
|
||
def test_multibounds_op(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the look of it this seems un-necessary. You just need
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not so, this is the example that I shared on slack, but will repeat here for reference:
In the second part of the
or
in your exampleself.single_prec()
will always returnTrue
, regardless of the types of the expression (unless the default is changed).This code is a bit horrible, but ensures that the "default" is only used if there is a floating point number whose type cannot be determined. Otherwise strictly use the correct function call for
float
ordouble
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, but my test is wrong anyway...Okay the test is fine, the issue is trying to detect numpy float32 or float64 as sympy eagerly squashes them to asympy.core.numbers.Float
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to merge this, but perhaps in a subsequent PR you could attach a comment to that
if