Skip to content

Commit

Permalink
Fix Triton Compile Error for Codegened Dropout Code (#17899)
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang authored Oct 12, 2023
1 parent 9d07ca3 commit fa0a79a
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def ReduceKernelNode( # noqa: N802
"Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n",
"Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n",
"Log": "{indent}{o0} = tl.log({i0})\n",
"DropoutGrad": "{indent}p = 1 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n",
"DropoutGrad": "{indent}p = 1.0 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n",
"Identity": "{indent}{o0} = {i0}\n",
}

Expand Down Expand Up @@ -420,7 +420,7 @@ def DropoutNode( # noqa: N802
offset_str = f"{node.global_offset} + " if node.global_offset != sympy.Integer(0) else ""
offset_str += self._get_offset_mask(node.offset_calc, node.inputs[0].name)[0]
code_buffer += (
f"{space_indent}p = 1 - {p_var_name}\n"
f"{space_indent}p = 1.0 - {p_var_name}\n"
f"{space_indent}random = tl.rand(t_seed_cuda, {offset_str})\n"
f"{space_indent}{mask_var_name} = random < p\n"
f"{space_indent}{output_var_name} = tl.where({mask_var_name}, {input_var_name} / p, 0.0)\n"
Expand Down

0 comments on commit fa0a79a

Please sign in to comment.