diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp index fb07f4cc4..db0432820 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp @@ -156,7 +156,7 @@ void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) { "ncclAllGather"); } -void nccl_alltoall( +void nccl_alltoall_single( at::Tensor dst, at::Tensor src, int64_t world_size, @@ -271,7 +271,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("nccl_allgather(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()"); m.def( - "nccl_alltoall(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()"); + "nccl_alltoall_single(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()"); m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()"); @@ -298,7 +298,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("nccl_allreduce", nccl_allreduce); m.impl("nccl_allgather", nccl_allgather); - m.impl("nccl_alltoall", nccl_alltoall); + m.impl("nccl_alltoall_single", nccl_alltoall_single); m.impl("nccl_reducescatter", nccl_reducescatter); m.impl("one_shot_car_allreduce", one_shot_car_allreduce); m.impl("two_shot_car_allreduce", two_shot_car_allreduce); @@ -309,7 +309,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("nccl_allreduce", nccl_allreduce); m.impl("nccl_allgather", nccl_allgather); - m.impl("nccl_alltoall", nccl_alltoall); + m.impl("nccl_alltoall_single", nccl_alltoall_single); m.impl("nccl_reducescatter", nccl_reducescatter); m.impl("one_shot_car_allreduce", one_shot_car_allreduce); m.impl("two_shot_car_allreduce", two_shot_car_allreduce); @@ -331,7 +331,7 @@ void nccl_allgather_meta( return; } -void nccl_alltoall_meta( +void nccl_alltoall_single_meta( at::Tensor /* dst */, at::Tensor /* src */, int64_t /* world_size */, @@ -365,7 +365,7 @@ void two_shot_car_allreduce_meta( TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("nccl_allreduce", nccl_allreduce_meta); m.impl("nccl_allgather", nccl_allgather_meta); - m.impl("nccl_alltoall", nccl_alltoall_meta); + m.impl("nccl_alltoall_single", nccl_alltoall_single_meta); m.impl("nccl_reducescatter", nccl_reducescatter_meta); m.impl("one_shot_car_allreduce", one_shot_car_allreduce_meta); m.impl("two_shot_car_allreduce", two_shot_car_allreduce_meta);