Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add scale c into sample code #1105

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <iostream>
#include <vector>

void simpleGemmScaleAB(hipblasLtHandle_t handle,
void simpleGemmScaleABC(hipblasLtHandle_t handle,
hipblasOperation_t trans_a,
hipblasOperation_t trans_b,
int64_t m,
Expand All @@ -49,18 +49,20 @@ void simpleGemmScaleAB(hipblasLtHandle_t handle,
int64_t max_workspace_size,
hipStream_t stream,
float h_scale_a,
float h_scale_b);
float h_scale_b,
float h_scale_c);

int main()
{
Runner<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblasLtHalf, float, float> runner(
128, 128, 128, 1, 1.f, 0.f, 32 * 1024 * 1024);
128, 128, 128, 1, 1.f, 1.f, 32 * 1024 * 1024);

float scale_a = 0.5f; // scale A setting
float scale_b = 2.0f; // scale B setting
std::cout << "Running with Scale A = " << scale_a << " and Scale B = " << scale_b << std::endl;
runner.run([&runner, scale_a, scale_b] {
simpleGemmScaleAB(runner.handle,
float scale_c = 2.0f; // scale C setting
std::cout << "Running with Scale A = " << scale_a << ", Scale B = " << scale_b << ", and Scale C = " << scale_c << std::endl;
runner.run([&runner, scale_a, scale_b, scale_c] {
simpleGemmScaleABC(runner.handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
runner.m,
Expand All @@ -77,13 +79,14 @@ int main()
runner.max_workspace_size,
runner.stream,
scale_a,
scale_b);
scale_b,
scale_c);
});

return 0;
}

void simpleGemmScaleAB(hipblasLtHandle_t handle,
void simpleGemmScaleABC(hipblasLtHandle_t handle,
hipblasOperation_t trans_a,
hipblasOperation_t trans_b,
int64_t m,
Expand All @@ -100,22 +103,27 @@ void simpleGemmScaleAB(hipblasLtHandle_t handle,
int64_t max_workspace_size,
hipStream_t stream,
float h_scale_a,
float h_scale_b)
float h_scale_b,
float h_scale_c)
{
float* d_scale_a;
float* d_scale_b;
float* d_scale_c;
CHECK_HIP_ERROR(hipMalloc(&d_scale_a, sizeof(float)));
CHECK_HIP_ERROR(hipMalloc(&d_scale_b, sizeof(float)));
CHECK_HIP_ERROR(hipMalloc(&d_scale_c, sizeof(float)));
CHECK_HIP_ERROR(
hipMemcpyAsync(d_scale_a, &h_scale_a, sizeof(float), hipMemcpyHostToDevice, stream));
CHECK_HIP_ERROR(
hipMemcpyAsync(d_scale_b, &h_scale_b, sizeof(float), hipMemcpyHostToDevice, stream));
CHECK_HIP_ERROR(
hipMemcpyAsync(d_scale_c, &h_scale_c, sizeof(float), hipMemcpyHostToDevice, stream));

hipblasLtMatrixLayout_t matA, matB, matC, matD;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIP_R_8F_E4M3_FNUZ, m, k, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIP_R_8F_E4M3_FNUZ, k, n, k));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_16F, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_16F, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_8F_E4M3_FNUZ, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_8F_E4M3_FNUZ, m, n, m));

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
Expand All @@ -124,11 +132,13 @@ void simpleGemmScaleAB(hipblasLtHandle_t handle,
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t)));

// Set A and B matrix scale factors
//Set A, B, and C matrix scale factors
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &d_scale_a, sizeof(float*)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &d_scale_b, sizeof(float*)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER, &d_scale_c, sizeof(float*)));

hipblasLtMatmulPreference_t pref;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include "helper.h"

void simpleGemmScaleABExt(hipblasLtHandle_t handle,
void simpleGemmScaleABCExt(hipblasLtHandle_t handle,
hipblasOperation_t trans_a,
hipblasOperation_t trans_b,
int64_t m,
Expand All @@ -47,20 +47,22 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle,
int64_t max_workspace_size,
hipStream_t stream,
float h_scale_a,
float h_scale_b);
float h_scale_b,
float h_scale_c);

int main()
{
// This is an example using hipblaslt extension API: ScaleA & ScaleB
// This is an example using hipblaslt extension API: ScaleA & ScaleB & ScaleC
Runner<hipblaslt_f8_fnuz, hipblaslt_f8_fnuz, hipblasLtHalf, float, float> runner(
128, 128, 128, 1, 1.f, 1.f, 32 * 128 * 128);

float scale_a = 0.5f; // scale A setting
float scale_b = 2.0f; // scale B setting
std::cout << "Running with Scale A = " << scale_a << " and Scale B = " << scale_b << std::endl;
float scale_c = 2.0f; // scale C setting
std::cout << "Running with Scale A = " << scale_a << ", Scale B = " << scale_b << ", and Scale C = " << scale_c << std::endl;

runner.run([&runner, scale_a, scale_b] {
simpleGemmScaleABExt(runner.handle,
runner.run([&runner, scale_a, scale_b, scale_c] {
simpleGemmScaleABCExt(runner.handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
runner.m,
Expand All @@ -77,13 +79,14 @@ int main()
runner.max_workspace_size,
runner.stream,
scale_a,
scale_b);
scale_b,
scale_c);
});

return 0;
}

void simpleGemmScaleABExt(hipblasLtHandle_t handle,
void simpleGemmScaleABCExt(hipblasLtHandle_t handle,
hipblasOperation_t trans_a,
hipblasOperation_t trans_b,
int64_t m,
Expand All @@ -100,7 +103,8 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle,
int64_t max_workspace_size,
hipStream_t stream,
float h_scale_a,
float h_scale_b)
float h_scale_b,
float h_scale_c)
{
hipblaslt_ext::GemmPreferenceV2 gemmPref;
gemmPref.setMaxWorkspaceBytes(max_workspace_size);
Expand All @@ -109,21 +113,25 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle,
trans_b,
HIP_R_8F_E4M3_FNUZ,
HIP_R_8F_E4M3_FNUZ,
HIP_R_16F,
HIP_R_16F,
HIP_R_8F_E4M3_FNUZ,
HIP_R_8F_E4M3_FNUZ,
HIPBLAS_COMPUTE_32F);

hipblaslt_ext::GemmEpilogueV2
epilogue; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
hipblaslt_ext::GemmInputsV2 inputs;
float* d_scale_a;
float* d_scale_b;
float* d_scale_c;
CHECK_HIP_ERROR(hipMalloc(&d_scale_a, sizeof(float)));
CHECK_HIP_ERROR(hipMalloc(&d_scale_b, sizeof(float)));
CHECK_HIP_ERROR(hipMalloc(&d_scale_c, sizeof(float)));
CHECK_HIP_ERROR(
hipMemcpyAsync(d_scale_a, &h_scale_a, sizeof(float), hipMemcpyHostToDevice, stream));
CHECK_HIP_ERROR(
hipMemcpyAsync(d_scale_b, &h_scale_b, sizeof(float), hipMemcpyHostToDevice, stream));
CHECK_HIP_ERROR(
hipMemcpyAsync(d_scale_c, &h_scale_c, sizeof(float), hipMemcpyHostToDevice, stream));

inputs.setA(d_a);
inputs.setB(d_b);
Expand All @@ -133,6 +141,7 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle,
inputs.setBeta(&beta);
inputs.setScaleA(d_scale_a);
inputs.setScaleB(d_scale_b);
inputs.setScaleC(d_scale_c);
gemm.setProblem(m, n, k, batch_count, epilogue, inputs);

const int request_solutions = 1;
Expand Down
8 changes: 4 additions & 4 deletions clients/samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ add_executable( sample_hipblaslt_gemm_dgelu_bgrad 12_gemm_dgelu_bgrad/sample_hip
add_executable( sample_hipblaslt_gemm_dgelu_bgrad_ext 12_gemm_dgelu_bgrad_ext/sample_hipblaslt_gemm_dgelu_bgrad_ext.cpp)
add_executable( sample_hipblaslt_gemm_is_tuned_ext 13_is_tuned_gemm_ext/sample_hipblaslt_gemm_is_tuned_ext.cpp)
add_executable( sample_hipblaslt_gemm_tuning_wgm_ext 14_tuning_wgm_gemm_ext/sample_hipblaslt_gemm_tuning_wgm_ext.cpp)
add_executable( sample_hipblaslt_gemm_with_scale_a_b 15_gemm_scale_a_b/sample_hipblaslt_gemm_with_scale_a_b.cpp)
add_executable( sample_hipblaslt_gemm_with_scale_a_b_ext 15_gemm_scale_a_b_ext/sample_hipblaslt_gemm_with_scale_a_b_ext.cpp)
add_executable( sample_hipblaslt_gemm_with_scale_a_b_c 15_gemm_scale_a_b_c/sample_hipblaslt_gemm_with_scale_a_b_c.cpp)
add_executable( sample_hipblaslt_gemm_with_scale_a_b_c_ext 15_gemm_scale_a_b_c_ext/sample_hipblaslt_gemm_with_scale_a_b_c_ext.cpp)
add_executable( sample_hipblaslt_groupedgemm_ext 16_gemm_grouped_ext/sample_hipblaslt_groupedgemm_ext.cpp)
add_executable( sample_hipblaslt_groupedgemm_fixed_mk_ext 17_fixed_mk_gemm_grouped_ext/sample_hipblaslt_groupedgemm_fixed_mk_ext.cpp)
add_executable( sample_hipblaslt_groupedgemm_get_all_algos_ext 18_get_all_algos_gemm_grouped_ext/sample_hipblaslt_groupedgemm_get_all_algos_ext.cpp)
Expand Down Expand Up @@ -83,8 +83,8 @@ set(samples sample_hipblaslt_gemm
sample_hipblaslt_gemm_dgelu_bgrad_ext
sample_hipblaslt_gemm_is_tuned_ext
sample_hipblaslt_gemm_tuning_wgm_ext
sample_hipblaslt_gemm_with_scale_a_b
sample_hipblaslt_gemm_with_scale_a_b_ext
sample_hipblaslt_gemm_with_scale_a_b_c
sample_hipblaslt_gemm_with_scale_a_b_c_ext
sample_hipblaslt_groupedgemm_ext
sample_hipblaslt_groupedgemm_fixed_mk_ext
sample_hipblaslt_groupedgemm_get_all_algos_ext
Expand Down