-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Fix code generation for pow and fabs when using float32 #2504
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
5e98bec
sympy: Update printer to better handle precision for Pow and Abs
JDBetteridge 7114100
tests: Add test to exercise pow and abs functionality, Fix test_integ…
JDBetteridge 9630b13
misc: Typo
JDBetteridge 3b6ae7e
misc: Add f-string
JDBetteridge 9753118
sympy: Remove namespacing and tidy
JDBetteridge b6a3a4a
docs: Update jupyter notebooks
JDBetteridge e034211
Revert "docs: Update jupyter notebooks"
JDBetteridge 8e3090e
docs: Update jupyter notebooks, second attempt
JDBetteridge ad97f5e
tests: Make maths functions precision tests more stringent
JDBetteridge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,10 @@ | |
|
||
from mpmath.libmp import prec_to_dps, to_str | ||
from packaging.version import Version | ||
from numbers import Real | ||
|
||
from sympy.core import S | ||
from sympy.core.numbers import equal_valued, Float | ||
from sympy.logic.boolalg import BooleanFunction | ||
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence | ||
from sympy.printing.c import C99CodePrinter | ||
|
@@ -122,15 +125,21 @@ def _print_math_func(self, expr, nest=False, known=None): | |
return f'{cname}({args})' | ||
|
||
def _print_Pow(self, expr): | ||
# Need to override because of issue #1627 | ||
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x | ||
try: | ||
if expr.exp == -1 and self.single_prec(): | ||
PREC = precedence(expr) | ||
return '1.0F/%s' % self.parenthesize(expr.base, PREC) | ||
except AttributeError: | ||
pass | ||
return super()._print_Pow(expr) | ||
# Completely reimplement `_print_Pow` from sympy, since it doesn't | ||
# correctly handle precision | ||
if "Pow" in self.known_functions: | ||
return self._print_Function(expr) | ||
PREC = precedence(expr) | ||
suffix = 'f' if self.single_prec(expr) else '' | ||
if equal_valued(expr.exp, -1): | ||
return self._print_Float(Float(1.0)) + '/' + \ | ||
self.parenthesize(expr.base, PREC) | ||
elif equal_valued(expr.exp, 0.5): | ||
return f'sqrt{suffix}({self._print(expr.base)})' | ||
elif expr.exp == S.One/3 and self.standard != 'C89': | ||
return f'cbrt{suffix}({self._print(expr.base)})' | ||
else: | ||
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})' | ||
|
||
def _print_Mod(self, expr): | ||
"""Print a Mod as a C-like %-based operation.""" | ||
|
@@ -159,8 +168,17 @@ def _print_Abs(self, expr): | |
if isinstance(self.compiler, AOMPCompiler): | ||
return "fabs(%s)" % self._print(expr.args[0]) | ||
# Check if argument is an integer | ||
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs" | ||
return "%s(%s)" % (func, self._print(expr.args[0])) | ||
if has_integer_args(*expr.args[0].args): | ||
func = "abs" | ||
elif self.single_prec(expr): | ||
func = "fabsf" | ||
elif any([isinstance(a, Real) for a in expr.args[0].args]): | ||
# The previous condition isn't sufficient to detect case with | ||
# Python `float`s in that case, fall back to the "default" | ||
func = "fabsf" if self.single_prec() else "fabs" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can see two self.single_prec tests. Isn;t this redundant here? |
||
else: | ||
func = "fabs" | ||
return f"{func}({self._print(expr.args[0])})" | ||
|
||
def _print_Add(self, expr, order=None): | ||
"""" | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the look of it this seems un-necessary. You just need
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not so, this is the example that I shared on slack, but will repeat here for reference:
In the second part of the
or
in your exampleself.single_prec()
will always returnTrue
, regardless of the types of the expression (unless the default is changed).This code is a bit horrible, but ensures that the "default" is only used if there is a floating point number whose type cannot be determined. Otherwise strictly use the correct function call for
float
ordouble
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, but my test is wrong anyway...Okay the test is fine, the issue is trying to detect numpy float32 or float64 as sympy eagerly squashes them to asympy.core.numbers.Float
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to merge this, but perhaps in a subsequent PR you could attach a comment to that
if