Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Mar 23, 2024
1 parent fd7b37b commit aeaf8e2
Show file tree
Hide file tree
Showing 33 changed files with 888 additions and 673 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace ort_fastertransformer {

template <typename GemmKernel, bool enable_cutlass_3x = false>
inline int compute_occupancy_for_kernel() {
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
int smem_size = static<int>(sizeof(typename GemmKernel::SharedStorage));

if (smem_size > (48 << 10)) {
cudaFuncAttributes attr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ struct GELU_taylor<float> {

CUTLASS_DEVICE
float operator()(float const& z) const {
float k0 = float(0.7978845608028654);
float k1 = float(0.044715);
float k0 = static<float>(0.7978845608028654);
float k1 = static<float>(0.044715);

return float(cutlass::constants::half<float>() * z * (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
return static<float>(
cutlass::constants::half<float>() * z *
(cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}

using Params = LinearCombinationGenericParams<float>;

CUTLASS_DEVICE
float operator()(float const& scalar, Params const& params_) const {
return this->operator()(scalar);
}
float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); }
};

} // namespace thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ namespace epilogue {
namespace threadblock {

template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
Expand Down Expand Up @@ -90,18 +91,17 @@ class EpilogueVisitorPerRowPerCol {
//
// Methods
//
Arguments()
: batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {
}
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}

Arguments(typename ElementwiseFunctor::Params elementwise_)

Check warning on line 96 in onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h:96: Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {
}

Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
int64_t batch_stride_C_, int64_t batch_stride_D_)
: elementwise(elementwise_), batch_stride_alpha(batch_stride_alpha_), batch_stride_C(batch_stride_C_), batch_stride_D(batch_stride_D_) {
}
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}

Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_) {}
};

struct Params {
Expand All @@ -118,13 +118,14 @@ class EpilogueVisitorPerRowPerCol {

CUTLASS_HOST_DEVICE
Params(Arguments const& args)

Check warning on line 120 in onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h:120: Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
: elementwise(args.elementwise), batch_stride_alpha(args.batch_stride_alpha), batch_stride_C(args.batch_stride_C), batch_stride_D(args.batch_stride_D) {
}
: elementwise(args.elementwise),
batch_stride_alpha(args.batch_stride_alpha),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D) {}
};

/// Shared storage
struct SharedStorage {
};
struct SharedStorage {};

private:
Params const& params_;
Expand Down Expand Up @@ -158,13 +159,26 @@ class EpilogueVisitorPerRowPerCol {
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D, tk::QuantMode quant_option,
AlphaScaleElementType* ptr_alpha_row, AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C, typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params), shared_storage_(shared_storage), extent_(problem_size), elementwise_(params.elementwise), per_token_quant_(quant_option.hasPerTokenScaling()), per_channel_quant_(quant_option.hasPerChannelScaling()), ptr_alpha_row_(ptr_alpha_row), ptr_alpha_col_(ptr_alpha_col), iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), extent_real_(problem_size_real) {
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
per_token_quant_(quant_option.hasPerTokenScaling()),
per_channel_quant_(quant_option.hasPerChannelScaling()),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
extent_real_(problem_size_real) {
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);

if (beta_ == ElementAccumulator()) {
Expand Down Expand Up @@ -219,7 +233,8 @@ class EpilogueVisitorPerRowPerCol {
void begin_row(int row_idx) {
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_) {
int thread_offset_row = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
int thread_offset_row =
iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();

arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
Expand Down Expand Up @@ -262,8 +277,8 @@ class EpilogueVisitorPerRowPerCol {

private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
Expand All @@ -274,8 +289,8 @@ class EpilogueVisitorPerRowPerCol {
}

CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ namespace detail {
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
using WarpTileIterator =
cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;

using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;

Expand Down Expand Up @@ -133,8 +134,9 @@ class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
static int const kThreads = ThreadMap::kThreads;

/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
using Fragment =
Array<Element, ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;

/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
Expand Down Expand Up @@ -163,8 +165,7 @@ class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {

/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
: stride_((ref.stride(0) / LoadType::kElements)) {
SharedLoadIteratorMixed(TensorRef ref, int thread_idx) : stride_((ref.stride(0) / LoadType::kElements)) {
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);

// Initialize pointers
Expand All @@ -173,7 +174,7 @@ class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());

int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
int bank_offset = (col_idx * static<int>(sizeof(LoadType)) / 128) % kLoadsPerAccess;

col_idx += (bank_offset + i) % kLoadsPerAccess;

Expand Down Expand Up @@ -207,7 +208,8 @@ class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ +
cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;

int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));

Expand All @@ -233,9 +235,7 @@ class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {

/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment& frag) const {
load_with_pointer_offset(frag, 0);
}
void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScali

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu> {
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, BiasScaleMode>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, BiasScaleMode>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
Expand All @@ -160,16 +160,14 @@ constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu> {
using Op =
cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU> {
using Op =
cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, DefaultScaleMode>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class GemmUniversalBaseCompat {
gemm_k_size = args.problem_size.k();

if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
int const kAlignK =
const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);

gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);

Expand Down Expand Up @@ -200,26 +201,26 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");

int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
int smem_size = static<int>(sizeof(typename GemmKernel::SharedStorage));

CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");

if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
GemmKernel::kThreadCount, smem_size);

if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
GemmKernel::kThreadCount, 0);

if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));

return -1;
}
Expand Down Expand Up @@ -292,10 +293,11 @@ class GemmUniversalBaseCompat {
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));

// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
int smem_size = static<int>(sizeof(typename GemmKernel::SharedStorage));

if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaError_t result =
cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);

if (result != cudaSuccess) {
return Status::kErrorInternal;
Expand Down Expand Up @@ -333,7 +335,7 @@ class GemmUniversalBaseCompat {
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);

int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
int smem_size = static<int>(sizeof(typename GemmKernel::SharedStorage));

//
// Launch kernel
Expand All @@ -358,9 +360,7 @@ class GemmUniversalBaseCompat {
}

/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
Status operator()(cudaStream_t stream = nullptr) { return run(stream); }

/// Runs the kernel using initialized state.
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Expand Down
Loading

0 comments on commit aeaf8e2

Please sign in to comment.