Skip to content

Commit

Permalink
CUTLASS 3.0 Hopper GEMMs are GETTs in disguise (#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
thakkarV authored and ttl10101 committed Feb 7, 2024
1 parent e954131 commit b2b17db
Show file tree
Hide file tree
Showing 10 changed files with 1,231 additions and 71 deletions.
371 changes: 371 additions & 0 deletions examples/51_hopper_gett/51_hopper_gett.cu

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions examples/51_hopper_gett/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

cutlass_example_add_executable(
51_hopper_gett
51_hopper_gett.cu
)
136 changes: 136 additions & 0 deletions examples/51_hopper_gett/gett_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include "cute/tensor.hpp"

#include "cutlass/arch/arch.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"

#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"

namespace example {

//
// GETT entry point
//
template <
class ProblemShapeMNKL,
class ElementA,
class StrideA,
class ElementB,
class StrideB,
class ElementAccumulator,
class ElementC,
class StrideC,
class ElementD,
class StrideD,
class ElementEpilogue>
cutlass::Status
gett_kernel(
ProblemShapeMNKL problem_shape_mnkl,
ElementA const* ptr_A, StrideA stride_a_mkl,
ElementB const* ptr_B, StrideB stride_b_nkl,
ElementAccumulator _,
ElementC const* ptr_C, StrideC stride_c_mnl,
ElementD * ptr_D, StrideD stride_d_mnl,
ElementEpilogue alpha, ElementEpilogue beta,
cudaStream_t stream = 0) {
using namespace cute;

// TileShape -- GETT configuration
// Specify the number of elements to take from each mode
// BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...)

// Take 128 from m0, 128 from n0, 64 from k0
using TileShape = Shape<Shape<_128>, Shape<_128>, Shape<_64>>;

/* Other examples:
* Take 32 elements from m0 and 4 elements from m1
* Take 64 elements from n0 and 2 elements from n1
* Take 8 elements from k0 and 8 elements from k1
**/
// using TileShape = Shape<Shape<_32,_4>, Shape<_64,_2>, Shape<_8,_8>>;

using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination<
ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC>;

// No changes are required to the default epilogue
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
StrideC,
StrideD,
EpilogueThreadOp>;

// CollectiveMma for GETTs can be built using the CollectiveBuilders
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, StrideA, 128 / cutlass::sizeof_bits<ElementA>::value,
ElementB, StrideB, 128 / cutlass::sizeof_bits<ElementB>::value,
ElementAccumulator,
TileShape, Shape<_1,_2,_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

// The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM
using GettKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShapeMNKL,
CollectiveMainloop,
CollectiveEpilogue>;

using GettOperator = cutlass::gemm::device::GemmUniversalAdapter<GettKernel>;

typename GettOperator::Arguments args {
cutlass::gemm::GemmUniversalMode::kBatched,
problem_shape_mnkl,
ptr_A, stride_a_mkl,
ptr_B, stride_b_nkl,
{ ptr_C, stride_c_mnl, ptr_D, stride_d_mnl, {alpha, beta} }
};

#if CUTLASS_DEBUG_TRACE_LEVEL > 0
print("Problem shape:");
print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n");
print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n");
print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n");
print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n");
print("TileSape:"); print(TileShape{}); print("\n");
#endif

GettOperator op;
return op(args, stream);
}

} // namespace example
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ foreach(EXAMPLE
48_hopper_warp_specialized_gemm
49_hopper_gemm_schedules_with_collective_builder
50_hopper_gemm_with_epilogue_swizzle
51_hopper_gett
)

add_subdirectory(${EXAMPLE})
Expand Down
12 changes: 11 additions & 1 deletion include/cute/atom/copy_traits_sm90_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,17 @@ make_tma_copy(CopyOp,
print("layout_tv : "); print(layout_tv); print("\n");
#endif

return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_tv), decltype(cta_tile)>{tma_desc, gmem_stride_bases};
// If CTA_Tile and SLayout are incompatible, product_each makes sure
// that the TiledCopy generates consistent accesses.
auto cta_tile_tiled = [&]() {
if constexpr (compatible(shape(CTA_Tile{}), shape(SLayout{}))) {
return cta_tile;
} else {
return product_each(cta_tile);
}
}();

return TiledCopy<Copy_Atom<Traits,T>, decltype(layout_tv), decltype(cta_tile_tiled)>{tma_desc, gmem_stride_bases};
}

// Explicit defaulting
Expand Down
14 changes: 8 additions & 6 deletions include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ template <class ElementA, class LayoutA>
constexpr cute::GMMA::Major
tag_to_gmma_major_A() {
// MN major mode is only valid for non-TF32 and non-int MMAs
if constexpr (std::is_same_v<LayoutA, cutlass::layout::ColumnMajor> &&
if constexpr (cutlass::gemm::detail::is_mn_major_A<LayoutA>() &&
not std::is_same_v<ElementA, tfloat32_t> &&
not std::is_same_v<ElementA, int8_t> &&
not std::is_same_v<ElementA, uint8_t>) {
Expand All @@ -77,7 +77,7 @@ template <class ElementB, class LayoutB>
constexpr cute::GMMA::Major
tag_to_gmma_major_B() {
// MN major mode is only valid for non-TF32 and non-int MMAs
if constexpr (std::is_same_v<LayoutB, cutlass::layout::RowMajor> &&
if constexpr (cutlass::gemm::detail::is_mn_major_B<LayoutB>() &&
not std::is_same_v<ElementB, tfloat32_t> &&
not std::is_same_v<ElementB, int8_t> &&
not std::is_same_v<ElementB, uint8_t>) {
Expand Down Expand Up @@ -113,7 +113,7 @@ make_cp_async_gmem_tiled_copy() {

// Maximize the number of threads along the gmem major mode to promote coalesced reads
// While making sure our thread layout tiles the threadblock tile evenly
if constexpr (cute::size<1>(StrideType{}) == 1) {
if constexpr (cutlass::gemm::detail::is_k_major<StrideType>()) {
// K major thread layout for K major gmem
constexpr int threads_major = TileSizeK / Alignment;
constexpr int threads_minor = ThreadCount / threads_major;
Expand All @@ -126,7 +126,7 @@ make_cp_async_gmem_tiled_copy() {
Stride<Int<threads_major>, _1>>{},
Layout<Shape<_1,Int<Alignment>>>{});
}
else if constexpr (cute::size<0>(StrideType{}) == 1) {
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideType>()) {
// MN major thread layout for MN major gmem
constexpr int threads_major = TileSizeMN / Alignment;
constexpr int threads_minor = ThreadCount / threads_major;
Expand Down Expand Up @@ -257,7 +257,8 @@ struct CollectiveBuilder<
not std::is_same_v<KernelScheduleType, KernelMultistage> &&
// dispatch TN tf32 and int8 kernels only to TMA builder
((sizeof(ElementA) == 2 && sizeof(ElementB) == 2) ||
(std::is_same_v<GmemLayoutA, layout::RowMajor> && std::is_same_v<GmemLayoutB, layout::ColumnMajor>))>
(cutlass::gemm::detail::is_k_major_A<GmemLayoutA>() &&
cutlass::gemm::detail::is_k_major_B<GmemLayoutB>()))>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
Expand Down Expand Up @@ -346,7 +347,8 @@ struct CollectiveBuilder<
((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes != 0) ||
// dispatch non-TN tf32 and int8 kernels only to cp_async builder
((sizeof(ElementA) != 2 || sizeof(ElementB) != 2) &&
(not std::is_same_v<GmemLayoutA, layout::RowMajor> || not std::is_same_v<GmemLayoutB, layout::ColumnMajor>))>
(not cutlass::gemm::detail::is_k_major_A<GmemLayoutA>() ||
not cutlass::gemm::detail::is_k_major_B<GmemLayoutB>()))>
> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
Expand Down
63 changes: 51 additions & 12 deletions include/cutlass/gemm/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
#include "cutlass/coord.h"
#include "cutlass/layout/matrix.h"
#include "cute/layout.hpp"
#include "cute/arch/copy_sm90.hpp"

#include "cute/arch/copy_sm90_tma.hpp"
namespace cutlass {
namespace gemm {

Expand Down Expand Up @@ -426,7 +425,9 @@ enum class SharedMemoryClearOption {
// For each cutlass::layout, provides its corresponding cute stride types, 64b by default

template <class L>
struct TagToStrideA {};
struct TagToStrideA {
using type = L;
};

// Maps to modes [M, K, L]
template <>
Expand All @@ -443,7 +444,9 @@ struct TagToStrideA<layout::ColumnMajor> {
};

template <class L>
struct TagToStrideB {};
struct TagToStrideB {
using type = L;
};

// Maps to modes [N, K, L]
template <>
Expand Down Expand Up @@ -479,13 +482,19 @@ using TagToStrideC_t = typename TagToStrideC<LayoutTag>::type;

namespace detail {

template<class Stride>
constexpr bool
is_mn_major() {
// Account for stride types with and without batch mode and batch modes with static zero stride
return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value;
}

// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
template<class StrideAC>
constexpr
auto
stride_to_layout_tag_A() {
// Account for stride types with and without batch mode and batch modes with static zero stride
if constexpr (cute::size<0>(StrideAC{}) == 1) { // M major
if constexpr (is_mn_major<StrideAC>()) { // M major
return layout::ColumnMajor{};
}
else { // K major
Expand All @@ -499,8 +508,7 @@ template<class StrideB>
constexpr
auto
stride_to_layout_tag_B() {
// Account for stride types with and without batch mode and batch modes with static zero stride
if constexpr (cute::size<0>(StrideB{}) == 1) { // N major
if constexpr (is_mn_major<StrideB>()) { // N major
return layout::RowMajor{};
}
else { // K major
Expand All @@ -515,12 +523,12 @@ template <class GmemTiledCopy, class Element>
constexpr int
get_alignment_count_from_gmem_tiled_copy() {
// For TMA tiled copies, we know the alignment has to be 128 bits
if constexpr (std::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy> ||
std::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>) {
if constexpr ( std::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|| std::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
) {
return 128 / sizeof_bits<Element>::value;
}
else
{
else {
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
return GmemTiledCopy::NumValSrc;
}
Expand Down Expand Up @@ -551,6 +559,37 @@ using StrideToLayoutTagB_t = typename StrideToLayoutTagB<S>::type;
template<class S>
using StrideToLayoutTagC_t = typename StrideToLayoutTagC<S>::type;

template<class Stride>
constexpr
bool
is_k_major() {
return ! is_mn_major<Stride>();
}

template<class LayoutA>
constexpr bool
is_mn_major_A() {
return is_mn_major<TagToStrideA_t<LayoutA>>();
}

template<class LayoutB>
constexpr bool
is_mn_major_B() {
return is_mn_major<TagToStrideB_t<LayoutB>>();
}

template<class LayoutA>
constexpr bool
is_k_major_A() {
return is_k_major<TagToStrideA_t<LayoutA>>();
}

template<class LayoutB>
constexpr bool
is_k_major_B() {
return is_k_major<TagToStrideB_t<LayoutB>>();
}

///////////////////////////////////////////////////////////////////////////////

// The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal`
Expand Down
Loading

0 comments on commit b2b17db

Please sign in to comment.