Skip to content

Commit

Permalink
feat[next]: Output argument with non-zero domain start (#1780)
Browse files Browse the repository at this point in the history
```python
field = gtx.as_field(gtx.domain({IDim: (1, 10)}), arr)
field_operator(out=field)
```

This PR also adds a test for non-zero domain start input arguments,
which already worked before.

---------

Co-authored-by: Edoardo Paone <[email protected]>
  • Loading branch information
tehrengruber and edopao authored Jan 17, 2025
1 parent 1b88276 commit 489ccbb
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 54 deletions.
26 changes: 16 additions & 10 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,41 +83,47 @@ def _process_args(
# TODO(tehrengruber): Previously this function was called with the actual arguments
# not their type. The check using the shape here is not functional anymore and
# should instead be placed in a proper location.
shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)]
ranges_and_dims = [*_field_constituents_range_and_dims(args[param_idx], param.type)]
# check that all non-scalar like constituents have the same shape and dimension, e.g.
# for `(scalar, (field1, field2))` the two fields need to have the same shape and
# dimension
if shapes_and_dims:
shape, dims = shapes_and_dims[0]
if ranges_and_dims:
range_, dims = ranges_and_dims[0]
if not all(
el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims
el_range == range_ and el_dims == dims
for (el_range, el_dims) in ranges_and_dims
):
raise ValueError(
"Constituents of composite arguments (e.g. the elements of a"
" tuple) need to have the same shape and dimensions."
)
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
size_args.extend(
shape if shape else [ts.ScalarType(kind=ts.ScalarKind.INT32)] * len(dims) # type: ignore[arg-type] # shape is always empty
range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty
)
return tuple(rewritten_args), tuple(size_args), kwargs


def _field_constituents_shape_and_dims(
def _field_constituents_range_and_dims(
arg: Any, # TODO(havogt): improve typing
arg_type: ts.DataType,
) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]:
) -> Iterator[tuple[tuple[tuple[int, int], ...], list[common.Dimension]]]:
match arg_type:
case ts.TupleType():
for el, el_type in zip(arg, arg_type.types):
assert isinstance(el_type, ts.DataType)
yield from _field_constituents_shape_and_dims(el, el_type)
yield from _field_constituents_range_and_dims(el, el_type)
case ts.FieldType():
dims = type_info.extract_dims(arg_type)
if isinstance(arg, ts.TypeSpec): # TODO
yield (tuple(), dims)
elif dims:
assert hasattr(arg, "shape") and len(arg.shape) == len(dims)
yield (arg.shape, dims)
assert (
hasattr(arg, "domain")
and isinstance(arg.domain, common.Domain)
and len(arg.domain.dims) == len(dims)
)
yield (tuple((r.start, r.stop) for r in arg.domain.ranges), dims)
else:
yield from [] # ignore 0-dim fields
case ts.ScalarType():
Expand Down
37 changes: 21 additions & 16 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]
return iter(scanops_per_axis.keys()).__next__()


def _size_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_size_{dim}"
def _range_arg_from_field(field_name: str, dim: int) -> str:
return f"__{field_name}_{dim}_range"


def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]:
Expand Down Expand Up @@ -217,13 +217,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]:
)
if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType`
assert all(field_dims == fields_dims[0] for field_dims in fields_dims)
index_type = ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
)
for dim_idx in range(len(fields_dims[0])):
size_params.append(
itir.Sym(
id=_size_arg_from_field(param.id, dim_idx),
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
id=_range_arg_from_field(param.id, dim_idx),
type=ts.TupleType(types=[index_type, index_type]),
)
)

Expand Down Expand Up @@ -286,7 +287,8 @@ def _visit_slice_bound(
self,
slice_bound: Optional[past.Constant],
default_value: itir.Expr,
dim_size: itir.Expr,
start_idx: itir.Expr,
stop_idx: itir.Expr,
**kwargs: Any,
) -> itir.Expr:
if slice_bound is None:
Expand All @@ -296,11 +298,9 @@ def _visit_slice_bound(
slice_bound.type
)
if slice_bound.value < 0:
lowered_bound = itir.FunCall(
fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)]
)
lowered_bound = im.plus(stop_idx, self.visit(slice_bound, **kwargs))
else:
lowered_bound = self.visit(slice_bound, **kwargs)
lowered_bound = im.plus(start_idx, self.visit(slice_bound, **kwargs))
else:
raise AssertionError("Expected 'None' or 'past.Constant'.")
if slice_bound:
Expand Down Expand Up @@ -348,8 +348,9 @@ def _construct_itir_domain_arg(
domain_args = []
domain_args_kind = []
for dim_i, dim in enumerate(out_dims):
# an expression for the size of a dimension
dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i))
# an expression for the range of a dimension
dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i))
dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range)
# bounds
lower: itir.Expr
upper: itir.Expr
Expand All @@ -359,11 +360,15 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
dim_size,
dim_start,
dim_start,
dim_stop,
)
upper = self._visit_slice_bound(
slices[dim_i].upper if slices else None, dim_size, dim_size
slices[dim_i].upper if slices else None,
dim_stop,
dim_start,
dim_stop,
)

if dim.kind == common.DimensionKind.LOCAL:
Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/otf/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]:
return None


def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]:
"""
Yield the size of each field argument in each dimension.
Expand All @@ -136,7 +136,9 @@ def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]:
if first_field:
yield from iter_size_args((first_field,))
case common.Field():
yield from arg.ndarray.shape
for range_ in arg.domain.ranges:
assert isinstance(range_, common.UnitRange)
yield (range_.start, range_.stop)
case _:
pass

Expand All @@ -156,6 +158,7 @@ def iter_size_compile_args(
)
if field_constituents:
# we only need the first field, because all fields in a tuple must have the same dims and sizes
index_type = ts.ScalarType(kind=ts.ScalarKind.INT32)
yield from [
ts.ScalarType(kind=ts.ScalarKind.INT32) for dim in field_constituents[0].dims
ts.TupleType(types=[index_type, index_type]) for dim in field_constituents[0].dims
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+")
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_((\d+_range_[01])|((size|stride)_\d+))$")


def as_dace_type(type_: ts.ScalarType) -> dace.typeclass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@
from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils


class CompiledDaceProgram(stages.CompiledProgram):
class CompiledDaceProgram(stages.ExtendedCompiledProgram):
sdfg_program: dace.CompiledSDFG

# Sorted list of SDFG arguments as they appear in program ABI and corresponding data type;
# scalar arguments that are not used in the SDFG will not be present.
sdfg_arglist: list[tuple[str, dace.dtypes.Data]]

def __init__(self, program: dace.CompiledSDFG):
def __init__(self, program: dace.CompiledSDFG, implicit_domain: bool):
self.sdfg_program = program
self.implicit_domain = implicit_domain
# `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument
# name to its data type, in the same order as arguments appear in the program ABI.
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
Expand Down Expand Up @@ -88,7 +89,7 @@ def __call__(
dace.config.Config.set("compiler", "cpu", "args", value=compiler_args)
sdfg_program = sdfg.compile(validate=False)

return CompiledDaceProgram(sdfg_program)
return CompiledDaceProgram(sdfg_program, inp.program_source.implicit_domain)


class DaCeCompilationStepFactory(factory.Factory):
Expand All @@ -113,9 +114,11 @@ def decorated_program(
if out is not None:
args = (*args, out)
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))
if inp.implicit_domain:
# generate implicit domain size arguments only if necessary
size_args = arguments.iter_size_args(args)
flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args))
flat_args = (*flat_args, *flat_size_args)

if sdfg_program._lastargs:
kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def builtin_if(*args: Any) -> str:
return f"{true_val} if {cond} else {false_val}"


def builtin_tuple_get(*args: Any) -> str:
index, tuple_name = args
return f"{tuple_name}_{index}"


def make_const_list(arg: str) -> str:
"""
Takes a single scalar argument and broadcasts this value on the local dimension
Expand All @@ -97,6 +102,7 @@ def make_const_list(arg: str) -> str:
"cast_": builtin_cast,
"if_": builtin_if,
"make_const_list": make_const_list,
"tuple_get": builtin_tuple_get,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811
if not cartesian_case.backend or "dace" not in cartesian_case.backend.name:
pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs")

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

backend = cartesian_case.backend

in_field = cases.allocate(cartesian_case, laplap_program, "in_field")()
Expand Down Expand Up @@ -87,6 +90,9 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811
if not unstructured_case.backend or "dace" not in unstructured_case.backend.name:
pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs")

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

allocator, backend = unstructured_case.allocator, unstructured_case.backend

if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811
def test_halo_exchange_helper_attrs(unstructured):
local_int = gtx.int

# TODO(edopao): add support for range symbols in field domain and re-enable this test
pytest.skip("Requires support for field domain range.")

@gtx.field_operator(backend=unstructured.backend)
def testee_op(
a: gtx.Field[[Vertex, KDim], gtx.int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain
ref = out.asnumpy().copy() # ensure we are not writing to out outside the domain
ref[1:9] = a.asnumpy()[1:9] * 2

cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

import gt4py.next as gtx
from gt4py.next import errors
from gt4py.next import errors, constructors, common

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
Expand Down Expand Up @@ -251,3 +251,42 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField):
ValueError, match=(r"Dimensions in out field and field domain are not equivalent")
):
cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={})


@pytest.mark.uses_origin
def test_out_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend)

size = cartesian_case.default_sizes[IDim]

inp = cases.allocate(cartesian_case, copy_program, "in_field").unique()()
out = constructors.empty(
common.domain({IDim: (1, size - 2)}),
allocator=cartesian_case.allocator,
)
ref = inp.ndarray[1:-2]

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)


@pytest.mark.uses_origin
def test_in_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
@gtx.field_operator
def identity(a: cases.IField) -> cases.IField:
return a

@gtx.program
def copy_program(a: cases.IField, out: cases.IField):
identity(a, out=out, domain={IDim: (1, 9)})

inp = constructors.empty(
common.domain({IDim: (1, 9)}),
dtype=np.int32,
allocator=cartesian_case.allocator,
)
inp.ndarray[...] = 42
out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})()
ref = out.asnumpy().copy() # ensure we are not writing to `out` outside the domain
ref[1:9] = inp.asnumpy()

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)
Loading

0 comments on commit 489ccbb

Please sign in to comment.