Skip to content

Commit

Permalink
add class accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Sep 27, 2019
1 parent 0f1803f commit 9601ca6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
21 changes: 16 additions & 5 deletions train_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ def train(self, train_loader, valid_loader):
print('Finish Epoch [%d/%d], Average Loss: %.7f' % (epoch, self.epoch, epoch_loss/len(tbar)))

# 验证模型
neg_accuracy, pos_accuracy, accuracy, loss_valid = self.validation(valid_loader)
class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_valid = \
self.validation(valid_loader)

if accuracy > self.max_accuracy_valid:
is_best = True
self.max_accuracy_valid = accuracy
else: is_best = False
else:
is_best = False

state = {
'epoch': epoch,
Expand All @@ -110,6 +113,10 @@ def train(self, train_loader, valid_loader):
self.solver.save_checkpoint(os.path.join(self.model_path, '%s_classify_fold%d.pth' % (self.model_name, self.fold)), state, is_best)
self.writer.add_scalar('valid_loss', loss_valid, epoch)
self.writer.add_scalar('valid_accuracy', accuracy, epoch)
self.writer.add_scalar('valid_class_0_accuracy', class_accuracy[0], epoch)
self.writer.add_scalar('valid_class_1_accuracy', class_accuracy[1], epoch)
self.writer.add_scalar('valid_class_2_accuracy', class_accuracy[2], epoch)
self.writer.add_scalar('valid_class_3_accuracy', class_accuracy[3], epoch)

def validation(self, valid_loader):
''' 完成模型的验证过程
Expand All @@ -135,9 +142,13 @@ def validation(self, valid_loader):
tbar.set_description(desc=descript)
loss_mean = loss_sum / len(tbar)

neg_accuracy, pos_accuracy, accuracy = meter.get_metrics()
print("Negative accuracy: %0.4f | positive accuracy: %0.4f | accuracy: %0.4f" % (neg_accuracy, pos_accuracy, accuracy))
return neg_accuracy, pos_accuracy, accuracy, loss_mean
class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy = meter.get_metrics()
print("Class_0_accuracy: %0.4f | Class_1_accuracy: %0.4f | Class_2_accuracy: %0.4f | Class_3_accuracy: %0.4f | "
"Negative accuracy: %0.4f | positive accuracy: %0.4f | accuracy: %0.4f" %
(class_accuracy[0], class_accuracy[1], class_accuracy[2], class_accuracy[3],
neg_accuracy, pos_accuracy, accuracy))
return class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy, loss_mean


class ChooseMinArea():
def __init__(self, ):
Expand Down
22 changes: 15 additions & 7 deletions utils/cal_classify_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def metric(logit, truth, threshold=0.5):
num_pos = t.sum(dim=[0, 2])
num_neg = batch_size * H * W - num_pos
# 预测正确的正样本和负样本的数目
tp = tp.data.cpu().numpy().sum()
tn = tn.data.cpu().numpy().sum()
tp = tp.data.cpu().numpy()
tn = tn.data.cpu().numpy()
# 正样本、负样本的数目
num_pos = num_pos.data.cpu().numpy().sum()
num_neg = num_neg.data.cpu().numpy().sum()
num_pos = num_pos.data.cpu().numpy()
num_neg = num_neg.data.cpu().numpy()

# tp = np.nan_to_num(tp / (num_pos + 1e-12), 0)
# tn = np.nan_to_num(tn / (num_neg + 1e-12), 0)
Expand Down Expand Up @@ -60,15 +60,23 @@ def update(self, targets, outputs):
self.number_positive.append(num_pos)

def get_metrics(self):
# 预测正确的样本的数目
# 各类预测正确的样本数目,样本总数目
class_tn = np.sum(np.array(self.true_negative), axis=0)
class_tp = np.sum(np.array(self.true_poisitive), axis=0)
class_num_neg = np.sum(np.array(self.number_negative), axis=0)
class_num_pos = np.sum(np.array(self.number_positive), axis=0)
# 预测正确的样本的总数目,样本总数目
tn = np.sum(self.true_negative)
tp = np.sum(self.true_poisitive)
# 负样本和正样本各自的数目
num_neg = np.sum(self.number_negative)
num_pos = np.sum(self.number_positive)
# 各类的正负样本的准确率和总的准确率
class_neg_accuracy = class_tn / class_num_neg
class_pos_accuracy = class_tp / class_num_pos
class_accuracy = (class_tn + class_tp) / (class_num_neg + class_num_pos)
# 正负样本各自的准确率和总的准确率
neg_accuracy = tn / (num_neg + 1e-12)
pos_accuracy = tp / (num_pos + 1e-12)
accuracy = (tn + tp) / (num_neg + num_pos)

return neg_accuracy, pos_accuracy, accuracy
return class_neg_accuracy, class_pos_accuracy, class_accuracy, neg_accuracy, pos_accuracy, accuracy

0 comments on commit 9601ca6

Please sign in to comment.