Skip to content

Commit

Permalink
bug fixes and enharcement to gemm reductionK fusion (#682)
Browse files Browse the repository at this point in the history
* add two missing files

* fix bunch of bugs of gemm-reducek fusion and add a device interface

* small changes

Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
2 people authored and ttl10101 committed Feb 7, 2024
1 parent c9d9a7b commit 2633cf8
Show file tree
Hide file tree
Showing 8 changed files with 445 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h
#include <sstream>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_with_k_reduction.h"
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/kernel/reduce_split_k.h"
Expand Down Expand Up @@ -101,6 +101,12 @@ constexpr int NumStages = 4;
// Reduce A or B operand along the K dimension
constexpr bool ReduceKForA = true;

// Alignment of A operand
constexpr int AlignmentA = 8;

// Alignment of B operand
constexpr int AlignmentB = 8;

// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
Expand All @@ -110,9 +116,9 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>;

using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8,
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8,
using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
Expand All @@ -124,10 +130,12 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
AlignmentA,
AlignmentB,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;

// Below is the reduction kernel used in the case of parallel split-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
Expand Down Expand Up @@ -368,21 +376,21 @@ Result profile(Options const &options) {
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
1997,
ElementInputA(2),
ElementInputA(-2),
0); // <- Fill tensor A on host with uniform-distribution random data

cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
2003,
ElementInputB(2),
ElementInputB(-2),
0); // <- Fill tensor B on host with uniform-distribution random data

cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
2017,
ElementOutput(2),
ElementOutput(-2),
0); // <- Fill matrix C on host with uniform-distribution random data
Expand Down Expand Up @@ -561,7 +569,7 @@ Result profile(Options const &options) {

tensor_reduction.sync_host();

// ReduceK in host code
// Reduce K in host code
if (ReduceKForA) {
for (int m = 0; m < options.problem_size.m(); ++m) {
for (int k = 0; k < options.problem_size.k(); ++k) {
Expand All @@ -581,7 +589,7 @@ Result profile(Options const &options) {
// Check if output from CUTLASS kernel and reference kernel are equal or not
bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view());

pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(),
tensor_reduction.host_view());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,12 @@ class EpilogueGemmKReduction {

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations / 4; ++i) {
ElementOutput tmp;
ElementOutput *source_ptr = reinterpret_cast<ElementOutput *>(&source);
cutlass::arch::global_load<ElementOutput, sizeof(ElementOutput)>(
tmp,
source_ptr[i],
(void *)(pointer_ + i * 32),
guard[i] && LoadForSerialSplitK);

source[i] = tmp;
}

FragmentAccumulator sum = gemm_k_with_reduction_accumulation;
Expand Down
Loading

0 comments on commit 2633cf8

Please sign in to comment.