Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/update_ndk
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Oct 12, 2023
2 parents 1ab4b33 + fa0a79a commit ee32a50
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def ReduceKernelNode( # noqa: N802
"Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n",
"Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n",
"Log": "{indent}{o0} = tl.log({i0})\n",
"DropoutGrad": "{indent}p = 1 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n",
"DropoutGrad": "{indent}p = 1.0 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n",
"Identity": "{indent}{o0} = {i0}\n",
}

Expand Down Expand Up @@ -420,7 +420,7 @@ def DropoutNode( # noqa: N802
offset_str = f"{node.global_offset} + " if node.global_offset != sympy.Integer(0) else ""
offset_str += self._get_offset_mask(node.offset_calc, node.inputs[0].name)[0]
code_buffer += (
f"{space_indent}p = 1 - {p_var_name}\n"
f"{space_indent}p = 1.0 - {p_var_name}\n"
f"{space_indent}random = tl.rand(t_seed_cuda, {offset_str})\n"
f"{space_indent}{mask_var_name} = random < p\n"
f"{space_indent}{output_var_name} = tl.where({mask_var_name}, {input_var_name} / p, 0.0)\n"
Expand Down

0 comments on commit ee32a50

Please sign in to comment.