Skip to content

Commit

Permalink
fix dupliate gpu check (#9578)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Sep 13, 2023
1 parent 300f9ac commit a343ae3
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions demo/nvflare/vertical/custom/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def _do_training(self, fl_ctx: FLContext):
'eval_metric': 'auc',
}
if self._use_gpus:
if self._use_gpus:
self.log_info(fl_ctx, f'Training with GPU {rank}')
param['device'] = f"cuda:{rank}"
self.log_info(fl_ctx, f'Training with GPU {rank}')
param['device'] = f"cuda:{rank}"

# specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]
Expand Down

0 comments on commit a343ae3

Please sign in to comment.