Skip to content

Commit

Permalink
torch._scaled_mm: support dims of size 0 for tensorwise scaling (pyto…
Browse files Browse the repository at this point in the history
…rch#140967)

Summary:

Ensures we support dims of size 0 properly in `torch._scaled_mm`. Follows the behavior from `torch.mm`.

For now only enable support for tensorwise, we can tackle rowwise in a future PR.

Test Plan:

```
python test/test_matmul_cuda.py -k test_zero_dim
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#140967
Approved by: https://github.com/eqy, https://github.com/drisspg
  • Loading branch information
vkuzo authored and pytorchmergebot committed Nov 27, 2024
1 parent 6e61ff4 commit 3d5fe0c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
13 changes: 13 additions & 0 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});

// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels
// do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}

return out;
}

// We are doing row-wise scaling
if (scaling_choice == ScalingType::RowWise) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
Expand Down
27 changes: 27 additions & 0 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,33 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):

torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
device = "cuda"
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
out_dtype = torch.bfloat16
M, K, N = 32, 32, 32
if which_dim_zero == 0:
M = 0
elif which_dim_zero == 1:
K = 0
elif which_dim_zero == 2:
N = 0

x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(float('-inf'), device=device)
scale_b = torch.tensor(float('-inf'), device=device)
f = torch._scaled_mm
if use_torch_compile:
f = torch.compile(torch._scaled_mm)
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))


@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
Expand Down
11 changes: 7 additions & 4 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5589,13 +5589,16 @@ def is_row_major(stride):
def is_col_major(stride):
return stride[0] == 1 and stride[1] > 1

def has_zero_dim(tensor_2d):
return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0

torch._check(
is_row_major(self.stride()),
lambda: "self must be row_major",
is_row_major(self.stride()) or has_zero_dim(self),
lambda: f"self must be row_major, got stride {self.stride()}",
)
torch._check(
is_col_major(mat2.stride()),
lambda: "mat2 must be col_major",
is_col_major(mat2.stride()) or has_zero_dim(mat2),
lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
)
torch._check(
self.size(1) % 16 == 0,
Expand Down

0 comments on commit 3d5fe0c

Please sign in to comment.