Skip to content

Commit

Permalink
Merge branch 'main' into dona-repr-keys
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer authored Nov 25, 2024
2 parents 25c5311 + bdb4aca commit 03e4901
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 100 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
"immutabledict",
"loopy>=2020.2",
"pytools>=2024.1.14",
"pymbolic>=2024.1",
"pymbolic>=2024.2",
"typing_extensions>=4",
]

Expand Down
20 changes: 10 additions & 10 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@
from typing_extensions import Self

import pymbolic.primitives as prim
from pymbolic import ArithmeticExpressionT, var
from pymbolic.typing import IntegerT, ScalarT, not_none
from pymbolic import ArithmeticExpression, var
from pymbolic.typing import Integer, Scalar, not_none
from pytools import memoize_method
from pytools.tag import Tag, Taggable

Expand Down Expand Up @@ -246,7 +246,7 @@ class EllipsisType(Enum):

# {{{ shape

ShapeComponent = Union[IntegerT, "Array"]
ShapeComponent = Union[Integer, "Array"]
ShapeType = tuple[ShapeComponent, ...]
ConvertibleToShape = ShapeComponent | Sequence[ShapeComponent]

Expand Down Expand Up @@ -401,7 +401,7 @@ def {cls.__name__}_hash(self):
# {{{ array interface

ConvertibleToIndexExpr = Union[int, slice, "Array", EllipsisType, None]
IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None]
IndexExpr = Union[Integer, "NormalizedSlice", "Array", None]
PyScalarType = type[bool] | type[int] | type[float] | type[complex]
DtypeOrPyScalarType = _dtype_any | PyScalarType

Expand Down Expand Up @@ -444,7 +444,7 @@ class NormalizedSlice:
"""
start: ShapeComponent
stop: ShapeComponent
step: IntegerT
step: Integer


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -834,7 +834,7 @@ def __repr__(self) -> str:
return Reprifier()(self)


ArrayOrScalar: TypeAlias = Array | ScalarT
ArrayOrScalar: TypeAlias = Array | Scalar

# }}}

Expand Down Expand Up @@ -1725,7 +1725,7 @@ def shape(self) -> ShapeType:
for i_basic_idx in i_basic_indices)

adv_idx_shape = get_shape_after_broadcasting([
cast(Array | IntegerT, not_none(self.indices[i_idx]))
cast(Array | Integer, not_none(self.indices[i_idx]))
for i_idx in i_adv_indices])

# type-ignored because mypy cannot figure out basic-indices only refer
Expand Down Expand Up @@ -1773,7 +1773,7 @@ def shape(self) -> ShapeType:
for i_basic_idx in i_basic_indices)

adv_idx_shape = get_shape_after_broadcasting([
cast(Array | IntegerT, not_none(self.indices[i_idx]))
cast(Array | Integer, not_none(self.indices[i_idx]))
for i_idx in i_adv_indices])

# type-ignored because mypy cannot figure out basic-indices only refer slices
Expand Down Expand Up @@ -2319,7 +2319,7 @@ def make_data_wrapper(data: DataInterface,

# {{{ full

def full(shape: ConvertibleToShape, fill_value: ScalarT | prim.NaN,
def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN,
dtype: Any = None, order: str = "C") -> Array:
"""
Returns an array of shape *shape* with all entries equal to *fill_value*.
Expand All @@ -2340,7 +2340,7 @@ def full(shape: ConvertibleToShape, fill_value: ScalarT | prim.NaN,
else:
fill_value = conv_dtype.type(fill_value)

return IndexLambda(expr=cast(ArithmeticExpressionT, fill_value),
return IndexLambda(expr=cast(ArithmeticExpression, fill_value),
shape=shape, dtype=conv_dtype,
bindings=immutabledict(),
tags=_get_default_tags(),
Expand Down
7 changes: 4 additions & 3 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
from immutabledict import immutabledict

import pymbolic.primitives as prim
from pymbolic import ExpressionT, ScalarT, var
from pymbolic import Scalar, var
from pymbolic.typing import Expression

from pytato.array import (
Array,
Expand Down Expand Up @@ -94,7 +95,7 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],

shape = None

sym_args: list[ExpressionT] = []
sym_args: list[Expression] = []
bindings: dict[str, Array] = {}
for index, inp in enumerate(inputs):
if isinstance(inp, Array):
Expand Down Expand Up @@ -232,7 +233,7 @@ def imag(x: ArrayOrScalar) -> ArrayOrScalar:
result_dtype = np.empty(0, dtype=x_dtype).real.dtype
else:
if np.isscalar(x):
return cast(ScalarT, x_dtype.type(0))
return cast(Scalar, x_dtype.type(0))
else:
assert isinstance(x, Array)
import pytato as pt
Expand Down
14 changes: 6 additions & 8 deletions pytato/loopy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from pymbolic import ArithmeticExpressionT, ExpressionT


__copyright__ = """
Copyright (C) 2021 Kaushik Kulkarni
Expand Down Expand Up @@ -41,7 +39,7 @@

import loopy as lp
import pymbolic.primitives as prim
from pymbolic.typing import IntegerT, not_none
from pymbolic.typing import ArithmeticExpression, Expression, Integer, not_none
from pytools import memoize_method

from pytato.array import (
Expand Down Expand Up @@ -114,7 +112,7 @@ def _result_names(self) -> frozenset[str]:
if lp_arg.is_output})

@memoize_method
def _to_pytato(self, expr: ScalarExpression) -> ExpressionT:
def _to_pytato(self, expr: ScalarExpression) -> Expression:
from pytato.scalar_expr import substitute
return substitute(expr, self.bindings)

Expand Down Expand Up @@ -329,8 +327,8 @@ def _get_val_in_bset(bset: isl.BasicSet, idim: int) -> ScalarExpression:

def solve_constraints(variables: Iterable[str],
parameters: Iterable[str],
constraints: Sequence[tuple[ArithmeticExpressionT,
ArithmeticExpressionT]],
constraints: Sequence[tuple[ArithmeticExpression,
ArithmeticExpression]],

) -> Mapping[str, ScalarExpression]:
"""
Expand Down Expand Up @@ -392,7 +390,7 @@ def _pt_var_to_global_namespace(name: str | None) -> str:
return f"_pt_{name}"


def _get_pt_dim_expr(dim: IntegerT | Array) -> ScalarExpression:
def _get_pt_dim_expr(dim: Integer | Array) -> ScalarExpression:
from pytato.scalar_expr import substitute
from pytato.utils import dim_to_index_lambda_components
dim_expr, dim_bnds = dim_to_index_lambda_components(dim)
Expand Down Expand Up @@ -449,7 +447,7 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,

# }}}

constraints: list[tuple[ArithmeticExpressionT, ArithmeticExpressionT]] = []
constraints: list[tuple[ArithmeticExpression, ArithmeticExpression]] = []

# {{{ collect constraints from passed arguments

Expand Down
17 changes: 9 additions & 8 deletions pytato/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import numpy as np

import pymbolic.primitives as prim
from pymbolic.typing import IntegerT, ScalarT
from pymbolic import Scalar
from pymbolic.typing import Integer
from pytools import UniqueNameGenerator

from pytato.array import Array, IndexLambda
Expand All @@ -22,8 +23,8 @@

def _get_constant_padded_idx_lambda(
array: Array,
pad_widths: Sequence[tuple[IntegerT, IntegerT]],
constant_vals: Sequence[tuple[ScalarT, ScalarT]]
pad_widths: Sequence[tuple[Integer, Integer]],
constant_vals: Sequence[tuple[Scalar, Scalar]]
) -> IndexLambda:
"""
Internal routine used by :func:`pad` for constant-mode padding.
Expand Down Expand Up @@ -72,9 +73,9 @@ def _get_constant_padded_idx_lambda(

def _normalize_pad_width(
array: Array,
pad_width: IntegerT | Sequence[IntegerT],
) -> Sequence[tuple[IntegerT, IntegerT]]:
processed_pad_widths: list[tuple[IntegerT, IntegerT]]
pad_width: Integer | Sequence[Integer],
) -> Sequence[tuple[Integer, Integer]]:
processed_pad_widths: list[tuple[Integer, Integer]]

if isinstance(pad_width, INT_CLASSES):
processed_pad_widths = [(pad_width, pad_width)
Expand Down Expand Up @@ -118,7 +119,7 @@ def _normalize_pad_width(


def pad(array: Array,
pad_width: IntegerT | Sequence[IntegerT],
pad_width: Integer | Sequence[Integer],
mode: str = "constant",
**kwargs: Any) -> Array:
r"""
Expand Down Expand Up @@ -173,7 +174,7 @@ def pad(array: Array,

# {{{ normalize constant_values

processed_constant_vals: Sequence[tuple[ScalarT, ScalarT]]
processed_constant_vals: Sequence[tuple[Scalar, Scalar]]

try:
constant_vals = kwargs.pop("constant_values")
Expand Down
4 changes: 2 additions & 2 deletions pytato/raising.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from immutabledict import immutabledict

import pymbolic.primitives as p
from pymbolic.typing import ScalarT
from pymbolic.typing import Scalar

from pytato.array import Array, ArrayOrScalar, IndexLambda, ShapeType
from pytato.diagnostic import UnknownIndexLambdaExpr
Expand Down Expand Up @@ -47,7 +47,7 @@ class HighLevelOp:

@dataclass(frozen=True, eq=True, repr=True)
class FullOp(HighLevelOp):
fill_value: ScalarT
fill_value: Scalar


@unique
Expand Down
8 changes: 4 additions & 4 deletions pytato/reductions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from pymbolic import ArithmeticExpressionT


__copyright__ = """
Copyright (C) 2020 Andreas Kloeckner
Expand Down Expand Up @@ -30,6 +28,7 @@
THE SOFTWARE.
"""


from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import Any
Expand All @@ -38,6 +37,7 @@
from immutabledict import immutabledict

import pymbolic.primitives as prim
from pymbolic import ArithmeticExpression

from pytato.array import Array, ReductionDescriptor, ShapeType, make_index_lambda
from pytato.scalar_expr import INT_CLASSES, Reduce, ScalarExpression
Expand Down Expand Up @@ -190,8 +190,8 @@ def _normalize_reduction_axes(
def _get_reduction_indices_bounds(shape: ShapeType,
axes: tuple[int, ...],
) -> tuple[Sequence[prim.Variable],
Mapping[str, tuple[ArithmeticExpressionT,
ArithmeticExpressionT]]]:
Mapping[str, tuple[ArithmeticExpression,
ArithmeticExpression]]]:
"""
Given *shape* and reduction axes *axes*, produce a list of inames
``indices`` named appropriately for reduction inames.
Expand Down
Loading

0 comments on commit 03e4901

Please sign in to comment.