[FRONTEND][WIP][RFC] Rewrite AST conversion to improve metaprogramming #5284
+1,589
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem Statement
The current limitations of metaprogramming in Triton have led major users, such as Torch Inductor, to resort to using string-based templating. This RFC aims to address some of these limitations by extending Triton's metaprogramming capabilities.
I also found several performance issues (like backtracking codegen) with the current code generator that I intend to fix.
Current metaprogramming limitations
Except for simple assignments that are marked constexpr, if conditions and simple loops it's not possible to embed python expressions inside triton.
Current design relies on interpreting python expressions inside
CodeGenerator
. This approach is inherently limited because it's not possible to have good metaprogramming support without building a full python interpreter inside this class.This proposal also makes it possible to use
while
loops and usefor
loops with arbitrary iterators for metaprogramming.Proposal Overview
I propose that instead of converting the Python AST directly to Triton IR, we make a code generator generator that for a given Triton AST generates a function of tensor argument types and constant args that returns a Triton IR function as a result. I also propose a technique to differentiate Triton expressions from python metaprogramming expressions.
This approach allows you to embed any Python expression you want inside Triton.
Input Triton Function
Resulting Function Generator for the Triton Code(exec'ed)
Code Generator generation from Triton AST
At this stage we process the Python AST and generate a new python function that will generate the triton IR. We also do loop-carried variables analysis to later make it easier to construct the SSA correctly.
Separating Triton and Python expressions from each other
To distinguish between Triton expressions and Python metaprogramming expressions, we will use the following rules:
Triton Function definition arguments: We assume all arguments not marked as
tl.constexpr
are triton variables.Binary expressions: If the left or right part of a binary expression is a Triton expression, it's assumed to be a Triton expression.
Control flow:
if
orwhile
blocks use Triton expressions as conditions, these are interpreted as Triton control flow blocks.for
loops that iterate over Triton iterables, the loop is considered a Tritonfor
loop.Function Calls: Function calls that are going to Triton builtins and other Triton functions are considered to be Triton expressions.
One exception to this rule is
min
andmax
, for those functions we look at the arguments and assume the expression is a Triton expression if any of the arguments is a trition expression.We use the global scope of the function to resolve things like
tl.full((1,), 1.0, tl.float32)
thePseudoInterpreter
class uses the global scope to resolve the called function. Limitations of this approach, which I think won't affect backwards compatibility, are discussed later.For builtins we inject the
_builder
keyword argument to the call (note: the PoC does not currently support _generator arg, this breaks reductions. This limitation will be addressed later)name = ...
. Any more complex expressions are considered Python metaprogramming expressions. A simple assignment is considered a Triton assignment if one of the following is true:name
is recognized as a Triton variable.Most of these rules are implemented in the
ExpressionAnalyser
.Kernel Launch performance
Triton must generate different kernels for different constant expressions and call argument types. Generating a function to generate the IR moves some work from kernel launch time to code initialization time. Python Bytecode interpretation will have better performance compared to AST based interpretation done by the older approach.
Also, old code generator had a bad backtracking behaviour that the new code generator fixes.
Old code:
Old generator does this to find loop carried variables and construct SSA correctly. Instead of compiling the loop body twice my approach patches the generated AST after compiling the loop body.
Return support is not complete yet but, I intend to also add caching to ContainsReturnChecker.
Status of Implementation
Along side this RFC, a PR with a PoC implementation of my ideas is included, I believe it will be enough to demonstrate the bulk of my ideas.
You can try the PoC implementation like this:
I already invested a lot of time without getting any feedback from the community.
I am already aware that there are some features that are not supported in the experimental frontend, I don't have a complete list of missing features as of now but here are some known ones:
Discussion Questions
I would like some feedback from community about overriden builtins. My goal with this proposal was to turn Triton into a superset of Python, but triton overriding the behaviours of some builtins makes that k.
As noted earlier we have special rules for
min
andmax
. We can implemenet a similar rule forprint
(by default overriden with device_print) but since it has side effects (printing to the console) it matters if we run it in code generation time or code execution time (eg, in the gpu)Functionality of
range
is also overriden by triton. Assuming allrange
s with non-Triton arguments are python expressions would be like fully unrolling them, which would not be desirable.