Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Fix code generation for pow and fabs when using float32 #2504

Merged
merged 9 commits into from
Dec 23, 2024
40 changes: 29 additions & 11 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version
from numbers import Real

from sympy.core import S
from sympy.core.numbers import equal_valued, Float
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand Down Expand Up @@ -122,15 +125,21 @@ def _print_math_func(self, expr, nest=False, known=None):
return f'{cname}({args})'

def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
pass
return super()._print_Pow(expr)
# Completely reimplement `_print_Pow` from sympy, since it doesn't
# correctly handle precision
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
suffix = 'f' if self.single_prec(expr) else ''
if equal_valued(expr.exp, -1):
return self._print_Float(Float(1.0)) + '/' + \
self.parenthesize(expr.base, PREC)
elif equal_valued(expr.exp, 0.5):
return f'sqrt{suffix}({self._print(expr.base)})'
elif expr.exp == S.One/3 and self.standard != 'C89':
return f'cbrt{suffix}({self._print(expr.base)})'
else:
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
Expand Down Expand Up @@ -159,8 +168,17 @@ def _print_Abs(self, expr):
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs"
return "%s(%s)" % (func, self._print(expr.args[0]))
if has_integer_args(*expr.args[0].args):
func = "abs"
elif self.single_prec(expr):
func = "fabsf"
elif any([isinstance(a, Real) for a in expr.args[0].args]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the look of it this seems un-necessary. You just need

If self.single_prec(expr) or self.single_prec():
    fabsf
else:
    # not integer or f32, default 
    fabs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so, this is the example that I shared on slack, but will repeat here for reference:

>>> i1 = Dimension(name="i1")
>>> sympy_dtype(i1 - 0.5)
<class 'numpy.int32'>

In the second part of the or in your example self.single_prec() will always return True, regardless of the types of the expression (unless the default is changed).

This code is a bit horrible, but ensures that the "default" is only used if there is a floating point number whose type cannot be determined. Otherwise strictly use the correct function call for float or double.

Copy link
Contributor Author

@JDBetteridge JDBetteridge Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but my test is wrong anyway... Okay the test is fine, the issue is trying to detect numpy float32 or float64 as sympy eagerly squashes them to a sympy.core.numbers.Float.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to merge this, but perhaps in a subsequent PR you could attach a comment to that if

# The previous condition isn't sufficient to detect case with
# Python `float`s in that case, fall back to the "default"
func = "fabsf" if self.single_prec() else "fabs"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see two self.single_prec tests. Isn;t this redundant here?
Can it be one check?

else:
func = "fabs"
return f"{func}({self._print(expr.args[0])})"

def _print_Add(self, expr, order=None):
""""
Expand Down
Loading
Loading