Skip to content

Commit

Permalink
[Hackathon 7th] 修复 vctkernie_sat 训练时出现的类型提升问题 (#3943)
Browse files Browse the repository at this point in the history
* [Fix] vctk type promotion

* [Fix] type promotion
  • Loading branch information
megemini authored Dec 9, 2024
1 parent b84e86d commit e4038b4
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions paddlespeech/t2s/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def forward(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype)
mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos,
[-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10)
Expand Down

0 comments on commit e4038b4

Please sign in to comment.