Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Sep 30, 2024
1 parent b3bc198 commit 2608f16
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _call(
self._tensordict = tensordict.copy()

torch.cuda.synchronize()
this_out = self.module(self._tensordict, *args, **kwargs)
this_out = self.module(tensordict, *args, **kwargs)
torch.cuda.synchronize()

self.graph = torch.cuda.CUDAGraph()
Expand Down Expand Up @@ -303,7 +303,6 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
lambda x: x.detach().clone() if x is not None else x,
self._out,
)
# torch.cuda.synchronize()
return result

if not self._has_cuda or self.counter < self._warmup - 1:
Expand All @@ -321,13 +320,12 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
)

torch.cuda.synchronize()
this_out = self.module(*self._args, **self._kwargs)
this_out = self.module(*args, **kwargs)
torch.cuda.synchronize()

self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self.graph.replay()
self._out = out
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
Expand Down

0 comments on commit 2608f16

Please sign in to comment.