diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c index 74a7bf0e5c2..f7edbc99d70 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c @@ -116,9 +116,13 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, uct_cuda_ipc_lkey_t *key; ucs_status_t status; #if HAVE_CUDA_FABRIC +#define UCT_CUDA_IPC_QUERY_NUM_ATTRS 2 CUmemGenericAllocationHandle handle; CUmemoryPool mempool; + CUpointer_attribute attr_type[UCT_CUDA_IPC_QUERY_NUM_ATTRS]; + void *attr_data[UCT_CUDA_IPC_QUERY_NUM_ATTRS]; int legacy_capable; + int allowed_handle_types; #endif key = ucs_calloc(1, sizeof(*key), "uct_cuda_ipc_lkey_t"); @@ -134,15 +138,28 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, /* cuda_ipc can handle VMM, mallocasync, and legacy pinned device so need to * pack appropriate handle */ - status = UCT_CUDADRV_FUNC_LOG_ERR(cuPointerGetAttribute(&legacy_capable, - CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE, - (CUdeviceptr)addr)); + attr_type[0] = CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE; + attr_data[0] = &legacy_capable; + attr_type[1] = CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES; + attr_data[1] = &allowed_handle_types; + + status = UCT_CUDADRV_FUNC_LOG_ERR( + cuPointerGetAttributes(ucs_static_array_size(attr_data), attr_type, + attr_data, (CUdeviceptr)addr)); + if (status != UCS_OK) { + goto err; + } + if (legacy_capable) { key->ph.handle_type = UCT_CUDA_IPC_KEY_HANDLE_TYPE_LEGACY; legacy_handle = &key->ph.handle.legacy; goto legacy_path; } + if (!(allowed_handle_types & CU_MEM_HANDLE_TYPE_FABRIC)) { + goto non_ipc; + } + status = UCT_CUDADRV_FUNC(cuMemRetainAllocationHandle(&handle, addr), UCS_LOG_LEVEL_DIAG); @@ -153,7 +170,8 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, CU_MEM_HANDLE_TYPE_FABRIC, 0)); if (status != UCS_OK) { cuMemRelease(handle); - goto err; + ucs_debug("unable to export handle for VMM ptr: %p", addr); + goto non_ipc; } status = UCT_CUDADRV_FUNC_LOG_ERR(cuMemRelease(handle)); @@ -179,7 +197,8 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, (void *)&key->ph.handle.fabric_handle, mempool, CU_MEM_HANDLE_TYPE_FABRIC, 0)); if (status != UCS_OK) { - goto err; + ucs_debug("unable to export handle for mempool ptr: %p", addr); + goto non_ipc; } status = UCT_CUDADRV_FUNC_LOG_ERR(cuMemPoolExportPointer(&key->ph.ptr, @@ -191,6 +210,10 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh, key->ph.handle_type = UCT_CUDA_IPC_KEY_HANDLE_TYPE_MEMPOOL; ucs_trace("packed mempool handle and export pointer for %p", addr); goto common_path; + +non_ipc: + key->ph.handle_type = UCT_CUDA_IPC_KEY_HANDLE_TYPE_ERROR; + goto common_path; #endif legacy_path: status = UCT_CUDADRV_FUNC(cuIpcGetMemHandle(legacy_handle, (CUdeviceptr)addr),