Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Sep 30, 2024
1 parent 6d1b1dd commit b3bc198
Showing 1 changed file with 36 additions and 35 deletions.
71 changes: 36 additions & 35 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,20 @@ def _call(
tensordict_out: TensorDictBase | None = None,
**kwargs: Any,
) -> Any:
if self.counter < self._warmup:
if self.counter >= self._warmup:
self._tensordict.update_(tensordict, non_blocking=True)
self.graph.replay()
if self._out_matches_in:
result = tensordict.update(
self._out, keys_to_update=self._selected_keys
)
elif tensordict_out is not None:
result = tensordict_out.update(self._out, clone=True)
else:
result = self._out.clone() if self._out is not None else None
return result

if not self._has_cuda or self.counter < self._warmup - 1:
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
Expand All @@ -223,7 +236,7 @@ def _call(
if self._has_cuda:
torch.cuda.synchronize()
return out
elif self.counter == self._warmup - 1:
else:
if tensordict.device is None:
tensordict.apply(self._check_device_and_grad, filter_empty=True)
elif tensordict.device.type != "cuda":
Expand Down Expand Up @@ -270,23 +283,30 @@ def check_tensor_id(name, t0, t1):
filter_empty=True,
)
return this_out
else:
self._tensordict.update_(tensordict, non_blocking=True)
self.graph.replay()
if self._out_matches_in:
result = tensordict.update(
self._out, keys_to_update=self._selected_keys
)
elif tensordict_out is not None:
result = tensordict_out.update(self._out, clone=True)
else:
result = self._out.clone() if self._out is not None else None
return result

else:

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter < self._warmup:
if self.counter >= self._warmup:
tree_map(
lambda x, y: x.copy_(y, non_blocking=True),
(self._args, self._kwargs),
(args, kwargs),
)
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
)
# torch.cuda.synchronize()
return result

if not self._has_cuda or self.counter < self._warmup - 1:
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
Expand All @@ -295,8 +315,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
torch.cuda.synchronize()
self.counter += self._has_cuda
return out
elif self.counter == self._warmup - 1:

else:
self._args, self._kwargs = tree_map(
self._check_device_and_clone, (args, kwargs)
)
Expand Down Expand Up @@ -332,24 +351,6 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
else:
self._return_unchanged = False
return this_out
else:
tree_map(
lambda x, y: x.copy_(y, non_blocking=True),
(self._args, self._kwargs),
(args, kwargs),
)
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
)
# torch.cuda.synchronize()
return result

_call_func = functools.wraps(self.module)(_call)
self._call_func = _call_func
Expand Down

0 comments on commit b3bc198

Please sign in to comment.