Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ksagiyam/attach dtype to nodes #327

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 74 additions & 39 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'Inverse', 'Solve', 'extract_type', 'uint_type']


uint_type = numpy.uintc
uint_type = numpy.dtype(numpy.uintc)


class NodeMeta(type):
Expand All @@ -56,14 +56,17 @@ def __call__(self, *args, **kwargs):
if not hasattr(obj, 'free_indices'):
obj.free_indices = unique(chain(*[c.free_indices
for c in obj.children]))
# Set dtype if not set already.
if not hasattr(obj, 'dtype'):
obj.dtype = obj.inherit_dtype_from_children(obj.children)

return obj


class Node(NodeBase, metaclass=NodeMeta):
"""Abstract GEM node class."""

__slots__ = ('free_indices',)
__slots__ = ('free_indices', 'dtype')

def is_equal(self, other):
"""Common subexpression eliminating equality predicate.
Expand Down Expand Up @@ -153,16 +156,46 @@ def __mod__(self, other):
def __rmod__(self, other):
return as_gem_uint(other).__mod__(self)

@staticmethod
def inherit_dtype_from_children(children):
if any(c.dtype is None for c in children):
# Set dtype = None will let _assign_dtype()
# assign the default dtype for this node later.
return
else:
return numpy.result_type(*(c.dtype for c in children))


class Terminal(Node):
"""Abstract class for terminal GEM nodes."""

__slots__ = ()
__slots__ = ('_dtype',)

children = ()

is_equal = NodeBase.is_equal

@property
def dtype(self):
"""dtype of the node.

We only need to set dtype (or _dtype) on terminal nodes, and
other nodes inherit dtype from their children.

Currently dtype is significant only for nodes under index DAGs
(DAGs underneath `VariableIndex`s representing indices), and
`VariableIndex` checks if the dtype of the node that it wraps is
of uint_type. _assign_dtype() will then assign uint_type to those nodes.

dtype can be `None` otherwise, and _assign_dtype() will assign
the default dtype to those nodes.

"""
if hasattr(self, '_dtype'):
return self._dtype
else:
raise AttributeError(f"Must set _dtype on terminal node, {type(self)}")


class Scalar(Node):
"""Abstract class for scalar-valued GEM nodes."""
Expand All @@ -181,6 +214,7 @@ class Failure(Terminal):
def __init__(self, shape, exception):
self.shape = shape
self.exception = exception
self._dtype = None


class Constant(Terminal):
Expand All @@ -190,35 +224,36 @@ class Constant(Terminal):
- array: numpy array of values
- value: float or complex value (scalars only)
"""
__slots__ = ('dtype',)
__back__ = ('dtype',)
pass


class Zero(Constant):
"""Symbolic zero tensor"""

__slots__ = ('shape',)
__front__ = ('shape',)
__back__ = ('dtype',)

def __init__(self, shape=(), dtype=float):
def __init__(self, shape=(), dtype=None):
self.shape = shape
self.dtype = dtype
self._dtype = dtype

@property
def value(self):
assert not self.shape
return numpy.array(0, dtype=self.dtype).item()
return numpy.array(0, dtype=self.dtype or float).item()


class Identity(Constant):
"""Identity matrix"""

__slots__ = ('dim',)
__front__ = ('dim',)
__back__ = ('dtype',)

def __init__(self, dim, dtype=float):
def __init__(self, dim, dtype=None):
self.dim = dim
self.dtype = dtype
self._dtype = dtype

@property
def shape(self):
Expand All @@ -234,6 +269,7 @@ class Literal(Constant):

__slots__ = ('array',)
__front__ = ('array',)
__back__ = ('dtype',)

def __new__(cls, array, dtype=None):
array = asarray(array)
Expand All @@ -245,14 +281,12 @@ def __init__(self, array, dtype=None):
# Assume float or complex.
try:
self.array = array.astype(float, casting="safe")
self.dtype = float
except TypeError:
self.array = array.astype(complex)
self.dtype = complex
else:
# Can be int, etc.
self.array = array.astype(dtype)
self.dtype = dtype
self._dtype = self.array.dtype

def is_equal(self, other):
if type(self) is not type(other):
Expand All @@ -277,13 +311,14 @@ def shape(self):
class Variable(Terminal):
"""Symbolic variable tensor"""

__slots__ = ('name', 'shape', 'dtype')
__front__ = ('name', 'shape', 'dtype')
__slots__ = ('name', 'shape')
__front__ = ('name', 'shape')
__back__ = ('dtype',)

def __init__(self, name, shape, dtype=None):
self.name = name
self.shape = shape
self.dtype = dtype
self._dtype = dtype


class Sum(Scalar):
Expand All @@ -300,8 +335,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value + b.value, dtype=dtype)
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b]))

self = super(Sum, cls).__new__(cls)
self.children = a, b
Expand All @@ -325,8 +359,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value * b.value, dtype=dtype)
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b]))

self = super(Product, cls).__new__(cls)
self.children = a, b
Expand All @@ -350,8 +383,7 @@ def __new__(cls, a, b):
return a

if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value / b.value, dtype=dtype)
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b]))

self = super(Division, cls).__new__(cls)
self.children = a, b
Expand All @@ -364,18 +396,17 @@ class FloorDiv(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# TODO: Attach dtype property to Node and check that
# numpy.result_dtype(a.dtype, b.dtype) is uint type.
# dtype is currently attached only to {Constant, Variable}.
dtype = Node.inherit_dtype_from_children([a, b])
if dtype != uint_type:
raise ValueError(f"dtype ({dtype}) != unit_type ({uint_type})")
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero(dtype=a.dtype)
return Zero(dtype=dtype)
if isinstance(b, Constant) and b.value == 1:
return a
if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value // b.value, dtype=dtype)
self = super(FloorDiv, cls).__new__(cls)
self.children = a, b
Expand All @@ -388,18 +419,17 @@ class Remainder(Scalar):
def __new__(cls, a, b):
assert not a.shape
assert not b.shape
# TODO: Attach dtype property to Node and check that
# numpy.result_dtype(a.dtype, b.dtype) is uint type.
# dtype is currently attached only to {Constant, Variable}.
dtype = Node.inherit_dtype_from_children([a, b])
if dtype != uint_type:
ksagiyam marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"dtype ({dtype}) != uint_type ({uint_type})")
# Constant folding
if isinstance(b, Zero):
raise ValueError("division by zero")
if isinstance(a, Zero):
return Zero(dtype=a.dtype)
return Zero(dtype=dtype)
if isinstance(b, Constant) and b.value == 1:
return Zero(dtype=b.dtype)
return Zero(dtype=dtype)
if isinstance(a, Constant) and isinstance(b, Constant):
dtype = numpy.result_type(a.dtype, b.dtype)
return Literal(a.value % b.value, dtype=dtype)
self = super(Remainder, cls).__new__(cls)
self.children = a, b
Expand All @@ -412,18 +442,16 @@ class Power(Scalar):
def __new__(cls, base, exponent):
assert not base.shape
assert not exponent.shape
dtype = Node.inherit_dtype_from_children([base, exponent])

# Constant folding
if isinstance(base, Zero):
dtype = numpy.result_type(base.dtype, exponent.dtype)
if isinstance(exponent, Zero):
raise ValueError("cannot solve 0^0")
return Zero(dtype=dtype)
elif isinstance(exponent, Zero):
dtype = numpy.result_type(base.dtype, exponent.dtype)
return Literal(1, dtype=dtype)
elif isinstance(base, Constant) and isinstance(exponent, Constant):
dtype = numpy.result_type(base.dtype, exponent.dtype)
return Literal(base.value ** exponent.value, dtype=dtype)

self = super(Power, cls).__new__(cls)
Expand Down Expand Up @@ -483,6 +511,7 @@ def __init__(self, op, a, b):

self.operator = op
self.children = a, b
self.dtype = None # Do not inherit dtype from children.


class LogicalNot(Scalar):
Expand Down Expand Up @@ -529,6 +558,7 @@ def __new__(cls, condition, then, else_):
self = super(Conditional, cls).__new__(cls)
self.children = condition, then, else_
self.shape = then.shape
self.dtype = Node.inherit_dtype_from_children([then, else_])
return self


Expand Down Expand Up @@ -591,6 +621,8 @@ class VariableIndex(IndexBase):
def __init__(self, expression):
assert isinstance(expression, Node)
assert not expression.shape
if expression.dtype != uint_type:
raise ValueError(f"expression.dtype ({expression.dtype}) != uint_type ({uint_type})")
self.expression = expression

def __eq__(self, other):
Expand Down Expand Up @@ -846,6 +878,7 @@ class ListTensor(Node):
def __new__(cls, array):
array = asarray(array)
assert numpy.prod(array.shape)
dtype = Node.inherit_dtype_from_children(tuple(array.flat))

# Handle children with shape
child_shape = array.flat[0].shape
Expand All @@ -861,7 +894,7 @@ def __new__(cls, array):

# Constant folding
if all(isinstance(elem, Constant) for elem in array.flat):
return Literal(numpy.vectorize(attrgetter('value'))(array))
return Literal(numpy.vectorize(attrgetter('value'))(array), dtype=dtype)

self = super(ListTensor, cls).__new__(cls)
self.array = array
Expand Down Expand Up @@ -907,9 +940,9 @@ class Concatenate(Node):
__slots__ = ('children',)

def __new__(cls, *children):
dtype = Node.inherit_dtype_from_children(children)
if all(isinstance(child, Zero) for child in children):
size = int(sum(numpy.prod(child.shape, dtype=int) for child in children))
dtype = numpy.result_type(*(child.dtype for child in children))
return Zero((size,), dtype=dtype)

self = super(Concatenate, cls).__new__(cls)
Expand All @@ -924,8 +957,9 @@ def shape(self):
class Delta(Scalar, Terminal):
__slots__ = ('i', 'j')
__front__ = ('i', 'j')
__back__ = ('dtype',)

def __new__(cls, i, j):
def __new__(cls, i, j, dtype=None):
assert isinstance(i, IndexBase)
assert isinstance(j, IndexBase)

Expand All @@ -948,6 +982,7 @@ def __new__(cls, i, j):
elif isinstance(index, VariableIndex):
raise NotImplementedError("Can not make Delta with VariableIndex")
self.free_indices = tuple(unique(free_indices))
self._dtype = dtype
return self


Expand Down
2 changes: 1 addition & 1 deletion tests/test_pickle_gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@pytest.mark.parametrize('protocol', range(3))
def test_pickle_gem(protocol):
f = gem.VariableIndex(gem.Indexed(gem.Variable('facet', (2,)), (1,)))
f = gem.VariableIndex(gem.Indexed(gem.Variable('facet', (2,), dtype=gem.uint_type), (1,)))
q = gem.Index()
r = gem.Index()
_1 = gem.Indexed(gem.Literal(numpy.random.rand(3, 6, 8)), (f, q, r))
Expand Down
8 changes: 4 additions & 4 deletions tsfc/kernel_interface/firedrake_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def __init__(self, scalar_type, interior_facet=False):

# Cell orientation
if self.interior_facet:
cell_orientations = gem.Variable("cell_orientations", (2,))
cell_orientations = gem.Variable("cell_orientations", (2,), dtype=gem.uint_type)
self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),
gem.Indexed(cell_orientations, (1,)))
else:
cell_orientations = gem.Variable("cell_orientations", (1,))
cell_orientations = gem.Variable("cell_orientations", (1,), dtype=gem.uint_type)
self._cell_orientations = (gem.Indexed(cell_orientations, (0,)),)

def _coefficient(self, coefficient, name):
Expand Down Expand Up @@ -257,12 +257,12 @@ def __init__(self, integral_data_info, scalar_type,

# Facet number
if integral_type in ['exterior_facet', 'exterior_facet_vert']:
facet = gem.Variable('facet', (1,))
facet = gem.Variable('facet', (1,), dtype=gem.uint_type)
self._entity_number = {None: gem.VariableIndex(gem.Indexed(facet, (0,)))}
facet_orientation = gem.Variable('facet_orientation', (1,), dtype=gem.uint_type)
self._entity_orientation = {None: gem.OrientationVariableIndex(gem.Indexed(facet_orientation, (0,)))}
elif integral_type in ['interior_facet', 'interior_facet_vert']:
facet = gem.Variable('facet', (2,))
facet = gem.Variable('facet', (2,), dtype=gem.uint_type)
self._entity_number = {
'+': gem.VariableIndex(gem.Indexed(facet, (0,))),
'-': gem.VariableIndex(gem.Indexed(facet, (1,)))
Expand Down
Loading
Loading