-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Softmax use mask #644
Softmax use mask #644
Commits on Jun 19, 2024
-
[BACKEND] Update LLVM to llvm/llvm-project@657ec7320d8a (triton-lang#…
…4147) Upgrading LLVM repo again, because we need a feature that has been recently submitted in llvm/llvm-project#95057 Changes made: - `MathExtras` has been merged with its LLVM version. So I had to replace `mlir::ceilDiv` with `llvm:divideCeilSigned`
Configuration menu - View commit details
-
Copy full SHA for 3e233d7 - Browse repository at this point
Copy the full SHA 3e233d7View commit details -
[Pipeliner] NFC: Expose Pipeliner infrastructure for use by other tar…
…get backends (triton-lang#4155) Non-functional changes to expose `lib/Dialect/TritonGPU/Transforms/Pipeliner` infrastructure for use by other target backends.
Configuration menu - View commit details
-
Copy full SHA for 6f6d032 - Browse repository at this point
Copy the full SHA 6f6d032View commit details -
Support performance warning (triton-lang#3922)
This commit adds a performance warning for not selecting MMA v3 for tl.dot on Hopper. For the added test case, we will get: ``` test-warning.py:24:18: remark: Warning: can't use MMA V3 for the dot op c = tl.dot(a, b) ^ test-warning.py:24:18: note: see current operation: %39 = tt.dot %37, %38, %cst, inputPrecision = tf32 : ```
Configuration menu - View commit details
-
Copy full SHA for 75b0321 - Browse repository at this point
Copy the full SHA 75b0321View commit details -
Configuration menu - View commit details
-
Copy full SHA for a06add0 - Browse repository at this point
Copy the full SHA a06add0View commit details -
[AMD][NFC] Refactor AccelerateAMDMatmul operand legalization (triton-…
…lang#4136) - Choose proper configuration of operands according to number of conversions if it possible; - Get rid of complicated logic to find operand config; - Remove helper `supportWMMA()` to get rid of impicit logical dependencies with AccelerateAMDMatmul Signed-off-by: Ilya Veselov <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for f04df24 - Browse repository at this point
Copy the full SHA f04df24View commit details
Commits on Jun 20, 2024
-
Tiny change: Added/Improved error message in visit_For (triton-lang#4171
) While developing a kernel, I was given the error message "AssertionError()" without much helpful context on how to proceed with debugging. I could only solve it by understanding that part of the triton source code and spending half a day. That's why I'm (1) adding an error message to this part of the code, and (2) making the error message above it clearer (like it is in visit_While). This should allow the end user to debug this error without the need to dive into the triton source code.
Configuration menu - View commit details
-
Copy full SHA for a5b3783 - Browse repository at this point
Copy the full SHA a5b3783View commit details -
[FRONTEND] Wrap experimental TMA descriptor creation into a helper (t…
…riton-lang#4179) We need to potentially flush the TMA cache when re-using TMA memory. In order to make it safe we flush the cache for every TMA descriptor created.
Configuration menu - View commit details
-
Copy full SHA for 416600a - Browse repository at this point
Copy the full SHA 416600aView commit details -
Configuration menu - View commit details
-
Copy full SHA for cf413b8 - Browse repository at this point
Copy the full SHA cf413b8View commit details
Commits on Jun 21, 2024
-
Use custom stacktrace signal handler w/ python interrupt (triton-lang…
…#4182) The LLVM signal handler appears to supersede the Python signal handler. Calling PyErr_SetInterrupt is equivalent to raising SIGINT and ensures the program terminates after the stack trace is printed. This Python function is also thread safe. It might be better to raise the original signal, but LLVM does not appear to pass the signal to its handler when captured, and the Python function to raise a specific signal is only supported in 3.10+. Using the repro from triton-lang#4129: <details><summary>Before this change:</summary> ``` $ python repro.py abort Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it): 0 libtriton.so 0x00007c5fff21e850 1 libtriton.so 0x00007c5fff21bbbf 2 libtriton.so 0x00007c5fff21bd15 3 libc.so.6 0x00007c6002e42520 4 libc.so.6 0x00007c6002e4275b kill + 11 5 python 0x000062e75f1b1004 6 python 0x000062e75f090c59 7 python 0x000062e75f07ecfa _PyEval_EvalFrameDefault + 24906 8 python 0x000062e75f0759c6 9 python 0x000062e75f16b256 PyEval_EvalCode + 134 10 python 0x000062e75f196108 11 python 0x000062e75f18f9cb 12 python 0x000062e75f195e55 13 python 0x000062e75f195338 _PyRun_SimpleFileObject + 424 14 python 0x000062e75f194f83 _PyRun_AnyFileObject + 67 15 python 0x000062e75f187a5e Py_RunMain + 702 16 python 0x000062e75f15e02d Py_BytesMain + 45 17 libc.so.6 0x00007c6002e29d90 18 libc.so.6 0x00007c6002e29e40 __libc_start_main + 128 19 python 0x000062e75f15df25 _start + 37 Still alive ``` </details> <details><summary>After this change:</summary> ``` $ python repro.py abort Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it): 0 libtriton.so 0x000078ca9e21fd90 1 libtriton.so 0x000078ca99fc3dd7 2 libtriton.so 0x000078ca9e21d0ff 3 libtriton.so 0x000078ca9e21d255 4 libc.so.6 0x000078caa1e42520 5 libc.so.6 0x000078caa1e4275b kill + 11 6 python 0x0000636f2ee16004 7 python 0x0000636f2ecf5c59 8 python 0x0000636f2ece3cfa _PyEval_EvalFrameDefault + 24906 9 python 0x0000636f2ecda9c6 10 python 0x0000636f2edd0256 PyEval_EvalCode + 134 11 python 0x0000636f2edfb108 12 python 0x0000636f2edf49cb 13 python 0x0000636f2edfae55 14 python 0x0000636f2edfa338 _PyRun_SimpleFileObject + 424 15 python 0x0000636f2edf9f83 _PyRun_AnyFileObject + 67 16 python 0x0000636f2edeca5e Py_RunMain + 702 17 python 0x0000636f2edc302d Py_BytesMain + 45 18 libc.so.6 0x000078caa1e29d90 19 libc.so.6 0x000078caa1e29e40 __libc_start_main + 128 20 python 0x0000636f2edc2f25 _start + 37 Traceback (most recent call last): File "/localdisk/abaden/Projects/triton/repro.py", line 7, in <module> os.kill(os.getpid(), signal.SIGABRT) KeyboardInterrupt ``` </details>
Configuration menu - View commit details
-
Copy full SHA for 0ef1848 - Browse repository at this point
Copy the full SHA 0ef1848View commit details -
Configuration menu - View commit details
-
Copy full SHA for d1fa40f - Browse repository at this point
Copy the full SHA d1fa40fView commit details -
[BACKEND] Fix bugs in load/storeDShared. (triton-lang#4181)
Fix bugs in load/storeDShared. Unfortunately we can't test this directly today. But all of these bugs were found by a WIP PR running existing unit tests.
Configuration menu - View commit details
-
Copy full SHA for 0654d75 - Browse repository at this point
Copy the full SHA 0654d75View commit details -
Avoid underflow in f2reduce (triton-lang#4168)
When running [convert_blocked1d_to_slice0](https://github.com/triton-lang/triton/blob/0ba5f0c3cd029d5c3d1f01b9bf29dac32c27345e/test/Conversion/tritongpu_to_llvm.mlir#L924) Triton ends up computing a rank of a matrix with 0 columns during linear layout lowering, which trips up f2reduce, and causes undefined behavior, detectable through [UBSAN](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html). Fix this by returning the rank (0) early in these cases, without calling f2reduce. <details><summary>Stack trace</summary> <p> ``` third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30: runtime error: shift exponent 18446744073709551615 is too large for 64-bit type 'unsigned long long' #0 0x556ee2fea3be in inplace_rref_small third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30 #1 0x556ee2fea3be in f2reduce::inplace_rref_strided(unsigned long*, unsigned long, unsigned long, unsigned long) third_party/triton/third_party/f2reduce/f2reduce.cpp:470:9 #2 0x556ee2ea70da in getMatrixRank third_party/triton/lib/Tools/LinearLayout.cpp:125:3 #3 0x556ee2ea70da in mlir::triton::LinearLayout::checkInvariants(bool) third_party/triton/lib/Tools/LinearLayout.cpp:299:7 #4 0x556ee2ea656d in mlir::triton::LinearLayout::tryCreate(llvm::MapVector<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>, llvm::DenseMap<mlir::StringAttr, unsigned int, llvm::DenseMapInfo<mlir::StringAttr, void>, llvm::detail::DenseMapPair<mlir::StringAttr, unsigned int>>, llvm::SmallVector<std::__u::pair<mlir::StringAttr, std::__u::vector<std::__u::vector<int, std::__u::allocator<int>>, std::__u::allocator<std::__u::vector<int, std::__u::allocator<int>>>>>, 0u>>, llvm::ArrayRef<std::__u::pair<mlir::StringAttr, int>>, bool) third_party/triton/lib/Tools/LinearLayout.cpp:190:41 #5 0x556ee2eb2150 in mlir::triton::LinearLayout::divideRight(mlir::triton::LinearLayout const&) third_party/triton/lib/Tools/LinearLayout.cpp:654:51 #6 0x556ee2ee1c39 in mlir::cvtNeedsSharedMemory(mlir::RankedTensorType, mlir::RankedTensorType) third_party/triton/lib/Analysis/Utility.cpp:652:14 #7 0x556ee2cf38fd in mlir::triton::getRepShapeForCvtLayout(mlir::triton::gpu::ConvertLayoutOp) third_party/triton/lib/Analysis/Allocation.cpp:66:8 #8 0x556ee2cf3efa in mlir::triton::getScratchConfigForCvtLayout(mlir::triton::gpu::ConvertLayoutOp, unsigned int&, unsigned int&) third_party/triton/lib/Analysis/Allocation.cpp:95:19 #9 0x556ee2cf6057 in mlir::triton::AllocationAnalysis::getScratchValueSize(mlir::Operation*) third_party/triton/lib/Analysis/Allocation.cpp:272:24 #10 0x556ee2cf5499 in operator() third_party/triton/lib/Analysis/Allocation.cpp:343:7 #11 0x556ee2cf5499 in void llvm::function_ref<void (mlir::Operation*)>::callback_fn<mlir::triton::AllocationAnalysis::getValuesAndSizes()::'lambda'(mlir::Operation*)>(long, mlir::Operation*) third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #12 0x556edeeee7a9 in operator() third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #13 0x556edeeee7a9 in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:174:5 #14 0x556edeeee87c in void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:182:9 #15 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), mlir::Operation *, void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Visitors.h:313:10 #16 0x556ee2cf49e7 in walk<(mlir::WalkOrder)0, mlir::ForwardIterator, (lambda at third_party/triton/lib/Analysis/Allocation.cpp:341:42), void> third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h:794:12 #17 0x556ee2cf49e7 in mlir::triton::AllocationAnalysis::getValuesAndSizes() third_party/triton/lib/Analysis/Allocation.cpp:341:16 #18 0x556ee2cf4852 in run third_party/triton/lib/Analysis/Allocation.cpp:182:5 #19 0x556ee2cf4852 in AllocationAnalysis third_party/triton/lib/Analysis/Allocation.cpp:169:5 #20 0x556ee2cf4852 in mlir::Allocation::run(llvm::DenseMap<mlir::FunctionOpInterface, mlir::Allocation, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>, llvm::detail::DenseMapPair<mlir::FunctionOpInterface, mlir::Allocation>>&) third_party/triton/lib/Analysis/Allocation.cpp:627:3 #21 0x556ee1677402 in operator() third_party/triton/include/triton/Analysis/Allocation.h:227:26 #22 0x556ee1677402 in void mlir::CallGraph<mlir::Allocation>::doWalk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)>(mlir::FunctionOpInterface, llvm::DenseSet<mlir::FunctionOpInterface, llvm::DenseMapInfo<mlir::FunctionOpInterface, void>>&, mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::CallOpInterface, mlir::FunctionOpInterface), mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp)::'lambda'(mlir::FunctionOpInterface)) third_party/triton/include/triton/Analysis/Utility.h:350:7 #23 0x556ee16756b3 in walk<(mlir::WalkOrder)0, (mlir::WalkOrder)1, (lambda at third_party/triton/include/triton/Analysis/Allocation.h:222:9), (lambda at third_party/triton/include/triton/Analysis/Allocation.h:224:9)> third_party/triton/include/triton/Analysis/Utility.h:242:7 #24 0x556ee16756b3 in mlir::ModuleAllocation::ModuleAllocation(mlir::ModuleOp) third_party/triton/include/triton/Analysis/Allocation.h:220:5 #25 0x556ee2c2bf18 in (anonymous namespace)::AllocateSharedMemory::runOnOperation() third_party/triton/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp:26:22 ... UndefinedBehaviorSanitizer: invalid-shift-exponent third_party/triton/third_party/f2reduce/f2reduce.cpp:421:30 ``` </p> </details>
Configuration menu - View commit details
-
Copy full SHA for 7ca6d12 - Browse repository at this point
Copy the full SHA 7ca6d12View commit details -
[AMD] Disable block merging to avoid block argument explosion (triton…
…-lang#4176) This PR disable block merging when running `convert-builtin-func-to-llvm`. The reason behind this is that for now block merging can double the arguments of the blocks. This means that after a while we can start witnessing a block argument "explosion" which hangs the compiler. I am working on this ticket: llvm/llvm-project#63230 to make block merging better, but in the meantime, we should stop merging blocks to avoid compiler hangs. I added the minimal test to reproduce the explosion. The test for now is checking that we don't try to merge blocks.
Configuration menu - View commit details
-
Copy full SHA for cf2ad02 - Browse repository at this point
Copy the full SHA cf2ad02View commit details -
Add powerpc as a possible installation target and link in correct LLV… (
triton-lang#4183) Support the PowerPC architecture when installing from source code and add the respective LLVM libraries as dependencies. All tests (lit/ ctests) pass in the powerpc cluster I am working on. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [X ] This PR does not need a test because it modifies the CPU installation architecture - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 1959a08 - Browse repository at this point
Copy the full SHA 1959a08View commit details -
[TRANSFORM] Fix while op layout propagation (triton-lang#4185)
We shouldn't propagate to the `operand` of a while op, which contradicts the **forward** propagation rule. In addition, this PR also changes the queue implementation from *vector* to *deque* when rewriting regions to ensure parent regions are visited first.
Configuration menu - View commit details
-
Copy full SHA for 0154c76 - Browse repository at this point
Copy the full SHA 0154c76View commit details -
[FrontEnd] Allow MLIR dumping a single kernel (triton-lang#4188)
A pytorch run can consist of lots of triton kernel runs. Adding the functionality to allow for dumping the IRs for a particularly interesting kernel.
Configuration menu - View commit details
-
Copy full SHA for d041891 - Browse repository at this point
Copy the full SHA d041891View commit details -
[FrontEnd] Truncate cached file names when names are too long (triton…
…-lang#4166) Truncate cached file names to 150 letters when names are too long to avoid overlong file paths that are beyond the system limit (255).
Configuration menu - View commit details
-
Copy full SHA for 0f18662 - Browse repository at this point
Copy the full SHA 0f18662View commit details
Commits on Jun 22, 2024
-
Add a more meaningful check to make sure we are not merging blocks (t…
…riton-lang#4186) This is a follow-up to triton-lang#4176 (comment) I am now counting the number of blocks with (17) and without (31) block merging. I double checked to make sure this does not pass when we use an aggressive region simplification strategy.
Configuration menu - View commit details
-
Copy full SHA for c7a37a9 - Browse repository at this point
Copy the full SHA c7a37a9View commit details
Commits on Jun 24, 2024
-
[AMD] Skip mfma layout in maybeDuplicate (triton-lang#4170)
The workaround introduced in triton-lang#4048 "forgot" to skip mfma layout.
Configuration menu - View commit details
-
Copy full SHA for 0a66c1b - Browse repository at this point
Copy the full SHA 0a66c1bView commit details -
Configuration menu - View commit details
-
Copy full SHA for 8f6b4de - Browse repository at this point
Copy the full SHA 8f6b4deView commit details -
[DOCS][NFC] Fix doc formatting problems (triton-lang#4195)
1. f-string cannot be used as docstrings in Python. 2. URLs should follow the reStructuredText format. 3. Code snippets in a code block should be indented. Tested and passed on a local machine.
Configuration menu - View commit details
-
Copy full SHA for 784b537 - Browse repository at this point
Copy the full SHA 784b537View commit details -
[BACKEND] Fix regression in pipeliner pre-checks. (triton-lang#4196)
During some previous refactoring we changed the logic and started pipeling cases that had incompatible shared encoding. This was missed because one of the lit test had not been updated :(
Configuration menu - View commit details
-
Copy full SHA for d0cd1c0 - Browse repository at this point
Copy the full SHA d0cd1c0View commit details -
Configuration menu - View commit details
-
Copy full SHA for 810e046 - Browse repository at this point
Copy the full SHA 810e046View commit details
Commits on Jun 25, 2024
-
[AMD] Guard against null in
BypassEpilogueSMEM
(triton-lang#4203)`val.getDefiningOp()` can return `nullptr`. In this case, we must fail the `BypassEpilogueSMEM` rewrite pass for the given op. This prevents run-time crashes.
Configuration menu - View commit details
-
Copy full SHA for fc8d1a5 - Browse repository at this point
Copy the full SHA fc8d1a5View commit details
Commits on Jun 26, 2024
-
[FRONTEND][NFC] Fix type checking, conditional logic, and loop struct…
…ures for improved readability and performance (triton-lang#4208)
Configuration menu - View commit details
-
Copy full SHA for 948a3e8 - Browse repository at this point
Copy the full SHA 948a3e8View commit details -
Document TRITON_HOME (triton-lang#4210)
Document the existence of `TRITON_HOME` environment variable. The `TRITON_HOME` variable controls the location of the `.triton` directory that stores, among other things, the files downloaded during a `pip install -e python` virtualenv build. By default, this is located in the user's home directory, at `~/.triton`. I was trying to build Triton on my system on a large local disk, but with limited network home directory space, and the `pip` command kept failing with out of disk space errors. It turned out that during installation, large files were downloaded to the `~/.triton` directory causing failure. After checking that it was not `pip` doing this, I found the `TRITON_HOME` variable which allowed me to workaround the issue and build Triton successfully. After seconding triton-lang#4007, I decided to contribute this documentation fix. Co-authored-by: sree <sree@buckyball>
Configuration menu - View commit details
-
Copy full SHA for 4fd733d - Browse repository at this point
Copy the full SHA 4fd733dView commit details -
[BACKEND] Fix regression in i1 reduction (triton-lang#4215)
Recent refactoring broke i1 shared memory load.
Configuration menu - View commit details
-
Copy full SHA for 9bf0c4f - Browse repository at this point
Copy the full SHA 9bf0c4fView commit details -
Configuration menu - View commit details
-
Copy full SHA for 06e6799 - Browse repository at this point
Copy the full SHA 06e6799View commit details
Commits on Jun 27, 2024
-
[BACKEND] Fix divisibility analysis for shift ops (triton-lang#4221)
Divisibility does not ensure that a value is not 0 therefore we cannot use divisibility as a minimum shifted values.
Configuration menu - View commit details
-
Copy full SHA for ab7b89b - Browse repository at this point
Copy the full SHA ab7b89bView commit details
Commits on Jun 28, 2024
-
Support FP8 constant (triton-lang#4222)
To unblock the compilation of kernels like below which don't operate arithmetically on FP8. ``` @triton.jit def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 400624 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex % 784 x1 = (xindex // 784) x2 = xindex tmp0 = x0 tmp1 = tl.full([1], 769, tl.int64) tmp2 = tmp0 < tmp1 tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0) tmp4 = tmp3.to(tl.float8e4nv) tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype) tmp6 = tl.where(tmp2, tmp4, tmp5) tl.store(out_ptr0 + (x2), tmp6, xmask) ```
Configuration menu - View commit details
-
Copy full SHA for 938e388 - Browse repository at this point
Copy the full SHA 938e388View commit details -
Configuration menu - View commit details
-
Copy full SHA for 1b35f11 - Browse repository at this point
Copy the full SHA 1b35f11View commit details -
[PROTON] Fix improper use of
reinterpret_cast
(triton-lang#4225)Casting from a generic `void *` object to other types should use `static_cast` instead of `reinterpret_cast`
Configuration menu - View commit details
-
Copy full SHA for 8e96b71 - Browse repository at this point
Copy the full SHA 8e96b71View commit details
Commits on Jul 1, 2024
-
[nvgpu] Expose helper code from NVGPU to LLVM lowering (NFC). (triton…
…-lang#4235) OpenXLA would like to reuse code from `NVGPUOpPatternBase` for Sparsity lowering. Currently this is done by patching `NVGPUToLLVMPass.cpp` to add an extra pattern deriving from `NVGPUOpPatternBase`. Instead, we would like to move this pattern to the downstream OpenXLA repository. This change removes the `NVGPUOpPatternBase` base class and exposes the core lowering code as a utility function (`rewriteAsPtxAsm()`) instead. The existing patterns now directly inherit from `mlir::OpRewritePattern` and use this utility function. I hope that exposing an extra function is acceptable. I tried to balance it with cleaning up the code a bit (e.g., no more CRTP, but of course that's subjective).
Configuration menu - View commit details
-
Copy full SHA for 54960ca - Browse repository at this point
Copy the full SHA 54960caView commit details -
Optimize code generated by EmitIndices for LinearLayout (triton-lang#…
…4213) 1. Split LinearLayout conversion in block+warp+thread and register parts 2. Move block+warp+thread part out of register loop in emitIndicesUsingLinearLayouts LLVM optimizes generated IR better and number of required registers is lowered. With it we are able to re-enable linear layout for AMD backend.
Configuration menu - View commit details
-
Copy full SHA for a7e3476 - Browse repository at this point
Copy the full SHA a7e3476View commit details
Commits on Jul 2, 2024
-
Enhance triton functionality to allow use of custom cuda toolkit at r…
…untime (triton-lang#4237) Fixes triton-lang#4224 As introduced in triton-lang#4224, this PR addresses two main issues in source build Triton with custom cuda toolkit: 1. `cupti_include_dir` lookup path for CMake should also include `TRITON_CUPTI_PATH`, as the `TRITON_XXX_PATH` seems to the standard for a custom triton build environment variables. 2. As described in the linked issue, when we specify a `TRITON_CUDART_PATH`, those cuda RT headers from conda will not be downloaded to triton/third_party paths. This makes the default include path bundled in Triton python package empty and thus leads to `fatal error: cuda.h: No such file or directory` error. If we're supporting users to "bring your own cuda toolkit" to build triton from source, we should also support them to run triton compile with their own cuda toolkit runtime. The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [X] This PR does not need a test because `This PR enhances custom build options and won't affect the default build. Default build will be validated by existing CI`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 5a6668e - Browse repository at this point
Copy the full SHA 5a6668eView commit details
Commits on Jul 3, 2024
-
[FRONTEND] fix conflicting multithreading and fork management (triton…
…-lang#4169) This PR disables multithreading in MLIR context after compilation ends. This is done to safely finalize thread pool implemented in MLIRContext. Not properly finalized thread pool can hang or crash in child process after fork. --------- Co-authored-by: Lei Zhang <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 7809164 - Browse repository at this point
Copy the full SHA 7809164View commit details -
[AMD] Fix tt.fp_to_fp bf16 to fp32 conversion (triton-lang#4238)
Added missing handling of tt.fp_to_fp in FpToFpOpConversion::createDestOps when type of src operand is bf16 and type of dst operand is fp32. When having tt.fp_to_fp with src type bf16 and dst operand of type fp32, there were following error message: error: failed to legalize operation 'tt.fp_to_fp' that was explicitly marked illegal BF16 to FP16 (FP32) was not handled neither in FpToFpOpConversion::createDestOps nor in getConversionFunc, thus conversion is added directly in FpToFpOpConversion::createDestOps, similarly, as for the case when source and destination operands are of the same type.
Configuration menu - View commit details
-
Copy full SHA for dcf83df - Browse repository at this point
Copy the full SHA dcf83dfView commit details -
[AMD] Search libamdhip64.so in PyTorch user site installation (triton…
…-lang#4246) Previously Triton cannot find the HIP runtime bundled in PyTorch if the wheel is installed with `--user`. Searching the user site-packages directory may solve this problem.
Configuration menu - View commit details
-
Copy full SHA for 13edc45 - Browse repository at this point
Copy the full SHA 13edc45View commit details -
Update the 07-extern-functions tutorial and include a unit test for e…
…xtern function (triton-lang#4226) ## Description 1 Updated the doc of 07-extern-functions tutorial tl.math.fn no longer works in the nightly (following change as a reference) https://github.com/triton-lang/triton/pull/3172/files#diff-8bb547c8082ddf083d49bdb5e6f126c1cebae899390c95afc0f713229982f516L35 2 Added an example of extern_lib path (just use the default path explicitly to showcase the usage). 3 Added a unit test. In the case of future backward breaking changes of extern math function usage, the unit test should break. Picked tanh since it is heavily used in gelu. ## Testing Executed in cuda.
Configuration menu - View commit details
-
Copy full SHA for 77e64c2 - Browse repository at this point
Copy the full SHA 77e64c2View commit details -
Tutorial - 06-fused-attention.py - plotting (triton-lang#4244)
This PR fixes a simple matplotlib error - if `HAS_FLASH_ATTN` & `torch.float8` is available the color array goes out of possible colors. Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. (Disclaimer: Its fixing a trivial matplotlib error in the python tutorials with flash_attn installed.) - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `this change fixes a tutorial`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 0929654 - Browse repository at this point
Copy the full SHA 0929654View commit details -
[nvgpu] Reorganize and complete tests for convert-nv-gpu-to-llvm pass. (
triton-lang#4242) As promised in PR triton-lang#4235, add tests for nvgpu to llvm conversion for the remaining ops.
Configuration menu - View commit details
-
Copy full SHA for 59d4b68 - Browse repository at this point
Copy the full SHA 59d4b68View commit details -
[BACKEND] add folder for AdvanceOp (triton-lang#4240)
add a folder pattern for AdvanceOp: advance(ptr, 0, 0) -> ptr
Configuration menu - View commit details
-
Copy full SHA for 3b03565 - Browse repository at this point
Copy the full SHA 3b03565View commit details
Commits on Jul 4, 2024
-
[AMD] Enable integration tests for MI300 (triton-lang#4245)
This PR updates CI script to enable integration tests for MI300. This only enables coverage the same as the MI250 machine. There is one test failure left as TODO in test_core.py to be figured out and fixed. --------- Co-authored-by: Lixun Zhang <[email protected]> Co-authored-by: Lei Zhang <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for c5d016d - Browse repository at this point
Copy the full SHA c5d016dView commit details -
[CI] Reduce the number of CPU threads to prevent memory overflow issu…
…es during build (triton-lang#4258) The documentation pipeline has failed multiple times due to this issue. Since this CI node has 24 cpu cores, we try to hardcode the number of threads as 24 to reduce memory usage.
Configuration menu - View commit details
-
Copy full SHA for e7350f4 - Browse repository at this point
Copy the full SHA e7350f4View commit details
Commits on Jul 5, 2024
-
[BUILD] Fix compilation errors under RelWithDebugInfo (triton-lang#4261)
Flags used by the CXX compiler during RELWITHDEBINFO builds. ``` CMAKE_CXX_FLAGS_RELWITHDEBINFO:STRING=-O2 -g -DNDEBUG ``` `setCurrentDebugTypes` is a different type of symbol (function or macro) under the control of the NDEBUG macro, so `llvm::setCurrentDebug` cannot be used directly; otherwise, it will cause compilation errors. https://github.com/llvm/llvm-project/blob/00c622e596f918d9d83674b58097c8982ae1af95/llvm/include/llvm/Support/Debug.h#L71 https://github.com/llvm/llvm-project/blob/959ff45bdad1777598c89b662fc98653a19eb8a3/mlir/lib/CAPI/Debug/Debug.cpp#L28-L31
Configuration menu - View commit details
-
Copy full SHA for 0430b9e - Browse repository at this point
Copy the full SHA 0430b9eView commit details
Commits on Jul 7, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 12f7537 - Browse repository at this point
Copy the full SHA 12f7537View commit details -
Configuration menu - View commit details
-
Copy full SHA for 6463241 - Browse repository at this point
Copy the full SHA 6463241View commit details
Commits on Jul 8, 2024
-
[TEST][NFC] Add back check for wgmma wait (triton-lang#4273)
Add back check to introduce dummy dependency in wgmma.wait. Also move back ops within llvm functions.
Configuration menu - View commit details
-
Copy full SHA for a8fd9a3 - Browse repository at this point
Copy the full SHA a8fd9a3View commit details -
[CI] Preserve environment variables when building docs (triton-lang#4270
Configuration menu - View commit details
-
Copy full SHA for b8ea727 - Browse repository at this point
Copy the full SHA b8ea727View commit details -
Make load/storeDShared handle large vectors. (triton-lang#4189)
Make load/storeDShared handle large vectors. Previously we errored out when given a vector of >4 elements or >128 bits. Now we handle this properly by merging elements (e.g. 16xi8 -> 4xi32) or splitting the vector (e.g. 16xi32 -> four 4xi32 loads/stores).
Configuration menu - View commit details
-
Copy full SHA for 22631da - Browse repository at this point
Copy the full SHA 22631daView commit details -
Configuration menu - View commit details
-
Copy full SHA for ca469d7 - Browse repository at this point
Copy the full SHA ca469d7View commit details -
[AMD] Remove promotion for reduceOp (triton-lang#4269)
ReduceOp promotion was introduced in triton-lang#3153 to fix `test_reduce1d` and `test_reduce`. This is not necessary. And it doesn't work if there is a callOp in the body of the reduceOp, since the type of the callee function definition is not promoted.
Configuration menu - View commit details
-
Copy full SHA for 3a648f6 - Browse repository at this point
Copy the full SHA 3a648f6View commit details -
[AMDGPU] Fix occupancy calculation in softmax tutorial. (triton-lang#…
…4243) This PR fixes occupancy calculation for AMDGPU so number of waves does not exceed maximum number of waves that can execute in parallel on a CU. Additionally, it fixes number of regular purpose registers available on CU. Co-authored-by: Ognjen Plavsic <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 47fc046 - Browse repository at this point
Copy the full SHA 47fc046View commit details
Commits on Jul 9, 2024
-
[BACKEND] Fix an issue with the pipeliner (triton-lang#4247)
During pipelining operations that do not depend on or being dependent by anchor operations are considered remaining ops and will be scheduled into the last stage. These ops do not present in the existing stages but can be visited by other staged ops. Fixing an ICE when looking for clusters for them.
Configuration menu - View commit details
-
Copy full SHA for 050b41d - Browse repository at this point
Copy the full SHA 050b41dView commit details -
[PROTON] Flush ops before activating sessions (triton-lang#4288)
To ensure that metrics from kernels launched before are not attributed to the current active session. Co-authored-by: Philippe Tillet <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for c14b033 - Browse repository at this point
Copy the full SHA c14b033View commit details -
[IR] Add verifier for tt.broadcast (triton-lang#4286)
Summary: This PR adds a verifier for tt.broadcast primitive to provide early error detection for incorrect usage. Test Plan: New test Reviewers: @htyu Subscribers: Tasks: Tags: The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 7a78c04 - Browse repository at this point
Copy the full SHA 7a78c04View commit details -
[BACKEND] Fix hopper mma to linear layout constraints (triton-lang#4283)
n = 8 should be a valid option https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
Configuration menu - View commit details
-
Copy full SHA for 18996e7 - Browse repository at this point
Copy the full SHA 18996e7View commit details -
[Tutorial] Fix a pointer not advanced issue in the persistent kernels. (
triton-lang#4218) It looks like the operand pointers are not advanced in the baseline and persistent kernels.
Configuration menu - View commit details
-
Copy full SHA for 4073423 - Browse repository at this point
Copy the full SHA 4073423View commit details
Commits on Jul 10, 2024
-
[FRONTEND] let CacheManager write to temp dir instead of temp file (t…
…riton-lang#4295) # Summary there've been multiple issues discussing around the `FileNotFoundError` on compilation when `CompiledKernel` is trying to read from the listed ASM files. triton-lang#2688 triton-lang#4002 vllm-project/vllm#6103 etc. and there have been some attempts to address it such as triton-lang#3544 . This PR attempts to explain the root cause and suggest a fix. # Why When a kernel is being compiled, triton first writes IRs to triton cache dir ([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L289)). Inside of the write operation, the process first writes it to a temp file unique to the current process (plus a uuid to distinguish between multiple processes with same PID on different hosts sharing the same underlying FS) ([ref](https://github.com/triton-lang/triton/blob/c14b033cd979d5c39e5fdb3847c022fa5d71a0c1/python/triton/runtime/cache.py#L124-L130)) and then atomically `os.replace` it to the final file name. Afterwards the `CompiledKernel` lists all the IRs and reads them ([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L362-L367)). On multiprocess set up this may however result in a race condition. Let's focus on a case where there's one host with 2 processes on it. ![Triton RC (1)](https://github.com/triton-lang/triton/assets/43726198/ffc20e0c-0404-4e7a-bd6c-022e710e97b9) At the time when `pid 1` lists ASMs, the dir may contain temp files generated from another process `pid 2`. However at the time when `pid 1` proceeds to read bytes from the listed files, `pid2` may have already `os.replace`ed its temp files, so `pid 1` will encounter `FileNotFoundError` when trying to read the temp file generated by `pid 2`. IBM/vllm#35 (comment) also believes this is the root cause. # How There're multiple potential solutions towards this, as mentioned in IBM/vllm#35 (comment) as well: - let each process write to a private temp dir instead so `glob` won't bother taking the temp stuff into consideration - or, exclude `tmp.pid_*` from `glob` This PR tries to go with the 1st approach to avoid adding an assumption on the tmp file pattern (which is only used in `runtime/cache.py`) in `compiler/compiler.py` but is open to any suggestion. Thanks! Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `not applicable`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for b674269 - Browse repository at this point
Copy the full SHA b674269View commit details -
[BACKEND] Remove special handling for bf16 in fp->int, int->fp handli…
…ng (triton-lang#4281) This PR removes some special handling for int->bf16 and bf16->int conversions in the TritonNVIDIAGPU->LLVM lowerings, in order to support, e.g. `cvt.bf16.s32` and `cvt.s32.bf16` instructions that are now available on Hopper. Before this PR - there was some special handling for conversions to and from bf16; for int->bf16, the conversion would be done as a int->fp32 followed by fp32->bf16. Presumably, this was done because, before sm90, the ptx "cvt" instruction doesn't support conversions to/from bf16. However, sm90 _does_ support direct conversions to/from bf16; so this PR removes this special handling in order to make use of the direct cvt instructions. For Ampere, it looks like the special handling is no longer needed and llvm handles the details of different hardware implementations (perhaps thanks to llvm/llvm-project#74827?) The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 0ac0d2a - Browse repository at this point
Copy the full SHA 0ac0d2aView commit details
Commits on Jul 11, 2024
-
[INTERPRETER] Only parse the AST of the JIT function (triton-lang#4294)
``` @triton.autotune(..., configs) @triton.jit(...) def fn(): ``` Suppose we have the above decorators, and configs are local variables. These variables are not visible in `interpreter.py`. Therefore, the source code should only contain the JIT function itself without additional decorators.
Configuration menu - View commit details
-
Copy full SHA for 0f093ea - Browse repository at this point
Copy the full SHA 0f093eaView commit details -
Configuration menu - View commit details
-
Copy full SHA for 96e5145 - Browse repository at this point
Copy the full SHA 96e5145View commit details -
[BACKEND] Remove
EmitIndicesTest.cpp
and always use linear layout t……o emit offsets (triton-lang#4299)
Configuration menu - View commit details
-
Copy full SHA for d511d1b - Browse repository at this point
Copy the full SHA d511d1bView commit details
Commits on Jul 12, 2024
-
[TEST] Fix numpy out-of-bound warnings/failures (triton-lang#4280)
In newer versions of numpy, converting python integers that are out of bounds results in a failure - e.g. see this warning that currently appears in CI: ``` DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 67830198458 to uint32 will fail in the future. ``` This PR fixes two tests which fail when using numpy 2.0.0: * test_randint - explicitly truncate the integer that was too large * test_reduce1d: previously, the test would set `x[3:10] = argmax(x)`. However: (1) this now errors due to overflow, e.g. if argmax() returns a value >= 256; and (2) it appears that the intent is actually to set x[3:10] = x[argmax(x)], so that there are duplicate elements with the max value. I've changed this line to `x[3:10] = argmax(x)`.
Configuration menu - View commit details
-
Copy full SHA for 241fc61 - Browse repository at this point
Copy the full SHA 241fc61View commit details -
Configuration menu - View commit details
-
Copy full SHA for c71073f - Browse repository at this point
Copy the full SHA c71073fView commit details -
[AMD] NFC:Adding AMD GPUs to official support in main branch (triton-…
…lang#4307) This commit change the main README.md. I updated AMD GPUs to the officially support hardware.
Configuration menu - View commit details
-
Copy full SHA for 332ddd9 - Browse repository at this point
Copy the full SHA 332ddd9View commit details
Commits on Jul 13, 2024
-
[CI][AMD] Fix excessive cache storage usage (triton-lang#4300)
This commit add steps to remove cache for AMD build bots to avoid accumulating without limit, given that the cache was generated inside docker with root so won't be cleaned up by GitHub automatically. --------- Co-authored-by: Lei Zhang <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 056fa0f - Browse repository at this point
Copy the full SHA 056fa0fView commit details -
Update README instruction to run C++ unit tests (triton-lang#4306)
The command to run C++ unit tests, `ninja test`, was failing with the following error: ``` ninja: error: build.ninja:1002: multiple outputs aren't (yet?) supported by depslog; bring this up on the mailing list if it affects you ``` CI uses `ctest -j32` to run C++ unit tests, so this changes the failing command to the one used by CI. --- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [ ] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `it just updates the README file`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 790ddf1 - Browse repository at this point
Copy the full SHA 790ddf1View commit details -
Configuration menu - View commit details
-
Copy full SHA for e4a0d93 - Browse repository at this point
Copy the full SHA e4a0d93View commit details
Commits on Jul 15, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 9562303 - Browse repository at this point
Copy the full SHA 9562303View commit details -
[BACKEND] Fix canonicalization of LocalLoad op (triton-lang#4313)
Currently, we create a new op at the location of the ConvertLayout op during canonicalization. For pure ops, it is fine, but for LocalLoad op, we need to create the new op at the location of the LocalLoad op (as memory could have been changed in between) The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [/test ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
Configuration menu - View commit details
-
Copy full SHA for 58454f8 - Browse repository at this point
Copy the full SHA 58454f8View commit details -
[OPTIMIZER] Fix layout error after backward remat (triton-lang#4311)
Fixes pytorch/pytorch#130101 In the linked issue `getConvertBackwardSlice` is returning a slice that requires two different layouts for the same operation, which isn't supported. There is a guard that checks `layouts.find(currentValue)` and bails out when this doesn't match, however this check came after the `visited` check so we would skip the queue item before we ever got to the check. To fix this I change the visited check to only apply if the current `(value, encoding)` pair has been enqueued previously and apply this uniformly to everywhere we want to update a layout. The reproducer is admittedly long, but I was able to reduce it down to the exact slice that was triggering the failure.
Configuration menu - View commit details
-
Copy full SHA for a6b15ef - Browse repository at this point
Copy the full SHA a6b15efView commit details -
[CI][AMD] Drop testing hopper specific tests (triton-lang#4326)
These tests are somewhat hopper related and may grow even more like so given we are enabling more there so not suitable for AMD side.
Configuration menu - View commit details
-
Copy full SHA for 6cdc650 - Browse repository at this point
Copy the full SHA 6cdc650View commit details -
[Frontend][AMD] Support K<16 dot cases (triton-lang#3908)
This commit relaxes tl.dot semantic requirements regarding shapes to allow per-backend limitions. This enables using K=8 intrinsics for AMD GPUs.
Configuration menu - View commit details
-
Copy full SHA for 538556a - Browse repository at this point
Copy the full SHA 538556aView commit details
Commits on Jul 16, 2024
-
[BACKEND][AMD] Attach triple right after converting to llvm (triton-l…
…ang#4329) This makes sure all following transformation/optimization steps sees it so not limited to on/after the optimization step. It also makes the `optimize_module` cleaner. This changes triton-lang#3901 effectively.
Configuration menu - View commit details
-
Copy full SHA for 69539b8 - Browse repository at this point
Copy the full SHA 69539b8View commit details -
[BE] Minimize sort_with_index.mlir reproducer (triton-lang#4327)
I managed to ablate this reproducer from triton-lang#4311 down a lot. I believe this is close to minimal now. cc @Jokeren --------- Co-authored-by: Peter Bell <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 79297ec - Browse repository at this point
Copy the full SHA 79297ecView commit details -
[INTERPRETER] Refactor function rewriter (triton-lang#4325)
1. Use the builtin `ast.increment_lineno` function to make it more robust 2. Clean up function rewrite logic 3. Resolve global variable reference issues 4. Enable line info tests
Configuration menu - View commit details
-
Copy full SHA for 7c42f6b - Browse repository at this point
Copy the full SHA 7c42f6bView commit details -
[Runtime] Dynamically load cuTensorMapEncodeTiled (triton-lang#4330)
That is only present in CUDA-12 compatible drivers, and is missing in CUDA-11 ones Spiritual follow up after triton-lang#2771 allows for dynamic query of the symbol and if run on an older driver, it will return an error. Also, fix `occupancyMaxActiveClusters` behavior when symbol is not found (before this change it would crash with null pointer deref, now it should return a structured exception)
Configuration menu - View commit details
-
Copy full SHA for f9f2960 - Browse repository at this point
Copy the full SHA f9f2960View commit details -
Use
device
fixture fortest_subprocess.py::test_print
(triton-lan……g#4333) **This pull request adds the use of the `device` fixture to the test to make it not only CUDA specific. This simplifies testing of various devices in the downstream, without having to modify the test code itself.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. ~- [ ] I am not making a trivial change, such as fixing a typo in a comment.~ - [ ] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `it only modifies the test`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) Signed-off-by: Anatoly Myachev <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for 2946cd1 - Browse repository at this point
Copy the full SHA 2946cd1View commit details -
Add Perf Kernels This is a combination of 2 commits. Add Perf Kernels Add Perf Kernels This is a combination of 6 commits. add perf-kernels fix formating issues fix unused variables and other bugs fix other issues remove scripts save check changes format save save try pre-commit check save
Configuration menu - View commit details
-
Copy full SHA for 2d2dbe1 - Browse repository at this point
Copy the full SHA 2d2dbe1View commit details -
Configuration menu - View commit details
-
Copy full SHA for 17575ea - Browse repository at this point
Copy the full SHA 17575eaView commit details -
Change all block pointers to tensor pointers (#585)
Change all block pointers to tensor pointers Block pointers are for nvidia TMAs. They are useful for regular loads as well but not well supported. Also cleaned up some code I came across along the way and updated comment at the top.
Configuration menu - View commit details
-
Copy full SHA for a3d784a - Browse repository at this point
Copy the full SHA a3d784aView commit details -
Add support for bshd layout (#587)
Add support for layouts commonly used by users. Add option for varlen / thd layout to specify equal context lengths for all batches. Also often used by users.
Configuration menu - View commit details
-
Copy full SHA for aa6685a - Browse repository at this point
Copy the full SHA aa6685aView commit details -
* remove on push for Integration Tests * rename * add post merge test * save * dtype params * skip bad config * fix more stuff
Configuration menu - View commit details
-
Copy full SHA for dbe1173 - Browse repository at this point
Copy the full SHA dbe1173View commit details
Commits on Jul 18, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 23ba546 - Browse repository at this point
Copy the full SHA 23ba546View commit details
Commits on Jul 19, 2024
-
Couple of FA optimizations (#608)
Couple of FA optimizations Set SM scale multiplication to a constexpr. Minor asm improvement. Changed acc scaling to adjust for softmax division to multiplication with reciprocal. ~10% perf improvement. --------- Co-authored-by: Michael Melesse <[email protected]>
Configuration menu - View commit details
-
Copy full SHA for df4c4d3 - Browse repository at this point
Copy the full SHA df4c4d3View commit details
Commits on Jul 31, 2024
-
* streamk v0.1 * remove unused variable * fix format issues * add README * fix format issue * change num_sms to num_cus
Configuration menu - View commit details
-
Copy full SHA for 52a908f - Browse repository at this point
Copy the full SHA 52a908fView commit details
Commits on Aug 6, 2024
-
Add explicit multiply-reduce GEMM kernel (#621)
* Add explicit multiply-reduce GEMM kernel * Remove `SPLIT_K` argument from kernel * Remove `GROUP_SIZE_M` argument from kernel * Remove conditional call to `tl.dot` from kernel * Remove table with performance data from README
Configuration menu - View commit details
-
Copy full SHA for 1d2e066 - Browse repository at this point
Copy the full SHA 1d2e066View commit details
Commits on Aug 13, 2024
-
Copy *tune_gemm* from
triton-mlir
branch tomain_perf
branch (#614)* Copy *tune_gemm* from `triton-mlir` branch to `main_perf` branch The source commit in `triton-mlir` branch is the following one: ``` commit cf44637 Author: Lixun Zhang <[email protected]> Date: Tue Jul 23 14:22:01 2024 -0500 [tuning] gemm tuning script v3.3 (#606) ``` *tune_gemm* was copied from the source branch directory `scripts/amd/gemm` to the destination branch directory `python/perf-kernels/tune_gemm`. The SHA-256 hashes of *tune_gemm* files are the following ones: ``` 423aef1deb6c60f6578a1ecfc94d2473f8746b00d0368c553d31641fcfa5e354 README.md 46ab93978fee33f75df23332f12546dae7910478c391f08b7b1ebd415d8266b7 icache_flush.py f18711544641b810a652e6a6629bfa2b613f6ade87399e88fdf05b81d4af58a4 matmul.py 84a1c80ede36d3154e51188276eda2d2d0f52ed4f496ff69349c390d83b8ec10 matmul_kernel.py 2812b40183637bc8d7e47d283c7d66b1792134a43de76f3eacf7b9b3e1c2431a one_config.py 0ac09c33b0173cea06ddabbf9f4e3afa1816781dea4fdcce5894a7e7d6a80e19 rocprof_gemm.py 00eff41cf1c0bfc41d623e42b51706af67639fec76146741e2067d2a93e0148a utils/file_generator.py cb7afb773ccee835b00396cccf87e0d44fe513131161f031fae42453725b3c82 utils/utils.py 59f23811b660e49e566927853926a21f02a7014bb19c8ea67e6b382db6c59900 tune_gemm.py e787f35d750b869f113b3c01692f64243a9cb8a71a18ade2f0465f614f7284e4 tune_gemm.sh ``` The files were kept as-is despite `pre-commit` intentions to change them. After that, *tune_gemm* directory in code and documentation was fixed to reflect it's new location.
Configuration menu - View commit details
-
Copy full SHA for 11e4447 - Browse repository at this point
Copy the full SHA 11e4447View commit details
Commits on Aug 16, 2024
-
Clean up *tune_gemm* script from
main_perf
branch (#629)* Reformat *tune_gemm* files with Triton's pre-commit The following command was executed to reformat the files: ``` $ pre-commit run --files \ python/perf-kernels/tune_gemm/* \ python/perf-kernels/tune_gemm/utils/* ``` * Fix *tune_gemm* issue with (1, 1) bias tensors * Fix `ruff` F405 errors Fix the following linter error: F405 `identifier` may be undefined, or defined from star imports * Fix `ruff` F841 errors Fix the following linter error: F841 Local variable `identifier` is assigned to but never used * Fix minor issues in README file * Add `--` to `num_threads` argument. * Replace `--icahe` argument (non-existent argument) with `--icache_flush` (existent argument). * Remove old files from *tune_gemm* V1 * Add dependency graph to README file * Selectively disable `yapf` for parts of `one_config.py`
Configuration menu - View commit details
-
Copy full SHA for 624335f - Browse repository at this point
Copy the full SHA 624335fView commit details
Commits on Aug 19, 2024
-
[tune gemm v3.4] Add xcd-based pid remapping and change back to rocpr…
…ofv1 (#630) * Change to rocprofv1 * improve post processing of rocprof results - set --iters=200 as default. This is enough since the time is stable after the first few runs. - Filter out kernel time that is too large. We use the first kernel time as the threshold. There must be something wrong with the kernel if its elapsedTime is larger than the first run. We need to investigate the reason. For now, just filter them out. * Add xcd-based pid remapping * Enable EVEN_K=false for large gemms * Update readme
Configuration menu - View commit details
-
Copy full SHA for 15cb3a8 - Browse repository at this point
Copy the full SHA 15cb3a8View commit details -
Configuration menu - View commit details
-
Copy full SHA for 177d0bd - Browse repository at this point
Copy the full SHA 177d0bdView commit details
Commits on Sep 6, 2024
-
Rahul Batra committed
Sep 6, 2024 Configuration menu - View commit details
-
Copy full SHA for e42690d - Browse repository at this point
Copy the full SHA e42690dView commit details -
Merge pull request #634 from ROCm/main_perf-softmax
Softmax kernel
Configuration menu - View commit details
-
Copy full SHA for 6d283a2 - Browse repository at this point
Copy the full SHA 6d283a2View commit details -
Move utility tools from triton-mlir to main_perf branch (#635)
* Move utility tools from triton-mlir to main_perf branch - Plot layout script - occ.sh - amdgcn-cfg * yapf format * More formats * remove executablility of plot_layout.py * Address ruff complains * Move tune_gemm to tools
Configuration menu - View commit details
-
Copy full SHA for 3704738 - Browse repository at this point
Copy the full SHA 3704738View commit details -
Rahul Batra committed
Sep 6, 2024 Configuration menu - View commit details
-
Copy full SHA for f80aed7 - Browse repository at this point
Copy the full SHA f80aed7View commit details -
Configuration menu - View commit details
-
Copy full SHA for 9da4278 - Browse repository at this point
Copy the full SHA 9da4278View commit details
Commits on Sep 7, 2024
-
Merge pull request #633 from ROCm/main_perf-rmsnorm
Add rmsnorm kernel
Configuration menu - View commit details
-
Copy full SHA for c4bd738 - Browse repository at this point
Copy the full SHA c4bd738View commit details
Commits on Sep 13, 2024
-
Rahul Batra committed
Sep 13, 2024 Configuration menu - View commit details
-
Copy full SHA for a782caf - Browse repository at this point
Copy the full SHA a782cafView commit details
Commits on Sep 16, 2024
-
Merge pull request #639 from ROCm/softmax_updates
Online softmax implementation
Configuration menu - View commit details
-
Copy full SHA for 96b3d37 - Browse repository at this point
Copy the full SHA 96b3d37View commit details
Commits on Sep 24, 2024
-
Use mask during load for Softmax
Rahul Batra committedSep 24, 2024 Configuration menu - View commit details
-
Copy full SHA for 8939325 - Browse repository at this point
Copy the full SHA 8939325View commit details