diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3c431941..23284092 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -39,4 +39,4 @@ jobs: - name: Remove Files run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}}_${{github.run_number}} sh -c "rm -r *" - name: Destroy docker - run: docker stop xfuser_test_docker_${{github.repository_owner_id}}_${{github.run_number}} + run: docker stop xfuser_test_docker_${{github.repository_owner_id}}_${{github.run_number}} \ No newline at end of file diff --git a/examples/cogvideox_example.py b/examples/cogvideox_example.py index 86abfa7d..4ee43bf5 100644 --- a/examples/cogvideox_example.py +++ b/examples/cogvideox_example.py @@ -1,3 +1,4 @@ +import logging import time import torch import torch.distributed @@ -35,12 +36,21 @@ def main(): torch_dtype=torch.bfloat16, ) if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: pipe.enable_model_cpu_offload(gpu_id=local_rank) - pipe.vae.enable_tiling() + logging.info(f"rank {local_rank} model CPU offload enabled") else: device = torch.device(f"cuda:{local_rank}") pipe = pipe.to(device) + if args.enable_tiling: + pipe.vae.enable_tiling() + + if args.enable_slicing: + pipe.vae.enable_slicing() + torch.cuda.reset_peak_memory_stats() start_time = time.time() diff --git a/examples/run_cogvideo.sh b/examples/run_cogvideo.sh index 4b7a771d..0e46894e 100644 --- a/examples/run_cogvideo.sh +++ b/examples/run_cogvideo.sh @@ -22,6 +22,7 @@ CFG_ARGS="--use_cfg_parallel" # PIPEFUSION_ARGS="--num_pipeline_patch 8" # OUTPUT_ARGS="--output_type latent" # PARALLLEL_VAE="--use_parallel_vae" +ENABLE_TILING="--enable_tiling" # COMPILE_FLAG="--use_torch_compile" torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ @@ -35,4 +36,5 @@ $OUTPUT_ARGS \ --prompt "A small dog" \ $CFG_ARGS \ $PARALLLEL_VAE \ +$ENABLE_TILING \ $COMPILE_FLAG \ No newline at end of file diff --git a/xfuser/config/args.py b/xfuser/config/args.py index 2ef59a02..88c39f6a 100644 --- a/xfuser/config/args.py +++ b/xfuser/config/args.py @@ -247,6 +247,21 @@ def add_cli_args(parser: FlexibleArgumentParser): action="store_true", help="Offloading the weights to the CPU.", ) + runtime_group.add_argument( + "--enable_model_cpu_offload", + action="store_true", + help="Offloading the weights to the CPU.", + ) + runtime_group.add_argument( + "--enable_tiling", + action="store_true", + help="Making VAE decode a tile at a time to save GPU memory.", + ) + runtime_group.add_argument( + "--enable_slicing", + action="store_true", + help="Making VAE decode a tile at a time to save GPU memory.", + ) # DiTFastAttn arguments fast_attn_group = parser.add_argument_group("DiTFastAttn Options") diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 33010a73..87e8a0af 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -1161,6 +1161,7 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, latent_seq_length], dim=1 )