Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
kuacakuaca and albertz authored Sep 5, 2024
1 parent c2a301f commit a4929dc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
6 changes: 2 additions & 4 deletions i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,8 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
# sequence of weighted sums over value sequence
v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F']
attn_output = (
torch.einsum("bhij, bjhf -> bhif", attn_output_weights, v)
.transpose(1, 2)
.contiguous()
.view(batch_dim_size, -1, self.embed_dim)
torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v)
.reshape(batch_dim_size, -1, self.embed_dim)
)

output_tensor = self.out_proj(attn_output)
Expand Down
5 changes: 3 additions & 2 deletions i6_models/parts/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def __init__(self, p: float, dropout_broadcast_axes: Optional[Literal["B", "T",

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
assumes input tensor of shape [B, T, F]
return tensor of shape [B, T, F]
:param tensor: input tensor of shape [B, T, F]
:return: tensor of shape [B, T, F]
"""
if self.dropout_broadcast_axes is None:
tensor = torch.nn.functional.dropout(tensor, p=self.p, training=self.training)
elif self.dropout_broadcast_axes == "T": # [B, T, F] -> [B, F, T] -> [B, T, F]
# torch.nn.functional.dropout1d expects a 3D tensor and broadcasts in the last dimension.
tensor = torch.nn.functional.dropout1d(tensor.transpose(1, 2), p=self.p, training=self.training).transpose(
1, 2
)
Expand Down

0 comments on commit a4929dc

Please sign in to comment.