From 3d5fe0ce78b3f2c6653ac592878ad735f4ae0c68 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 27 Nov 2024 04:07:49 +0000 Subject: [PATCH] torch._scaled_mm: support dims of size 0 for tensorwise scaling (#140967) Summary: Ensures we support dims of size 0 properly in `torch._scaled_mm`. Follows the behavior from `torch.mm`. For now only enable support for tensorwise, we can tackle rowwise in a future PR. Test Plan: ``` python test/test_matmul_cuda.py -k test_zero_dim ``` Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/140967 Approved by: https://github.com/eqy, https://github.com/drisspg --- aten/src/ATen/native/cuda/Blas.cpp | 13 +++++++++++++ test/test_matmul_cuda.py | 27 +++++++++++++++++++++++++++ torch/_meta_registrations.py | 11 +++++++---- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 1ede6964be6c0..2367285f9d492 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1050,6 +1050,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels + // do not support this case). + if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) { + // `out` was created with `at::empty`. In the case where we are multiplying + // MxK by KxN and K is the zero dim, we need to initialize here to properly + // return a tensor of zeros. + if (mat1_sizes[1] == 0) { + out.zero_(); + } + + return out; + } + // We are doing row-wise scaling if (scaling_choice == ScalingType::RowWise) { TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling."); diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 1d5f6bd711f8e..0af0969526920 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -700,6 +700,33 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @parametrize("which_dim_zero", [0, 1, 2]) + @parametrize("use_torch_compile", [False, True]) + def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: + device = "cuda" + x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn + out_dtype = torch.bfloat16 + M, K, N = 32, 32, 32 + if which_dim_zero == 0: + M = 0 + elif which_dim_zero == 1: + K = 0 + elif which_dim_zero == 2: + N = 0 + + x_fp8 = torch.zeros(M, K, device=device).to(x_dtype) + y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t() + out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) + scale_a = torch.tensor(float('-inf'), device=device) + scale_b = torch.tensor(float('-inf'), device=device) + f = torch._scaled_mm + if use_torch_compile: + f = torch.compile(torch._scaled_mm) + out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) + self.assertEqual(out_dtype, out_fp8.dtype) + self.assertEqual(out_fp32, out_fp8.to(torch.float)) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 291e302228811..74cfac4298766 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5589,13 +5589,16 @@ def is_row_major(stride): def is_col_major(stride): return stride[0] == 1 and stride[1] > 1 + def has_zero_dim(tensor_2d): + return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0 + torch._check( - is_row_major(self.stride()), - lambda: "self must be row_major", + is_row_major(self.stride()) or has_zero_dim(self), + lambda: f"self must be row_major, got stride {self.stride()}", ) torch._check( - is_col_major(mat2.stride()), - lambda: "mat2 must be col_major", + is_col_major(mat2.stride()) or has_zero_dim(mat2), + lambda: f"mat2 must be col_major, got stride {mat2.stride()}", ) torch._check( self.size(1) % 16 == 0,