Skip to content

Commit

Permalink
fix[next]: gtfn with offset name != local dimension name (#1789)
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt authored Jan 14, 2025
1 parent 8346bcd commit db5325b
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/gt4py/next/otf/binding/nanobind.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str:
return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>"
elif isinstance(type_, ts.FieldType):
ndims = len(type_.dims)
# cannot be ListType: the concept is represented as Field with local Dimension in this interface
assert isinstance(type_.dtype, ts.ScalarType)
dtype = cpp_interface.render_scalar_type(type_.dtype)
shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def _process_connectivity_args(
# connectivity argument expression
nbtbl = (
f"gridtools::fn::sid_neighbor_table::as_neighbor_table<"
f"generated::{connectivity_type.source_dim.value}_t, "
f"generated::{name}_t, {connectivity_type.max_neighbors}"
f"generated::{connectivity_type.domain[0].value}_t, "
f"generated::{connectivity_type.domain[1].value}_t, "
f"{connectivity_type.max_neighbors}"
f">(std::forward<decltype({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()})>({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))"
)
arg_exprs.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def _collect_offset_definitions(
):
assert grid_type == common.GridType.UNSTRUCTURED
offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name))
if offset_name != connectivity_type.neighbor_dim.value:
offset_definitions[connectivity_type.neighbor_dim.value] = TagDefinition(
name=Sym(id=connectivity_type.neighbor_dim.value)
)

for dim in [connectivity_type.source_dim, connectivity_type.codomain]:
if dim.kind != common.DimensionKind.HORIZONTAL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def _add_storage(
# represent zero-dimensional fields as scalar arguments
return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient)
# handle default case: field with one or more dimensions
# ListType not supported: concept is represented as Field with local Dimension
assert isinstance(gt_type.dtype, ts.ScalarType)
dc_dtype = dace_utils.as_dace_type(gt_type.dtype)
if tuple_name is None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __str__(self) -> str:
class ListType(DataType):
"""Represents a neighbor list in the ITIR representation.
Note: not used in the frontend.
Note: not used in the frontend. The concept is represented as Field with local Dimension.
"""

element_type: DataType
Expand Down
2 changes: 1 addition & 1 deletion tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
common,
constructors,
field_utils,
utils as gt_utils,
)
from gt4py.next.ffront import decorator
from gt4py.next.type_system import type_specifications as ts, type_translation
Expand Down Expand Up @@ -55,7 +56,6 @@
mesh_descriptor,
)

from gt4py.next import utils as gt_utils

# mypy does not accept [IDim, ...] as a type

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pytest

from gt4py import next as gtx
from gt4py.next import Dims, Field, common

from next_tests import definitions as test_defs
from next_tests.integration_tests import cases
from next_tests.integration_tests.feature_tests.ffront_tests import ffront_test_utils


V = gtx.Dimension("V")
E = gtx.Dimension("E")
Neigh = gtx.Dimension("Neigh", kind=common.DimensionKind.LOCAL)
Off = gtx.FieldOffset("Off", source=E, target=(V, Neigh))


@pytest.fixture
def case():
mesh = ffront_test_utils.simple_mesh()
exec_alloc_descriptor = test_defs.ProgramBackendId.GTFN_CPU.load()
v2e_arr = mesh.offset_provider["V2E"].ndarray
return cases.Case(
exec_alloc_descriptor,
offset_provider={
"Off": common._connectivity(
v2e_arr,
codomain=E,
domain={V: v2e_arr.shape[0], Neigh: 4},
skip_value=None,
),
},
default_sizes={
V: mesh.num_vertices,
E: mesh.num_edges,
},
grid_type=common.GridType.UNSTRUCTURED,
allocator=exec_alloc_descriptor.allocator,
)


def test_offset_dimension_name_differ(case):
"""
Ensure that gtfn works with offset name that differs from the name of the local dimension.
If the value of the `NeighborConnectivityType.neighbor_dim` did not match the `FieldOffset` value,
gtfn would silently ignore the neighbor index, see https://github.com/GridTools/gridtools/pull/1814.
"""

@gtx.field_operator
def foo(a: Field[Dims[E], float]) -> Field[Dims[V], float]:
return a(Off[1])

cases.verify_with_default_data(
case, foo, lambda a: a[case.offset_provider["Off"].ndarray[:, 1]]
)

0 comments on commit db5325b

Please sign in to comment.