diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 8e21013da2353..0bf402b750115 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -280,7 +280,7 @@ def ReduceKernelNode( # noqa: N802 "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", - "DropoutGrad": "{indent}p = 1 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n", + "DropoutGrad": "{indent}p = 1.0 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n", "Identity": "{indent}{o0} = {i0}\n", } @@ -420,7 +420,7 @@ def DropoutNode( # noqa: N802 offset_str = f"{node.global_offset} + " if node.global_offset != sympy.Integer(0) else "" offset_str += self._get_offset_mask(node.offset_calc, node.inputs[0].name)[0] code_buffer += ( - f"{space_indent}p = 1 - {p_var_name}\n" + f"{space_indent}p = 1.0 - {p_var_name}\n" f"{space_indent}random = tl.rand(t_seed_cuda, {offset_str})\n" f"{space_indent}{mask_var_name} = random < p\n" f"{space_indent}{output_var_name} = tl.where({mask_var_name}, {input_var_name} / p, 0.0)\n"