diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 6c93538a24f5..0e90e00bf22a 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -215,7 +215,7 @@ jobs: - name: Install pip dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: TRITON_BUILD_WITH_CCACHE: "true" @@ -239,8 +239,9 @@ jobs: echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 fi cd python/test/unit - python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py python3 -m pytest -s -n 8 language/test_subprocess.py + python3 -m pytest -s -n 8 test_debug.py --forked # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py # Run hopper/test_flashattention.py separately to avoid out of gpu memory @@ -382,7 +383,10 @@ jobs: pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ - --ignore=language/test_line_info.py + --ignore=language/test_line_info.py \ + --ignore=test_debug.py + # TODO: uncomment + # pytest --capture=tee-sys -rfs test_debug.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py @@ -494,7 +498,7 @@ jobs: python3 -m venv ~/.venv source ~/.venv/bin/activate python3 -m pip install --upgrade pip - python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: TRITON_BUILD_WITH_CCACHE: "true" diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 1b4c46a26c5b..411c5b00e922 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -246,7 +246,7 @@ jobs: - name: Install pip dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: @@ -274,8 +274,9 @@ jobs: echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1 fi cd python/test/unit - python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py python3 -m pytest -s -n 8 language/test_subprocess.py + python3 -m pytest -s -n 8 test_debug.py --forked # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py # Run hopper/test_flashattention.py separately to avoid out of gpu memory @@ -387,7 +388,10 @@ jobs: pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py cd python/test/unit pytest --capture=tee-sys -rfs -n 16 language runtime \ - --ignore=language/test_line_info.py + --ignore=language/test_line_info.py \ + --ignore=test_debug.py + # TODO: uncomment + # pytest --capture=tee-sys -rfs test_debug.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \ pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py @@ -440,7 +444,7 @@ jobs: python3 -m venv ~/.venv source ~/.venv/bin/activate python3 -m pip install --upgrade pip - python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: TRITON_BUILD_WITH_CCACHE: "true" diff --git a/python/test/unit/language/assert_helper.py b/python/test/unit/language/assert_helper.py deleted file mode 100644 index 1b13ce948134..000000000000 --- a/python/test/unit/language/assert_helper.py +++ /dev/null @@ -1,154 +0,0 @@ -import sys - -import torch -from torch.testing import assert_close - -import triton -import triton.language as tl - - -def get_current_target_warp_size(): - return triton.runtime.driver.active.get_current_target().warp_size - - -@triton.jit -def kernel_device_assert(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.device_assert(x == 0, "x != 0") - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit -def kernel_assert_passes(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - # Trivial assert, should not be an error. - tl.device_assert(0 == 0, "x != 0") - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit(debug=False) -def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.device_assert(x == 0, "x != 0") - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit -def kernel_assert(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - assert x == 0, "x != 0" - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit -def kernel_static_assert(X, Y, BLOCK: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - tl.static_assert(BLOCK == 128, "BLOCK != 128") - tl.store(Y + tl.arange(0, BLOCK), x) - - -def test_assert(func: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. - num_warps = N // get_current_target_warp_size() - - x = torch.arange(0, N, dtype=torch.int32, device='cuda') - y = torch.zeros((N, ), dtype=x.dtype, device="cuda") - if func == "device_assert": - kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - if func == "device_assert_passes": - # Assert passes; no error. - kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "no_debug": - # TRITON_DEBUG=1 can override the debug flag - kernel_device_assert_no_debug[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "assert": - kernel_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "static_assert": - kernel_static_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "double_assert": - # Launching a different kernel after the first one asserted used to - # segfault. What seems to have happened is: - # - The first kernel is enqueued but doesn't run yet. - # - We go to launch the second kernel. Because this is the first time - # we're running it, we have to load the kernel into the GPU. - # - Loading the kernel takes some time, during which the first launch - # completes. - # - Now the GPU is in an error state. We need to detect this inside - # the kernel-launch/loading code and bail out properly. If we don't, - # we segfault. - kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) - kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) - assert_close(y, x) - # GPU/host synchronization before exiting the test. - torch.cuda.synchronize() - - -@triton.jit -def jit_device_assert_none(x): - tl.device_assert(x == 0, "x != 0") - - -@triton.jit(debug=True) -def jit_device_assert_true(x): - tl.device_assert(x == 0, "x != 0") - - -@triton.jit(debug=False) -def jit_device_assert_false(x): - tl.device_assert(x == 0, "x != 0") - - -@triton.jit -def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - if jit_debug == "true": - jit_device_assert_true(x) - elif jit_debug == "false": - jit_device_assert_false(x) - else: - jit_device_assert_none(x) - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit(debug=True) -def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - if jit_debug == "true": - jit_device_assert_true(x) - elif jit_debug == "false": - jit_device_assert_false(x) - else: - jit_device_assert_none(x) - tl.store(Y + tl.arange(0, BLOCK), x) - - -@triton.jit(debug=False) -def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): - x = tl.load(X + tl.arange(0, BLOCK)) - if jit_debug == "true": - jit_device_assert_true(x) - elif jit_debug == "false": - jit_device_assert_false(x) - else: - jit_device_assert_none(x) - tl.store(Y + tl.arange(0, BLOCK), x) - - -def test_assert_nested(caller: str, callee: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. - num_warps = N // get_current_target_warp_size() - - x = torch.arange(0, N, dtype=torch.int32, device=device) - y = torch.zeros((N, ), dtype=x.dtype, device=device) - if caller == "none": - kernel_device_assert_nested[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) - elif caller == "true": - kernel_device_assert_nested_true[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) - elif caller == "false": - kernel_device_assert_nested_false[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) - assert_close(y, x) - - -if __name__ == "__main__": - fn = globals()[sys.argv[1]] - fn(*sys.argv[2:]) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 17118a29bd1f..2ad97e8a6815 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -8,11 +8,6 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) print_path = os.path.join(dir_path, "print_helper.py") -assert_path = os.path.join(dir_path, "assert_helper.py") - -# TODO: bfloat16 after LLVM-15 -assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] -nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] @@ -120,59 +115,3 @@ def test_print(func_type: str, data_type: str, device: str): continue print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') assert all(delta == 0 for delta in diff.values()) - - -@pytest.mark.parametrize("func_type", assert_types) -def test_assert(func_type: str, device: str): - # The total number of elements in the 1-D tensor to assert on. - N = 128 - - proc = subprocess.run( - [sys.executable, assert_path, "test_assert", func_type, device], - capture_output=True, - env={**os.environ, "TRITON_DEBUG": "1"}, - ) - errs = proc.stderr.splitlines() - num_errs = 0 - for err in errs: - if "x != 0" in err.decode("utf-8", errors="ignore"): - num_errs += 1 - - # Check for segfaults. - assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) - - if func_type == "static_assert" or func_type == "device_assert_passes": - assert num_errs == 0 - else: - assert num_errs == N - 1 - - -@pytest.mark.parametrize("caller_type, callee_type", nested_types) -def test_assert_nested(caller_type, callee_type, device): - # The total number of elements in the 1-D tensor to assert on. - N = 128 - - proc = subprocess.run( - [sys.executable, assert_path, "test_assert_nested", caller_type, callee_type, device], - capture_output=True, - ) - errs = proc.stderr.splitlines() - num_errs = 0 - for err in errs: - if "x != 0" in err.decode("utf-8", errors="ignore"): - num_errs += 1 - if caller_type == "none": - if callee_type == "true": - assert num_errs == N - 1 - else: - assert num_errs == 0 - elif caller_type == "true": - if callee_type == "false": - assert num_errs == 0 - else: - assert num_errs == N - 1 - elif caller_type == "false": - if callee_type == "true": - assert num_errs == N - 1 - else: - assert num_errs == 0 diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index d896e7766859..6a80c5d8492d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -427,23 +427,18 @@ def kernel_add(a, b, o, N: tl.constexpr): def test_jit_debug() -> None: @triton.jit - def kernel_add(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.device_assert(idx < 32, "idx < 32") - tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + def kernel(tmp): + tl.device_assert(tl.load(tmp) == 1, "tmp == 1") device = torch.cuda.current_device() - assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 1 - kernel_add.debug = False - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 2 - kernel_add.debug = True - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 3 - bins = list(kernel_add.cache[device].values()) - assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + tmp = torch.tensor([1], dtype=torch.int32, device="cuda") + assert len(kernel.cache[device]) == 0 + kernel[(1, )](tmp, debug=False) + assert len(kernel.cache[device]) == 1 + kernel[(1, )](tmp, debug=True) + assert len(kernel.cache[device]) == 2 + bins = list(kernel.cache[device].values()) + assert bins[0].asm['ttir'] != bins[1].asm['ttir'] @triton.jit diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py new file mode 100644 index 000000000000..d37062396687 --- /dev/null +++ b/python/test/unit/test_debug.py @@ -0,0 +1,124 @@ +import os +import pytest +import torch +import triton.language as tl +import triton + +@pytest.mark.parametrize('cond, opt_flag, env_var', [ + (cond, opt_flag, env_var) for cond in [True, False] \ + for opt_flag in [True, False] \ + for env_var in [True, False]\ +]) +@pytest.mark.forked +def test_device_assert(cond, opt_flag, env_var, device="cuda"): + os.environ['TRITON_DEBUG'] = str(int(env_var)) + torch.zeros([1], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.device_assert(COND, 'test') + + if not cond and (opt_flag or env_var): + with pytest.raises(RuntimeError): + _kernel[(1, )](cond, debug=opt_flag) + torch.cuda.synchronize() + return + + _kernel[(1, )](cond, debug=opt_flag) + torch.cuda.synchronize() + + +@pytest.mark.parametrize("cond", [False, True]) +def test_static_assert(cond): + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.static_assert(COND) + + if not cond: + with pytest.raises(triton.compiler.errors.CompileTimeAssertionFailure): + _kernel[(1, )](cond) + return + + _kernel[(1, )](cond) + + +def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func): + device = "cuda" + x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device) + y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) + z = torch.empty_like(x) + if should_overflow and debug: + with pytest.raises(RuntimeError) as exc_info: + tri_func[(1, )](x, y, z, debug=debug) + torch.cuda.synchronize() + assert "device-side assert" in str(exc_info.value) + else: + tri_func[(1, )](x, y, z, debug=debug) + torch.cuda.synchronize() + assert int(z) == int(ref_func(x, y)) + + +# integer overflow sanitization + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, -1, 'int32', 'int32', False, False), + (-2**31, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, True), + (2**31 - 1, 100, 'int32', 'int32', True, True), + (-2**31, 0, 'int32', 'int32', True, False), + (-2**31, 2, 'int32', 'int32', True, False), + (0, -1, 'int32', 'int32', True, False), + (-2**15, -1, 'int16', 'int16', True, True), + (2**15 - 1, 1, 'int16', 'int16', True, True), +]) +@pytest.mark.forked +def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): + + @triton.jit + def _kernel_add(X, Y, Z): + tl.store(Z, tl.load(X) + tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y) + + +# mul overflow + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (2**30, 4, 'int32', 'int32', False, False), + (2**30, 4, 'int32', 'int32', True, True), + (2**30, 2, 'int32', 'int32', True, True), + (-2**30, -4, 'int32', 'int32', True, True), + (-2**31, 1, 'int32', 'int32', True, False), + (-2**30, 2, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): + + @triton.jit + def _kernel_mul(X, Y, Z): + tl.store(Z, tl.load(X) * tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y) + + +# sub overflow + + +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, 1, 'int32', 'int32', False, False), + (-2**31, 1, 'int32', 'int32', True, True), + (2**31 - 1, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, False), + (-2**31, -1, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): + + @triton.jit + def _kernel_sub(X, Y, Z): + tl.store(Z, tl.load(X) - tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ee3426bd4c65..b3f2b4af3882 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -188,8 +188,8 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, module_map, debug=None, module=None, is_kernel=False, - function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -225,7 +225,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.function_name = function_name self.is_kernel = is_kernel self.cur_node = None - self.debug = options.debug if debug is None else debug self.noinline = noinline self.scf_stack = [] self.ret_type = None @@ -1037,11 +1036,8 @@ def visit_keyword(self, node) -> Tuple[str, Any]: return node.arg, self.visit(node.value) def visit_Assert(self, node) -> Any: - if not self.debug: - return test = self.visit(node.test) msg = self.visit(node.msg) if node.msg is not None else "" - # Convert assert to triton's device_assert which happens on the device return language.core.device_assert(test, msg, _builder=self.builder) def call_JitFunction(self, fn: JITFunction, args, kwargs): @@ -1063,12 +1059,11 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - debug = self.debug if fn.debug is None else fn.debug generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, - module_map=self.builder.module_map, debug=debug) + module_map=self.builder.module_map) try: generator.visit(fn.parse()) except Exception as e: @@ -1100,9 +1095,6 @@ def visit_Call(self, node): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - # TODO: this should not be so hardcoded - if fn is language.core.device_assert and not self.debug: - return if isinstance(fn, JITFunction): _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index cf86e9296a6a..5894a81c4dc0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -461,6 +461,20 @@ def is_int(self): def is_bool(self): return self.is_int1() + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + @staticmethod def is_dtype(type_str): return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 119fac4bc1e6..4224a705c1a1 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -128,6 +128,26 @@ def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.bui return lhs, rhs +def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.debug: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = cast(lhs, tl.int64, builder) + rhs = cast(rhs, tl.int64, builder) + ret = binary_op(lhs, rhs, builder) + max_value = lhs_sca_ty.get_int_max_value() + max_value = tl.tensor(builder.get_int64(max_value), tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = tl.tensor(builder.get_int64(min_value), tl.int64) + cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder) + cond = splat(cond, [1], builder) + msg = "integer overflow detected" + builder.create_assert(cond.handle, msg, "unknown", "unknown", 0) + + def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar @@ -148,6 +168,7 @@ def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) # int + int elif input_scalar_ty.is_int(): + binary_op_sanitize_overflow_impl(input, other, builder, add) return tl.tensor(builder.create_add(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {input_scalar_ty}") @@ -163,6 +184,7 @@ def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) # int - int elif scalar_ty.is_int(): + binary_op_sanitize_overflow_impl(input, other, builder, sub) return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") @@ -175,6 +197,7 @@ def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) # * int elif scalar_ty.is_int(): + binary_op_sanitize_overflow_impl(input, other, builder, mul) return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") @@ -1544,6 +1567,8 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + if not builder.options.debug: + return cond_ty = cond.type if not cond_ty.is_block(): cond_ty = tl.block_type(cond_ty.scalar, (1, )) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d65510624f55..2a3c31a23be9 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -598,10 +598,11 @@ def create_binder(self): ] def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + # parse options device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) - kwargs["debug"] = self.debug # Execute pre run hooks with args and kwargs for hook in self.pre_run_hooks: @@ -732,7 +733,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel = None - self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug self.noinline = noinline # TODO(jlebar): Remove uses of these fields outside this file, then