diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 9a555ebab..ecaef85b6 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -210,7 +210,20 @@ def _call( tensordict_out: TensorDictBase | None = None, **kwargs: Any, ) -> Any: - if self.counter < self._warmup: + if self.counter >= self._warmup: + self._tensordict.update_(tensordict, non_blocking=True) + self.graph.replay() + if self._out_matches_in: + result = tensordict.update( + self._out, keys_to_update=self._selected_keys + ) + elif tensordict_out is not None: + result = tensordict_out.update(self._out, clone=True) + else: + result = self._out.clone() if self._out is not None else None + return result + + if not self._has_cuda or self.counter < self._warmup - 1: if self._has_cuda: torch.cuda.synchronize() with self._warmup_stream_cm: @@ -223,7 +236,7 @@ def _call( if self._has_cuda: torch.cuda.synchronize() return out - elif self.counter == self._warmup - 1: + else: if tensordict.device is None: tensordict.apply(self._check_device_and_grad, filter_empty=True) elif tensordict.device.type != "cuda": @@ -270,23 +283,30 @@ def check_tensor_id(name, t0, t1): filter_empty=True, ) return this_out - else: - self._tensordict.update_(tensordict, non_blocking=True) - self.graph.replay() - if self._out_matches_in: - result = tensordict.update( - self._out, keys_to_update=self._selected_keys - ) - elif tensordict_out is not None: - result = tensordict_out.update(self._out, clone=True) - else: - result = self._out.clone() if self._out is not None else None - return result else: def _call(*args: torch.Tensor, **kwargs: torch.Tensor): - if self.counter < self._warmup: + if self.counter >= self._warmup: + tree_map( + lambda x, y: x.copy_(y, non_blocking=True), + (self._args, self._kwargs), + (args, kwargs), + ) + self.graph.replay() + if self._return_unchanged == "clone": + result = self._out.clone() + elif self._return_unchanged: + result = self._out + else: + result = tree_map( + lambda x: x.detach().clone() if x is not None else x, + self._out, + ) + # torch.cuda.synchronize() + return result + + if not self._has_cuda or self.counter < self._warmup - 1: if self._has_cuda: torch.cuda.synchronize() with self._warmup_stream_cm: @@ -295,8 +315,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): torch.cuda.synchronize() self.counter += self._has_cuda return out - elif self.counter == self._warmup - 1: - + else: self._args, self._kwargs = tree_map( self._check_device_and_clone, (args, kwargs) ) @@ -332,24 +351,6 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): else: self._return_unchanged = False return this_out - else: - tree_map( - lambda x, y: x.copy_(y, non_blocking=True), - (self._args, self._kwargs), - (args, kwargs), - ) - self.graph.replay() - if self._return_unchanged == "clone": - result = self._out.clone() - elif self._return_unchanged: - result = self._out - else: - result = tree_map( - lambda x: x.detach().clone() if x is not None else x, - self._out, - ) - # torch.cuda.synchronize() - return result _call_func = functools.wraps(self.module)(_call) self._call_func = _call_func