diff --git a/clients/samples/15_gemm_scale_a_b/sample_hipblaslt_gemm_with_scale_a_b.cpp b/clients/samples/15_gemm_scale_a_b_c/sample_hipblaslt_gemm_with_scale_a_b_c.cpp similarity index 87% rename from clients/samples/15_gemm_scale_a_b/sample_hipblaslt_gemm_with_scale_a_b.cpp rename to clients/samples/15_gemm_scale_a_b_c/sample_hipblaslt_gemm_with_scale_a_b_c.cpp index 6b6d92a639..e557797668 100644 --- a/clients/samples/15_gemm_scale_a_b/sample_hipblaslt_gemm_with_scale_a_b.cpp +++ b/clients/samples/15_gemm_scale_a_b_c/sample_hipblaslt_gemm_with_scale_a_b_c.cpp @@ -32,7 +32,7 @@ #include #include -void simpleGemmScaleAB(hipblasLtHandle_t handle, +void simpleGemmScaleABC(hipblasLtHandle_t handle, hipblasOperation_t trans_a, hipblasOperation_t trans_b, int64_t m, @@ -49,18 +49,20 @@ void simpleGemmScaleAB(hipblasLtHandle_t handle, int64_t max_workspace_size, hipStream_t stream, float h_scale_a, - float h_scale_b); + float h_scale_b, + float h_scale_c); int main() { Runner runner( - 128, 128, 128, 1, 1.f, 0.f, 32 * 1024 * 1024); + 128, 128, 128, 1, 1.f, 1.f, 32 * 1024 * 1024); float scale_a = 0.5f; // scale A setting float scale_b = 2.0f; // scale B setting - std::cout << "Running with Scale A = " << scale_a << " and Scale B = " << scale_b << std::endl; - runner.run([&runner, scale_a, scale_b] { - simpleGemmScaleAB(runner.handle, + float scale_c = 2.0f; // scale C setting + std::cout << "Running with Scale A = " << scale_a << ", Scale B = " << scale_b << ", and Scale C = " << scale_c << std::endl; + runner.run([&runner, scale_a, scale_b, scale_c] { + simpleGemmScaleABC(runner.handle, HIPBLAS_OP_N, HIPBLAS_OP_N, runner.m, @@ -77,13 +79,14 @@ int main() runner.max_workspace_size, runner.stream, scale_a, - scale_b); + scale_b, + scale_c); }); return 0; } -void simpleGemmScaleAB(hipblasLtHandle_t handle, +void simpleGemmScaleABC(hipblasLtHandle_t handle, hipblasOperation_t trans_a, hipblasOperation_t trans_b, int64_t m, @@ -100,22 +103,27 @@ void simpleGemmScaleAB(hipblasLtHandle_t handle, int64_t max_workspace_size, hipStream_t stream, float h_scale_a, - float h_scale_b) + float h_scale_b, + float h_scale_c) { float* d_scale_a; float* d_scale_b; + float* d_scale_c; CHECK_HIP_ERROR(hipMalloc(&d_scale_a, sizeof(float))); CHECK_HIP_ERROR(hipMalloc(&d_scale_b, sizeof(float))); + CHECK_HIP_ERROR(hipMalloc(&d_scale_c, sizeof(float))); CHECK_HIP_ERROR( hipMemcpyAsync(d_scale_a, &h_scale_a, sizeof(float), hipMemcpyHostToDevice, stream)); CHECK_HIP_ERROR( hipMemcpyAsync(d_scale_b, &h_scale_b, sizeof(float), hipMemcpyHostToDevice, stream)); + CHECK_HIP_ERROR( + hipMemcpyAsync(d_scale_c, &h_scale_c, sizeof(float), hipMemcpyHostToDevice, stream)); hipblasLtMatrixLayout_t matA, matB, matC, matD; CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIP_R_8F_E4M3_FNUZ, m, k, m)); CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIP_R_8F_E4M3_FNUZ, k, n, k)); - CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_16F, m, n, m)); - CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_16F, m, n, m)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIP_R_8F_E4M3_FNUZ, m, n, m)); + CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIP_R_8F_E4M3_FNUZ, m, n, m)); hipblasLtMatmulDesc_t matmul; CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); @@ -124,11 +132,13 @@ void simpleGemmScaleAB(hipblasLtHandle_t handle, CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t))); - // Set A and B matrix scale factors + //Set A, B, and C matrix scale factors CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &d_scale_a, sizeof(float*))); CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &d_scale_b, sizeof(float*))); + CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_C_SCALE_POINTER, &d_scale_c, sizeof(float*))); hipblasLtMatmulPreference_t pref; CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref)); diff --git a/clients/samples/15_gemm_scale_a_b_ext/sample_hipblaslt_gemm_with_scale_a_b_ext.cpp b/clients/samples/15_gemm_scale_a_b_c_ext/sample_hipblaslt_gemm_with_scale_a_b_c_ext.cpp similarity index 85% rename from clients/samples/15_gemm_scale_a_b_ext/sample_hipblaslt_gemm_with_scale_a_b_ext.cpp rename to clients/samples/15_gemm_scale_a_b_c_ext/sample_hipblaslt_gemm_with_scale_a_b_c_ext.cpp index e9c097fef7..8800fe7bb9 100644 --- a/clients/samples/15_gemm_scale_a_b_ext/sample_hipblaslt_gemm_with_scale_a_b_ext.cpp +++ b/clients/samples/15_gemm_scale_a_b_c_ext/sample_hipblaslt_gemm_with_scale_a_b_c_ext.cpp @@ -30,7 +30,7 @@ #include "helper.h" -void simpleGemmScaleABExt(hipblasLtHandle_t handle, +void simpleGemmScaleABCExt(hipblasLtHandle_t handle, hipblasOperation_t trans_a, hipblasOperation_t trans_b, int64_t m, @@ -47,20 +47,22 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle, int64_t max_workspace_size, hipStream_t stream, float h_scale_a, - float h_scale_b); + float h_scale_b, + float h_scale_c); int main() { - // This is an example using hipblaslt extension API: ScaleA & ScaleB + // This is an example using hipblaslt extension API: ScaleA & ScaleB & ScaleC Runner runner( 128, 128, 128, 1, 1.f, 1.f, 32 * 128 * 128); float scale_a = 0.5f; // scale A setting float scale_b = 2.0f; // scale B setting - std::cout << "Running with Scale A = " << scale_a << " and Scale B = " << scale_b << std::endl; + float scale_c = 2.0f; // scale C setting + std::cout << "Running with Scale A = " << scale_a << ", Scale B = " << scale_b << ", and Scale C = " << scale_c << std::endl; - runner.run([&runner, scale_a, scale_b] { - simpleGemmScaleABExt(runner.handle, + runner.run([&runner, scale_a, scale_b, scale_c] { + simpleGemmScaleABCExt(runner.handle, HIPBLAS_OP_N, HIPBLAS_OP_N, runner.m, @@ -77,13 +79,14 @@ int main() runner.max_workspace_size, runner.stream, scale_a, - scale_b); + scale_b, + scale_c); }); return 0; } -void simpleGemmScaleABExt(hipblasLtHandle_t handle, +void simpleGemmScaleABCExt(hipblasLtHandle_t handle, hipblasOperation_t trans_a, hipblasOperation_t trans_b, int64_t m, @@ -100,7 +103,8 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle, int64_t max_workspace_size, hipStream_t stream, float h_scale_a, - float h_scale_b) + float h_scale_b, + float h_scale_c) { hipblaslt_ext::GemmPreferenceV2 gemmPref; gemmPref.setMaxWorkspaceBytes(max_workspace_size); @@ -109,8 +113,8 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle, trans_b, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, - HIP_R_16F, - HIP_R_16F, + HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E4M3_FNUZ, HIPBLAS_COMPUTE_32F); hipblaslt_ext::GemmEpilogueV2 @@ -118,12 +122,16 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle, hipblaslt_ext::GemmInputsV2 inputs; float* d_scale_a; float* d_scale_b; + float* d_scale_c; CHECK_HIP_ERROR(hipMalloc(&d_scale_a, sizeof(float))); CHECK_HIP_ERROR(hipMalloc(&d_scale_b, sizeof(float))); + CHECK_HIP_ERROR(hipMalloc(&d_scale_c, sizeof(float))); CHECK_HIP_ERROR( hipMemcpyAsync(d_scale_a, &h_scale_a, sizeof(float), hipMemcpyHostToDevice, stream)); CHECK_HIP_ERROR( hipMemcpyAsync(d_scale_b, &h_scale_b, sizeof(float), hipMemcpyHostToDevice, stream)); + CHECK_HIP_ERROR( + hipMemcpyAsync(d_scale_c, &h_scale_c, sizeof(float), hipMemcpyHostToDevice, stream)); inputs.setA(d_a); inputs.setB(d_b); @@ -133,6 +141,7 @@ void simpleGemmScaleABExt(hipblasLtHandle_t handle, inputs.setBeta(&beta); inputs.setScaleA(d_scale_a); inputs.setScaleB(d_scale_b); + inputs.setScaleC(d_scale_c); gemm.setProblem(m, n, k, batch_count, epilogue, inputs); const int request_solutions = 1; diff --git a/clients/samples/CMakeLists.txt b/clients/samples/CMakeLists.txt index 6cb2a7d162..67318d0a9d 100644 --- a/clients/samples/CMakeLists.txt +++ b/clients/samples/CMakeLists.txt @@ -45,8 +45,8 @@ add_executable( sample_hipblaslt_gemm_dgelu_bgrad 12_gemm_dgelu_bgrad/sample_hip add_executable( sample_hipblaslt_gemm_dgelu_bgrad_ext 12_gemm_dgelu_bgrad_ext/sample_hipblaslt_gemm_dgelu_bgrad_ext.cpp) add_executable( sample_hipblaslt_gemm_is_tuned_ext 13_is_tuned_gemm_ext/sample_hipblaslt_gemm_is_tuned_ext.cpp) add_executable( sample_hipblaslt_gemm_tuning_wgm_ext 14_tuning_wgm_gemm_ext/sample_hipblaslt_gemm_tuning_wgm_ext.cpp) -add_executable( sample_hipblaslt_gemm_with_scale_a_b 15_gemm_scale_a_b/sample_hipblaslt_gemm_with_scale_a_b.cpp) -add_executable( sample_hipblaslt_gemm_with_scale_a_b_ext 15_gemm_scale_a_b_ext/sample_hipblaslt_gemm_with_scale_a_b_ext.cpp) +add_executable( sample_hipblaslt_gemm_with_scale_a_b_c 15_gemm_scale_a_b_c/sample_hipblaslt_gemm_with_scale_a_b_c.cpp) +add_executable( sample_hipblaslt_gemm_with_scale_a_b_c_ext 15_gemm_scale_a_b_c_ext/sample_hipblaslt_gemm_with_scale_a_b_c_ext.cpp) add_executable( sample_hipblaslt_groupedgemm_ext 16_gemm_grouped_ext/sample_hipblaslt_groupedgemm_ext.cpp) add_executable( sample_hipblaslt_groupedgemm_fixed_mk_ext 17_fixed_mk_gemm_grouped_ext/sample_hipblaslt_groupedgemm_fixed_mk_ext.cpp) add_executable( sample_hipblaslt_groupedgemm_get_all_algos_ext 18_get_all_algos_gemm_grouped_ext/sample_hipblaslt_groupedgemm_get_all_algos_ext.cpp) @@ -83,8 +83,8 @@ set(samples sample_hipblaslt_gemm sample_hipblaslt_gemm_dgelu_bgrad_ext sample_hipblaslt_gemm_is_tuned_ext sample_hipblaslt_gemm_tuning_wgm_ext - sample_hipblaslt_gemm_with_scale_a_b - sample_hipblaslt_gemm_with_scale_a_b_ext + sample_hipblaslt_gemm_with_scale_a_b_c + sample_hipblaslt_gemm_with_scale_a_b_c_ext sample_hipblaslt_groupedgemm_ext sample_hipblaslt_groupedgemm_fixed_mk_ext sample_hipblaslt_groupedgemm_get_all_algos_ext