Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][FRONTEND] added support for tuples #5220

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
51ab367
progress
ptillet Sep 3, 2024
65da71f
.
ptillet Sep 4, 2024
746a2a3
prototype works
ptillet Sep 4, 2024
630ec6c
added test
ptillet Sep 5, 2024
f2439f9
fixup
ptillet Sep 5, 2024
d9af0ba
cleanup
ptillet Sep 6, 2024
f758ef2
.
ptillet Sep 6, 2024
1b58df2
progress
ptillet Sep 8, 2024
1558d6c
bugfix
ptillet Sep 8, 2024
812af43
progress
ptillet Sep 9, 2024
e226cfd
.
ptillet Oct 9, 2024
98c526e
Merge remote-tracking branch 'origin/main' into phil/tuple-support
ptillet Oct 9, 2024
8cee89d
.
ptillet Oct 11, 2024
627bef2
.
ptillet Oct 12, 2024
756d75a
.
ptillet Oct 12, 2024
2a86fb4
.
ptillet Oct 12, 2024
d614226
.
ptillet Oct 13, 2024
5d29bef
fails again?
ptillet Oct 13, 2024
a790867
more hacks
ptillet Oct 13, 2024
fa23bfc
giant mess; more tests pass
ptillet Oct 13, 2024
d88cca0
very hacky but tests pass; TO REFACTOR
ptillet Oct 13, 2024
e299bf2
.
ptillet Oct 15, 2024
fcae528
.
ptillet Oct 16, 2024
d0168c9
progress
ptillet Nov 16, 2024
0ba41ff
more progress
ptillet Nov 16, 2024
e7289dc
more progress
ptillet Nov 17, 2024
b7d8117
.
ptillet Nov 19, 2024
33505ac
more progress
ptillet Nov 21, 2024
3c08877
more fixes
ptillet Nov 21, 2024
dba9b2d
all tests pass
ptillet Nov 21, 2024
a35e89a
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Nov 21, 2024
ae2ebf6
Merge branch 'main' into phil/tuple-support-2
ptillet Nov 22, 2024
18f24ef
fixed TMA descriptors
ptillet Nov 22, 2024
bba29ae
.
ptillet Nov 22, 2024
67fc1b4
more fixes
ptillet Nov 22, 2024
04d463f
.
ptillet Nov 24, 2024
aa74737
.
ptillet Nov 24, 2024
6161e78
.
ptillet Nov 29, 2024
2ab9b39
more fixes
ptillet Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
8 changes: 4 additions & 4 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 4)

triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constants={"N": 0}))

@triton.jit
def kernel2(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 8)

triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constants={"N": 1}))


@triton.jit
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
def kernel(a=1, B: tl.constexpr = ""):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constants={'B': ""}))


def test_where_warning(fresh_triton_cache):
Expand Down Expand Up @@ -369,7 +369,7 @@ def dtype_kernel(dtype: tl.constexpr):
ctx = pytest.raises(CompilationError, match="")

with ctx as e:
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constants={"dtype": dtype}))

if dtype not in supported_dtypes:
try:
Expand Down
98 changes: 98 additions & 0 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
import triton
import triton.language as tl
import torch


@triton.jit
def _tuple_increment(values):
for i in tl.static_range(len(values)):
values[i] = values[i] + 1
return values


@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])


@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="cuda"):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])


# ----


@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
# assign from tuple
X0, X1 = XPtrs
x0, x1 = values
tl.store(X0, x0)
tl.store(X1, x1)
# assign to tuple
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])


def test_assign(device="cuda"):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]

# -------
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add more unit tests


@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[1][0]))
tl.store(Ptr + 8, tuple1[1][1][0])
tl.store(Ptr + 9, tl.load(tuple1[1][1][1]))

# test serialization/deserialization of tuple arguments in
# the frontend.
@triton.jit
def _tuple_serdes(Ptr, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][1]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, tuple1))

def test_serdes(device="cuda"):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10,), dtype=torch.int32, device=device)
# we want to check that JIT specialization propagates to tuples:
_tuple_serdes[(1,)](z, (x0, (1, x1)), 20, 1, (y0,))
print(z)


# function call (tuple argument)
# function call (tuple return value)
# __getitem__ and __setitem__
# assignment (into a tuple, from a tuple)
1 change: 0 additions & 1 deletion python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def walk_fn(op):
signature={
kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs
},
constants={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def kernel_sub(a, b, o, N: tl.constexpr):
src = ASTSource(
fn=kernel_sub,
constants={'N': 32},
signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"},
signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'},
attrs=attrs,
)
triton.compile(src=src, target=target)
Expand Down
1 change: 1 addition & 0 deletions python/test/unit/test_perf_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)
"in_ptr2": "*fp16",
"in_ptr3": "*fp32",
"out_ptr0": "*fp16",
"XBLOCK": "constexpr",
},
constants={"XBLOCK": XBLOCK},
),
Expand Down
61 changes: 48 additions & 13 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@
import hashlib
import subprocess
import sysconfig

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
from types import ModuleType

def find_paths_if(iterable, pred):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: remove duplicate

is_iterable = lambda x: isinstance(x, (list, tuple))
ret = []
def _impl(current, path):
if pred(current):
if len(path) == 1:
ret.append((path[0],))
else:
ret.append(tuple(path))
elif is_iterable(current):
for idx, item in enumerate(current):
_impl(item, path + [idx])
if is_iterable(iterable):
_impl(iterable, [])
else:
ret = [tuple()] if pred(iterable) else []
return ret
# Table that associates strings to AttrsDescriptor (sub)classes.
# In this way we can dynamically select the correct class
# constructor
Expand Down Expand Up @@ -52,7 +68,7 @@ class AttrsDescriptor:
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant

"""
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
__slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values', 'constant_properties')

def __init__(self, params=None, values=None):
"""
Expand All @@ -67,6 +83,7 @@ def __init__(self, params=None, values=None):
# Default initialization
self.arg_properties = {}
self.property_values = {}
self.equal_to_none = {}
self.constant_properties = set()

self._add_common_properties(params, values)
Expand All @@ -86,18 +103,34 @@ def _add_common_properties(self, params, values):
assert (len(params) == len(values))

# Divisibility property
self.arg_properties["tt.divisibility"] = [
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]
divisibility_16 = []
for param, arg in zip(params, values):
if not AttrsDescriptor.is_divisible_by_16(arg) or \
param.do_not_specialize or \
param.do_not_specialize_on_alignment:
continue
paths = find_paths_if(arg, AttrsDescriptor.is_divisible_by_16)
divisibility_16 += [(param.num,) + x for x in paths]
self.arg_properties["tt.divisibility"] = divisibility_16

# Equal to 1 property
self.arg_properties["tt.equal_to"] = [
param.num
for param, arg in zip(params, values)
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
]

equal_to_1 = []
for param, arg in zip(params, values):
if not AttrsDescriptor.is_equal_to_1(arg) or param.do_not_specialize:
continue
paths = find_paths_if(arg, AttrsDescriptor.is_equal_to_1)
equal_to_1 += [(param.num,) + x for x in paths]
self.arg_properties["tt.equal_to"] = equal_to_1

# Equal to None property
equal_to_none = []
for param, arg in zip(params, values):
if arg is not None or param.do_not_specialize:
continue
paths = find_paths_if(arg, lambda v: v is None)
equal_to_none += [(param.num,) + x for x in paths]
self.equal_to_none = equal_to_none

def _add_backend_properties(self, params=None, values=None):
""" This method is for different subclasses to implement their own compile-time properties """
pass
Expand Down Expand Up @@ -130,6 +163,8 @@ def get_constants(self) -> Dict:
for prop_name in self.constant_properties:
for p in self.arg_properties.get(prop_name, []):
constants[p] = self.property_values[prop_name]
for v in self.equal_to_none:
constants[v] = None
return constants

def filter_out_constants(self):
Expand Down Expand Up @@ -166,7 +201,7 @@ def from_dict(data):
"""
attrs_descriptor = _descriptor_table[data["cls"]]()
for prop_name, param_ids in data["arg_properties"].items():
attrs_descriptor.arg_properties[prop_name] = param_ids
attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids))
attrs_descriptor._init_slots()
return attrs_descriptor

Expand Down
Loading