From 5375d71bbd41b24abfa966042bb57916b56973b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:32:34 +0200 Subject: [PATCH] `trl env` report all cuda devices (#2216) --- trl/commands/cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/commands/cli.py b/trl/commands/cli.py index dcc13aaff6..3a11984747 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -38,6 +38,9 @@ def print_env(): + if torch.cuda.is_available(): + devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + accelerate_config = accelerate_config_str = "not found" # Get the default from the config file. @@ -56,7 +59,7 @@ def print_env(): "Platform": platform.platform(), "Python version": platform.python_version(), "PyTorch version": version("torch"), - "CUDA device": torch.cuda.get_device_name() if torch.cuda.is_available() else "not available", + "CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available", "Transformers version": version("transformers"), "Accelerate version": version("accelerate"), "Accelerate config": accelerate_config_str,