Skip to content

Commit

Permalink
Type-annotate main body of precompute
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 18, 2023
1 parent 7551c34 commit 28b0eb5
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 33 deletions.
33 changes: 28 additions & 5 deletions loopy/transform/array_buffer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"""


from abc import ABC, abstractmethod
from typing import Tuple
import islpy as isl
from islpy import dim_type
from loopy.symbolic import (get_dependencies, SubstitutionMapper)
Expand All @@ -29,6 +31,8 @@
from pytools import ImmutableRecord, memoize_method
from pymbolic import var

from loopy.typing import ExpressionT


class AccessDescriptor(ImmutableRecord):
"""
Expand Down Expand Up @@ -192,7 +196,23 @@ def compute_bounds(kernel, domain, stor2sweep,

# {{{ array-to-buffer map

class ArrayToBufferMap:
class ArrayToBufferMapBase(ABC):
non1_storage_axis_names: Tuple[str, ...]
storage_base_indices: Tuple[str, ...]
non1_storage_shape: Tuple[ExpressionT, ...]
non1_storage_axis_flags: Tuple[ExpressionT, ...]

@abstractmethod
def is_access_descriptor_in_footprint(self, accdesc: AccessDescriptor) -> bool:
...

@abstractmethod
def augment_domain_with_sweep(self, domain, new_non1_storage_axis_names,
boxify_sweep=False):
...


class ArrayToBufferMap(ArrayToBufferMapBase):
def __init__(self, kernel, domain, sweep_inames, access_descriptors,
storage_axis_count):
self.kernel = kernel
Expand Down Expand Up @@ -298,7 +318,7 @@ def __init__(self, kernel, domain, sweep_inames, access_descriptors,
self.non1_storage_axis_flags = non1_storage_axis_flags
self.aug_domain = aug_domain
self.storage_base_indices = storage_base_indices
self.non1_storage_shape = non1_storage_shape
self.non1_storage_shape = tuple(non1_storage_shape)

def augment_domain_with_sweep(self, domain, new_non1_storage_axis_names,
boxify_sweep=False):
Expand Down Expand Up @@ -336,7 +356,7 @@ def augment_domain_with_sweep(self, domain, new_non1_storage_axis_names,
else:
return convexify(domain)

def is_access_descriptor_in_footprint(self, accdesc):
def is_access_descriptor_in_footprint(self, accdesc: AccessDescriptor) -> bool:
return self._is_access_descriptor_in_footprint_inner(
tuple(accdesc.storage_axis_exprs))

Expand Down Expand Up @@ -399,17 +419,20 @@ def _is_access_descriptor_in_footprint_inner(self, storage_axis_exprs):
aligned_g_s2s_parm_dom)


class NoOpArrayToBufferMap:
class NoOpArrayToBufferMap(ArrayToBufferMapBase):
non1_storage_axis_names = ()
storage_base_indices = ()
non1_storage_shape = ()

def is_access_descriptor_in_footprint(self, accdesc):
def is_access_descriptor_in_footprint(self, accdesc: AccessDescriptor) -> bool:
# no index dependencies--every reference to the subst rule
# is necessarily in the footprint.

return True

def augment_domain_with_sweep(self, domain, new_non1_storage_axis_names,
boxify_sweep=False):
return domain
# }}}

# vim: foldmethod=marker
71 changes: 43 additions & 28 deletions loopy/transform/precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,34 @@
"""


from typing import FrozenSet, List, Mapping, Optional, Sequence, Type, Union
from immutables import Map
import numpy as np
import islpy as isl
from pytools.tag import Tag
from loopy.kernel import LoopKernel
from loopy.typing import auto
from loopy.match import ToStackMatchCovertible
from loopy.symbolic import (get_dependencies,
RuleAwareIdentityMapper, RuleAwareSubstitutionMapper,
SubstitutionRuleMappingContext, CombineMapper)
from loopy.diagnostic import LoopyError
from pymbolic.mapper.substitutor import make_subst_func
from loopy.translation_unit import TranslationUnit
from loopy.kernel.instruction import MultiAssignmentBase
from loopy.kernel.function_interface import CallableKernel, ScalarCallable
from loopy.kernel.instruction import InstructionBase, MultiAssignmentBase
from loopy.kernel.function_interface import (CallableKernel, InKernelCallable,
ScalarCallable)
from loopy.kernel.tools import (kernel_has_global_barriers,
find_most_recent_global_barrier)
from loopy.kernel.data import AddressSpace
from loopy.types import LoopyType
from loopy.types import LoopyType, ToLoopyTypeConvertible, to_loopy_type

from pymbolic import var
from pytools import memoize_on_first_arg

from loopy.transform.array_buffer_map import (ArrayToBufferMap, NoOpArrayToBufferMap,
AccessDescriptor)
from loopy.transform.array_buffer_map import (ArrayToBufferMap,
ArrayToBufferMapBase,
NoOpArrayToBufferMap,
AccessDescriptor)


# {{{ contains_subst_rule_invocation
Expand Down Expand Up @@ -348,18 +355,25 @@ def map_kernel(self, kernel):
# }}}


def precompute_for_single_kernel(kernel, callables_table, subst_use,
sweep_inames=None, within=None, storage_axes=None, temporary_name=None,
precompute_inames=None, precompute_outer_inames=None,
def precompute_for_single_kernel(
kernel: LoopKernel,
callables_table: Mapping[str, InKernelCallable], subst_use,
sweep_inames=None,
within: ToStackMatchCovertible = None,
*,
storage_axes=None,
temporary_name: Optional[str] = None,
precompute_inames: Optional[Sequence[str]] = None,
precompute_outer_inames: Optional[FrozenSet[str]] = None,
storage_axis_to_tag=None,

default_tag=None,
default_tag: Union[None, Tag, str] = None,

dtype=None,
fetch_bounding_box=False,
temporary_address_space=None,
compute_insn_id=None,
**kwargs):
dtype: Optional[ToLoopyTypeConvertible] = None,
fetch_bounding_box: bool = False,
temporary_address_space: Union[AddressSpace, None, Type[auto]] = None,
compute_insn_id: Optional[str] = None,
**kwargs) -> LoopKernel:
"""Precompute the expression described in the substitution rule determined by
*subst_use* and store it in a temporary array. A precomputation needs two
things to operate, a list of *sweep_inames* (order irrelevant) and an
Expand Down Expand Up @@ -428,11 +442,8 @@ def precompute_for_single_kernel(kernel, callables_table, subst_use,
May also be specified as a comma-separated string.
:arg default_tag: The :ref:`iname tag <iname-tags>` to be applied to the
inames created to perform the precomputation. The current default will
make them local axes and automatically split them to fit the work
group size, but this default will disappear in favor of simply leaving them
untagged in 2019. For 2018, a warning will be issued if no *default_tag* is
specified.
inames created to perform the precomputation. By default, new
inames remain untagged.
:arg dtype: The dtype of the temporary variable to precompute the result
in. Can be either a dtype as understood by :class:`numpy.dtype` or
Expand Down Expand Up @@ -498,7 +509,7 @@ def precompute_for_single_kernel(kernel, callables_table, subst_use,

footprint_generators = None

subst_name = None
subst_name: Optional[str] = None
subst_tag = None

from pymbolic.primitives import Variable, Call
Expand Down Expand Up @@ -536,6 +547,8 @@ def precompute_for_single_kernel(kernel, callables_table, subst_use,
from loopy.match import parse_stack_match
within = parse_stack_match(within)

assert subst_name is not None

try:
subst = kernel.substitutions[subst_name]
except KeyError:
Expand Down Expand Up @@ -742,7 +755,8 @@ def precompute_for_single_kernel(kernel, callables_table, subst_use,

# }}}

abm = ArrayToBufferMap(kernel, domch.domain, sweep_inames,
abm: ArrayToBufferMapBase = ArrayToBufferMap(
kernel, domch.domain, sweep_inames,
access_descriptors, len(storage_axis_names))

non1_storage_axis_names = []
Expand Down Expand Up @@ -917,7 +931,7 @@ def add_assumptions(d):
# within_inames determined below
)
compute_dep_id = compute_insn_id
added_compute_insns = [compute_insn]
added_compute_insns: List[InstructionBase] = [compute_insn]

if temporary_address_space == AddressSpace.GLOBAL:
barrier_insn_id = kernel.make_unique_instruction_id(
Expand Down Expand Up @@ -949,7 +963,7 @@ def add_assumptions(d):

kernel = invr.map_kernel(kernel)
kernel = kernel.copy(
instructions=added_compute_insns + kernel.instructions)
instructions=added_compute_insns + list(kernel.instructions))
kernel = rule_mapping_context.finish_kernel(kernel)

# }}}
Expand Down Expand Up @@ -1024,19 +1038,19 @@ def add_assumptions(d):
# {{{ set up temp variable

import loopy as lp
if dtype is not None:
dtype = np.dtype(dtype)

loopy_type = to_loopy_type(dtype, allow_none=True)

if temporary_address_space is None:
temporary_address_space = lp.auto

new_temp_shape = tuple(abm.non1_storage_shape)

new_temporary_variables = kernel.temporary_variables.copy()
new_temporary_variables = dict(kernel.temporary_variables)
if temporary_name not in new_temporary_variables:
temp_var = lp.TemporaryVariable(
name=temporary_name,
dtype=dtype,
dtype=loopy_type,
base_indices=(0,)*len(new_temp_shape),
shape=tuple(abm.non1_storage_shape),
address_space=temporary_address_space,
Expand All @@ -1060,6 +1074,7 @@ def add_assumptions(d):

temp_var = temp_var.copy(dtype=dtype)

assert isinstance(temp_var.shape, tuple)
if len(temp_var.shape) != len(new_temp_shape):
raise LoopyError("Existing and new temporary '%s' do not "
"have matching number of dimensions ('%d' vs. '%d') "
Expand Down

0 comments on commit 28b0eb5

Please sign in to comment.