Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into master-patched
Browse files Browse the repository at this point in the history
  • Loading branch information
xwang233 committed Nov 21, 2024
2 parents 536a21f + 620cb4f commit e06f1ef
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,20 +930,29 @@ def main():
# step LR for next epoch
lr_scheduler.step(epoch + 1, latest_metric)

results.append({
latest_results = {
'epoch': epoch,
'train': train_metrics,
'validation': eval_metrics,
})
}
if eval_metrics is not None:
latest_results['validation'] = eval_metrics
results.append(latest_results)

except KeyboardInterrupt:
pass

results = {'all': results}
if best_metric is not None:
results['best'] = results['all'][best_epoch - start_epoch]
# log best metric as tracked by checkpoint saver
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
print(f'--result\n{json.dumps(results, indent=4)}')

if utils.is_primary(args):
# for parsable results display, dump top-10 summaries to avoid excess console spam
display_results = sorted(
results,
key=lambda x: x.get('validation', x.get('train')).get(eval_metric, 0),
reverse=decreasing_metric,
)
print(f'--result\n{json.dumps(display_results[-10:], indent=4)}')


def train_one_epoch(
Expand Down Expand Up @@ -1042,8 +1051,7 @@ def _backward(_loss):
loss = _forward()
_backward(loss)

if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)

if not need_update:
Expand All @@ -1068,16 +1076,18 @@ def _backward(_loss):
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)

loss_avg, loss_now = losses_m.avg, losses_m.val
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
# synchronize current step and avg loss, each process keeps its own running avg
loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item()
loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item()
update_sample_count *= args.world_size

if utils.is_primary(args):
_logger.info(
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
Expand Down Expand Up @@ -1106,7 +1116,12 @@ def _backward(_loss):
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()

return OrderedDict([('loss', losses_m.avg)])
loss_avg = losses_m.avg
if args.distributed:
# synchronize avg loss, each process keeps its own running avg
loss_avg = torch.tensor([loss_avg], device=device, dtype=torch.float32)
loss_avg = utils.reduce_tensor(loss_avg, args.world_size).item()
return OrderedDict([('loss', loss_avg)])


def validate(
Expand Down

0 comments on commit e06f1ef

Please sign in to comment.