Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Sep 20, 2023
1 parent cd3479d commit 2f7c953
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather);
#endif

#if defined(USE_MPI) && defined(ORT_USE_NCCL)
#if defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);
Expand Down Expand Up @@ -298,7 +298,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
#endif

#if defined(USE_MPI) && defined(ORT_USE_NCCL)
#if defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3003,7 +3003,7 @@ Having this op allows runtime to do operator re-ordering to reduce compute FLOPs
}
#endif

#ifdef USE_MPI
#ifdef ORT_USE_NCCL
RegisterCollectiveOps();
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void RegisterContribSchemas();
void RegisterNchwcSchemas();
void RegisterQuantizationSchemas();

#if defined(USE_MPI)
#if defined(ORT_USE_NCCL)
void RegisterCollectiveOps();
#endif

Expand Down

0 comments on commit 2f7c953

Please sign in to comment.