diff --git a/pytato/utils.py b/pytato/utils.py index a7ff40cfb..82d3768f8 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -236,14 +236,13 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: - from pytato.scalar_expr import PYTHON_SCALAR_CLASSES if ((isinstance(array, Array) or isinstance(array, np.generic)) and array.dtype != result_dtype): # Loopy's type casts don't like casting to bool assert result_dtype != np.bool_ expr = TypeCast(result_dtype, expr) - elif isinstance(expr, PYTHON_SCALAR_CLASSES): + elif isinstance(expr, (complex,)) and not isinstance(expr, np.generic): # See https://github.com/inducer/pytato/pull/247 and # https://github.com/inducer/pytato/issues/542 expr = np.dtype(type(expr)).type(expr)