diff --git a/OmniGen/transformer.py b/OmniGen/transformer.py index 9733626..3da5031 100644 --- a/OmniGen/transformer.py +++ b/OmniGen/transformer.py @@ -33,13 +33,15 @@ def prefetch_layer(self, layer_idx: int, device: torch.device): "Starts prefetching the next layer cache" with torch.cuda.stream(self.prefetch_stream): # Prefetch next layer tensors to GPU - self.layers[layer_idx] = self.layers[layer_idx].to(device, non_blocking=True) + for name, param in self.layers[layer_idx].named_parameters(): + param.data = param.data.to(device, non_blocking=True) def evict_previous_layer(self, layer_idx: int): "Moves the previous layer cache to the CPU" prev_layer_idx = layer_idx - 1 - self.layers[prev_layer_idx] = self.layers[prev_layer_idx].to("cpu") - + for name, param in self.layers[prev_layer_idx].named_parameters(): + param.data = param.data.to("cpu") + def get_offload_layer(self, layer_idx: int, device: torch.device): # init stream if not hasattr(self, "prefetch_stream"):