diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b3672727f222..294779732c98 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1229,7 +1229,7 @@ def configure( zero_stage = 0 if not isinstance(model, ModelWrapper): - # Can't use pp (frequent grad accumulation) with torch ddp + # Shouldn't use pp (frequent grad accumulation) with torch ddp use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c1b5245b8b91..8bc83881d93c 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -128,6 +128,7 @@ def prepare_attn_kwargs( q_padding_mask: Optional[torch.Tensor] = None, kv_padding_mask: Optional[torch.Tensor] = None, is_causal: bool = False, + invert: bool = True, ) -> Dict[str, torch.Tensor]: """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. @@ -145,6 +146,7 @@ def prepare_attn_kwargs( The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. + invert_mask (bool, optional): Whether to invert the mask. Defaults to True. Returns: Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. """ @@ -164,19 +166,21 @@ def prepare_attn_kwargs( assert q_padding_mask.shape == ( b, s_q, - ), f"q_padding_mask shape {q_padding_mask.shape} should be the same. ({shape_4d})" + ), f"q_padding_mask shape {q_padding_mask.shape} should be {b, s_q}." max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) if kv_padding_mask is None: # self attention kv_padding_mask = q_padding_mask max_seqlen_kv, cu_seqlens_kv, kv_indices = max_seqlen_q, cu_seqlens_q, q_indices + attention_mask = q_padding_mask[:, :, None].expand(b, s_q, s_kv).to(dtype=dtype, device=device) else: max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) assert kv_padding_mask.shape == ( b, s_kv, ), f"q_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})" - attention_mask = q_padding_mask[:, None, :].expand(b, s_kv, s_q).to(dtype=dtype, device=device) + outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -192,7 +196,8 @@ def prepare_attn_kwargs( attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: outputs["attention_mask_type"] = AttnMaskType.PADDED - attention_mask = invert_mask(attention_mask).unsqueeze(1) + if invert: + attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask return outputs @@ -412,19 +417,22 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # min_scale = torch.min(lse, block_lse) # max_scale = torch.max(lse, block_lse) - # lse.data = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + # NOTE: directly assigning to .data here is buggy + # probably due to casting dtypes/strides new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - new_block_lse = torch.exp(block_lse - lse) - out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) - lse.copy_(new_lse) + new_block_lse = torch.exp(block_lse - new_lse) + out = (torch.exp(lse - new_lse) * out + new_block_lse * block_out).to(out) + lse = new_lse + # Equivalent to the above # See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - # out.data = (out - F.sigmoid(block_lse - lse) * (out - block_out)) - # lse.data = (lse - F.logsigmoid(lse - block_lse)) - + # out = (out - F.sigmoid(block_lse - lse) * (out - block_out)) + # lse = (lse - F.logsigmoid(lse - block_lse)) assert not (lse.isnan().any() or lse.isinf().any()), f"lse is nan: {lse}" + return out, lse class RingAttention(torch.autograd.Function): @@ -439,7 +447,10 @@ class RingAttention(torch.autograd.Function): # Globle cache to avoid recomputation for same-lengthed sequences CU_SEQLENS: torch.Tensor = None # [B+1] TOTAL_SEQLEN: int = None + HALF_INDICES: Tuple = None SUPPORTED_MASK_TYPES = (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL) + CORRECTION_DONE = torch.cuda.Event() + ATTN_DONE = torch.cuda.Event() @staticmethod def attention( @@ -452,11 +463,10 @@ def attention( cu_seqlens=None, max_seqlen=None, valid_indices=None, - dropout_p=0, + dropout_p=0.0, softmax_scale=None, deterministic=False, return_softmax=False, - pad_output=True, **kwargs, ): """ @@ -479,8 +489,6 @@ def attention( softmax_scale (Optional[float], optional): Scaling factor applied prior to softmax. deterministic (bool, optional): Whether to force deterministic backward pass. See https://github.com/Dao-AILab/flash-attention/issues/349 return_softmax (bool, optional): Whether to return the softmax denominator (logsumexp). - pad_output: (bool, optional): If True, return a batch of sequences in the casual padded mode. Else, - return a packed sequence. Returns: out: Output tensor of shape [B, nHeads, Sq, D] or [T, nHeads, D] if pad_output is False. @@ -498,27 +506,28 @@ def attention( # (B, H, Sq, D) -> (B, Sq, H, D) q, k, v = [x.transpose(1, 2).contiguous() for x in (q, k, v)] - b, sq, h, d = q.shape + pad_output = q.dim() == 4 # Get sequence length info for varlen forward if attention_mask_type == AttnMaskType.CAUSAL: # All sequences share the same length + b, sq, h, d = q.shape + max_seqlen = sq # Cache to avoid recreation for a single sequence - if b * sq == RingAttention.TOTAL_SEQLEN: + if sq * b == RingAttention.TOTAL_SEQLEN: cu_seqlens = RingAttention.CU_SEQLENS - max_seqlen = RingAttention.TOTAL_SEQLEN else: - RingAttention.CU_SEQLENS = cu_seqlens = torch.arange( - 0, b * sq + 1, sq, device=q.device, dtype=torch.int32 - ) - RingAttention.TOTAL_SEQLEN = max_seqlen = b * sq + cu_seqlens = torch.arange(0, b * sq + 1, sq, device=q.device, dtype=torch.int32) + RingAttention.TOTAL_SEQLEN = b * sq # "Packed" mode where sequences of different lengths are packed into [total_q_seqlen, H, D] elif attention_mask_type == AttnMaskType.PADDED_CAUSAL: assert ( cu_seqlens is not None and max_seqlen is not None and valid_indices is not None ), "Packed mode requires pre-computed cu_seqlens and max_seq_len." - q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)] + if pad_output: + b, sq, h, d = q.shape + q, k, v = [_unpad_input(x, valid_indices) for x in (q, k, v)] out, softmax_lse = RingAttention.apply( q, @@ -573,17 +582,29 @@ def forward( "softcap": 0.0, "return_softmax": False, } - if is_packed: - t, h, d = q.shape - # half of each seq + + # For Flash Attn, indexing blocks of contiguous mem has the same perf + # as indexing one big contiguous block. + # Also the former avoids frequent mem copies, e.g. when indexing + # half of the seq dim and reshaping + if ( + RingAttention.HALF_INDICES is not None + and cu_seqlens.shape == RingAttention.CU_SEQLENS.shape + and (cu_seqlens == RingAttention.CU_SEQLENS).all() + ): + half_idx_front, half_idx_back = RingAttention.HALF_INDICES + else: half_idx_front = get_half_index(cu_seqlens, front=True) half_idx_back = get_half_index(cu_seqlens, front=False) + RingAttention.HALF_INDICES = (half_idx_front, half_idx_back) + RingAttention.CU_SEQLENS = cu_seqlens + + if is_packed: + t, h, d = q.shape else: b, sq, h, d = q.shape t = b * sq q, k, v = [x.view(t, h, d) for x in (q, k, v)] - half_cu_seqlens = cu_seqlens // 2 - half_max_seqlen = max_seqlen // 2 kv_comms = [RingComm(sp_group) for _ in range(2)] sp_size = kv_comms[0].world_size @@ -591,7 +612,7 @@ def forward( # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) - kv_buffers.append(None) + kv_buffers.append(torch.empty_like(kv_buffers[0])) # outputs out = None @@ -600,30 +621,30 @@ def forward( block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention rng_states = [None for _ in range(sp_size)] sp_streams = [torch.cuda.current_stream(), sp_stream] - correction_done = torch.cuda.Event() # Overlap output correction with next flash attn for i in range(sp_size): with torch.cuda.stream(sp_streams[i % 2]): - if i < sp_size - 1: - # Avoid overwriting the kv block used by the last flash attn call - kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) # Wait for current kv from prev rank # NOTE: waiting outside the current stream will NOT correctly synchronize. kv_comms[(i + 1) % 2].wait() + # Avoid overwriting attn input when it shares mem with buffer + if not RingAttention.ATTN_DONE.query(): + kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) + if i < sp_size - 1: kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2]) if i == 0: # Compute with local KV; no mask + kv_block = kv_buffers[0] q_block = q - kv_block = kv_buffers[i % 2] ( _, _, _, _, - block_out[i % 2], # (B, Sq, H, D) + block_out[i % 2], # (B * Sq, H, D) block_softmax_lse[i % 2], # (H, total_q_seqlen) _, rng_states[i], @@ -641,20 +662,15 @@ def forward( elif i <= sp_rank: # Received the "surrounding" kv chunks # Drop the second half of received kv - q_block = q # (2, t // 2, H, D) - kv_block = kv_buffers[i % 2] - if is_packed: - kv_block = kv_block[:, half_idx_front] - else: - kv_block = kv_block[:, : t // 2] - + kv_block = kv_buffers[i % 2][:, half_idx_front] + q_block = q ( _, _, _, _, - block_out[i % 2], # (B, Sq, H, D) + block_out[i % 2], # (B * Sq, H, D) block_softmax_lse[i % 2], # (H, total_q_seqlen) _, rng_states[i], @@ -663,9 +679,9 @@ def forward( kv_block[0], kv_block[1], cu_seqlens_q, - half_cu_seqlens, + cu_seqlens_kv // 2, max_seqlen_q, - half_max_seqlen, + max_seqlen_kv // 2, causal=False, **misc_kwargs, ) @@ -673,18 +689,16 @@ def forward( else: # Received the inner kv chunks # Drop the first half of q - if is_packed: - q_block = q[half_idx_back] - else: - q_block = q[t // 2 :] - kv_block = kv_buffers[i % 2] + q_block = q[half_idx_back] + + # dist.barrier() ( _, _, _, _, - block_out[i % 2], # (B, Sq // 2, H, D) + block_out[i % 2], # (B * Sq // 2, H, D) block_softmax_lse[i % 2], # (H, total_q_seqlen) _, rng_states[i], @@ -692,41 +706,38 @@ def forward( q_block, kv_block[0], kv_block[1], - half_cu_seqlens, + cu_seqlens_q // 2, cu_seqlens_kv, - half_max_seqlen, + max_seqlen_q // 2, max_seqlen_kv, causal=False, **misc_kwargs, ) + RingAttention.ATTN_DONE.record(sp_streams[i % 2]) - block_out[i % 2] = block_out[i % 2].view(-1, h, d) block_softmax_lse[i % 2] = ( block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float() ) # (H, T) -> (T, H, 1) assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1] - if i == 1: - pass # Output and log sum exp correction if i > 0: - sp_streams[i % 2].wait_event(correction_done) + sp_streams[i % 2].wait_event(RingAttention.CORRECTION_DONE) + if sp_rank == 0: + pass # Overlap output correction with next flash attn kernel if i == 0: out = block_out[0] softmax_lse = block_softmax_lse[0] elif i <= sp_rank: - _rescale_out_lse(out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]) + out, softmax_lse = _rescale_out_lse(out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2]) else: - # Dropped the first half of q sequence - if is_packed: - _out, _lse = out[half_idx_back], softmax_lse[half_idx_back] - else: - _out, _lse = out[t // 2 :], softmax_lse[t // 2 :] - _rescale_out_lse(_out, block_out[i % 2], _lse, block_softmax_lse[i % 2]) + out[half_idx_back], softmax_lse[half_idx_back] = _rescale_out_lse( + out[half_idx_back], block_out[i % 2], softmax_lse[half_idx_back], block_softmax_lse[i % 2] + ) - sp_streams[i % 2].record_event(correction_done) - torch.cuda.current_stream().wait_event(correction_done) + RingAttention.CORRECTION_DONE.record(sp_streams[i % 2]) + torch.cuda.current_stream().wait_event(RingAttention.CORRECTION_DONE) out = out.to(q.dtype) if not is_packed: @@ -735,13 +746,12 @@ def forward( softmax_lse = softmax_lse.squeeze(-1) ctx.sp_group = sp_group - ctx.max_seqlen_q = max_seqlen - ctx.max_seqlen_kv = max_seqlen_kv + ctx.max_seqlen_q = ctx.max_seqlen_kv = max_seqlen misc_kwargs["deterministic"] = deterministic del misc_kwargs["return_softmax"] ctx.misc_kwargs = misc_kwargs ctx.is_packed = is_packed - indices = (half_idx_front, half_idx_back) if is_packed else tuple() + ctx.save_for_backward( q, k, @@ -750,7 +760,8 @@ def forward( softmax_lse.transpose(0, 1).contiguous(), # (T, H) -> (H, T) cu_seqlens_q, cu_seqlens_kv, - *indices, + half_idx_front, + half_idx_back, *rng_states, ) @@ -763,23 +774,9 @@ def backward(ctx, dout, _): During backward, we accumulate q grads on each rank locally, but iterate kv and their grads over all ranks for accumulation. """ - ( - q, - k, - v, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_kv, - ) = ctx.saved_tensors[:7] + (q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9] is_packed = ctx.is_packed - if is_packed: - half_idx_front, half_idx_back = ctx.saved_tensors[7:9] - rng_states = ctx.saved_tensors[9:] - softmax_lse1 = softmax_lse[:, half_idx_back].contiguous() - else: - rng_states = ctx.saved_tensors[7:] - softmax_lse1 = softmax_lse.chunk(2, dim=-1)[1].contiguous() # Second half of seq + rng_states = ctx.saved_tensors[9:] max_seqlen_q = ctx.max_seqlen_q max_seqlen_kv = ctx.max_seqlen_kv @@ -796,7 +793,7 @@ def backward(ctx, dout, _): else: b, sq, h, d = q.shape t = b * sq - q, k, v, out, dout = [x.view(t, h, d) for x in (q, k, v, out, dout)] + q, k, v, out, dout = [x.view(t, h, d) for x in (q, k, v, out, dout)] # Sequence parallel args sp_group = ctx.sp_group @@ -849,13 +846,12 @@ def backward(ctx, dout, _): rng_state=rng_states[i], **misc_kwargs, ) + elif i <= sp_rank: # Drop the second half of kv # (T, H, D) -> (T // 2, H, D) - if is_packed: - k_, v_, dk_, dv_ = [x[half_idx_front] for x in (*kv_buffers[i % 2], dk_block, dv_block)] - else: - k_, v_, dk_, dv_ = [x[: t // 2] for x in (*kv_buffers[i % 2], dk_block, dv_block)] + k_, v_ = [x[half_idx_front] for x in kv_buffers[i % 2]] + dk_, dv_ = [x[: t // 2] for x in (dk_block, dv_block)] dq_, q_, out_, dout_ = (dq_block, q, out, dout) _flash_attn_backward( @@ -881,17 +877,16 @@ def backward(ctx, dout, _): # Drop the first half of q k_, v_ = kv_buffers[i % 2] dk_, dv_ = dk_block, dv_block - if is_packed: - q_, dq_, out_, dout_ = [x[half_idx_back] for x in (q, dq_block, out, dout)] - else: - q_, dq_, out_, dout_ = [x[t // 2 :] for x in (q, dq_block, out, dout)] + q_, out_, dout_ = [x[half_idx_back] for x in (q, out, dout)] + dq_ = dq_block[: t // 2] + _flash_attn_backward( dout_, q_, k_, v_, out_, - softmax_lse1, + softmax_lse[:, half_idx_back], dq_, dk_, dv_, @@ -909,37 +904,29 @@ def backward(ctx, dout, _): dkv_recv = dkv_buffers[(i + 1) % 2] if i == 0: dq = dq_block.float() - dkv_recv[0].copy_(dk_block) - dkv_recv[1].copy_(dv_block) + dkv_recv[0] = dk_block.float() + dkv_recv[1] = dv_block.float() else: # Accumulate local dq if i <= sp_rank: - dq += dq_block # (T, H, D) + dq += dq_ # (T, H, D) else: - if is_packed: - dq[half_idx_back] += dq_block[half_idx_back] - else: - dq[t // 2 :] += dq_block[t // 2 :] # (T // 2, H, D) + dq[half_idx_back] += dq_ # Wait for mobile kv grad accumulators dkv_comm.wait() if i <= sp_rank: # q blocks "surrounded" by kv blocks - if is_packed: - dkv_recv[0][half_idx_front] += dk_block[half_idx_front] - dkv_recv[1][half_idx_front] += dv_block[half_idx_front] - else: - dkv_recv[0][: t // 2] += dk_block[: t // 2] # (T // 2, H, D) - dkv_recv[1][: t // 2] += dv_block[: t // 2] + dkv_recv[0][half_idx_front] += dk_ + dkv_recv[1][half_idx_front] += dv_ else: # q blocks "surrounding" kv blocks - dkv_recv[0] += dk_block - dkv_recv[1] += dv_block + dkv_recv[0] += dk_ + dkv_recv[1] += dv_ dkv_comm.send_recv(send_tensor=dkv_recv, recv_tensor=dkv_send) dkv_comm.wait() dkv_recv = dkv_send - dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv_recv)] if not is_packed: dq, dk, dv = [x.view(b, sq, h, d) for x in (dq, dk, dv)] @@ -960,18 +947,31 @@ def backward(ctx, dout, _): @staticmethod def prepare_varlen_batch( - inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, sp_group: dist.ProcessGroup, + inputs_embeds: torch.Tensor = None, position_ids: Optional[torch.Tensor] = None, + is_label: bool = False, + is_2d: bool = True, ): """ Preprocess a batch of padded sequence by splitting input sequence by sp_size sequence-wise and packing them into one sequence. Updates the mask info accordingly. Args: + attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked. + sp_group (dist.ProcessGroup): Process group for sequence parallelism inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...] - attention_mask (torch.Tensor): Contains the mask [B, Sq] - position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq]. Defaults to None. + position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None. + is_label (bool, optional): Whether the input is a label tensor. If True, mask out the first + token of each sequence. + is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten + the batch dim to a packed 1d sequence. Contingent on model forward shape definitions. + + Returns: + inputs_embeds: Packed input embeddings of shape [B, Sq // sp_size, ...]. + mask_info: A dictionary of mask info. + position_ids: Packed position ids of shape [..., Sq // sp_size]. + """ _load_varlen_helpers() sp_size = dist.get_world_size(group=sp_group) @@ -982,21 +982,32 @@ def prepare_varlen_batch( # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size) # Split mask to compute local nonzero position indices # (B, Sq) -> (B, max_seqlen // sp_size) - attention_mask, inputs_embeds = split_varlen_zigzag( - [attention_mask, inputs_embeds], - mask_info["cu_seqlens"], - sp_group, - mask_info["max_seqlen"], - is_2d=True, + attention_mask = attention_mask[:, : mask_info["max_seqlen"]] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]] + inputs_embeds = split_varlen_zigzag( + inputs_embeds, + mask_info["cu_seqlens"], + sp_group, + mask_info["max_seqlen"], + is_2d=is_2d, + is_label=is_label, + ) + attention_mask = split_varlen_zigzag( + attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d ) - mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - mask_info["max_seqlen"] //= sp_size - mask_info["cu_seqlens"] //= sp_size - mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL if position_ids is not None: indices = torch.tensor([sp_rank, 2 * sp_size - sp_rank - 1], device=inputs_embeds.device) position_ids = ( - position_ids[: mask_info["max_seqlen"]].view(sp_size * 2, -1).index_select(0, indices).view(-1) + position_ids[..., : mask_info["max_seqlen"]] + .view(-1, sp_size * 2, mask_info["max_seqlen"] // (sp_size * 2)) + .index_select(-1, indices) + .view(-1, mask_info["max_seqlen"] // sp_size) ) - return inputs_embeds, position_ids, mask_info + mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + mask_info["max_seqlen"] //= sp_size + mask_info["cu_seqlens"] //= sp_size + mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL + + return inputs_embeds, mask_info, position_ids diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index a91f0207e371..bc38d1c68b58 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,6 +1,5 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss @@ -151,7 +150,7 @@ def cross_entropy_1d( def dist_cross_entropy( - labels: torch.Tensor, # [B, S] + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, out_features: int, @@ -169,35 +168,54 @@ def dist_cross_entropy( sp_mode = shard_config.sequence_parallelism_mode parallel_output = shard_config.parallel_output is_tp = shard_config.enable_tensor_parallelism - - bs, seq_len = labels.shape + is_packed = labels.dim() == 2 + if is_packed: + bs, seq_len = labels.shape + else: + # padded sequence + seq_len = labels.shape[-1] + logits = logits.reshape(-1, *logits.shape[2:]) + seq_dim = 0 # Shift labels to predict the next token, and remove the tail logit predicting is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode)) split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward if is_sp: - # shift only once + # shift only once: either before splitting or on the last rank without splitting if split_labels_here or (sp_rank == sp_size - 1): labels = labels[..., 1:] # Split labels when logits are split if split_labels_here: labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] - # Pad to the same shape across all ranks in TP all_reduce if sp_rank == sp_size - 1: logits = logits[..., :-1, :] + # Pad logits and labels to the same shape across all ranks for TP all_reduce if is_tp and parallel_output: - pad_shape = [0] * logits.dim() * 2 - pad_shape[-3] = 1 # Right side, dim = -2 - logits = F.pad(logits, pad_shape, value=_IGNORE_IDX) - labels = F.pad(labels, (0, 1, 0, 0), value=_IGNORE_IDX) + # If is packed sequence (label dim is 1), then each seq already has the end label token padded. + # NOTE: torch.cat is faster than F.pad... + pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:]) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device) + logits = torch.cat([logits, padding], dim=seq_dim) + + pad_shape = (labels.shape[0], 1) if is_packed else (1,) + padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device) + labels = torch.cat([labels, padding], dim=seq_dim) + # pad_shape = [0] * labels.dim() * 2 + # pad_shape[1] = 1 + # labels = F.pad(labels, (0, 1, 0, 0), value=_IGNORE_IDX) else: labels = labels[..., 1:] logits = logits[..., :-1, :] labels = labels.contiguous() logits = logits.contiguous() num_nonzero = (labels != _IGNORE_IDX).sum() - assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + try: + assert ( + labels.shape == logits.shape[:-1] + ), f"label shape {labels.shape} does not match logit shape {logits.shape}" + except: + pass # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum") @@ -218,7 +236,10 @@ def dist_cross_entropy( else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D logits = logits.view(-1, vocab_size) - loss = loss_fct(logits, labels) + try: + loss = loss_fct(logits, labels) + except: + pass # Reduce loss instead of gathering logits over seq dim for savings if split_labels_here or sp_mode == "ring_attn": diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 1c4808de5fa8..a525eff05a2c 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -291,9 +291,7 @@ def create_randomizer_with_offset( return Randomizer(seed=base_seed) -def split_batch_zigzag( - batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False -): +def split_batch_zigzag(batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim=1): """ Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask in the causal setting will result in the preceding ranks having much less workload. @@ -304,9 +302,7 @@ def split_batch_zigzag( batch (List[torch.Tensor] or Tensor): The input tensor(s) to split. sp_group (ProcessGroup): The process group for sequence parallelism. seq_dim (int): The sequence dimension to split. - varlen (bool): If the input is padded (aka "packing" mode), such that - sequences in a batch have different lengths, and we need to unpad and - split each sequence evenly by sp_size. + """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) @@ -329,7 +325,7 @@ def split_batch_zigzag( indices = torch.tensor([sp_rank, 2 * sp_size - 1 - sp_rank], device=tensor.device) tensor = tensor.index_select(seq_dim, indices).contiguous() # (B, 2, Sq // (2 * sp_size), ...) -> (B, Sq // sp_size, ...) - batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]).contiguous() + batch[idx] = tensor.view(*tensor.shape[:seq_dim], -1, *tensor.shape[seq_dim + 2 :]) if len(batch) == 1: return batch[0] @@ -342,6 +338,7 @@ def split_varlen_zigzag( sp_group: ProcessGroup, max_seqlen: int = 0, is_2d: bool = False, + is_label: bool = False, ) -> Union[List[torch.Tensor], torch.Tensor]: """Split each sequence in a batch of packed sequences in a zigzag fashion. For each tensor in batch, return packed sequences if is_2d is False; @@ -353,6 +350,7 @@ def split_varlen_zigzag( sp_group (ProcessGroup): The process group for sequence parallelism. max_seqlen (int): The maximum sequence length in the batch before splitting. is_2d (bool): If True, then input has batch size and sequence length split into two dimensions. + is_label (bool): If True, mask out the first token in each sequence (). Returns: batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size) @@ -373,7 +371,10 @@ def split_varlen_zigzag( assert max_seqlen % (sp_size * 2) == 0 # Recreate a padded tensor with the new max seqlen shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:]) - local_seq = torch.zeros(shape, dtype=dtype, device=device) + if is_label: + local_seq = torch.full(shape, -100, dtype=dtype, device=device) + else: + local_seq = torch.zeros(shape, dtype=dtype, device=device) else: total_seqlen = cu_seqlens[-1] assert ( @@ -389,12 +390,18 @@ def split_varlen_zigzag( ), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting" if is_2d: - seq = packed_seq[j][:seqlen].chunk(2 * sp_size, dim=0) + seq = packed_seq[j][:seqlen] + if is_label: + seq[0] = -100 + seq = seq.chunk(2 * sp_size, dim=0) half = seqlen // sp_size // 2 local_seq[j][:half] = seq[sp_rank] local_seq[j][half : seqlen // sp_size] = seq[2 * sp_size - 1 - sp_rank] else: - seq = packed_seq[start:end].chunk(sp_size * 2) + seq = packed_seq[start:end] + if is_label: + seq[0] = -100 + seq = seq.chunk(sp_size * 2) local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]]) if is_2d: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index e7fd0e3e0611..9ea6b320d30f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -30,7 +30,7 @@ from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, RingAttention, dist_cross_entropy, get_pad_info +from ..layer import ColoAttention, RingAttention, dist_cross_entropy _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] @@ -132,18 +132,24 @@ def llama_model_forward( position_ids = cache_position.unsqueeze(0) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if shard_config.enable_flash_attention: + if not stage_manager.is_first_stage() and sp_mode == "ring_attn": + _, attention_mask, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group) + elif shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attn_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + try: + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + invert=(sp_mode != "ring_attn"), + ) + except: + pass else: - attn_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP # TODO: support padded casual cu_seqlens across stages @@ -151,14 +157,12 @@ def llama_model_forward( # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( - attn_mask["attention_mask"].squeeze(1).any(dim=-1) - ) # [B, 1, Sq, Skv] -> [B, Sq] - - batch = [hidden_states, position_ids] - # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) - hidden_states, position_ids = split_batch_zigzag(batch, sp_group) + if attention_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + hidden_states, attention_mask, position_ids = RingAttention.prepare_varlen_batch( + attention_mask["attention_mask"].squeeze(1).any(-1), sp_group, hidden_states, position_ids + ) + else: + hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) @@ -199,7 +203,7 @@ def llama_model_forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attn_mask, + attention_mask, position_ids, past_key_values, output_attentions, @@ -209,7 +213,7 @@ def llama_model_forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attn_mask, + attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -312,9 +316,13 @@ def llama_for_causal_lm_forward( logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - if stage_manager.is_first_stage(): - if shard_config.sequence_parallelism_mode == "ring_attn": - labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group) + if shard_config.sequence_parallelism_mode == "ring_attn": + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1) + else: + # [B, max_seqlen // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -545,8 +553,12 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + try: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + except: + pass if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} @@ -560,6 +572,7 @@ def forward( attn_output = RingAttention.attention( query_states, key_states, value_states, sp_group, shard_config.sp_stream, **attention_mask ) + elif shard_config.enable_flash_attention: assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) @@ -670,27 +683,30 @@ def forward( if shard_config.enable_flash_attention: mask_shape = (batch_size, 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) - mask_info: dict = ColoAttention.prepare_attn_kwargs( + attention_mask: dict = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, inputs_embeds.device, q_padding_mask=attention_mask, is_causal=True, + invert=(sp_mode != "ring_attn"), ) else: - mask_info: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + attention_mask: torch.Tensor = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # Ring Attention zigzag batch processing if sp_mode == "ring_attn": assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." - if mask_info["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: - inputs_embeds, position_ids, mask_info = RingAttention.prepare_varlen_batch( - inputs_embeds, mask_info["attention_mask"], sp_group, position_ids + if attention_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + inputs_embeds, attention_mask, position_ids = RingAttention.prepare_varlen_batch( + attention_mask["attention_mask"].squeeze(1).any(-1), sp_group, inputs_embeds, position_ids ) else: inputs_embeds, position_ids = split_batch_zigzag([inputs_embeds, position_ids], sp_group) - mask_info = {"attention_mask_type": mask_info["attention_mask_type"]} # drop redundant tensors + attention_mask = { + "attention_mask_type": attention_mask["attention_mask_type"] + } # drop redundant tensors elif is_share_sp_tp(sp_mode): inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) @@ -710,7 +726,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - mask_info, + attention_mask, position_ids, past_key_values, output_attentions, @@ -721,7 +737,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=mask_info, + attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, @@ -813,7 +829,13 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if shard_config.sequence_parallelism_mode == "ring_attn": - labels = split_batch_zigzag(labels, shard_config.sequence_parallel_process_group) + sp_group = shard_config.sequence_parallel_process_group + if attention_mask.bool().all(): + labels = split_batch_zigzag(labels, sp_group, seq_dim=1) + else: + # [B, max_seq_len // sp_size] + labels, _, _ = RingAttention.prepare_varlen_batch(attention_mask, sp_group, labels, is_label=True) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 9b3db7ca96eb..05ac9d8d24ed 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -33,20 +33,21 @@ def data_gen(): [1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082], ] ).long() - - attention_mask = torch.Tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ).long() - + attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for causal lm def data_gen_for_causal_lm(): data = data_gen() + + # Test padded sequence + padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long) + data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1) + data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1) + + ignore_idx = -100 labels = data["input_ids"].clone() + labels[~data["attention_mask"].bool()] = ignore_idx data["labels"] = labels return data diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 9d39c69783e3..0044d3533573 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -1,10 +1,11 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -12,10 +13,10 @@ @parameterize("seq_len", [4096]) -@parameterize("bs", [1]) +@parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) -@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) def check_ring_attn(seq_len, bs, nheads, d, dtype): torch.cuda.manual_seed(2) device = get_current_device() @@ -46,8 +47,8 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): local_out = split_batch_zigzag(out, sp_group) local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1) local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads) - assert_close(ring_out, local_out, atol=atol, rtol=rtol) assert_close(ring_lse, local_lse, atol=atol, rtol=rtol) + assert_close(ring_out, local_out, atol=atol, rtol=rtol) # Check grads ring_out.sum().backward() @@ -60,44 +61,40 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol) -@parameterize("seqlen", [16]) +@parameterize("seqlen", [4096]) @parameterize("bs", [2]) @parameterize("nheads", [5]) @parameterize("d", [128]) -@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() sp_stream = torch.cuda.Stream() atol = rtol = 7e-3 - + torch.cuda.manual_seed(2) # Prepare varlen attention mask padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device) - # padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 + padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0 padding_mask[:, seqlen // 2 :] = 0 - mask_info = ColoAttention.prepare_attn_kwargs( - (bs, 1, seqlen, seqlen), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True - ) - # input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) - input_embeds = ( - torch.arange(seqlen, device=device, dtype=dtype, requires_grad=True) - .repeat(bs, nheads, d, 1) - .permute(0, 3, 1, 2) - .contiguous() - ) - q, k, v = [input_embeds.clone().transpose(1, 2) for _ in range(3)] + + input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True) # Forward # out = ColoAttention.attention(q, k, v, **mask_info) flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()] qkv = torch.stack([flat_input] * 3, dim=1) qkv.retain_grad() + + input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds) out, lse, _ = flash_attn_varlen_qkvpacked_func( - qkv, mask_info["cu_seqlens_q"], mask_info["max_seqlen_q"], return_attn_probs=True, causal=True + qkv, + mask_info["cu_seqlens"] * sp_size, + mask_info["max_seqlen"] * sp_size, + return_attn_probs=True, + causal=True, + # deterministic=True ) - - input_embeds, _, mask_info = RingAttention.prepare_varlen_batch(input_embeds, padding_mask, sp_group) # Test the splitting function local_input = split_varlen_zigzag( flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size @@ -109,23 +106,31 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): q_ring.retain_grad() k_ring.retain_grad() v_ring.retain_grad() + ring_out, ring_lse = RingAttention.attention( - q_ring, k_ring, v_ring, sp_group, sp_stream, **mask_info, pad_output=False, return_softmax=True + q_ring, + k_ring, + v_ring, + sp_group, + sp_stream, + **mask_info, + pad_output=False, + return_softmax=True, + # deterministic=True ) # Check output - # ring_out, out = [x.transpose(1, 2) for x in (ring_out, out)] # to (B, Sq, nHeads, D) - # out = split_varlen_zigzag(out, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size, is_2d=True) lse = lse.transpose(0, 1) out, lse = split_varlen_zigzag( [out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size ) - # assert_close(lse, ring_lse, atol=atol, rtol=rtol) + assert_close(lse, ring_lse, atol=atol, rtol=rtol) assert_close(out, ring_out, atol=atol, rtol=rtol) # Check grads - out.sum().backward() - ring_out.sum().backward() + labels = torch.ones(out.shape[0], dtype=dtype, device=device) + F.mse_loss(out.sum((-2, -1)), labels).backward() + F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward() dq, dk, dv = [ split_varlen_zigzag( qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size @@ -136,6 +141,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]] for x in (q_ring.grad, k_ring.grad, v_ring.grad) ] + assert_close(dq, dq_ring, atol=atol, rtol=rtol) assert_close(dk, dk_ring, atol=atol, rtol=rtol) assert_close(dv, dv_ring, atol=atol, rtol=rtol) @@ -143,13 +149,13 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): def launch(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - # check_packed_seq() + check_packed_seq() check_ring_attn() @rerun_if_address_is_in_use() def test_ring_attn(): - spawn(launch, nprocs=8) + spawn(launch, nprocs=2) if __name__ == "__main__":