Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix code generation for pow and fabs when using float32 #2504

Merged
merged 9 commits into from
Dec 23, 2024
39 changes: 29 additions & 10 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version
from numbers import Real

from sympy.core import S
from sympy.core.numbers import equal_valued, Float
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand Down Expand Up @@ -122,15 +125,22 @@ def _print_math_func(self, expr, nest=False, known=None):
return f'{cname}({args})'

def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
pass
return super()._print_Pow(expr)
# Completely reimplement `_print_Pow` from sympy, since it doesn't
# correctly handle precision
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
suffix = 'f' if self.single_prec(expr) else ''
if equal_valued(expr.exp, -1):
return f'{self._print_Float(Float(1.0))}/' + \
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
f'{self.parenthesize(expr.base, PREC)}'
elif equal_valued(expr.exp, 0.5):
return f'{self._ns}sqrt{suffix}({self._print(expr.base)}'
elif expr.exp == S.One/3 and self.standard != 'C89':
return f'{self._ns}cbrt{suffix}({self._print(expr.base)})'
else:
return f'{self._ns}pow{suffix}({self._print(expr.base)}, ' + \
f'{self._print(expr.exp)})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
Expand Down Expand Up @@ -159,7 +169,16 @@ def _print_Abs(self, expr):
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs"
if has_integer_args(*expr.args[0].args):
func = "abs"
elif self.single_prec(expr):
func = "fabsf"
elif any([isinstance(a, Real) for a in expr.args[0].args]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the look of it this seems un-necessary. You just need

If self.single_prec(expr) or self.single_prec():
    fabsf
else:
    # not integer or f32, default 
    fabs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so, this is the example that I shared on slack, but will repeat here for reference:

>>> i1 = Dimension(name="i1")
>>> sympy_dtype(i1 - 0.5)
<class 'numpy.int32'>

In the second part of the or in your example self.single_prec() will always return True, regardless of the types of the expression (unless the default is changed).

This code is a bit horrible, but ensures that the "default" is only used if there is a floating point number whose type cannot be determined. Otherwise strictly use the correct function call for float or double.

Copy link
Contributor Author

@JDBetteridge JDBetteridge Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but my test is wrong anyway... Okay the test is fine, the issue is trying to detect numpy float32 or float64 as sympy eagerly squashes them to a sympy.core.numbers.Float.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to merge this, but perhaps in a subsequent PR you could attach a comment to that if

# The previous condition isn't sufficient to detect case with
# Python `float`s in that case, fall back to the "default"
func = "fabsf" if self.single_prec() else "fabs"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see two self.single_prec tests. Isn;t this redundant here?
Can it be one check?

else:
func = "fabs"
return "%s(%s)" % (func, self._print(expr.args[0]))
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved

def _print_Add(self, expr, order=None):
Expand Down
49 changes: 48 additions & 1 deletion tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def test_extended_sympy_arithmetic():
def test_integer_abs():
i1 = Dimension(name="i1")
assert ccode(Abs(i1 - 1)) == "abs(i1 - 1)"
assert ccode(Abs(i1 - .5)) == "fabs(i1 - 5.0e-1F)"
assert ccode(Abs(i1 - .5)) == "fabsf(i1 - 5.0e-1F)"
assert ccode(
Abs(i1 - Constant('half', dtype=np.float64, default_value=0.5))
) == "fabs(i1 - half)"
Comment on lines +302 to +304
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not happy with this, but don't 100% know how I'm supposed to generate a literal double precision 0.5 symbolically

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just np.float64(0.5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> ccode(Abs(i1 - np.float64(0.5)))
'fabs(i1 - 5.0e-1F)'

It looks like the float64 is being demoted to float32, this could point to an issue elsewhere



def test_cos_vs_cosf():
Expand Down Expand Up @@ -587,6 +590,50 @@ def test_minmax_precision(dtype, expected):
assert np.all(f.data == 6.0)


@pytest.mark.parametrize('dtype,expected', [
(np.float32, "powf"),
(np.float64, "pow"),
])
def test_pow_precision(dtype, expected):
grid = Grid(shape=(5, 5), dtype=dtype)

f = Function(name="f", grid=grid)
g = Function(name="g", grid=grid)

eqn = Eq(f, g**1.5)

op = Operator(eqn)

g.data[:] = 4.0

op.apply()

assert expected in str(op)
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
assert np.all(f.data == 8.0)


@pytest.mark.parametrize('dtype,expected', [
(np.float32, "absf"),
(np.float64, "abs"),
])
def test_abs_precision(dtype, expected):
grid = Grid(shape=(5, 5), dtype=dtype)

f = Function(name="f", grid=grid)
g = Function(name="g", grid=grid)

eqn = Eq(f, abs(g))

op = Operator(eqn)

g.data[:] = -1.0

op.apply()

assert expected in str(op)
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
assert np.all(f.data == 1.0)


class TestRelationsWithAssumptions:

def test_multibounds_op(self):
Expand Down
Loading