From fb291adc5f425c0fe0ea8842b1da4edb88268c73 Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Fri, 18 Oct 2024 21:45:17 +0800 Subject: [PATCH] fix bugs in cogvideo allgather (#314) --- examples/run_cogvideo.sh | 6 +++--- .../pipelines/pipeline_cogvideox.py | 17 +---------------- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/examples/run_cogvideo.sh b/examples/run_cogvideo.sh index 0e46894e..2b10caa9 100644 --- a/examples/run_cogvideo.sh +++ b/examples/run_cogvideo.sh @@ -14,9 +14,9 @@ mkdir -p ./results TASK_ARGS="--height 480 --width 720 --num_frames 9" # CogVideoX parallel configuration -N_GPUS=4 -PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1" -CFG_ARGS="--use_cfg_parallel" +N_GPUS=6 +PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 3" +#CFG_ARGS="--use_cfg_parallel" # Uncomment and modify these as needed # PIPEFUSION_ARGS="--num_pipeline_patch 8" diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 4cb2be6f..5d04352c 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -343,22 +343,7 @@ def __call__( progress_bar.update() if get_sequence_parallel_world_size() > 1: - sp_degree = get_sequence_parallel_world_size() - sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) - latents_list = [] - for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): - latents_list += [ - sp_latents_list[sp_patch_idx][ - :, - :, - get_runtime_state() - .pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state() - .pp_patches_start_idx_local[pp_patch_idx + 1], - :, - ] - for sp_patch_idx in range(sp_degree) - ] - latents = torch.cat(latents_list, dim=-2) + latents = get_sp_group().all_gather(latents, dim=-2) if is_dp_last_group(): if not (output_type == "latents" or output_type == "latent"):