diff --git a/README.md b/README.md index 3841c6f3e829..f916461d8c79 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit. - `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. +- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). # Changelog diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 7d0d41075ab1..b873fe236aa3 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -23,6 +23,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", "MLIR_ENABLE_TIMING", + "TRITON_DEFAULT_FP_FUSION", "TRITON_DISABLE_LINE_INFO", "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", "TRITON_ENABLE_LLVM_DEBUG", diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a3488c178e0a..f453e6bdbb79 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5185,7 +5185,8 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): @pytest.mark.parametrize("enable_fp_fusion", [False, True]) -def test_enable_fp_fusion(enable_fp_fusion, device): +@pytest.mark.parametrize("default_override", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, default_override, device): if is_hip(): pytest.skip( 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' @@ -5198,7 +5199,11 @@ def mul_add(data): tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) data = torch.randn((128, ), device=device, dtype=torch.float32) - h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + if default_override: + os.environ["TRITON_DEFAULT_FP_FUSION"] = "1" if enable_fp_fusion else "0" + h = mul_add[(1, )](data) + else: + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) if not is_cuda(): return diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index b092e32e6ce1..9f2a321a314b 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -63,6 +63,8 @@ def __init__(self, target: GPUTarget) -> None: def parse_options(self, opts) -> Any: args = {'arch': self.target.arch} + if not "enable_fp_fusion" in args: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts}) return HIPOptions(**args) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6d7994923495..6ee7f2281e04 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -110,6 +110,8 @@ def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} args["allow_fp8e4nv"] = self.capability >= 89 args["allow_fp8e4b15"] = self.capability < 90 + if not "enable_fp_fusion" in args: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 return CUDAOptions(**args)