Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Nov 12, 2024
1 parent 6ff521e commit 7cd2b19
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tsfc/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _assign_dtype(expression, self):

@_assign_dtype.register(gem.Terminal)
def _assign_dtype_terminal(expression, self):
return {self.scalar_type}
return {expression.dtype or self.scalar_type}


@_assign_dtype.register(gem.Variable)
Expand All @@ -59,7 +59,7 @@ def _assign_dtype_variable(expression, self):
@_assign_dtype.register(gem.Identity)
@_assign_dtype.register(gem.Delta)
def _assign_dtype_real(expression, self):
return {self.real_type}
return {expression.dtype or self.real_type}


@_assign_dtype.register(gem.Literal)
Expand All @@ -70,15 +70,15 @@ def _assign_dtype_identity(expression, self):
@_assign_dtype.register(gem.Power)
def _assign_dtype_power(expression, self):
# Conservative
return {self.scalar_type}
return {expression.dtype or self.scalar_type}


@_assign_dtype.register(gem.MathFunction)
def _assign_dtype_mathfunction(expression, self):
if expression.name in {"abs", "real", "imag"}:
return {self.real_type}
return {expression.dtype or self.real_type}
elif expression.name == "sqrt":
return {self.scalar_type}
return {expression.dtype or self.scalar_type}
else:
return set.union(*map(self, expression.children))

Expand All @@ -87,7 +87,7 @@ def _assign_dtype_mathfunction(expression, self):
@_assign_dtype.register(gem.MaxValue)
def _assign_dtype_minmax(expression, self):
# UFL did correctness checking
return {self.real_type}
return {expression.dtype or self.real_type}


@_assign_dtype.register(gem.Conditional)
Expand All @@ -100,7 +100,7 @@ def _assign_dtype_conditional(expression, self):
@_assign_dtype.register(gem.LogicalAnd)
@_assign_dtype.register(gem.LogicalOr)
def _assign_dtype_logical(expression, self):
return {numpy.int8}
return {expression.dtype or numpy.int8}


def assign_dtypes(expressions, scalar_type):
Expand Down

0 comments on commit 7cd2b19

Please sign in to comment.