Skip to content

Commit

Permalink
Fix build when flash attention and memory efficient attention are dis…
Browse files Browse the repository at this point in the history
…abled (microsoft#18761)

### Fix build when flash attention and memory efficient attention are
disabled

On a customer env with lower version of CUDA < 11.6. Both flash
attention and memory efficient attention is turned OFF according to
https://github.com/microsoft/onnxruntime/blob/e8f33b54bab5129b0dea177669bbd1c1d0894dd8/cmake/CMakeLists.txt#L701.
So
https://github.com/microsoft/onnxruntime/blob/e8f33b54bab5129b0dea177669bbd1c1d0894dd8/cmake/external/cutlass.cmake#L1
condition check return false. No cutlass lib is built.

```
Turn off flash attention since CUDA compiler version < 11.6
```

While, the kernels in
https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/contrib_ops/cuda/moe/ft_moe
are depending on cutass for its build, so we get error like this:

```
[ 77%] Building CUDA object CMakeFiles/onnxruntime_providers_cuda.dir/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu.o
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
fatal   : Could not open input file /tmp/tmpxft_00044da3_00000000-11_moe_gemm_kernels_fp16_fp16.compute_60.cpp1.ii
make[2]: *** [CMakeFiles/onnxruntime_providers_cuda.dir/build.make:6290: CMakeFiles/onnxruntime_providers_cuda.dir/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu.o] Error 1
make[2]: *** Waiting for unfinished jobs....
make[1]: *** [CMakeFiles/Makefile2:2210: CMakeFiles/onnxruntime_providers_cuda.dir/all] Error 2
make: *** [Makefile:166: all] Error 2
Traceback (most recent call last):
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 2746, in <module>
    sys.exit(main())
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 2639, in main
    build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, args.target)
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 1527, in build_targets
    run_subprocess(cmd_args, env=env)
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 824, in run_subprocess
    return run(*args, cwd=cwd, capture_stdout=capture_stdout, shell=shell, env=my_env)
  File "/tmp/onnxruntime/tools/python/util/run.py", line 49, in run
    completed_process = subprocess.run(
  File "/opt/conda/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
```


### Motivation and Context

To summarize, there are two cases we will have build failure for Linux
CUDA build:
1. User use cuda version < 11.6
2. User disabled Flash attention and memory efficient attention
explictly with onnxruntime_USE_FLASH_ATTENTION and
onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION
  • Loading branch information
pengwa authored Dec 26, 2023
1 parent 9dd9461 commit 37f7436
Show file tree
Hide file tree
Showing 25 changed files with 115 additions and 10 deletions.
24 changes: 17 additions & 7 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)

cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

Expand Down Expand Up @@ -693,16 +694,20 @@ if (onnxruntime_USE_CUDA)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
set(onnxruntime_USE_CUTLASS OFF)
endif()
else()
set(onnxruntime_USE_CUTLASS OFF)
endif()

if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS)
if (onnxruntime_DISABLE_CONTRIB_OPS)
message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled")
else()
message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled")
endif()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
Expand Down Expand Up @@ -887,6 +892,11 @@ function(onnxruntime_set_compile_flags target_name)
if (onnxruntime_ENABLE_ATEN)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()

if (onnxruntime_USE_CUTLASS)
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
endif()

set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if (onnxruntime_USE_CUDA)
# Suppress a "conversion_function_not_usable" warning in gsl/span
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
if (onnxruntime_USE_CUTLASS)
include(FetchContent)
FetchContent_Declare(
cutlass
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
Expand Down Expand Up @@ -202,3 +204,5 @@ Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
Expand Down Expand Up @@ -34,3 +36,5 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
#ifdef USE_CUTLASS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -165,8 +167,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);

#ifdef USE_CUTLASS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);
#endif

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);
Expand Down Expand Up @@ -266,8 +270,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
#ifdef USE_CUTLASS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
Expand Down Expand Up @@ -367,8 +373,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,

#ifdef USE_CUTLASS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,
#endif

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -49,3 +52,5 @@ inline int compute_occupancy_for_kernel() {
}

} // namespace ort_fastertransformer

#endif
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_CUTLASS

#include "cutlass_heuristic.h"

Expand Down Expand Up @@ -185,3 +186,5 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
}

} // namespace ort_fastertransformer

#endif
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_CUTLASS

#pragma once

Expand All @@ -37,3 +38,4 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
const int multi_processor_count, const int is_weight_only);

} // namespace ort_fastertransformer
#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/array.h"
Expand Down Expand Up @@ -131,3 +133,5 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
};

} // namespace ort_fastertransformer

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

namespace ort_fastertransformer {
Expand Down Expand Up @@ -56,3 +58,5 @@ struct CutlassGemmConfig {
};

} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
*
**************************************************************************************************/

#ifdef USE_CUTLASS

/*! \file
\brief Scheduler for grouped GEMM
*/
Expand Down Expand Up @@ -77,3 +79,5 @@ struct GemmMoeProblemVisitor
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/layout/matrix.h"
Expand Down Expand Up @@ -150,4 +152,6 @@ struct MixedGemmArchTraits<

} // namespace kernel
} // namespace gemm
} // namespace cutlass
} // namespace cutlass

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
*
**************************************************************************************************/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/complex.h"
Expand Down Expand Up @@ -461,3 +463,5 @@ struct MoeFCGemm {
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -62,3 +64,5 @@ class MoeGemmRunner {
};

} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#include "moe_gemm_kernels_template.h"

namespace ort_fastertransformer {
template class MoeGemmRunner<half, half>;
} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#include "moe_gemm_kernels_template.h"

namespace ort_fastertransformer {
template class MoeGemmRunner<float, float>;
} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__
#pragma GCC diagnostic push
Expand Down Expand Up @@ -426,3 +428,5 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A, const WeightType* B, con
}

} // namespace ort_fastertransformer

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
Expand Down Expand Up @@ -898,3 +900,5 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half
cudaStream_t);

} // namespace ort_fastertransformer

#endif
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "moe_gemm_kernels.h"
Expand Down Expand Up @@ -172,4 +174,6 @@ class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_s
}
};

} // namespace ort_fastertransformer
} // namespace ort_fastertransformer

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
\brief Base scheduler for grouped problems, using MoE
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
Expand Down Expand Up @@ -288,3 +290,5 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
} // namespace kernel
} // namespace gemm
} // namespace cutlass

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
/*! \file
\brief Defines new layouts needed for MoE
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/cutlass.h"
Expand Down Expand Up @@ -59,3 +62,5 @@ struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {

} // namespace layout
} // namespace cutlass

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "moe.h"
Expand Down Expand Up @@ -117,3 +119,5 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
Expand All @@ -24,3 +26,5 @@ class MoE final : public CudaKernel, public MoEBase {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
Loading

0 comments on commit 37f7436

Please sign in to comment.