From 36b5d1adaa85f8c9d27862a5cd00c20cb2c37e3b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 20 Nov 2024 11:45:53 -0800 Subject: [PATCH 1/2] In dist training, update loss running avg every step, only sync on log updates / final. --- train.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index ff11622bb..95b4a9852 100755 --- a/train.py +++ b/train.py @@ -1042,8 +1042,7 @@ def _backward(_loss): loss = _forward() _backward(loss) - if not args.distributed: - losses_m.update(loss.item() * accum_steps, input.size(0)) + losses_m.update(loss.item() * accum_steps, input.size(0)) update_sample_count += input.size(0) if not need_update: @@ -1068,16 +1067,18 @@ def _backward(_loss): lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) + loss_avg, loss_now = losses_m.avg, losses_m.val if args.distributed: - reduced_loss = utils.reduce_tensor(loss.data, args.world_size) - losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) + # synchronize current step and avg loss, each process keeps its own running avg + loss_avg = utils.reduce_tensor(loss.new([loss_avg]), args.world_size).item() + loss_now = utils.reduce_tensor(loss.new([loss_now]), args.world_size).item() update_sample_count *= args.world_size if utils.is_primary(args): _logger.info( f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] ' - f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) ' + f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) ' f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' f'LR: {lr:.3e} ' @@ -1106,7 +1107,12 @@ def _backward(_loss): if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() - return OrderedDict([('loss', losses_m.avg)]) + loss_avg = losses_m.avg + if args.distributed: + # synchronize avg loss, each process keeps its own running avg + loss_avg = torch.tensor([loss_avg], device=device, dtype=torch.float32) + loss_avg = utils.reduce_tensor(loss_avg, args.world_size).item() + return OrderedDict([('loss', loss_avg)]) def validate( From 620cb4f3cb02cd489c9a77711e7f5fc8d40e54b7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 20 Nov 2024 16:43:16 -0800 Subject: [PATCH 2/2] Improve the parsable results dump at end of train, stop excessive output, only display top-10. --- train.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 95b4a9852..aa9db22b6 100755 --- a/train.py +++ b/train.py @@ -930,20 +930,29 @@ def main(): # step LR for next epoch lr_scheduler.step(epoch + 1, latest_metric) - results.append({ + latest_results = { 'epoch': epoch, 'train': train_metrics, - 'validation': eval_metrics, - }) + } + if eval_metrics is not None: + latest_results['validation'] = eval_metrics + results.append(latest_results) except KeyboardInterrupt: pass - results = {'all': results} if best_metric is not None: - results['best'] = results['all'][best_epoch - start_epoch] + # log best metric as tracked by checkpoint saver _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) - print(f'--result\n{json.dumps(results, indent=4)}') + + if utils.is_primary(args): + # for parsable results display, dump top-10 summaries to avoid excess console spam + display_results = sorted( + results, + key=lambda x: x.get('validation', x.get('train')).get(eval_metric, 0), + reverse=decreasing_metric, + ) + print(f'--result\n{json.dumps(display_results[-10:], indent=4)}') def train_one_epoch(