Skip to content

Commit

Permalink
fix SeqParallelMultiHeadCrossAttention for consistent results in dist…
Browse files Browse the repository at this point in the history
…ributed mode (#510)
  • Loading branch information
Kipsora authored Jun 24, 2024
1 parent a6036e4 commit 0312a0d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion opensora/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def forward(self, x, cond, mask=None):

# shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
k, v = kv.unbind(2)
Expand Down

0 comments on commit 0312a0d

Please sign in to comment.