Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarize changes to support prediction #1

Open
wants to merge 221 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
221 commits
Select commit Hold shift + click to select a range
f5c377e
Add tag to store array creation traceback
matthiasdiener Mar 23, 2022
4c32cb6
don't make it a unique tag
matthiasdiener Mar 23, 2022
56fbf4c
adds a common _get_default_tags
kaushikcfd Mar 23, 2022
8fee6d4
Change back to UniqueTag
matthiasdiener Mar 23, 2022
da3e962
Merge branch 'default_tags' into tag_created_at
matthiasdiener Mar 23, 2022
bdff59c
use _get_default_tags
matthiasdiener Mar 23, 2022
6d18144
store a tupleized StackSummary
matthiasdiener Mar 24, 2022
30dfc49
Merge branch 'main' into tag_created_at
matthiasdiener Mar 24, 2022
36e720b
Merge branch 'main' into tag_created_at
matthiasdiener Mar 24, 2022
2c3faeb
work around mypy
matthiasdiener Mar 24, 2022
739f3d3
more line fixes
matthiasdiener Mar 24, 2022
897e39b
Merge branch 'main' into tag_created_at
matthiasdiener Mar 25, 2022
3ec977f
Merge branch 'main' into tag_created_at
matthiasdiener Mar 25, 2022
fdb2906
Merge branch 'main' into tag_created_at
matthiasdiener Mar 28, 2022
4fd3d64
use a class for the traceback instead of tuples
matthiasdiener Mar 28, 2022
7702550
also test to_stacksummary
matthiasdiener Mar 28, 2022
7a86557
flake8
matthiasdiener Mar 28, 2022
31bcda8
Add remove_tags_of_type
inducer Mar 28, 2022
5c32222
test_array_dot_repr: Remove CreatedAt tags before comparing
inducer Mar 28, 2022
2d0358f
only add CreatedAt in debug mode
matthiasdiener Mar 28, 2022
1c275fd
restructure test_created_at
matthiasdiener Mar 28, 2022
02362bf
make _PytatoStackSummary a dataclass
matthiasdiener Mar 28, 2022
7b1f7b8
add __repr__
matthiasdiener Mar 29, 2022
07b2fa1
fix 2 tests
matthiasdiener Mar 29, 2022
437954c
illustrate test failure with construct_intestine_graph
matthiasdiener Mar 29, 2022
f05592e
shorten traceback printing
matthiasdiener Mar 29, 2022
d040996
use separate field for CreatedAt
matthiasdiener Mar 30, 2022
91d0929
Merge branch 'main' into tag_created_at
matthiasdiener Mar 30, 2022
9fdd602
fix tests
matthiasdiener Mar 30, 2022
e606a48
fix doctest
matthiasdiener Mar 30, 2022
235d9a7
make it a tag again
matthiasdiener Mar 30, 2022
fc9a873
Merge branch 'main' into tag_created_at
matthiasdiener Mar 31, 2022
d80066f
use tooltip instead of table row
matthiasdiener Apr 1, 2022
ef3339f
force openmpi usage
matthiasdiener Apr 1, 2022
1ff1a2b
check for existing CreatedAt and make it a UniqueTag again
matthiasdiener Apr 1, 2022
f509f41
Merge branch 'main' into tag_created_at
matthiasdiener Apr 2, 2022
73d6ab6
Merge branch 'main' into tag_created_at
matthiasdiener Apr 4, 2022
638e6e6
Merge branch 'main' into tag_created_at
matthiasdiener May 13, 2022
c57a4a1
flake8
matthiasdiener May 16, 2022
4ae31b1
add simple equality test
matthiasdiener May 16, 2022
f559a59
lint fixes
matthiasdiener May 16, 2022
0794bdb
add InfoTag class and filter tags based on it
matthiasdiener May 16, 2022
43c83ec
fix doc
matthiasdiener May 16, 2022
c09bbf3
another doc fix
matthiasdiener May 16, 2022
cd67d68
use IgnoredForEqualityTag
matthiasdiener May 17, 2022
71dd791
UNDO BEFORE MERGE: use external project branches
matthiasdiener May 17, 2022
b4f8b82
Revert "UNDO BEFORE MERGE: use external project branches"
matthiasdiener May 18, 2022
bfb22ba
Revert "use IgnoredForEqualityTag"
matthiasdiener May 18, 2022
99ff0cd
rename InfoTag -> IgnoredForEqualityTag
matthiasdiener May 18, 2022
a818694
more stringent tests
matthiasdiener May 18, 2022
8a0a773
undo unnecessary test changes
matthiasdiener May 18, 2022
91fe92f
Revert "Revert "use IgnoredForEqualityTag""
matthiasdiener May 19, 2022
1111b79
simplify condition
matthiasdiener May 19, 2022
4d5e5e2
Merge branch 'main' into tag_created_at
matthiasdiener May 19, 2022
ff2582f
Revert "simplify condition"
matthiasdiener May 19, 2022
26c3590
bump pytools version + a few spelling fixes
matthiasdiener May 21, 2022
423e3fb
remove duplicated self.axes in hash()
matthiasdiener May 21, 2022
a3e6120
Merge branch 'main' into tag_created_at
matthiasdiener Jun 1, 2022
87606ed
use Taggable{__eq__,__hash__}
matthiasdiener Jun 2, 2022
73c7f77
add another test
matthiasdiener Jun 4, 2022
b84b66e
add vis test
matthiasdiener Jun 4, 2022
26824fa
Merge branch 'main' into tag_created_at
matthiasdiener Jun 20, 2022
6c653bf
make _PytatoFrameSummary, _PytatoStackSummary undocumented
matthiasdiener Jun 20, 2022
02dd6f5
use Taggable.__hash__ for tags in Array.__hash__
matthiasdiener Jun 20, 2022
05b5383
Merge branch 'main' into tag_created_at
matthiasdiener Jun 23, 2022
38d475e
Merge branch 'main' into tag_created_at
matthiasdiener Aug 1, 2022
6fe4129
Merge branch 'main' into tag_created_at
matthiasdiener Aug 26, 2022
88887bf
Merge branch 'main' into tag_created_at
matthiasdiener Oct 4, 2022
15e78d6
Merge branch 'main' into tag_created_at
matthiasdiener Oct 20, 2022
3f084ba
Merge branch 'main' into tag_created_at
matthiasdiener Mar 24, 2023
7c93707
change dataclass to attrs
matthiasdiener Mar 28, 2023
dd9916b
flake8
matthiasdiener Mar 28, 2023
61c029c
Taggable.__eq__
matthiasdiener Mar 28, 2023
3cf1559
add Array.tagged()
matthiasdiener Mar 29, 2023
9d528c3
Merge branch 'main' into tag_created_at
matthiasdiener Apr 28, 2023
f9251d2
Merge branch 'main' into tag_created_at
matthiasdiener May 19, 2023
44d1c34
restrict to DEBUG_ENABLED
matthiasdiener May 19, 2023
a150c79
force DEBUG_ENABLED for test
matthiasdiener May 19, 2023
0769067
CHERRY-PICK: Preserve High-Level Info in the Pymbolic expressions
kaushikcfd Nov 17, 2021
8b3a13b
[CHERRY-PICK]: Call BranchMorpher after dw deduplication
kaushikcfd May 26, 2022
945a147
Merge branch 'main' into production-mrgup
MTCam Jun 27, 2023
818aec4
Merge branch 'main' into production-pilot
MTCam Jul 19, 2023
f4123b4
Update to inducer@main
MTCam Jul 25, 2023
afe340f
Merge branch 'main' into production-pilot
MTCam Jul 27, 2023
c0ea704
Merge branch 'main' into production-pilot
MTCam Jul 28, 2023
3d8ba63
Merge branch 'main' into production-pilot
MTCam Jul 31, 2023
1435a00
Merge remote-tracking branch 'origin/production-pilot' into productio…
MTCam Jul 31, 2023
7bbafe4
Merge branch 'main' into production-pilot
MTCam Aug 3, 2023
8665140
Merge branch 'main' into production-pilot
MTCam Aug 4, 2023
8e20d15
Define __attrs_post_init__ only if __debug__, for all Array classes
inducer Aug 4, 2023
9a4e01d
Merge remote-tracking branch 'inducer/attrs-post-init-only-if-debug' …
MTCam Aug 4, 2023
229f0f2
Merge branch 'main' into production-pilot
MTCam Aug 9, 2023
efcae65
First shot at implementing 'F' ordered array reshapes
a-alveyblanc Sep 11, 2023
cc8f07f
Resolve merge conflicts
a-alveyblanc Sep 11, 2023
35c6d1f
Remove restriction on reshape order
a-alveyblanc Sep 11, 2023
e8e11ce
Merge branch 'main' into production-pilot
MTCam Sep 18, 2023
a8ae5e2
Merge branch 'inducer:main' into implement-f-ordered-reshapes
a-alveyblanc Oct 8, 2023
351bb6f
Merge branch 'main' into production-pilot
MTCam Oct 10, 2023
ad8ff90
Merge branch 'main' into tag_created_at
matthiasdiener Oct 14, 2023
86233c6
work around mypy/attrs issue
matthiasdiener Oct 14, 2023
4b2d2cf
Merge branch 'attrs-mypy' into tag_created_at
matthiasdiener Oct 14, 2023
3b6fdad
fix for fields
matthiasdiener Oct 14, 2023
f6b8d98
Merge remote-tracking branch 'addison/implement-f-ordered-reshapes' i…
MTCam Oct 24, 2023
148527f
Merge branch 'main' into production-pilot
MTCam Oct 26, 2023
a64ea5a
Merge branch 'main' into production-pilot
MTCam Nov 1, 2023
8a390a5
Update comments a little
MTCam Nov 1, 2023
87efc4f
Merge branch 'main' into production-pilot
MTCam Nov 2, 2023
8dc06cf
Merge branch 'main' into production-pilot
MTCam Nov 8, 2023
870849a
attempt to fix tag issue
MTCam Nov 9, 2023
060f864
number_distributed_tags: non-set, non-sorted numbering
matthiasdiener Nov 9, 2023
65d0142
make the test a bit more difficult
matthiasdiener Nov 9, 2023
41a6998
Merge remote-tracking branch 'inducer/deterministic-mpi_tag-v2' into …
MTCam Nov 9, 2023
90954ea
Merge branch 'main' into production-pilot
MTCam Nov 14, 2023
9244899
Merge branch 'main' into tag_created_at
matthiasdiener Nov 14, 2023
3ebfcfd
undo mypy ignores
matthiasdiener Nov 14, 2023
eb1c052
rewrite to use a new field in Array, non_equality_tags
matthiasdiener Nov 14, 2023
a5cec50
misc fixes
matthiasdiener Nov 14, 2023
d9898c9
undo some unecessary changes
matthiasdiener Nov 15, 2023
c5c8920
more misc fixes
matthiasdiener Nov 15, 2023
4ec3cbf
copymapper, tests
matthiasdiener Nov 15, 2023
176595d
explicitly enable/disable traceback
matthiasdiener Nov 17, 2023
40557e9
add hash test
matthiasdiener Nov 17, 2023
5240495
undo more unnecessary changes
matthiasdiener Nov 17, 2023
6e047f4
Merge branch 'main' into tag_created_at
matthiasdiener Nov 17, 2023
48b1723
Merge branch 'main' into tag_created_at
matthiasdiener Nov 21, 2023
9338f0b
more lint fixes
matthiasdiener Nov 21, 2023
f5cb92f
run all examples, fix demo_distributed_node_duplication
matthiasdiener Nov 21, 2023
36166c6
enable CreatedAt for distributed nodes
matthiasdiener Nov 21, 2023
d110d0f
Merge branch 'main' into tag_created_at
matthiasdiener Nov 22, 2023
65f7317
Merge branch 'main' into tag_created_at
matthiasdiener Nov 26, 2023
c377937
Merge branch 'main' into tag_created_at
matthiasdiener Nov 28, 2023
ea93dc1
Merge branch 'main' into production-pilot
MTCam Nov 29, 2023
4c3b06a
undo MPI tag ordering
matthiasdiener Nov 29, 2023
db9f5c9
Merge branch 'main' into production-pilot
MTCam Jan 8, 2024
50bea3e
Merge branch 'main' into tag_created_at
matthiasdiener Jan 18, 2024
708114f
Merge branch 'production' into merge-addison-with-production
MTCam Feb 2, 2024
06503b1
get precise traceback of array creation
majosm Feb 2, 2024
ab87fbf
partialmethod doesn't introduce a stack frame
majosm Feb 2, 2024
d17db17
Merge branch 'main' into tag_created_at
matthiasdiener Feb 6, 2024
c30a320
Merge branch 'main' into production-pilot
MTCam Feb 6, 2024
d8df5f8
add support for make_distributed_send_ref_holder
matthiasdiener Feb 6, 2024
dee201f
Merge branch 'main' into production-pilot
MTCam Feb 7, 2024
dd7b288
Merge branch 'main' into tag_created_at
matthiasdiener Feb 7, 2024
49be05a
Merge branch 'main' into production-pilot
MTCam Feb 13, 2024
f0d52aa
Merge remote-tracking branch 'inducer/tag_created_at' into pytato-arr…
MTCam Feb 13, 2024
a3feae3
Merge branch 'main' into tag_created_at
matthiasdiener Feb 13, 2024
e1b9181
add to MPMSMaterializer
matthiasdiener Feb 16, 2024
74965e4
Merge remote-tracking branch 'inducer/tag_created_at' into pytato-arr…
MTCam Feb 16, 2024
be9dcdd
Spew array tracing to stdout.
MTCam Mar 2, 2024
c04d053
Merge branch 'main' into production-pilot
MTCam Mar 2, 2024
ba74f02
Merge branch 'main' into production-pilot
MTCam Mar 6, 2024
eade18d
Merge branch 'production' into array-tracing
MTCam Mar 7, 2024
3856ab7
Merge remote-tracking branch 'majosm/tag_created_at-precise-tb' into …
MTCam Mar 7, 2024
ad0aa4c
Get precise traceback of array creation (#480)
majosm Mar 7, 2024
655db9a
Merge branch 'main' into tag_created_at
matthiasdiener Mar 7, 2024
331dff1
Merge remote-tracking branch 'inducer/tag_created_at' into precise-ar…
MTCam Mar 7, 2024
6e879ad
Merge branch 'main' into production-pilot
MTCam Mar 24, 2024
0f5680b
Merge with inducer/main
MTCam Apr 4, 2024
ab5728e
Disable assert non_equality_tag
MTCam Apr 11, 2024
2b3eed8
Merge branch 'main' into production-pilot
MTCam Apr 15, 2024
69042c3
Merge branch 'main' into production-pilot
MTCam May 11, 2024
38e4332
add PytatoKeyBuilder
matthiasdiener Sep 25, 2023
4dd3250
mypy fixes
matthiasdiener Sep 25, 2023
970e7bb
support TaggableCLArray, Subscript
matthiasdiener Sep 28, 2023
95dec09
CL Array, function
matthiasdiener Sep 28, 2023
2ac10ee
add prim.Variable
matthiasdiener Feb 5, 2024
62a13ae
fixes to ndarray, pymb expressions
matthiasdiener Feb 5, 2024
b8e04bf
flake8
matthiasdiener Feb 5, 2024
ad9aa28
improve test
matthiasdiener Feb 5, 2024
60d8e41
add full invocation test
matthiasdiener Feb 5, 2024
9d45e65
lint fixes
matthiasdiener Feb 5, 2024
08be380
add missing pymbolic expressions
matthiasdiener Feb 5, 2024
058f6f9
flake8
matthiasdiener Feb 6, 2024
70b25c2
Merge branch 'main' into production-pilot
MTCam Jun 3, 2024
a9d33fc
Merge branch 'main' into production-pilot
MTCam Jun 11, 2024
352bab6
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jun 13, 2024
0360e21
remove update_for_function (now handled directly by pytools)
matthiasdiener Jun 13, 2024
8e3277c
Merge remote-tracking branch 'refs/remotes/origin/PytatoKeyBuilder' i…
matthiasdiener Jun 13, 2024
6c7c786
Merge branch 'PytatoKeyBuilder' into production-pilot
matthiasdiener Jun 13, 2024
2e9aae3
Merge branch 'main' into production-pilot
MTCam Jul 2, 2024
52643d3
Merge branch 'main' into production-pilot-update
MTCam Jul 17, 2024
32124fa
Merge branch 'main' into production-pilot-update
MTCam Jul 18, 2024
fefadcf
Merge branch 'main' into production-pilot-update
MTCam Jul 24, 2024
3364a4f
working pass 1
matthiasdiener Jul 25, 2024
ef7ea0b
cleanups
matthiasdiener Jul 25, 2024
817b255
enable determinism test
matthiasdiener Jul 25, 2024
f3f3c7d
eliminate _OrderedSets
matthiasdiener Jul 25, 2024
8bf2daf
misc improvements
matthiasdiener Jul 25, 2024
5d906b5
revert change to SubsetDependencyMapper
matthiasdiener Jul 25, 2024
142c8e6
some mypy fixes
matthiasdiener Jul 25, 2024
4e0e174
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Aug 12, 2024
bd70620
ruff
matthiasdiener Aug 12, 2024
076a76e
replace orderedsets with unique tuples in DirectPredecessorsGetter
matthiasdiener Aug 13, 2024
ea1462c
mypy fixes
matthiasdiener Aug 14, 2024
168ef53
remove unnecesary cast
matthiasdiener Aug 14, 2024
d711989
adjust comment
matthiasdiener Aug 14, 2024
679f5cd
Revert "Implement numpy 2 type promotion"
matthiasdiener Aug 20, 2024
2a79348
Merge branch 'main' into production-pilot-up2date
MTCam Aug 29, 2024
c4459d8
Merge branch 'production' into prod-revert-update
MTCam Aug 30, 2024
c976c23
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Sep 4, 2024
fde8f77
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Sep 12, 2024
e15b647
Merge branch 'main' into production-pilot
MTCam Sep 13, 2024
56ae346
Merge branch 'main' into production-pilot
MTCam Sep 23, 2024
a6a91c6
Fix a merge fail
MTCam Sep 26, 2024
3b7385e
Merge with main
MTCam Sep 26, 2024
7e03b35
Merge branch 'main' into deterministic-fdp-nonset
matthiasdiener Sep 27, 2024
5847800
performance fix
matthiasdiener Sep 27, 2024
7dd83bb
switch to dicts
matthiasdiener Sep 27, 2024
1ea962c
more dict usage
matthiasdiener Sep 27, 2024
f166cd7
Merge branch 'deterministic-fdp-nonset' into production-pilot
matthiasdiener Sep 30, 2024
345cde7
Merge branch 'main' into production-pilot
MTCam Oct 13, 2024
6e58548
Merge branch 'main' into production-pilot
MTCam Oct 21, 2024
b4f7230
Merge branch 'main' into production-pilot
MTCam Nov 1, 2024
afb59ca
Merge branch 'main' into production-pilot
MTCam Nov 11, 2024
99f4d10
Import union
MTCam Nov 11, 2024
cede10c
Use IntegralT --> IntegerT
MTCam Nov 11, 2024
02b1980
Use Scalar --> ScalarT
MTCam Nov 11, 2024
b820049
Disable update_for_pymbolic_expression
MTCam Nov 11, 2024
1337590
remove duplicate Hashable
matthiasdiener Nov 12, 2024
2c001d6
add missing import
matthiasdiener Nov 12, 2024
39ba8eb
Merge branch 'main' into production-pilot
MTCam Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,3 @@ Pytato is written to pose no particular restrictions on the version of numpy
used for execution. To use mypy-based type checking on Pytato itself or
packages using Pytato, numpy 1.20 or newer is required, due to the
typing-based changes to numpy in that release.

Furthermore, pytato now uses type promotion rules based on those in
`numpy <https://numpy.org/devdocs/numpy_2_0_migration_guide.html#changes-to-numpy-data-type-promotion>`__ that should result in the same
data types as the currently installed version of numpy.
94 changes: 69 additions & 25 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method
from loopy.tools import LoopyKeyBuilder

from pytato.array import (
Array,
Expand Down Expand Up @@ -326,37 +327,37 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], []]):

We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})
def _get_preds_from_shape(self, shape: ShapeType) -> dict[Array, None]:
return dict.fromkeys(dim for dim in shape if isinstance(dim, Array))

def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]:
return (frozenset(expr.bindings.values())
def map_index_lambda(self, expr: IndexLambda) -> dict[Array, None]:
return (dict.fromkeys(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
def map_stack(self, expr: Stack) -> dict[Array, None]:
return (dict.fromkeys(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]:
return (frozenset(expr.arrays)
def map_concatenate(self, expr: Concatenate) -> dict[Array, None]:
return (dict.fromkeys(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]:
return (frozenset(expr.args)
def map_einsum(self, expr: Einsum) -> dict[Array, None]:
return (dict.fromkeys(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]:
def map_loopy_call_result(self, expr: NamedArray) -> dict[Array, None]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
return (frozenset(ary
return (dict.fromkeys(ary
for ary in expr._container.bindings.values()
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
def _map_index_base(self, expr: IndexBase) -> dict[Array, None]:
return (dict.fromkeys([expr.array])
| dict.fromkeys(idx for idx in expr.indices
if isinstance(idx, Array))
| self._get_preds_from_shape(expr.shape))

Expand All @@ -365,34 +366,34 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> frozenset[ArrayOrNames]:
return frozenset([expr.array])
) -> dict[ArrayOrNames, None]:
return dict.fromkeys([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]:
def _map_input_base(self, expr: InputArgumentBase) -> dict[Array, None]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]:
def map_distributed_recv(self, expr: DistributedRecv) -> dict[Array, None]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> frozenset[ArrayOrNames]:
return frozenset([expr.passthrough_data])
) -> dict[ArrayOrNames, None]:
return dict.fromkeys([expr.passthrough_data])

def map_call(self, expr: Call) -> frozenset[ArrayOrNames]:
return frozenset(expr.bindings.values())
def map_call(self, expr: Call) -> dict[ArrayOrNames, None]:
return dict.fromkeys(expr.bindings.values())

def map_named_call_result(
self, expr: NamedCallResult) -> frozenset[ArrayOrNames]:
return frozenset([expr._container])
self, expr: NamedCallResult) -> dict[ArrayOrNames, None]:
return dict.fromkeys([expr._container])


# }}}
Expand Down Expand Up @@ -565,4 +566,47 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:

# }}}


# {{{ PytatoKeyBuilder

class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc]
"""A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
for objects within :mod:`pytato`.
"""

def update_for_ndarray(self, key_hash: Any, key: Any) -> None:
self.rec(key_hash, key.data.tobytes())

def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None:
self.rec(key_hash, key.get())

def update_for_Array(self, key_hash: Any, key: Any) -> None:
# CL Array
self.rec(key_hash, key.get())

# update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
# update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
# update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
# update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
# update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501
# update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815
# update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501

# }}}

# vim: fdm=marker
65 changes: 43 additions & 22 deletions pytato/array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from traceback import FrameSummary, StackSummary


__copyright__ = """
Expand Down Expand Up @@ -296,19 +297,25 @@ def normalize_shape_component(
# {{{ array interface

ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType]
IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None]
IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType]
DtypeOrScalar = Union[_dtype_any, ScalarT]
ArrayOrScalar = Union["Array", ScalarT]
PyScalarType = type[bool] | type[int] | type[float] | type[complex]
DtypeOrPyScalarType = _dtype_any | PyScalarType


def _np_result_dtype(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
# https://github.com/numpy/numpy/issues/19302
def _np_result_type(
# actual dtype:
#*arrays_and_dtypes: Union[np.typing.ArrayLike, np.typing.DTypeLike],
# our dtype:
*arrays_and_dtypes: DtypeOrScalar,
) -> np.dtype[Any]:
return np.result_type(*arrays_and_dtypes)


def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]:
dtype = _np_result_dtype(*dtypes)
def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[Any]:
dtype = _np_result_type(arg1, arg2)
# See: test_true_divide in numpy/core/tests/test_ufunc.py
# pylint: disable=no-member
if dtype.kind in "iu":
Expand Down Expand Up @@ -571,11 +578,12 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array:

def _binary_op(
self,
op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression],
op: Callable[[ScalarExpression, ScalarExpression],
ScalarExpression],
other: ArrayOrScalar,
get_result_type: Callable[
[ArrayOrScalar, ArrayOrScalar],
np.dtype[Any]] = _np_result_dtype,
np.dtype[Any]] = _np_result_type,
reverse: bool = False,
cast_to_result_dtype: bool = True,
is_pow: bool = False,
Expand Down Expand Up @@ -632,21 +640,33 @@ def _unary_op(self, op: Any) -> Array:
non_equality_tags=_get_created_at_tag(),
var_to_reduction_descr=immutabledict())

__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_binary_op, operator.mul, reverse=True)
# NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is
# essential as opposed to performing "expr1 * expr2". This is to account
# for pymbolic's implementation of the "*" operator which might not
# instantiate the node corresponding to the operation when one of
# the operands is the neutral element of the operation.
#
# For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the
# python operators on the operands.

__add__ = partialmethod(_binary_op, operator.add)
__radd__ = partialmethod(_binary_op, operator.add, reverse=True)
__mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)))
__rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)),
reverse=True)

__sub__ = partialmethod(_binary_op, operator.sub)
__rsub__ = partialmethod(_binary_op, operator.sub, reverse=True)
__add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)))
__radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)),
reverse=True)

__floordiv__ = partialmethod(_binary_op, operator.floordiv)
__rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True)
__sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)))
__rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)),
reverse=True)

__truediv__ = partialmethod(_binary_op, operator.truediv,
__floordiv__ = partialmethod(_binary_op, prim.FloorDiv)
__rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True)

__truediv__ = partialmethod(_binary_op, prim.Quotient,
get_result_type=_truediv_result_type)
__rtruediv__ = partialmethod(_binary_op, operator.truediv,
__rtruediv__ = partialmethod(_binary_op, prim.Quotient,
get_result_type=_truediv_result_type, reverse=True)

__mod__ = partialmethod(_binary_op, operator.mod)
Expand Down Expand Up @@ -1397,7 +1417,7 @@ class Stack(_SuppliedAxesAndTagsMixin, Array):

@property
def dtype(self) -> np.dtype[Any]:
return _np_result_dtype(*(arr.dtype for arr in self.arrays))
return _np_result_type(*(arr.dtype for arr in self.arrays))

@property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1430,7 +1450,7 @@ class Concatenate(_SuppliedAxesAndTagsMixin, Array):

@property
def dtype(self) -> np.dtype[Any]:
return _np_result_dtype(*(arr.dtype for arr in self.arrays))
return _np_result_type(*(arr.dtype for arr in self.arrays))

@property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1547,6 +1567,7 @@ class Reshape(_SuppliedAxesAndTagsMixin, IndexRemappingBase):

if __debug__:
def __attrs_post_init__(self) -> None:
# assert self.non_equality_tags
super().__attrs_post_init__()

@property
Expand Down Expand Up @@ -2058,9 +2079,9 @@ def reshape(array: Array, newshape: int | Sequence[int],
*and* the output array are linearized according to this order
and 'matched up'.

Groups are found by multiplying axis lengths on the input and output side,
a matching input/output group is found once adding an input or axis to the
group makes the two products match.
Groups are found by multiplying axis lengths on the input and output
side, a matching input/output group is found once adding an input or
axis to the group makes the two products match.

The semantics are identical to :func:`numpy.reshape`.

Expand Down
Loading
Loading