Skip to content

Commit

Permalink
fix bugs in cogvideo allgather (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored Oct 18, 2024
1 parent 9e35ef7 commit fb291ad
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 19 deletions.
6 changes: 3 additions & 3 deletions examples/run_cogvideo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ mkdir -p ./results
TASK_ARGS="--height 480 --width 720 --num_frames 9"

# CogVideoX parallel configuration
N_GPUS=4
PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
N_GPUS=6
PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 3"
#CFG_ARGS="--use_cfg_parallel"

# Uncomment and modify these as needed
# PIPEFUSION_ARGS="--num_pipeline_patch 8"
Expand Down
17 changes: 1 addition & 16 deletions xfuser/model_executor/pipelines/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,22 +343,7 @@ def __call__(
progress_bar.update()

if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
latents = torch.cat(latents_list, dim=-2)
latents = get_sp_group().all_gather(latents, dim=-2)

if is_dp_last_group():
if not (output_type == "latents" or output_type == "latent"):
Expand Down

0 comments on commit fb291ad

Please sign in to comment.