Skip to content

Commit

Permalink
fix: revert layer offload iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
Rypo committed Nov 28, 2024
1 parent 3b8057c commit 2fd6a5d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions OmniGen/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ def prefetch_layer(self, layer_idx: int, device: torch.device):
"Starts prefetching the next layer cache"
with torch.cuda.stream(self.prefetch_stream):
# Prefetch next layer tensors to GPU
self.layers[layer_idx] = self.layers[layer_idx].to(device, non_blocking=True)
for name, param in self.layers[layer_idx].named_parameters():
param.data = param.data.to(device, non_blocking=True)

def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
prev_layer_idx = layer_idx - 1
self.layers[prev_layer_idx] = self.layers[prev_layer_idx].to("cpu")

for name, param in self.layers[prev_layer_idx].named_parameters():
param.data = param.data.to("cpu")

def get_offload_layer(self, layer_idx: int, device: torch.device):
# init stream
if not hasattr(self, "prefetch_stream"):
Expand Down

0 comments on commit 2fd6a5d

Please sign in to comment.