Skip to content

Commit

Permalink
little print bug
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 9, 2019
1 parent 50bb46b commit cf0247e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
14 changes: 7 additions & 7 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ def tta(self, images, seg=True):
images_vflip = torch.flip(images, dims=[2])
pred_vflip = self.model(images_vflip)
# 水平加垂直翻转
images_hvflip = torch.flip(images, dims=[3])
images_hvflip = torch.flip(images_hvflip, dims=[2])
pred_hvflip = self.model(images_hvflip)
# images_hvflip = torch.flip(images, dims=[3])
# images_hvflip = torch.flip(images_hvflip, dims=[2])
# pred_hvflip = self.model(images_hvflip)

if seg:
# 分割需要将预测结果翻转回去
pred_hflip = torch.flip(pred_hflip, dims=[3])
pred_vflip = torch.flip(pred_vflip, dims=[2])
pred_hvflip = torch.flip(pred_hvflip, dims=[2])
pred_hvflip = torch.flip(pred_hvflip, dims=[3])
preds = preds + pred_origin + pred_hflip + pred_vflip + pred_hvflip
# pred_hvflip = torch.flip(pred_hvflip, dims=[2])
# pred_hvflip = torch.flip(pred_hvflip, dims=[3])
preds = preds + pred_origin + pred_hflip + pred_vflip # + pred_hvflip
# 求平均
pred = preds / 4.0
pred = preds / 3.0

return pred

Expand Down
3 changes: 2 additions & 1 deletion train_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ def train(self, train_loader, valid_loader):
# 保存到tensorboard,每一步存储一个
self.writer.add_scalar('train_loss', loss.item(), global_step+i)

descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (self.fold, loss.item(), self.lr)
descript = "Fold: %d, Train Loss: %.7f, lr: %.7f" % (self.fold, loss.item(), self.lr)
tbar.set_description(desc=descript)

# 每一个epoch完毕之后,执行学习率衰减
lr_scheduler.step()
self.lr = lr_scheduler.get_lr()
global_step += len(train_loader)

# Print the log info
Expand Down
3 changes: 2 additions & 1 deletion train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def train(self, train_loader, valid_loader):
# 保存到tensorboard,每一步存储一个
self.writer.add_scalar('train_loss', loss.item(), global_step+i)

descript = "Fold: %d, Train Loss: %.7f, lr: %s" % (self.fold, loss.item(), self.lr)
descript = "Fold: %d, Train Loss: %.7f, lr: %.7f" % (self.fold, loss.item(), self.lr)
tbar.set_description(desc=descript)

# 每一个epoch完毕之后,执行学习率衰减
lr_scheduler.step()
self.lr = lr_scheduler.get_lr()
global_step += len(train_loader)

# Print the log info
Expand Down

0 comments on commit cf0247e

Please sign in to comment.