diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index bf858104850e..239294e65739 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -643,7 +643,9 @@ def forward( with torch.cuda.stream(sp_streams[i % 2]): # Wait for current kv from prev rank # NOTE: waiting outside the current stream will NOT correctly synchronize. - kv_comms[(i + 1) % 2].wait() + if i > 0: + kv_comms[(i + 1) % 2].wait() + # Avoid overwriting attn input when it shares mem with buffer if not RingAttention.ATTN_DONE.query(): kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2]) @@ -836,7 +838,9 @@ def backward(ctx, dout, _): # NOTE: We avoid using two streams since it requires doubling dkv and kv buffers, # and backward is more communication intensive than forward for i in range(sp_size): - kv_comm.wait() + if i > 0: + kv_comm.wait() + if i < sp_size - 1: # Send kv to next rank for backward kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])