From e3c49d230689cc7043f23079d2224ab6f5b4bfb6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 12 Sep 2024 18:01:28 -0500 Subject: [PATCH] restrict special scalar handling to complex values --- pytato/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index a7ff40cfb..82d3768f8 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -236,14 +236,13 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: - from pytato.scalar_expr import PYTHON_SCALAR_CLASSES if ((isinstance(array, Array) or isinstance(array, np.generic)) and array.dtype != result_dtype): # Loopy's type casts don't like casting to bool assert result_dtype != np.bool_ expr = TypeCast(result_dtype, expr) - elif isinstance(expr, PYTHON_SCALAR_CLASSES): + elif isinstance(expr, (complex,)) and not isinstance(expr, np.generic): # See https://github.com/inducer/pytato/pull/247 and # https://github.com/inducer/pytato/issues/542 expr = np.dtype(type(expr)).type(expr)