diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index b9c16c0996..07928c5e4a 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -108,6 +108,7 @@ def __init__(self, config: Config): print(f"Start running torch distributed training on local rank {local_rank}.", file=log.v2) assert self._device == "cuda", f"torch distributed: unexpected device {self._device!r}" self._device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) # Theano and TensorFlow print sth like: Using gpu device 2: GeForce GTX 980 (...) # Print in a similar format so that some scripts which grep our stdout work just as before.