Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Sep 30, 2024
1 parent 1f2e931 commit 6d1b1dd
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
14 changes: 9 additions & 5 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7916,7 +7916,9 @@ def lerp(
end_val = end._values_list(True, True)
else:
end_val = end
if _is_tensor_collection(type(weight)):
if isinstance(weight, (float, torch.Tensor)):
weight_val = weight
elif _is_tensor_collection(type(weight)):
weight_val = weight._values_list(True, True)
else:
weight_val = weight
Expand All @@ -7936,10 +7938,12 @@ def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float):
end_val = end._values_list(True, True)
else:
end_val = end
# if isinstance(weight, TensorDictBase) or _is_tensor_collection(type(weight)):
# weight_val = weight._values_list(True, True)
#else:
weight_val = weight
if isinstance(weight, (float, torch.Tensor)):
weight_val = weight
elif _is_tensor_collection(type(weight)):
weight_val = weight._values_list(True, True)
else:
weight_val = weight
torch._foreach_lerp_(self._values_list(True, True), end_val, weight_val)
return self

Expand Down
62 changes: 37 additions & 25 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def tree_leaves(pytree):
"""Torch 2.0 compatible version of tree_leaves."""
return tree_flatten(pytree)[0]


class CudaGraphModule:
"""A cudagraph wrapper for PyTorch callables.
Expand Down Expand Up @@ -198,6 +199,8 @@ def __init__(
)
_exclude_td_from_pytree().set()

functools.update_wrapper(self, module)

if self._is_tensordict_module:

@dispatch(source=self.in_keys, dest=self.out_keys, auto_batch_size=False)
Expand All @@ -208,17 +211,19 @@ def _call(
**kwargs: Any,
) -> Any:
if self.counter < self._warmup:
torch.cuda.synchronize()
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
self.counter += self._has_cuda
torch.cuda.synchronize()
if self._has_cuda:
torch.cuda.synchronize()
return out
elif self.counter == self._warmup:
elif self.counter == self._warmup - 1:
if tensordict.device is None:
tensordict.apply(self._check_device_and_grad, filter_empty=True)
elif tensordict.device.type != "cuda":
Expand Down Expand Up @@ -277,39 +282,29 @@ def check_tensor_id(name, t0, t1):
else:
result = self._out.clone() if self._out is not None else None
return result

else:

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter < self._warmup:
torch.cuda.synchronize()
if self._has_cuda:
torch.cuda.synchronize()
with self._warmup_stream_cm:
out = self.module(*args, **kwargs)
torch.cuda.synchronize()
if self._has_cuda:
torch.cuda.synchronize()
self.counter += self._has_cuda
return out
elif self.counter == self._warmup:

def check_device_and_clone(x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
if x.requires_grad:
raise RuntimeError(self._REQUIRES_GRAD_ERROR)
if x.device is None:
# Check device of leaves of tensordict
x.apply(self._check_device_and_grad, filter_empty=True)

elif x.device.type != "cuda":
raise ValueError(
f"All tensors must be stored on CUDA. Got {x.device.type}."
)

return x.clone()
return x
elif self.counter == self._warmup - 1:

self._args, self._kwargs = tree_map(
check_device_and_clone, (args, kwargs)
self._check_device_and_clone, (args, kwargs)
)

torch.cuda.synchronize()
this_out = self.module(*self._args, **self._kwargs)
torch.cuda.synchronize()

self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
Expand Down Expand Up @@ -350,15 +345,32 @@ def check_device_and_clone(x):
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x, self._out
lambda x: x.detach().clone() if x is not None else x,
self._out,
)
# torch.cuda.synchronize()
return result


_call_func = functools.wraps(self.module)(_call)
self._call_func = _call_func

@classmethod
def _check_device_and_clone(cls, x):
if isinstance(x, torch.Tensor) or is_tensor_collection(x):
if x.requires_grad:
raise RuntimeError(cls._REQUIRES_GRAD_ERROR)
if x.device is None:
# Check device of leaves of tensordict
x.apply(cls._check_device_and_grad, filter_empty=True)

elif x.device.type != "cuda":
raise ValueError(
f"All tensors must be stored on CUDA. Got {x.device.type}."
)

return x.clone()
return x

@classmethod
def _check_device_and_grad(cls, x):
if isinstance(x, torch.Tensor):
Expand Down

0 comments on commit 6d1b1dd

Please sign in to comment.