From 13715e97283a612665c5f738d82307fc7df80499 Mon Sep 17 00:00:00 2001 From: sabrina Date: Tue, 27 Apr 2021 16:59:20 +0800 Subject: [PATCH] fix bce loss --- decoders/balance_cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/decoders/balance_cross_entropy_loss.py b/decoders/balance_cross_entropy_loss.py index c962bd7..44c7bf7 100644 --- a/decoders/balance_cross_entropy_loss.py +++ b/decoders/balance_cross_entropy_loss.py @@ -37,8 +37,8 @@ def forward(self, gt: shape :math:`(N, 1, H, W)`, the target mask: shape :math:`(N, H, W)`, the mask indicates positive regions ''' - positive = (gt * mask).byte() - negative = ((1 - gt) * mask).byte() + positive = (gt[:,0,:,:] * mask).byte() + negative = ((1 - gt[:,0,:,:]) * mask).byte() positive_count = int(positive.float().sum()) negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))