Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Sep 30, 2024
1 parent 4894d1e commit 1f2e931
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 34 deletions.
8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7936,10 +7936,10 @@ def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float):
end_val = end._values_list(True, True)
else:
end_val = end
if _is_tensor_collection(type(weight)):
weight_val = weight._values_list(True, True)
else:
weight_val = weight
# if isinstance(weight, TensorDictBase) or _is_tensor_collection(type(weight)):
# weight_val = weight._values_list(True, True)
#else:
weight_val = weight
torch._foreach_lerp_(self._values_list(True, True), end_val, weight_val)
return self

Expand Down
48 changes: 18 additions & 30 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def tree_leaves(pytree):
"""Torch 2.0 compatible version of tree_leaves."""
return tree_flatten(pytree)[0]


class CudaGraphModule:
"""A cudagraph wrapper for PyTorch callables.
Expand Down Expand Up @@ -209,17 +208,15 @@ def _call(
**kwargs: Any,
) -> Any:
if self.counter < self._warmup:
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.synchronize()
with self._warmup_stream_cm:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
self.counter += self._has_cuda
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
torch.cuda.synchronize()
return out
elif self.counter == self._warmup:
if tensordict.device is None:
Expand All @@ -230,15 +227,17 @@ def _call(
)

tree_map(self._check_non_tensor, (args, kwargs))
self._tensordict = tensordict.copy()

self.graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
self._tensordict = tensordict.copy()
this_out = self.module(self._tensordict, *args, **kwargs)
torch.cuda.synchronize()

self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(self._tensordict, *args, **kwargs)
self.graph.replay()

if not is_tensor_collection(out) and out is not None:
raise RuntimeError(
Expand All @@ -265,15 +264,9 @@ def check_tensor_id(name, t0, t1):
default=None,
filter_empty=True,
)
return tensordict.update(
self._out, keys_to_update=self._selected_keys
)
if tensordict_out is not None:
return tensordict_out.update(out, clone=True)
return out.clone() if self._out is not None else None
return this_out
else:
self._tensordict.update_(tensordict, non_blocking=True)
torch.cuda.synchronize()
self.graph.replay()
if self._out_matches_in:
result = tensordict.update(
Expand All @@ -283,17 +276,14 @@ def check_tensor_id(name, t0, t1):
result = tensordict_out.update(self._out, clone=True)
else:
result = self._out.clone() if self._out is not None else None
torch.cuda.synchronize()
return result
else:
def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter < self._warmup:
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.synchronize()
with self._warmup_stream_cm:
out = self.module(*args, **kwargs)
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
torch.cuda.synchronize()
self.counter += self._has_cuda
return out
elif self.counter == self._warmup:
Expand All @@ -317,6 +307,9 @@ def check_device_and_clone(x):
self._args, self._kwargs = tree_map(
check_device_and_clone, (args, kwargs)
)
torch.cuda.synchronize()
this_out = self.module(*self._args, **self._kwargs)
torch.cuda.synchronize()
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
Expand All @@ -341,30 +334,25 @@ def check_device_and_clone(x):
self._return_unchanged = (
"clone" if self._out is not None else True
)
return (
self._out.clone()
if self._return_unchanged == "clone"
else self._out
)
self._return_unchanged = False
return tree_map(lambda x: x.clone(), out)
else:
self._return_unchanged = False
return this_out
else:
tree_map(
lambda x, y: x.copy_(y, non_blocking=True),
(self._args, self._kwargs),
(args, kwargs),
)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.clone() if x is not None else x, self._out
lambda x: x.detach().clone() if x is not None else x, self._out
)
torch.cuda.synchronize()
# torch.cuda.synchronize()
return result


Expand Down

0 comments on commit 1f2e931

Please sign in to comment.