Skip to content

Commit

Permalink
Batched GEMM Multiple D based on Universal GEMM (#1655)
Browse files Browse the repository at this point in the history
* Batched GEMM Multiple D based on Universal GEMM

Co-authored-by: Jing Zhang <[email protected]>

* CI fixes

Co-authored-by: Jing Zhang <[email protected]>

---------

Co-authored-by: Jing Zhang <[email protected]>
  • Loading branch information
bartekxk and Jing Zhang authored Nov 18, 2024
1 parent efb3474 commit 754adc7
Show file tree
Hide file tree
Showing 21 changed files with 2,655 additions and 11 deletions.
6 changes: 6 additions & 0 deletions example/24_batched_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)

add_example_executable(example_batched_gemm_xdl_bf16_v3 batched_gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16_v3)

add_example_executable(example_batched_gemm_xdl_fp8_rowwise_v3 batched_gemm_xdl_fp8_rowwise_v3.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp8_rowwise_v3)

add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)

Expand Down
99 changes: 99 additions & 0 deletions example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/literals.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using BF16 = ck::bhalf_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;

using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmDefault,
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<8>, // CDEShuffleBlockTransferScalarPerVectors
ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler
ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion
>;

#include "run_batched_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
106 changes: 106 additions & 0 deletions example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/literals.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;

using ADataType = F8;
using BDataType = F8;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;

using ALayout = Row;
using BLayout = Col;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmDefault,
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors
ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler
ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion
F8 // ComputeTypeA
>;

#include "run_batched_gemm_example_rowwise.inc"

int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); }
36 changes: 26 additions & 10 deletions example/24_batched_gemm/run_batched_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -210,31 +210,47 @@ bool run_batched_gemm_example(int argc, char* argv[])

problem_size.M = 256 * (dis(gen) + 1);
problem_size.N = 128 * (dis(gen) + 1);
problem_size.K = 64 * (dis(gen) + 2);
problem_size.K = 128 * (dis(gen) + 2);

problem_size.stride_A = problem_size.K;
problem_size.stride_B = problem_size.K;
problem_size.stride_C = problem_size.N;

problem_size.batch_stride_A = problem_size.M * problem_size.K;
problem_size.batch_stride_B = problem_size.K * problem_size.N;
problem_size.batch_stride_C = problem_size.M * problem_size.N;

problem_size.batch_count = 16;
problem_size.batch_count = 2;

if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 8)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.batch_count = std::stoi(argv[7]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("optinal\n");
printf("arg4-7: M = %d N = %d K = %d Batch = %d\n",
problem_size.M,
problem_size.N,
problem_size.K,
problem_size.batch_count);
exit(0);
}

problem_size.stride_A = problem_size.K;
problem_size.stride_B = problem_size.K;
problem_size.stride_C = problem_size.N;

problem_size.batch_stride_A = problem_size.M * problem_size.K;
problem_size.batch_stride_B = problem_size.K * problem_size.N;
problem_size.batch_stride_C = problem_size.M * problem_size.N;

return run_batched_gemm(problem_size, config);
}
Loading

0 comments on commit 754adc7

Please sign in to comment.