Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Begin to add Contextual positional encoding #1645

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix the implementation of CoPE
csukuangfj committed Jul 3, 2024
commit 36808b89406d97ee5ab68c43136a509eb0d193fc
8 changes: 6 additions & 2 deletions egs/librispeech/ASR/zipformer/test_cope.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#!/usr/bin/env python3

import torch
from zipformer import ContextualPositionalEncoding


def test():
embed_dim = 5
npos_max = 10

cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
q = torch.rand(2, 3, 4, embed_dim)
qk = torch.rand(2, 3, 4, 6)
q = torch.rand(2, 3, npos_max, embed_dim)

qk = torch.rand(2, 3, npos_max, npos_max)

p = cope(q=q, qk=qk)
print(p.shape)
@@ -19,4 +22,5 @@ def main():


if __name__ == "__main__":
torch.manual_seed(20240703)
main()
49 changes: 41 additions & 8 deletions egs/librispeech/ASR/zipformer/zipformer.py
Original file line number Diff line number Diff line change
@@ -1402,26 +1402,59 @@ def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
Returns:
Return a tensor of shape (head, batch, time1, npos_max)

Note the implementation assumes time1 == time2 and npos_max <= time2.
The implementation is reasonable for the streaming ASR encoder where
only self attention is used.
"""
# The implementation on page 13 Listing 1 from the paper does not use
# a mask to ensure that only gates[:, :, i, j] where j < i is computed.
#
# Here we fix that by introducing a mask
mask = torch.triu(
torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool),
diagonal=0,
)
#
# if qk.size(3) is 4, mask is
#
# tensor([[ True, True, True, True],
# [False, True, True, True],
# [False, False, True, True],
# [False, False, False, True]])
#
# mask[i, j] is True if i >= j
gates = torch.sigmoid(qk)
pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1)
# Note: We don't use cumulative sum here for non-streaming
# speech recognition

# We don't use an in-place operation here for the sake of autograd
gates = gates.masked_fill(mask, 0)

# cumsum() is an inclusive sum in PyTorch
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2)
# pos[:, :, i, j] should be 0 for j >= i
# pos[:, :, i, j] contains the position between i and j. If gates
# is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i)
# Note: The paper says on page 4 it equals to i - j + 1 instead of i - j.

pos = pos.clamp(max=self.npos_max - 1)
pos_ceil = pos.ceil().long()
pos_floor = pos.floor().long()

# We assume query_head_dim equals to embed_dim

logits_int = torch.matmul(
q, self.embedding.weight.t()
) # (head, batch, time1, npos_max)
logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape))
logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape))

# We assume that npos_max <= time2
logits_cell = logits_int.gather(-1, pos_ceil)
logits_floor = logits_int.gather(-1, pos_floor)

w = pos - pos_floor
return logits_cell * w + logits_floor * (1 - w)

def streaming_forward(self):
raise RuntimeError("To be implemented")
# Note: The code in the paper on page 13 is correct
# while the description on page 4 equation (5) is wrong
return logits_cell * w + logits_floor * (1 - w)


class CompactRelPositionalEncoding(torch.nn.Module):