From fa0a79a921e60f3db2df3cc07f2e8ef858818acf Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 12 Oct 2023 20:57:14 +0800 Subject: [PATCH] Fix Triton Compile Error for Codegened Dropout Code (#17899) --- .../orttraining/python/training/ort_triton/_codegen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 8e21013da2353..0bf402b750115 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -280,7 +280,7 @@ def ReduceKernelNode( # noqa: N802 "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", - "DropoutGrad": "{indent}p = 1 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n", + "DropoutGrad": "{indent}p = 1.0 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n", "Identity": "{indent}{o0} = {i0}\n", } @@ -420,7 +420,7 @@ def DropoutNode( # noqa: N802 offset_str = f"{node.global_offset} + " if node.global_offset != sympy.Integer(0) else "" offset_str += self._get_offset_mask(node.offset_calc, node.inputs[0].name)[0] code_buffer += ( - f"{space_indent}p = 1 - {p_var_name}\n" + f"{space_indent}p = 1.0 - {p_var_name}\n" f"{space_indent}random = tl.rand(t_seed_cuda, {offset_str})\n" f"{space_indent}{mask_var_name} = random < p\n" f"{space_indent}{output_var_name} = tl.where({mask_var_name}, {input_var_name} / p, 0.0)\n"