Skip to content

Commit

Permalink
Set cuda device max connections based on cuda capability
Browse files Browse the repository at this point in the history
Signed-off-by: Guyue Huang <[email protected]>
  • Loading branch information
Guyue Huang committed Dec 20, 2024
1 parent 5730fac commit 5ea86a0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
12 changes: 12 additions & 0 deletions launcher_scripts/nemo_launcher/collections/conditional_cfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,23 @@ def get_ag_overlap(cfg):
else:
print(1)

@hydra.main(version_base=None, config_path="conf", config_name="get_cuda_device_max_connections")
def get_cuda_device_max_connections(cfg):
"""
Set CUDA_DEVICE_MAX_CONNECTIONS to 32 for blackwell, to 1 for hopper and earlier generations
"""
global cuda_capability
if cuda_capability >= 10:
print(32)
else:
print(1)

if __name__ == "__main__":
if sys.argv[1] == "name=get_ln_sm_margin":
get_ln_sm_margin()
elif sys.argv[1] == "name=get_ag_overlap":
get_ag_overlap()
elif sys.argv[1] == "name=get_cuda_device_max_connections":
get_cuda_device_max_connections()
else:
raise ValueError("The provided conditional config function does not exist.")
11 changes: 11 additions & 0 deletions launcher_scripts/nemo_launcher/core/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,17 @@ def _cuda_device_max_connections(self) -> str:
tensor_model_parallel_size = model_cfg.get("tensor_model_parallel_size", 1)
context_parallel_size = model_cfg.get("context_parallel_size", 1)
fsdp = model_cfg.get("fsdp", False)
if (
(tensor_model_parallel_size > 1 or context_parallel_size > 1)
and not fsdp
):
get_cuda_device_max_connections_command= (
f"python3 {self._launcher_scripts_path / 'nemo_launcher/collections/conditional_cfgs.py'} "
f"name=get_cuda_device_max_connections"
)
return f"CUDA_DEVICE_MAX_CONNECTIONS=\$({get_cuda_device_max_connections_command})"
return ""

return (
"CUDA_DEVICE_MAX_CONNECTIONS=1"
if (
Expand Down

0 comments on commit 5ea86a0

Please sign in to comment.