Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Prevent builtins on transient functions #2506

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np

import devito as dv
from devito.builtins.utils import make_retval
from devito.builtins.utils import make_retval, check_builtins_args

__all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax']


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def norm(f, order=2):
"""
Compute the norm of a Function.
Expand Down Expand Up @@ -41,6 +42,7 @@ def norm(f, order=2):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def sum(f, dims=None):
"""
Compute the sum of the Function data over specified dimensions.
Expand Down Expand Up @@ -94,6 +96,7 @@ def sum(f, dims=None):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def sumall(f):
"""
Compute the sum of all Function data.
Expand Down Expand Up @@ -123,6 +126,7 @@ def sumall(f):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def inner(f, g):
"""
Inner product of two Functions.
Expand Down Expand Up @@ -177,6 +181,7 @@ def inner(f, g):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def mmin(f):
"""
Retrieve the minimum.
Expand All @@ -200,6 +205,7 @@ def mmin(f):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def mmax(f):
"""
Retrieve the maximum.
Expand Down
6 changes: 5 additions & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import devito as dv
from devito.tools import as_tuple, as_list
from devito.builtins.utils import nbl_to_padsize, pad_outhalo
from devito.builtins.utils import check_builtins_args, nbl_to_padsize, pad_outhalo

__all__ = ['assign', 'smooth', 'gaussian_smooth', 'initialize_function']


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
"""
Assign a list of RHSs to a list of Functions.
Expand Down Expand Up @@ -85,6 +86,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
op(time_M=f._time_size)


@check_builtins_args
def smooth(f, g, axis=None):
"""
Smooth a Function through simple moving average.
Expand Down Expand Up @@ -114,6 +116,7 @@ def smooth(f, g, axis=None):
dv.Operator(dv.Eq(f, g.avg(dims=axis)), name='smoother')()


@check_builtins_args
def gaussian_smooth(f, sigma=1, truncate=4.0, mode='reflect'):
"""
Gaussian smooth function.
Expand Down Expand Up @@ -273,6 +276,7 @@ def buff(i, j):
return lhs, rhs, options


@check_builtins_args
def initialize_function(function, data, nbl, mapper=None, mode='constant',
name=None, pad_halo=True, **kwargs):
"""
Expand Down
28 changes: 27 additions & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import numpy as np

import devito as dv
from devito.arch import Device
from devito.symbolics import uxreplace
from devito.tools import as_tuple

__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']
__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args',
'check_builtins_args']


accumulator_mapper = {
Expand Down Expand Up @@ -131,3 +133,27 @@ def wrapper(*args, **kwargs):
return func(*processed, argmap=argmap, **kwargs)

return wrapper


def check_builtins_args(func):
"""
Perform checks on the arguments supplied to a builtin.
"""

@wraps(func)
def wrapper(*args, **kwargs):
platform = dv.configuration['platform']
if not isinstance(platform, Device):
return func(*args, **kwargs)

for i in args:
try:
mloubout marked this conversation as resolved.
Show resolved Hide resolved
if i.is_transient:
raise ValueError(f"Cannot apply `{func.__name__}` to transient "
f"function `{i.name}` on backend `{platform}`")
except AttributeError:
pass

return func(*args, **kwargs)

return wrapper
10 changes: 9 additions & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Dimension, MatrixSparseTimeFunction, SparseTimeFunction,
SubDimension, SubDomain, SubDomainSet, TimeFunction,
Operator, configuration, switchconfig, TensorTimeFunction,
Buffer)
Buffer, assign)
from devito.arch import get_gpu_info, get_cpu_info, Device, Cpu64
from devito.exceptions import InvalidArgument
from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols,
Expand Down Expand Up @@ -1491,6 +1491,14 @@ def test_pickling(self):

assert str(op) == str(new_op)

def test_is_transient_w_builtins(self):
grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid, is_transient=True)

with pytest.raises(ValueError):
assign(f, 4)


class TestEdgeCases:

Expand Down
Loading