Skip to content

Commit

Permalink
UCT/CUDA-IPC: Fix reachability check
Browse files Browse the repository at this point in the history
  • Loading branch information
brminich committed Sep 29, 2024
1 parent 7d0bf01 commit 0e5cbb4
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions src/uct/cuda/cuda_ipc/cuda_ipc_iface.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ static ucs_status_t uct_cuda_ipc_iface_get_address(uct_iface_h tl_iface,
return UCS_OK;
}

#if HAVE_CUDA_FABRIC
static int uct_cuda_ipc_iface_is_mnnvl_supported(uct_cuda_ipc_md_t *md)
{
#if HAVE_CUDA_FABRIC
CUdevice cu_device;
int coherent;
ucs_status_t status;
Expand All @@ -95,17 +95,17 @@ static int uct_cuda_ipc_iface_is_mnnvl_supported(uct_cuda_ipc_md_t *md)
}

return coherent && (md->enable_mnnvl != UCS_NO);
}
#endif

return 0;
}

static int
uct_cuda_ipc_iface_is_reachable_v2(const uct_iface_h tl_iface,
const uct_iface_is_reachable_params_t *params)
{
#if HAVE_CUDA_FABRIC
uct_base_iface_t *base_iface = ucs_derived_of(tl_iface, uct_base_iface_t);
uct_cuda_ipc_md_t *md = ucs_derived_of(base_iface->md, uct_cuda_ipc_md_t);
#endif

if (!uct_iface_is_reachable_params_addrs_valid(params)) {
return 0;
Expand All @@ -116,16 +116,10 @@ uct_cuda_ipc_iface_is_reachable_v2(const uct_iface_h tl_iface,
return 0;
}

#if HAVE_CUDA_FABRIC
if (uct_cuda_ipc_iface_is_mnnvl_supported(md)) {
/* multi-node nvlink is supported and enabled */
return 1;
}
#endif

/* Not fabric capable or multi-node nvlink disabled, so iface has to be on
* the same node for cuda-ipc to be reachable */
if ((ucs_get_system_id() != *((const uint64_t*)params->device_addr))) {
/* Either multi-node NVLINK should be supported or iface has to be on the
* same node for cuda-ipc to be reachable */
if ((ucs_get_system_id() != *((const uint64_t*)params->device_addr)) &&
!uct_cuda_ipc_iface_is_mnnvl_supported(md)) {
uct_iface_fill_info_str_buf(params,
"different system id %"PRIx64" vs %"PRIx64"",
ucs_get_system_id(),
Expand Down Expand Up @@ -612,13 +606,12 @@ uct_cuda_ipc_query_devices(
unsigned *num_tl_devices_p)
{
uct_device_type_t dev_type = UCT_DEVICE_TYPE_SHM;
#if HAVE_CUDA_FABRIC
uct_cuda_ipc_md_t *md = ucs_derived_of(uct_md, uct_cuda_ipc_md_t);

if (uct_cuda_ipc_iface_is_mnnvl_supported(md)) {
dev_type = UCT_DEVICE_TYPE_NET;
}
#endif

return uct_cuda_base_query_devices_common(uct_md, dev_type,
tl_devices_p, num_tl_devices_p);
}
Expand Down

0 comments on commit 0e5cbb4

Please sign in to comment.