Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/juliagmt-google/benchmark i…
Browse files Browse the repository at this point in the history
…nto testing
  • Loading branch information
juliagmt-google committed Oct 1, 2024
2 parents 2e8ea36 + 252a3b1 commit fbf2498
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchbenchmark/operators/low_mem_dropout/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def triton_dropout(self, p, x):
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
x_keep = (torch.rand(size=(n_elements,)) > p).to(torch.int32).cuda()

def _inner():
return _triton_dropout[grid](
Expand Down

0 comments on commit fbf2498

Please sign in to comment.