diff --git a/decoders/balance_cross_entropy_loss.py b/decoders/balance_cross_entropy_loss.py index c962bd7..44c7bf7 100644 --- a/decoders/balance_cross_entropy_loss.py +++ b/decoders/balance_cross_entropy_loss.py @@ -37,8 +37,8 @@ def forward(self, gt: shape :math:`(N, 1, H, W)`, the target mask: shape :math:`(N, H, W)`, the mask indicates positive regions ''' - positive = (gt * mask).byte() - negative = ((1 - gt) * mask).byte() + positive = (gt[:,0,:,:] * mask).byte() + negative = ((1 - gt[:,0,:,:]) * mask).byte() positive_count = int(positive.float().sum()) negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))