Skip to content

Commit

Permalink
clarify kv_comm.wait()
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 8, 2024
1 parent e90e984 commit e26c910
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,9 @@ def forward(
with torch.cuda.stream(sp_streams[i % 2]):
# Wait for current kv from prev rank
# NOTE: waiting outside the current stream will NOT correctly synchronize.
kv_comms[(i + 1) % 2].wait()
if i > 0:
kv_comms[(i + 1) % 2].wait()

# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
Expand Down Expand Up @@ -836,7 +838,9 @@ def backward(ctx, dout, _):
# NOTE: We avoid using two streams since it requires doubling dkv and kv buffers,
# and backward is more communication intensive than forward
for i in range(sp_size):
kv_comm.wait()
if i > 0:
kv_comm.wait()

if i < sp_size - 1:
# Send kv to next rank for backward
kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
Expand Down

0 comments on commit e26c910

Please sign in to comment.