Skip to content

Commit

Permalink
restrict special scalar handling to complex values
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 12, 2024
1 parent 6ceb241 commit e3c49d2
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,13 @@ def cast_to_result_type(
array: ArrayOrScalar,
expr: ScalarExpression
) -> ScalarExpression:
from pytato.scalar_expr import PYTHON_SCALAR_CLASSES
if ((isinstance(array, Array) or isinstance(array, np.generic))
and array.dtype != result_dtype):
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_

expr = TypeCast(result_dtype, expr)
elif isinstance(expr, PYTHON_SCALAR_CLASSES):
elif isinstance(expr, (complex,)) and not isinstance(expr, np.generic):
# See https://github.com/inducer/pytato/pull/247 and
# https://github.com/inducer/pytato/issues/542
expr = np.dtype(type(expr)).type(expr)
Expand Down

0 comments on commit e3c49d2

Please sign in to comment.