diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 3853860c4636..75718f608065 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -468,27 +468,29 @@ def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", ranks=[0], ) - + num_rings = sp_size // inner_ring_size inner_ring_group = None inter_ring_group = None world_size = dist.get_world_size() rank = dist.get_rank() - inner_rings = world_size // sp_size - num_rings = sp_size // inner_ring_size + num_ring_size = world_size // num_rings + num_inner_group = num_ring_size // inner_ring_size if tp_size > 1: - for i in range(inner_rings): - for j in range(sp_size // tp_size): - # find inner ring group in one sp group - ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) + # inner_ring_size = 2 + for i in range(num_rings): + for j in range(num_inner_group): + # find inner ring group in one sp groups + ranks = list(range(j + i * num_ring_size, j + (i + 1) * num_ring_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group - for i in range(inner_rings): - for j in range(sp_size // tp_size): - ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) + for i in range(num_rings): + for j in range(num_inner_group): + start = j + (i * num_inner_group) + ranks = list(range(start, start + num_ring_size + 1, num_ring_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group