Skip to content

Commit

Permalink
Merge pull request #2504 from devitocodes/JDBetteridge/fabspow-codegen
Browse files Browse the repository at this point in the history
compiler: Fix code generation for pow and fabs when using float32
  • Loading branch information
FabioLuporini authored Dec 23, 2024
2 parents cc04242 + ad97f5e commit 484c832
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 149 deletions.
40 changes: 29 additions & 11 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version
from numbers import Real

from sympy.core import S
from sympy.core.numbers import equal_valued, Float
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand Down Expand Up @@ -122,15 +125,21 @@ def _print_math_func(self, expr, nest=False, known=None):
return f'{cname}({args})'

def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
pass
return super()._print_Pow(expr)
# Completely reimplement `_print_Pow` from sympy, since it doesn't
# correctly handle precision
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
suffix = 'f' if self.single_prec(expr) else ''
if equal_valued(expr.exp, -1):
return self._print_Float(Float(1.0)) + '/' + \
self.parenthesize(expr.base, PREC)
elif equal_valued(expr.exp, 0.5):
return f'sqrt{suffix}({self._print(expr.base)})'
elif expr.exp == S.One/3 and self.standard != 'C89':
return f'cbrt{suffix}({self._print(expr.base)})'
else:
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
Expand Down Expand Up @@ -159,8 +168,17 @@ def _print_Abs(self, expr):
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs"
return "%s(%s)" % (func, self._print(expr.args[0]))
if has_integer_args(*expr.args[0].args):
func = "abs"
elif self.single_prec(expr):
func = "fabsf"
elif any([isinstance(a, Real) for a in expr.args[0].args]):
# The previous condition isn't sufficient to detect case with
# Python `float`s in that case, fall back to the "default"
func = "fabsf" if self.single_prec() else "fabs"
else:
func = "fabs"
return f"{func}({self._print(expr.args[0])})"

def _print_Add(self, expr, order=None):
""""
Expand Down
Loading

0 comments on commit 484c832

Please sign in to comment.