From 1d350d6611bcd0434abe80f697aff9b540555f1f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 Jan 2025 23:18:56 +0800 Subject: [PATCH 01/10] add w4a4 --- torchao/csrc/cuda/int4_cutlass.cu | 231 ++++++++++++++++++++++++++++++ torchao/ops.py | 53 +++++++ 2 files changed, 284 insertions(+) create mode 100644 torchao/csrc/cuda/int4_cutlass.cu diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu new file mode 100644 index 0000000000..452abcceaa --- /dev/null +++ b/torchao/csrc/cuda/int4_cutlass.cu @@ -0,0 +1,231 @@ +#include +#include + +// copied from s8s4_linear_cutlass.cu +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) +#define BUILD_INT4_MM_CUTLASS +#endif + +#if defined(BUILD_INT4_MM_CUTLASS) +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace torchao { + +#if defined(BUILD_INT4_MM_CUTLASS) +// define common params +using ElementA = cutlass::int4b_t; +using ElementB = cutlass::int4b_t; +using ElementAccumulator = int32_t; +using OpClass = cutlass::arch::OpClassTensorOp; +using ArchTag = cutlass::arch::Sm80; + +// how many elements to load at a time -> load 128-bit = 32 x 4-bit +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; +#endif + +// we will do input checks in python. A and B are stored as int8 +torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) { +#if defined(BUILD_INT4_MM_CUTLASS) + int M = A.size(0); + int K = A.size(1) * 2; + int N = B.size(1); + torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32)); + + // some configs for int4 mma + // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu + // using default config. this can be tuned. + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + // static int const kStages = 3; + using ElementC = int32_t; + using Gemm = cutlass::gemm::device::Gemm< + ElementA, cutlass::layout::RowMajor, // A matrix + ElementB, cutlass::layout::ColumnMajor, // B matrix + ElementC, cutlass::layout::RowMajor, // C matrix + ElementAccumulator, OpClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape + >; + Gemm::Arguments args { + {M, N, K}, + {reinterpret_cast(A.data_ptr()), K}, + {reinterpret_cast(B.data_ptr()), K}, + {C.data_ptr(), N}, + {C.data_ptr(), N}, + {1, 0} // epilogue + }; + Gemm gemm_op; + CUTLASS_STATUS_CHECK(gemm_op(args)); + return C; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +template< + typename ElementC, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int numStages> +void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) { + // problem shape + int M = A.size(0); + int K = A.size(1) * 2; + int N = B.size(1); + + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // 8 for BF16/FP16 + using ElementEpilogue = float; + constexpr int numEpilogueStages = 1; + + // build epilogue visitor tree + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest; + using Multiply = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode + >; + + // (1, N) + using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, ElementC, + cute::Stride // MNL + >; + using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; + + // (M, 1) + using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, ElementC, + cute::Stride // MNL + >; + using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementC, RoundMode, + cute::Stride // MNL + >; + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT; + + using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, cutlass::layout::RowMajor, AlignmentC, + ElementAccumulator, ElementEpilogue, OpClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, + EVTOutput, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + numStages, + cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work + numEpilogueStages + >::GemmKernel; + using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // col_scale, row_scale, and C must have the same dtype + const ElementA *A_ptr = reinterpret_cast(A.data_ptr()); + const ElementB *B_ptr = reinterpret_cast(B.data_ptr()); + const ElementC *col_scale_ptr = reinterpret_cast(col_scale.data_ptr()); + const ElementC *row_scale_ptr = reinterpret_cast(row_scale.data_ptr()); + ElementC *C_ptr = reinterpret_cast(C.data_ptr()); + + typename EVTOutput::Arguments callback_args{ + { + { + {}, // Accum + {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale + {} // Multiply + }, // EVTCompute0 + {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale + {} // Multiply + }, // EVTCompute1 + {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput + }; + + typename DeviceGemm::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + cutlass::gemm::GemmCoord{M, N, K}, + 1, // batch_split + callback_args, + A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr + M * K, N * K, 0, 0, // batch_stride A, B, C, D + K, K, 0, 0 // stride A, B, C, D + ); + + DeviceGemm gemm_op; + auto stream = at::cuda::getCurrentCUDAStream(); + CUTLASS_STATUS_CHECK(gemm_op.can_implement(args)); + CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream)); +} + +// we will do input checks in python. A and B are stored as int8 +// this function is based on the following cutlass example +// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu +// also with the help of emitted code from cutlass Python +torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) { +#if defined(BUILD_INT4_MM_CUTLASS) + int M = A.size(0); + int N = B.size(1); + torch::Tensor C = torch::empty({M, N}, row_scale.options()); + + // some configs for int4 mma + // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu + // using default config. this can be tuned. + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + constexpr int numStages = 3; + + AT_DISPATCH_SWITCH( + row_scale.scalar_type(), + "scaled_int4_mm_cutlass", + AT_DISPATCH_CASE( + torch::ScalarType::Half, + [&]() { + using ElementC = cutlass::half_t; + scaled_int4_mm_cutlass_dispatch< + ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( + A, B, row_scale, col_scale, C); + } + ) + AT_DISPATCH_CASE( + torch::ScalarType::BFloat16, + [&]() { + using ElementC = cutlass::bfloat16_t; + scaled_int4_mm_cutlass_dispatch< + ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( + A, B, row_scale, col_scale, C); + } + ) + ); + + return C; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass); + m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass); +} + +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..840dbc0e97 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -22,6 +22,10 @@ lib.define( "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor") +lib.define( + "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor" +) def register_custom_op(name): @@ -615,3 +619,52 @@ def _( dtype=input_scale.dtype, device=input.device, ) + + +def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor: + """ + CUTLASS-based W4A4 matmul. + Args: + A: first INT4 tensor, packed in INT8 dtype, row-major layout. + B: second INT4 tensor, packed in INT8 dtype, column-major layout. + Returns: + output: result tensor, in row-major layout. + """ + assert A.dtype == B.dtype == torch.int8 + assert A.ndim == B.ndim == 2 + assert A.shape[1] == B.shape[0] + assert A.is_contiguous() and B.T.is_contiguous() + return torch.ops.torchao.int4_mm_cutlass.default(A, B) + + +@register_custom_op("torchao::int4_mm_cutlass") +def _(A: Tensor, B: Tensor) -> Tensor: + return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32) + + +def scaled_int4_mm_cutlass( + A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor +) -> Tensor: + """ + CUTLASS-based W4A4 scaled-matmul. + Args: + A: first INT4 tensor, packed in INT8 dtype, row-major layout. + B: second INT4 tensor, packed in INT8 dtype, column-major layout. + row_scale: scaling for each output row. + col_scale: scaling for each output column. + Returns: + output: result tensor, in row-major layout. + """ + assert A.dtype == B.dtype == torch.int8 + assert A.ndim == B.ndim == 2 + assert A.shape[1] == B.shape[0] + assert A.is_contiguous() and B.T.is_contiguous() + assert row_scale.ndim == col_scale.ndim == 1 + assert row_scale.dtype == col_scale.dtype + assert row_scale.dtype in (torch.float16, torch.bfloat16) + return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale) + + +@register_custom_op("torchao::scaled_int4_mm_cutlass") +def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: + return row_scale.new_empty(A.shape[0], B.shape[1]) From 7e277df04ab75ceeda1ec0efc71116c91e9c6083 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 7 Jan 2025 23:41:15 +0800 Subject: [PATCH 02/10] add test --- test/test_ops.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index c5821eed44..a7e05843ec 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -536,5 +536,63 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M,N,K", + [(1, 256, 512), (18, 512, 256), (17, 256, 512)] +) +def test_int4_mm_cutlass(M, N, K): + A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") + B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") + actual = torchao.ops.int4_mm_cutlass(A, B.T) + + # NOTE: A >> 4 will perform sign-bit extension + unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) + unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) + expected = (unpacked_A.float() @ unpacked_B.float().T).to(torch.int32) + + torch.testing.assert_close(actual, expected) + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.int4_mm_cutlass, + (A, B), + test_utils=test_utils, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "M,N,K", + [(1, 256, 512), (18, 512, 256), (17, 256, 512)] +) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +def test_scaled_int4_mm_cutlass(M, N, K, dtype): + A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") + B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") + row_scale = torch.randn(M, dtype=dtype, device="cuda") + col_scale = torch.randn(N, dtype=dtype, device="cuda") + actual = torchao.ops.scaled_int4_mm_cutlass(A, B.T, row_scale, col_scale) + + # NOTE: A >> 4 will perform sign-bit extension + unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) + unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) + + expected = unpacked_A.float() @ unpacked_B.float().T + expected = expected * row_scale.view(-1, 1) * col_scale.view(1, -1) + expected = expected.to(dtype) + + torch.testing.assert_close(actual, expected) + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.scaled_int4_mm_cutlass, + (A, B, row_scale, col_scale), + test_utils=test_utils, + ) + + if __name__ == "__main__": run_tests() From a44df9e421d45826a6003a554c7a0718a1889b3e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 Jan 2025 17:40:52 +0700 Subject: [PATCH 03/10] hook up to AQT --- torchao/dtypes/affine_quantized_tensor_ops.py | 6 ++ .../uintx/cutlass_int4_packed_layout.py | 36 ++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 56 +++++++++++++++++++ 4 files changed, 100 insertions(+) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 76df949852..f844a6a89a 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,6 +21,8 @@ _linear_int8_act_int8_weight_block_sparse_impl, ) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) @@ -151,6 +153,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ), + ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index a6412ec88c..018c2e2ad5 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -154,3 +154,39 @@ def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias) return out + + +def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int4(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + +def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import scaled_int4_mm_cutlass + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + batch_dims = input_tensor.shape[:-2] + input = input.view(-1, input.shape[-1]) + input_scale = input_scale.view(-1) + out = scaled_int4_mm_cutlass(input, weight.T, input_scale, weight_scale) + if bias is not None: + out = out + bias + out = out.view(*batch_dims, out.shape[-1]) + + return out diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index d0d29cf4be..49e75822b2 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -51,6 +51,7 @@ fpx_weight_only, gemlite_uintx_weight_only, int4_weight_only, + int4_dynamic_activation_int4_weight, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, int8_dynamic_activation_int8_weight, @@ -102,6 +103,7 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", + "int4_dynamic_activation_int4_weight", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int8_dynamic_activation_int8_semi_sparse_weight", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..801911b01e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -654,6 +654,62 @@ def int8_dynamic_activation_int4_weight( ) +def int4_dynamic_activation_int4_weight( + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + + if not isinstance(layout, CutlassInt4PackedLayout): + raise NotImplementedError( + f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." + ) + if mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only mapping_type=SYMMETRIC is supported.") + if act_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") + + def _int4_symm_per_token_quant_cutlass(x): + return to_affine_quantized_intx( + x, + mapping_type=act_mapping_type, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=1e-5, + _layout=layout, + ) + + def apply_int4_dynamic_activation_int4_weight_quant(weight): + weight = to_affine_quantized_intx( + weight, + mapping_type=mapping_type, + block_size=(1, weight.shape[1]), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=torch.finfo(torch.float32).eps, + _layout=layout, + ) + weight = to_linear_activation_quantized( + weight, + _int4_symm_per_token_quant_cutlass, + ) + return weight + + return _get_linear_subclass_inserter( + apply_int4_dynamic_activation_int4_weight_quant + ) + + def gemlite_uintx_weight_only( group_size: Optional[int] = 64, bit_width: int = 4, From de167f062ccf4d4b5f3f7385558587d4769cb469 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 Jan 2025 21:21:59 +0800 Subject: [PATCH 04/10] fix quant api test --- test/dtypes/test_affine_quantized.py | 2 + test/test_ops.py | 10 +-- .../uintx/cutlass_int4_packed_layout.py | 2 +- torchao/quantization/__init__.py | 2 +- torchao/quantization/quant_api.py | 84 +++++++++++-------- 5 files changed, 53 insertions(+), 47 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..eb79b05332 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -11,6 +11,7 @@ from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -57,6 +58,7 @@ def get_quantization_functions( layout=CutlassInt4PackedLayout(), ) ) + base_functions.append(int4_dynamic_activation_int4_weight()) if do_sparse: base_functions.append( diff --git a/test/test_ops.py b/test/test_ops.py index 8d66bc516c..38797d8cab 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -608,10 +608,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "M,N,K", - [(1, 256, 512), (18, 512, 256), (17, 256, 512)] -) +@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)]) def test_int4_mm_cutlass(M, N, K): A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") @@ -634,10 +631,7 @@ def test_int4_mm_cutlass(M, N, K): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "M,N,K", - [(1, 256, 512), (18, 512, 256), (17, 256, 512)] -) +@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) def test_scaled_int4_mm_cutlass(M, N, K, dtype): A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index 018c2e2ad5..d7374c8d50 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -181,7 +181,7 @@ def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias) input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale - batch_dims = input_tensor.shape[:-2] + batch_dims = input_tensor.shape[:-1] input = input.view(-1, input.shape[-1]) input_scale = input_scale.view(-1) out = scaled_int4_mm_cutlass(input, weight.T, input_scale, weight_scale) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 49e75822b2..aa4a51d497 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -50,8 +50,8 @@ float8_weight_only, fpx_weight_only, gemlite_uintx_weight_only, - int4_weight_only, int4_dynamic_activation_int4_weight, + int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, int8_dynamic_activation_int8_weight, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 801911b01e..b209f28043 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -654,19 +654,12 @@ def int8_dynamic_activation_int4_weight( ) -def int4_dynamic_activation_int4_weight( +def apply_int4_dynamic_activation_int4_weight_quant( + weight: torch.Tensor, layout=CutlassInt4PackedLayout(), mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, ): - """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear - - Args: - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - """ - if not isinstance(layout, CutlassInt4PackedLayout): raise NotImplementedError( f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." @@ -676,37 +669,40 @@ def int4_dynamic_activation_int4_weight( if act_mapping_type != MappingType.SYMMETRIC: raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") - def _int4_symm_per_token_quant_cutlass(x): - return to_affine_quantized_intx( - x, - mapping_type=act_mapping_type, - block_size=_get_per_token_block_size(x), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - eps=1e-5, - _layout=layout, - ) + weight = to_affine_quantized_intx( + weight, + mapping_type=mapping_type, + block_size=(1, weight.shape[1]), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=torch.finfo(torch.float32).eps, + _layout=layout, + ) + weight = to_linear_activation_quantized( + weight, + _int4_symm_per_token_quant_cutlass, + ) + return weight - def apply_int4_dynamic_activation_int4_weight_quant(weight): - weight = to_affine_quantized_intx( - weight, - mapping_type=mapping_type, - block_size=(1, weight.shape[1]), - target_dtype=torch.int8, - quant_min=-8, - quant_max=7, - eps=torch.finfo(torch.float32).eps, - _layout=layout, - ) - weight = to_linear_activation_quantized( - weight, - _int4_symm_per_token_quant_cutlass, - ) - return weight +def int4_dynamic_activation_int4_weight( + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ return _get_linear_subclass_inserter( - apply_int4_dynamic_activation_int4_weight_quant + apply_int4_dynamic_activation_int4_weight_quant, + layout=layout, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, ) @@ -909,6 +905,19 @@ def _int8_symm_per_token_reduced_range_quant_cutlass( ) +def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: + return to_affine_quantized_intx( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=1e-5, + _layout=CutlassInt4PackedLayout(), + ) + + def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, @@ -1348,6 +1357,7 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, _int8_symm_per_token_reduced_range_quant_cutlass, + _int4_symm_per_token_quant_cutlass, _input_activation_quant_func_fp8, ] ) From fe1f0eb4af306a31175df8f6464473e829d6bb14 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 8 Jan 2025 22:17:11 +0800 Subject: [PATCH 05/10] fix test --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 38797d8cab..d99bc6b055 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -641,8 +641,8 @@ def test_scaled_int4_mm_cutlass(M, N, K, dtype): actual = torchao.ops.scaled_int4_mm_cutlass(A, B.T, row_scale, col_scale) # NOTE: A >> 4 will perform sign-bit extension - unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) - unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) + unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=2).reshape(M, K) + unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=2).reshape(N, K) expected = unpacked_A.float() @ unpacked_B.float().T expected = expected * row_scale.view(-1, 1) * col_scale.view(1, -1) From 883384b1a07f6425940984cb4a7c1f9a3d3745ab Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 Jan 2025 14:03:59 +0800 Subject: [PATCH 06/10] make threadblockswizzle a template param --- .../s8s4_linear_cutlass.cu | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 411343f0da..0504fef619 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -33,6 +33,7 @@ template< typename WarpShape, typename InstructionShape, int NumStages, + typename ThreadblockSwizzle, typename ElementA, typename ElementB, typename ElementAccumulator, @@ -53,10 +54,6 @@ void s8s4_linear_kernel_cutlass_sm8x( using LayoutOutput = cutlass::layout::RowMajor; using ElementEpilogue = float; - - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; - constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); @@ -293,13 +290,15 @@ template static void select_config( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const at::Tensor& tensor_c, at::Tensor& tensor_d) { const auto dprops = at::cuda::getCurrentDeviceProperties(); const auto is_sm8x = dprops->major == 8; if (is_sm8x) { if constexpr (std::is_same::value && std::is_same::value) { + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; // A minimal heuristic to improve performance for small number @@ -309,8 +308,8 @@ static void select_config( using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; constexpr auto NumStages = 6; s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + ThreadblockShape, WarpShape, InstructionShape, NumStages, + ThreadblockSwizzle, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } else if (tensor_a.size(0) <= 32) { @@ -318,8 +317,8 @@ static void select_config( using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; constexpr auto NumStages = 5; s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + ThreadblockShape, WarpShape, InstructionShape, NumStages, + ThreadblockSwizzle, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } else { @@ -327,8 +326,8 @@ static void select_config( using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; constexpr auto NumStages = 4; s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + ThreadblockShape, WarpShape, InstructionShape, NumStages, + ThreadblockSwizzle, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } From ee34bb2f32fe083066f64a28acc1378e7c5b2c95 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 22 Jan 2025 15:40:51 +0800 Subject: [PATCH 07/10] re-use s8s4 cutlass template --- .../s4s4_linear_cutlass.cu | 268 ++++++++++++++++ .../s8s4_linear_cutlass.cu | 275 +---------------- .../cuda/s8s4_linear_cutlass/scaled_linear.h | 288 ++++++++++++++++++ 3 files changed, 561 insertions(+), 270 deletions(-) create mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu create mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu new file mode 100644 index 0000000000..8faf13c1bd --- /dev/null +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu @@ -0,0 +1,268 @@ +#include + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) +#define BUILD_S4S4_LINEAR_CUTLASS +#endif + +#if defined(BUILD_S4S4_LINEAR_CUTLASS) +#include "scaled_linear.h" +#include +#include +#include +#endif + +namespace torchao { + +#if defined(BUILD_S4S4_LINEAR_CUTLASS) + +template +static void select_config( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major >= 8; + + if (is_sm8x) { + using ElementA = cutlass::int4b_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + constexpr auto NumStages = 3; + using Operator = cutlass::arch::OpMultiplyAddSaturate; + // using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; // this does not work + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + + scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, + ThreadblockSwizzle, ElementA, ElementB, ElementAccumulator, Operator, + Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_tensor_c( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + if (tensor_c.numel() == 0) { + using ElementC = ElementOutput; + using UseTensorC = std::false_type; + select_config< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + using UseTensorC = std::true_type; + if (tensor_c.scalar_type() == at::ScalarType::Half) { + using ElementC = cutlass::half_t; + select_config< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { + using ElementC = cutlass::bfloat16_t; + select_config< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); +} + +static void +dispatch_on_tensor_a_scale_and_tensor_b_scale( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), + __func__, " : Operator not supported for output datatype ", + tensor_d.scalar_type(), " as it's different from the first ", + " operand scale datatype ", tensor_a_scale.scalar_type()); + + if (tensor_a_scale.scalar_type() == at::ScalarType::Half && + tensor_b_scale.scalar_type() == at::ScalarType::Half) { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementOutput = cutlass::half_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && + tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a_scale.scalar_type(), + " for first operand scale and ", tensor_b_scale.scalar_type(), + " for second operand scale"); +} + +static void +check_inputs( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate layouts of arguments. + TORCH_CHECK(xq.dim() >= 2, + __func__, " : Expected xq argument to be 2D or " + "higher-dimensional tensor, got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, + __func__, " : Expected xq argument to be strided, got layout ", + xq.layout()); + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, + __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + "D tensor, got ", x_scale.dim(), " dims"); + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, + __func__, " : Expected xq scale argument to be strided, got " + "layout ", x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, + __func__, " : Expected wq argument to be 2D tensor, got ", + wq.dim(), " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, + __func__, " : Expected wq argument to be strided, got layout ", + wq.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, + __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", + "got ", w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, + __func__, " : Expected wq scale argument to be strided, got " + "layout ", w_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, + __func__, " : Expected bias argument to be 1D tensor, got ", + bias.dim(), " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, + __func__, " : Expected bias argument to be strided, got ", + "layout ", bias.layout()); + } + + // Validate sizes of arguments. + const auto xq_sizes = xq.sizes().vec(); + TORCH_CHECK(xq_sizes.back() == wq.size(1), + __func__, " : Expected xq argument to have ", wq.size(1), + " columns, but got ", xq_sizes.back()); + const auto x_scale_sizes = x_scale.sizes().vec(); + for (auto i = 0; i < x_scale_sizes.size(); ++i) + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], + __func__, " : Expected xq scale argument size at position ", + i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), + __func__, " : Expected wq scale argument to have ", wq.size(0), + " elements, got ", w_scale.numel(), " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == wq.size(0), + __func__, " : Expected bias argument to have ", wq.size(0), + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto xq_strides = xq.strides(); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, + __func__, " : Expected xq argument in row-major layout"); + auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; + for (int i = xq_strides.size() - 3; i >= 0; --i) { + xq_stride_expected *= xq_sizes[i + 1]; + TORCH_CHECK(xq_strides[i] == xq_stride_expected, + __func__, " : Expected xq argument in row-major layout"); + } + TORCH_CHECK(x_scale.is_contiguous(), + __func__, " : Expected xq scale argument to be contiguous"); + const auto wq_strides = wq.strides(); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, + __func__, " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), + __func__, " : Expected wq scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, + __func__, " : Expected bias argument to be contiguous"); + } +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (xq * x_scale) @ (wq * w_scale).T + bias +// Notes: The "x_scale" tensor is expected to be a vector, of size +// equal to number of rows of "xq" tensor. The "w_scale" tensor is +// expected to be a vector, of size equal to number of rows of "wq" +// tensor. The "bias" tensor is expected to be a vector, of size equal +// to number of rows of "wq" tensor. +at::Tensor +s4s4_linear_cutlass( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { +#if defined(BUILD_S4S4_LINEAR_CUTLASS) + // Check inputs. + check_inputs(xq, x_scale, wq, w_scale, bias); + + // Squash the input tensors as appropriate. + const auto xq_sizes = xq.sizes().vec(); + const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); + const auto x_scale_sizes = x_scale.sizes().vec(); + const auto x_scale_1d = x_scale.reshape({-1}); + const auto w_scale_1d = w_scale.reshape({-1}); + + // Introduce alias names for arguments, according to the CUTLASS + // naming conventions. + const auto& tensor_a = xq_2d; + const auto& tensor_a_scale = x_scale_1d; + const auto& tensor_b = wq; + const auto& tensor_b_scale = w_scale_1d; + const auto& tensor_c = bias; + + // Create output tensor. + at::Tensor tensor_d = + tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + + // Dispatch to appropriate kernel template. + dispatch_on_tensor_a_scale_and_tensor_b_scale( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + + // Reshape and return output tensor. + auto tensor_d_sizes = xq_sizes; + tensor_d_sizes.back() = wq.size(0); + return tensor_d.reshape(tensor_d_sizes); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::s4s4_linear_cutlass", &s4s4_linear_cutlass); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 0504fef619..53eaf53961 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -11,280 +11,15 @@ #endif #if defined(BUILD_S8S4_LINEAR_CUTLASS) +#include "scaled_linear.h" #include #include #include -#include -#include - -#define CUTLASS_STATUS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - __func__, " : Got CUTLASS error: ", \ - cutlassGetStatusString(status)); \ - } #endif namespace torchao { #if defined(BUILD_S8S4_LINEAR_CUTLASS) -template< - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - int NumStages, - typename ThreadblockSwizzle, - typename ElementA, - typename ElementB, - typename ElementAccumulator, - typename Operator, - typename ElementAScale, - typename ElementBScale, - typename ElementC, - typename UseTensorC, - typename ElementOutput> -void s8s4_linear_kernel_cutlass_sm8x( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using SmArch = cutlass::arch::Sm80; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using ElementEpilogue = float; - constexpr auto NumEVTEpilogueStages = 1; - - const int m = tensor_a.size(0); - const int n = tensor_b.size(0); - const int k = tensor_a.size(1); - - constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; - - // Check for current CUTLASS limitations w.r.t. alignments. - TORCH_CHECK(k % AlignmentA == 0, - __func__, " : Number of columns of tensor A must be divisible ", - "by ", AlignmentA); - TORCH_CHECK(k % AlignmentB == 0, - __func__, " : Number of columns of tensor B must be divisible ", - "by ", AlignmentB); - TORCH_CHECK(n % AlignmentC == 0, - __func__, " : Number of columns of tensor C must be divisible ", - "by ", AlignmentC); - - using TensorAScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementAScale, - AlignmentAScale, - NumEVTEpilogueStages>; - using TensorBScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementBScale, - AlignmentBScale, - NumEVTEpilogueStages>; - using TensorCTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementC, - AlignmentC, - NumEVTEpilogueStages>; - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementOutput, - AlignmentOutput, - NumEVTEpilogueStages>; - - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - using TensorAScale = - cutlass::epilogue::threadblock::VisitorColBroadcast< - TensorAScaleTileThreadMap, - ElementAScale, - cute::Stride>; - using TensorAScaleArguments = typename TensorAScale::Arguments; - - using TensorBScale = - cutlass::epilogue::threadblock::VisitorRowBroadcast< - TensorBScaleTileThreadMap, - ElementBScale, - cute::Stride>; - using TensorBScaleArguments = typename TensorBScale::Arguments; - - using TensorCScalar = - cutlass::epilogue::threadblock::VisitorScalarBroadcast; - using TensorCTensor = - cutlass::epilogue::threadblock::VisitorRowBroadcast< - TensorCTileThreadMap, - ElementC, - cute::Stride>; - using TensorC = - std::conditional_t; - using TensorCArguments = typename TensorC::Arguments; - - using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, - cutlass::FloatRoundStyle::round_to_nearest - >; - using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< - ApplyAScale, - Accum, - TensorAScale>; - - using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, - cutlass::FloatRoundStyle::round_to_nearest - >; - using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< - ApplyBScale, - EVTApplyAScale, - TensorBScale>; - - using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::plus, ElementEpilogue, ElementEpilogue, - cutlass::FloatRoundStyle::round_to_nearest - >; - using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< - ApplySum, - EVTApplyBScale, - TensorC>; - - using Output = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementOutput, - cutlass::FloatRoundStyle::round_to_nearest, - cute::Stride // StrideMNL - >; - - using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< - Output, - EVTApplySum>; - - using EVTKernel = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, - ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementOutput, LayoutOutput, AlignmentOutput, - ElementAccumulator, - ElementEpilogue, - cutlass::arch::OpClassTensorOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EVTOutput, - ThreadblockSwizzle, - NumStages, - Operator, - NumEVTEpilogueStages - >::GemmKernel; - - using Gemm = cutlass::gemm::device::GemmUniversalBase; - - cutlass::gemm::GemmCoord problem_size(m, n, k); - constexpr auto SplitKFactor = 1; - - TensorAScaleArguments tensor_a_scale_arguments{ - (ElementAScale*)tensor_a_scale.data_ptr(), - ElementAScale(1), - {cute::_1{}, cute::_0{}, problem_size.m()} - }; - TensorBScaleArguments tensor_b_scale_arguments{ - (ElementBScale*)tensor_b_scale.data_ptr(), - ElementBScale(1), - {cute::_0{}, cute::_1{}, problem_size.n()} - }; - TensorCArguments tensor_c_arguments{ - [&]() -> TensorCArguments { - if constexpr (UseTensorC::value) { - return {(ElementC*)tensor_c.data_ptr(), - ElementC(0), - {cute::_0{}, cute::_1{}, problem_size.n()}}; - } else { - return {ElementC(0)}; - } - }() - }; - typename Output::Arguments output_arguments{ - (ElementOutput*)tensor_d.data_ptr(), - {problem_size.n(), cute::_1{}, problem_size.mn().product()} - }; - typename EVTOutput::Arguments callback_arguments{ - { - { - { - {}, // Accum - tensor_a_scale_arguments, // TensorAScale - {} // ApplyAScale - }, // EVTApplyAScale - tensor_b_scale_arguments, // TensorBScale - {}, // ApplyBScale - }, // EVTApplyBScale - tensor_c_arguments, // TensorC - {} // ApplySum - }, // EVTApplySum - output_arguments // Output - }; // EVTOutput - constexpr auto AvailSms = -1; - - typename Gemm::Arguments arguments( - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - SplitKFactor, - callback_arguments, // arguments of EVT callbacks - (ElementA*)tensor_a.data_ptr(), - (ElementB*)tensor_b.data_ptr(), - nullptr, // ptr C (unused) - nullptr, // ptr D (unused) - problem_size.mk().product(), // batch stride A - problem_size.nk().product(), // batch stride B - 0, // batch stride C (unused) - 0, // batch stride D (unused) - problem_size.k(), // stride A - problem_size.k(), // stride B - 0, // stride C (unused) - 0, // stride D (unused) - AvailSms); - - Gemm gemm_op; - - cutlass::Status status; - - // Verify that GEMM operation with given arguments can be performed - // by CUTLASS. - status = gemm_op.can_implement(arguments); - CUTLASS_STATUS_CHECK(status); - - // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. - const auto workspace_size = Gemm::get_workspace_size(arguments); - auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, - at::TensorOptions().dtype(at::kByte)); - - // Initialize CUTLASS mixed datatypes GEMM object. - status = gemm_op.initialize(arguments, workspace.data_ptr(), - at::cuda::getCurrentCUDAStream()); - CUTLASS_STATUS_CHECK(status); - - // Perform mixed datatypes GEMM operation. - status = gemm_op.run(at::cuda::getCurrentCUDAStream()); - CUTLASS_STATUS_CHECK(status); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} template static void select_config( @@ -307,7 +42,7 @@ static void select_config( using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass_sm8x< + scaled_linear_kernel_cutlass_sm8x< ThreadblockShape, WarpShape, InstructionShape, NumStages, ThreadblockSwizzle, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, @@ -316,7 +51,7 @@ static void select_config( using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass_sm8x< + scaled_linear_kernel_cutlass_sm8x< ThreadblockShape, WarpShape, InstructionShape, NumStages, ThreadblockSwizzle, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, @@ -325,7 +60,7 @@ static void select_config( using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass_sm8x< + scaled_linear_kernel_cutlass_sm8x< ThreadblockShape, WarpShape, InstructionShape, NumStages, ThreadblockSwizzle, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, @@ -442,7 +177,7 @@ dispatch_on_tensor_a_scale_and_tensor_b_scale( " for second operand scale"); } -void +static void check_inputs( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, const at::Tensor& w_scale, const at::Tensor& bias) { diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h new file mode 100644 index 0000000000..991384b572 --- /dev/null +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h @@ -0,0 +1,288 @@ +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } + +namespace torchao { + +template< + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int NumStages, + typename ThreadblockSwizzle, + typename ElementA, + typename ElementB, + typename ElementAccumulator, + typename Operator, + typename ElementAScale, + typename ElementBScale, + typename ElementC, + typename UseTensorC, + typename ElementOutput> +void scaled_linear_kernel_cutlass_sm8x( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using SmArch = cutlass::arch::Sm80; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementEpilogue = float; + constexpr auto NumEVTEpilogueStages = 1; + + const int m = tensor_a.size(0); + const int n = tensor_b.size(0); + const int k = std::is_same::value ? + tensor_a.size(1) * 2 : + tensor_a.size(1); + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentAScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentBScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; + + // Check for current CUTLASS limitations w.r.t. alignments. + TORCH_CHECK(k % AlignmentA == 0, + __func__, " : Number of columns of tensor A must be divisible ", + "by ", AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, + __func__, " : Number of columns of tensor B must be divisible ", + "by ", AlignmentB); + TORCH_CHECK(n % AlignmentC == 0, + __func__, " : Number of columns of tensor C must be divisible ", + "by ", AlignmentC); + + using TensorAScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementAScale, + AlignmentAScale, + NumEVTEpilogueStages>; + using TensorBScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementBScale, + AlignmentBScale, + NumEVTEpilogueStages>; + using TensorCTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + NumEVTEpilogueStages>; + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using TensorAScale = + cutlass::epilogue::threadblock::VisitorColBroadcast< + TensorAScaleTileThreadMap, + ElementAScale, + cute::Stride>; + using TensorAScaleArguments = typename TensorAScale::Arguments; + + using TensorBScale = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorBScaleTileThreadMap, + ElementBScale, + cute::Stride>; + using TensorBScaleArguments = typename TensorBScale::Arguments; + + using TensorCScalar = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using TensorCTensor = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorCTileThreadMap, + ElementC, + cute::Stride>; + using TensorC = + std::conditional_t; + using TensorCArguments = typename TensorC::Arguments; + + using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyAScale, + Accum, + TensorAScale>; + + using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBScale, + EVTApplyAScale, + TensorBScale>; + + using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< + ApplySum, + EVTApplyBScale, + TensorC>; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL + >; + + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplySum>; + + using EVTKernel = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementOutput, LayoutOutput, AlignmentOutput, + ElementAccumulator, + ElementEpilogue, + cutlass::arch::OpClassTensorOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + ThreadblockSwizzle, + NumStages, + Operator, + NumEVTEpilogueStages + >::GemmKernel; + + // GemmUniversalBase doesn't work with W4A4 + // using Gemm = cutlass::gemm::device::GemmUniversalBase; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + cutlass::gemm::GemmCoord problem_size(m, n, k); + constexpr auto SplitKFactor = 1; + + TensorAScaleArguments tensor_a_scale_arguments{ + (ElementAScale*)tensor_a_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, problem_size.m()} + }; + TensorBScaleArguments tensor_b_scale_arguments{ + (ElementBScale*)tensor_b_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, problem_size.n()} + }; + TensorCArguments tensor_c_arguments{ + [&]() -> TensorCArguments { + if constexpr (UseTensorC::value) { + return {(ElementC*)tensor_c.data_ptr(), + ElementC(0), + {cute::_0{}, cute::_1{}, problem_size.n()}}; + } else { + return {ElementC(0)}; + } + }() + }; + typename Output::Arguments output_arguments{ + (ElementOutput*)tensor_d.data_ptr(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + { + { + {}, // Accum + tensor_a_scale_arguments, // TensorAScale + {} // ApplyAScale + }, // EVTApplyAScale + tensor_b_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, // EVTApplyBScale + tensor_c_arguments, // TensorC + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput + // constexpr auto AvailSms = -1; + + typename Gemm::Arguments arguments( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + SplitKFactor, + callback_arguments, // arguments of EVT callbacks + (ElementA*)tensor_a.data_ptr(), + (ElementB*)tensor_b.data_ptr(), + nullptr, // ptr C (unused) + nullptr, // ptr D (unused) + problem_size.mk().product(), // batch stride A + problem_size.nk().product(), // batch stride B + 0, // batch stride C (unused) + 0, // batch stride D (unused) + problem_size.k(), // stride A + problem_size.k(), // stride B + 0, // stride C (unused) + 0 + // , // stride D (unused) + // AvailSms // GemmUniversalBase requires passing AvailSms, but GemmUniversalAdapter doesn't + ); + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace torchao From f513523098e0520727461f2f892ebf1964789bcf Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 23 Jan 2025 17:56:02 +0800 Subject: [PATCH 08/10] add Alex's patch and some changes --- ...enchmark_rowwise_scaled_linear_cutlass.py} | 12 +- test/test_ops.py | 52 -- test/test_rowwise_scaled_linear_cutlass.py | 127 ++++ test/test_s8s4_linear_cutlass.py | 77 --- .../rowwise_scaled_linear_cutlass.cuh | 580 ++++++++++++++++++ .../rowwise_scaled_linear_cutlass_s4s4.cu | 28 + .../rowwise_scaled_linear_cutlass_s8s4.cu | 28 + .../s4s4_linear_cutlass.cu | 268 -------- .../s8s4_linear_cutlass.cu | 315 ---------- .../cuda/s8s4_linear_cutlass/scaled_linear.h | 288 --------- .../uintx/cutlass_int4_packed_layout.py | 6 +- torchao/ops.py | 81 +-- 12 files changed, 807 insertions(+), 1055 deletions(-) rename benchmarks/{benchmark_s8s4_cutlass.py => benchmark_rowwise_scaled_linear_cutlass.py} (73%) create mode 100644 test/test_rowwise_scaled_linear_cutlass.py delete mode 100644 test/test_s8s4_linear_cutlass.py create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu delete mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu delete mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu delete mode 100644 torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py similarity index 73% rename from benchmarks/benchmark_s8s4_cutlass.py rename to benchmarks/benchmark_rowwise_scaled_linear_cutlass.py index fbf07ebb35..00bcb0aa21 100644 --- a/benchmarks/benchmark_s8s4_cutlass.py +++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py @@ -2,7 +2,7 @@ import torch from tqdm import tqdm -from torchao.ops import s8s4_linear_cutlass +from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 from torchao.utils import benchmark_torch_function_in_microseconds @@ -24,8 +24,8 @@ def benchmark(m: int, k: int, n: int): A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k) fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) - s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds( - s8s4_linear_cutlass, A, A_scale, B, B_scale, C + rowwise_scaled_linear_cutlass_s8s4_time = benchmark_torch_function_in_microseconds( + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C ) return { @@ -33,8 +33,8 @@ def benchmark(m: int, k: int, n: int): "k": k, "n": n, "fp16_latency (ms)": fp16_time, - "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time, - "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time, + "rowwise_scaled_linear_cutlass latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, + "speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, } @@ -48,5 +48,5 @@ def benchmark(m: int, k: int, n: int): results.append(benchmark(m, k, n)) df = pd.DataFrame(results) - df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False) + df.to_csv("rowwise_scaled_linear_cutlass_s8s4_time_results.csv", index=False) print(df.to_markdown(index=False)) diff --git a/test/test_ops.py b/test/test_ops.py index 0372ea3be6..26671ddf40 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -613,57 +613,5 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)]) -def test_int4_mm_cutlass(M, N, K): - A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") - B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") - actual = torchao.ops.int4_mm_cutlass(A, B.T) - - # NOTE: A >> 4 will perform sign-bit extension - unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=1).reshape(M, K) - unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=1).reshape(N, K) - expected = (unpacked_A.float() @ unpacked_B.float().T).to(torch.int32) - - torch.testing.assert_close(actual, expected) - - # Performs opcheck - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] - opcheck( - torch.ops.torchao.int4_mm_cutlass, - (A, B), - test_utils=test_utils, - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("M,N,K", [(1, 256, 512), (18, 512, 256), (17, 256, 512)]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) -def test_scaled_int4_mm_cutlass(M, N, K, dtype): - A = torch.randint(-128, 127, size=(M, K // 2), dtype=torch.int8, device="cuda") - B = torch.randint(-128, 127, size=(N, K // 2), dtype=torch.int8, device="cuda") - row_scale = torch.randn(M, dtype=dtype, device="cuda") - col_scale = torch.randn(N, dtype=dtype, device="cuda") - actual = torchao.ops.scaled_int4_mm_cutlass(A, B.T, row_scale, col_scale) - - # NOTE: A >> 4 will perform sign-bit extension - unpacked_A = torch.stack([A >> 4, A << 4 >> 4], dim=2).reshape(M, K) - unpacked_B = torch.stack([B >> 4, B << 4 >> 4], dim=2).reshape(N, K) - - expected = unpacked_A.float() @ unpacked_B.float().T - expected = expected * row_scale.view(-1, 1) * col_scale.view(1, -1) - expected = expected.to(dtype) - - torch.testing.assert_close(actual, expected) - - # Performs opcheck - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] - opcheck( - torch.ops.torchao.scaled_int4_mm_cutlass, - (A, B, row_scale, col_scale), - test_utils=test_utils, - ) - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py new file mode 100644 index 0000000000..a2a5022489 --- /dev/null +++ b/test/test_rowwise_scaled_linear_cutlass.py @@ -0,0 +1,127 @@ +import itertools + +import pytest +import torch + +from torchao.ops import ( + rowwise_scaled_linear_cutlass_s4s4, + rowwise_scaled_linear_cutlass_s8s4, +) +from torchao.quantization.utils import group_quantize_tensor_symmetric + +ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] +ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] +ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True] +ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list( + itertools.product( + ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE, + ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE, + ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK, + ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS, + ) +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS +) +def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): + size_m, size_n, size_k = size_mnk + + input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") + weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") + bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + + input_2d = input.view(-1, input.shape[-1]) + input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( + input_2d, 4, size_k, dtype + ) + assert torch.all(input_2d_zeros == 0) + input_s8 = input_2d_s8.reshape(input.shape) + input_s4 = (input_s8[..., 1::2] << 4) | (input_s8[..., 0::2] & 0xF) + input_scales = input_2d_scales.reshape(input.shape[:-1]) + + weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( + weight, 4, size_n, dtype + ) + assert torch.all(weight_zeros == 0) + weight_s4 = (weight_s8[:, 1::2] << 4) | (weight_s8[:, 0::2] & 0xF) + + # If torch.nn.functional.linear(input, weight, bias) used as + # reference, the error would be too big. The calculation below is + # approximately what rowwise_scaled_linear_cutlass kernel is doing + # (except that matrix multiplication is over integers there)). + size_m_2d = input_2d.shape[0] + output_ref = ( + (input_2d_s8.float() @ weight_s8.float().T) + * input_2d_scales.view(size_m_2d, 1) + * weight_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(dtype).reshape(input.shape[:-1] + (size_n,)) + + fn_inputs = (input_s4, input_scales, weight_s4, weight_scales, bias) + try: + output = rowwise_scaled_linear_cutlass_s4s4(*fn_inputs) + except NotImplementedError: + pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented") + + torch.testing.assert_close(output, output_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS +) +def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): + size_m, size_n, size_k = size_mnk + + input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") + weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") + bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + + input_2d = input.view(-1, input.shape[-1]) + input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( + input_2d, 8, size_k, dtype + ) + assert torch.all(input_2d_zeros == 0) + input_s8 = input_2d_s8.reshape(input.shape) + input_scales = input_2d_scales.reshape(input.shape[:-1]) + + weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( + weight, 4, size_n, dtype + ) + assert torch.all(weight_zeros == 0) + weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF) + + # If torch.nn.functional.linear(input, weight, bias) used as + # reference, the error would be too big. The calculation below is + # approximately what rowwise_scaled_linear_cutlass kernel is doing + # (except that matrix multiplication is over integers there)). + size_m_2d = input_2d.shape[0] + output_ref = ( + (input_2d_s8.float() @ weight_s8.float().T) + * input_2d_scales.view(size_m_2d, 1) + * weight_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(dtype).reshape(input.shape[:-1] + (size_n,)) + + fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) + try: + output = rowwise_scaled_linear_cutlass_s8s4(*fn_inputs) + except NotImplementedError: + pytest.xfail("rowwise_scaled_linear_cutlass() op not implemented") + + torch.testing.assert_close(output, output_ref) diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py deleted file mode 100644 index 6510adaea3..0000000000 --- a/test/test_s8s4_linear_cutlass.py +++ /dev/null @@ -1,77 +0,0 @@ -import itertools - -import pytest -import torch - -from torchao.ops import s8s4_linear_cutlass -from torchao.quantization.utils import group_quantize_tensor_symmetric -from torchao.utils import compute_max_diff - -S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] -S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -S8S4_LINEAR_CUTLASS_SIZE_MNK = [ - (2, 512, 128), - (3, 2048, 2048), - (4, 3584, 640), - (13, 8704, 8576), - (26, 18944, 1664), - (67, 6656, 1408), -] -S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True] -S8S4_LINEAR_CUTLASS_TEST_PARAMS = list( - itertools.product( - S8S4_LINEAR_CUTLASS_DTYPE, - S8S4_LINEAR_CUTLASS_BATCH_SIZE, - S8S4_LINEAR_CUTLASS_SIZE_MNK, - S8S4_LINEAR_CUTLASS_USE_BIAS, - ) -) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS -) -def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): - size_m, size_n, size_k = size_mnk - - input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") - weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") - bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None - - input_2d = input.view(-1, input.shape[-1]) - input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( - input_2d, 8, size_k, dtype - ) - assert torch.all(input_2d_zeros == 0) - input_s8 = input_2d_s8.reshape(input.shape) - input_scales = input_2d_scales.reshape(input.shape[:-1]) - - weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( - weight, 4, size_n, dtype - ) - assert torch.all(weight_zeros == 0) - weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF) - - # If torch.nn.functional.linear(input, weight, bias) used as - # reference, the error would be too big. The calculation below is - # approximately what s8s4_linear_cutlass kernel is doing (except - # that matrrix multiplication is over integers there)). - size_m_2d = input_2d.shape[0] - output_ref = ( - (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T) - * input_2d_scales.view(size_m_2d, 1) - * weight_scales.view(1, size_n) - ) - if bias is not None: - output_ref += bias - output_ref = output_ref.reshape(input.shape[:-1] + (size_n,)) - - fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) - try: - output = s8s4_linear_cutlass(*fn_inputs) - except NotImplementedError: - pytest.xfail("s8s4_linear_cutlass() op not implemented") - - max_diff = compute_max_diff(output, output_ref) - assert max_diff < 5e-3 diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh new file mode 100644 index 0000000000..ab7cda07f6 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh @@ -0,0 +1,580 @@ +#pragma once + +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) +#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS +#endif + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) +#include +#include +#include +#include +#include +#include + +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : Got CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace torchao { + +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) +template< + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadblockSwizzle, + int NumStages, + typename ElementA, + typename ElementB, + typename ElementOutput, + typename ElementC, + typename UseTensorC, + typename ElementAScale, + typename ElementBScale> +void rowwise_scaled_linear_kernel_cutlass_sm8x( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + static_assert((cutlass::sizeof_bits::value >= 8 || + 8 % cutlass::sizeof_bits::value == 0) && + (cutlass::sizeof_bits::value >= 8 || + 8 % cutlass::sizeof_bits::value == 0)); + + using SmArch = cutlass::arch::Sm80; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // TODO: use FP32 if either ElementA/B is FP + using ElementAccumulator = int32_t; + using Operator = + std::conditional_t::value, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAddMixedInputUpcast>; + + using ElementEpilogue = float; + + constexpr auto NumEVTEpilogueStages = 1; + + const int m = tensor_a.size(0); + const int n = tensor_b.size(0); + int k = tensor_a.size(1); + if constexpr (cutlass::sizeof_bits::value < 8) { + k *= 8 / cutlass::sizeof_bits::value; + } + + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentAScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentBScale = + 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; + + // Check for current CUTLASS limitations w.r.t. alignments. + TORCH_CHECK(k % AlignmentA == 0, + __func__, " : Number of columns of tensor A must be divisible ", + "by ", AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, + __func__, " : Number of columns of tensor B must be divisible ", + "by ", AlignmentB); + TORCH_CHECK(n % AlignmentC == 0, + __func__, " : Number of columns of tensor C must be divisible ", + "by ", AlignmentC); + + using TensorAScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementAScale, + AlignmentAScale, + NumEVTEpilogueStages>; + using TensorBScaleTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementBScale, + AlignmentBScale, + NumEVTEpilogueStages>; + using TensorCTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + NumEVTEpilogueStages>; + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using TensorAScale = + cutlass::epilogue::threadblock::VisitorColBroadcast< + TensorAScaleTileThreadMap, + ElementAScale, + cute::Stride>; + using TensorAScaleArguments = typename TensorAScale::Arguments; + + using TensorBScale = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorBScaleTileThreadMap, + ElementBScale, + cute::Stride>; + using TensorBScaleArguments = typename TensorBScale::Arguments; + + using TensorCScalar = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using TensorCTensor = + cutlass::epilogue::threadblock::VisitorRowBroadcast< + TensorCTileThreadMap, + ElementC, + cute::Stride>; + using TensorC = + std::conditional_t; + using TensorCArguments = typename TensorC::Arguments; + + using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyAScale, + Accum, + TensorAScale>; + + using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBScale, + EVTApplyAScale, + TensorBScale>; + + using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementEpilogue, ElementEpilogue, + cutlass::FloatRoundStyle::round_to_nearest + >; + using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< + ApplySum, + EVTApplyBScale, + TensorC>; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL + >; + + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplySum>; + + using EVTKernel = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementOutput, LayoutOutput, AlignmentOutput, + ElementAccumulator, + ElementEpilogue, + cutlass::arch::OpClassTensorOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + ThreadblockSwizzle, + NumStages, + Operator, + NumEVTEpilogueStages + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + cutlass::gemm::GemmCoord problem_size(m, n, k); + constexpr auto SplitKFactor = 1; + + TensorAScaleArguments tensor_a_scale_arguments{ + (ElementAScale*)tensor_a_scale.data_ptr(), + ElementAScale(1), + {cute::_1{}, cute::_0{}, problem_size.m()} + }; + TensorBScaleArguments tensor_b_scale_arguments{ + (ElementBScale*)tensor_b_scale.data_ptr(), + ElementBScale(1), + {cute::_0{}, cute::_1{}, problem_size.n()} + }; + TensorCArguments tensor_c_arguments{ + [&]() -> TensorCArguments { + if constexpr (UseTensorC::value) { + return {(ElementC*)tensor_c.data_ptr(), + ElementC(0), + {cute::_0{}, cute::_1{}, problem_size.n()}}; + } else { + return {ElementC(0)}; + } + }() + }; + typename Output::Arguments output_arguments{ + (ElementOutput*)tensor_d.data_ptr(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + { + { + {}, // Accum + tensor_a_scale_arguments, // TensorAScale + {} // ApplyAScale + }, // EVTApplyAScale + tensor_b_scale_arguments, // TensorBScale + {}, // ApplyBScale + }, // EVTApplyBScale + tensor_c_arguments, // TensorC + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput + + typename Gemm::Arguments arguments( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + SplitKFactor, + callback_arguments, // arguments of EVT callbacks + (ElementA*)tensor_a.data_ptr(), + (ElementB*)tensor_b.data_ptr(), + nullptr, // ptr C (unused) + nullptr, // ptr D (unused) + problem_size.mk().product(), // batch stride A + problem_size.nk().product(), // batch stride B + 0, // batch stride C (unused) + 0, // batch stride D (unused) + problem_size.k(), // stride A + problem_size.k(), // stride B + 0, // stride C (unused) + 0 // stride D (unused) + ); + + Gemm gemm_op; + + cutlass::Status status; + + // Verify that GEMM operation with given arguments can be performed + // by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status); + + // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS mixed datatypes GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + // Perform mixed datatypes GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +static void select_config( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + + if (is_sm8x) { + if constexpr (std::is_same::value && + std::is_same::value) { + // TODO: add some tuning + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + constexpr auto NumStages = 3; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + + // A minimal heuristic to improve performance for small number + // of inputs cases. + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + return; + } + } + + TORCH_CHECK(false, + __func__, " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template< + typename ElementA, + typename ElementB, + typename ElementOutput, + typename... Types> +static void +dispatch_on_tensor_c( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + if (tensor_c.numel() == 0) { + using ElementC = ElementOutput; + using UseTensorC = std::false_type; + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + using UseTensorC = std::true_type; + if (tensor_c.scalar_type() == at::ScalarType::Half) { + using ElementC = cutlass::half_t; + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { + using ElementC = cutlass::bfloat16_t; + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); +} + +template +static void +dispatch_on_tensor_a_scale_and_tensor_b_scale( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), + __func__, " : Operator not supported for output datatype ", + tensor_d.scalar_type(), " as it's different from the first ", + " operand scale datatype ", tensor_a_scale.scalar_type()); + + if (tensor_a_scale.scalar_type() == at::ScalarType::Half && + tensor_b_scale.scalar_type() == at::ScalarType::Half) { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementOutput = cutlass::half_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && + tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a_scale.scalar_type(), + " for first operand scale and ", tensor_b_scale.scalar_type(), + " for second operand scale"); +} + +template +void +rowwise_scaled_linear_cutlass_check_inputs( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate layouts of arguments. + TORCH_CHECK(xq.dim() >= 2, + __func__, " : Expected xq argument to be 2D or " + "higher-dimensional tensor, got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, + __func__, " : Expected xq argument to be strided, got layout ", + xq.layout()); + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, + __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + "D tensor, got ", x_scale.dim(), " dims"); + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, + __func__, " : Expected xq scale argument to be strided, got " + "layout ", x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, + __func__, " : Expected wq argument to be 2D tensor, got ", + wq.dim(), " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, + __func__, " : Expected wq argument to be strided, got layout ", + wq.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, + __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", + "got ", w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, + __func__, " : Expected wq scale argument to be strided, got " + "layout ", w_scale.layout()); + if (bias.numel() > 0) { + TORCH_CHECK(bias.dim() == 1, + __func__, " : Expected bias argument to be 1D tensor, got ", + bias.dim(), " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, + __func__, " : Expected bias argument to be strided, got ", + "layout ", bias.layout()); + } + + // Validate sizes of arguments. + const auto xq_sizes = xq.sizes().vec(); + TORCH_CHECK(xq_sizes.back() == wq.size(1) || + xq_sizes.back() == 2 * wq.size(1), + __func__, " : Expected xq argument to have ", wq.size(1), " or ", + 2 * wq.size(1), " columns, but got ", xq_sizes.back()); + const auto x_scale_sizes = x_scale.sizes().vec(); + for (auto i = 0; i < x_scale_sizes.size(); ++i) + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], + __func__, " : Expected xq scale argument size at position ", + i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), + __func__, " : Expected wq scale argument to have ", wq.size(0), + " elements, got ", w_scale.numel(), " elements"); + if (bias.numel() > 0) { + TORCH_CHECK(bias.numel() == wq.size(0), + __func__, " : Expected bias argument to have ", wq.size(0), + " elements, got ", bias.numel(), " elements"); + } + + // Validate strides of arguments. + const auto xq_strides = xq.strides(); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, + __func__, " : Expected xq argument in row-major layout"); + auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; + for (int i = xq_strides.size() - 3; i >= 0; --i) { + xq_stride_expected *= xq_sizes[i + 1]; + TORCH_CHECK(xq_strides[i] == xq_stride_expected, + __func__, " : Expected xq argument in row-major layout"); + } + TORCH_CHECK(x_scale.is_contiguous(), + __func__, " : Expected xq scale argument to be contiguous"); + const auto wq_strides = wq.strides(); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, + __func__, " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), + __func__, " : Expected wq scale argument to be contiguous"); + if (bias.numel() > 0) { + const auto bias_strides = bias.strides(); + TORCH_CHECK(bias_strides[0] == 1, + __func__, " : Expected bias argument to be contiguous"); + } +} +#endif + +// Perform linear operation, using corresponding CUTLASS datatypes +// GEMM kernel, to given arguments - result produced is: +// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c +// +// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors. +// The "tensor_a_scale" tensor is expected to be a vector, of size +// equal to number of rows of "tensor_a" tensor. The "tensor_b_scale" +// tensor is expected to be a vector, of size equal to number of rows +// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a +// vector, of size equal to number of rows of "tensor_b" tensor. +template +at::Tensor +rowwise_scaled_linear_cutlass( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) + // Check inputs. + rowwise_scaled_linear_cutlass_check_inputs( + xq, x_scale, wq, w_scale, bias); + + // Squash the input tensors as appropriate. + const auto xq_sizes = xq.sizes().vec(); + const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); + const auto x_scale_1d = x_scale.reshape({-1}); + const auto w_scale_1d = w_scale.reshape({-1}); + + // Create result tensor. + at::Tensor result = + x_scale.new_empty({xq_2d.size(0), wq.size(0)}); + + // Dispatch to appropriate kernel template. + dispatch_on_tensor_a_scale_and_tensor_b_scale( + xq_2d, x_scale_1d, wq, w_scale_1d, bias, result); + + // Reshape and return result tensor. + auto result_sizes = xq_sizes; + result_sizes.back() = wq.size(0); + return result.reshape(result_sizes); +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu new file mode 100644 index 0000000000..9a64b2bdfb --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -0,0 +1,28 @@ +#include + +#include "rowwise_scaled_linear_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_cutlass_s4s4( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", xq.dtype(), + " for xq and ", wq.dtype(), " for wq is not supported"); + + // Dispatch to appropriate kernel template. + using ElementA = cutlass::int4b_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4", + &rowwise_scaled_linear_cutlass_s4s4); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu new file mode 100644 index 0000000000..752c557e79 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -0,0 +1,28 @@ +#include + +#include "rowwise_scaled_linear_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_cutlass_s8s4( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", xq.dtype(), + " for xq and ", wq.dtype(), " for wq is not supported"); + + // Dispatch to appropriate kernel template. + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4", + &rowwise_scaled_linear_cutlass_s8s4); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu deleted file mode 100644 index 8faf13c1bd..0000000000 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s4s4_linear_cutlass.cu +++ /dev/null @@ -1,268 +0,0 @@ -#include - -#include -#include -#include -#include - -#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ - defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) -#define BUILD_S4S4_LINEAR_CUTLASS -#endif - -#if defined(BUILD_S4S4_LINEAR_CUTLASS) -#include "scaled_linear.h" -#include -#include -#include -#endif - -namespace torchao { - -#if defined(BUILD_S4S4_LINEAR_CUTLASS) - -template -static void select_config( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major >= 8; - - if (is_sm8x) { - using ElementA = cutlass::int4b_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - - using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; - constexpr auto NumStages = 3; - using Operator = cutlass::arch::OpMultiplyAddSaturate; - // using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; // this does not work - using ThreadblockSwizzle = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; - - scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, - ThreadblockSwizzle, ElementA, ElementB, ElementAccumulator, Operator, - Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } - - TORCH_CHECK(false, - __func__, " : Operator not supported on SM", dprops->major, ".", - dprops->minor, " for given operands"); -} - -template -static void -dispatch_on_tensor_c( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - if (tensor_c.numel() == 0) { - using ElementC = ElementOutput; - using UseTensorC = std::false_type; - select_config< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } - - using UseTensorC = std::true_type; - if (tensor_c.scalar_type() == at::ScalarType::Half) { - using ElementC = cutlass::half_t; - select_config< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { - using ElementC = cutlass::bfloat16_t; - select_config< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for datatype ", - tensor_c.scalar_type(), " for addend"); -} - -static void -dispatch_on_tensor_a_scale_and_tensor_b_scale( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), - __func__, " : Operator not supported for output datatype ", - tensor_d.scalar_type(), " as it's different from the first ", - " operand scale datatype ", tensor_a_scale.scalar_type()); - - if (tensor_a_scale.scalar_type() == at::ScalarType::Half && - tensor_b_scale.scalar_type() == at::ScalarType::Half) { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementOutput = cutlass::half_t; - dispatch_on_tensor_c( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); - return; - } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && - tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementOutput = cutlass::bfloat16_t; - dispatch_on_tensor_c( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); - return; - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for combination of data ", - "types ", tensor_a_scale.scalar_type(), - " for first operand scale and ", tensor_b_scale.scalar_type(), - " for second operand scale"); -} - -static void -check_inputs( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { - // Validate layouts of arguments. - TORCH_CHECK(xq.dim() >= 2, - __func__, " : Expected xq argument to be 2D or " - "higher-dimensional tensor, got ", xq.dim(), " dims"); - TORCH_CHECK(xq.layout() == at::Layout::Strided, - __func__, " : Expected xq argument to be strided, got layout ", - xq.layout()); - TORCH_CHECK(x_scale.dim() == xq.dim() - 1, - __func__, " : Expected xq scale argument to be ", xq.dim() - 1, - "D tensor, got ", x_scale.dim(), " dims"); - TORCH_CHECK(x_scale.layout() == at::Layout::Strided, - __func__, " : Expected xq scale argument to be strided, got " - "layout ", x_scale.layout()); - TORCH_CHECK(wq.dim() == 2, - __func__, " : Expected wq argument to be 2D tensor, got ", - wq.dim(), " dims"); - TORCH_CHECK(wq.layout() == at::Layout::Strided, - __func__, " : Expected wq argument to be strided, got layout ", - wq.layout()); - TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, - __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", - "got ", w_scale.dim(), " dims"); - TORCH_CHECK(w_scale.layout() == at::Layout::Strided, - __func__, " : Expected wq scale argument to be strided, got " - "layout ", w_scale.layout()); - if (bias.numel() > 0) { - TORCH_CHECK(bias.dim() == 1, - __func__, " : Expected bias argument to be 1D tensor, got ", - bias.dim(), " dims"); - TORCH_CHECK(bias.layout() == at::Layout::Strided, - __func__, " : Expected bias argument to be strided, got ", - "layout ", bias.layout()); - } - - // Validate sizes of arguments. - const auto xq_sizes = xq.sizes().vec(); - TORCH_CHECK(xq_sizes.back() == wq.size(1), - __func__, " : Expected xq argument to have ", wq.size(1), - " columns, but got ", xq_sizes.back()); - const auto x_scale_sizes = x_scale.sizes().vec(); - for (auto i = 0; i < x_scale_sizes.size(); ++i) - TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], - __func__, " : Expected xq scale argument size at position ", - i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); - TORCH_CHECK(w_scale.numel() == wq.size(0), - __func__, " : Expected wq scale argument to have ", wq.size(0), - " elements, got ", w_scale.numel(), " elements"); - if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == wq.size(0), - __func__, " : Expected bias argument to have ", wq.size(0), - " elements, got ", bias.numel(), " elements"); - } - - // Validate strides of arguments. - const auto xq_strides = xq.strides(); - TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, - __func__, " : Expected xq argument in row-major layout"); - auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; - for (int i = xq_strides.size() - 3; i >= 0; --i) { - xq_stride_expected *= xq_sizes[i + 1]; - TORCH_CHECK(xq_strides[i] == xq_stride_expected, - __func__, " : Expected xq argument in row-major layout"); - } - TORCH_CHECK(x_scale.is_contiguous(), - __func__, " : Expected xq scale argument to be contiguous"); - const auto wq_strides = wq.strides(); - TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, - __func__, " : Expected wq argument in row-major layout"); - TORCH_CHECK(w_scale.is_contiguous(), - __func__, " : Expected wq scale argument to be contiguous"); - if (bias.numel() > 0) { - const auto bias_strides = bias.strides(); - TORCH_CHECK(bias_strides[0] == 1, - __func__, " : Expected bias argument to be contiguous"); - } -} -#endif - -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (xq * x_scale) @ (wq * w_scale).T + bias -// Notes: The "x_scale" tensor is expected to be a vector, of size -// equal to number of rows of "xq" tensor. The "w_scale" tensor is -// expected to be a vector, of size equal to number of rows of "wq" -// tensor. The "bias" tensor is expected to be a vector, of size equal -// to number of rows of "wq" tensor. -at::Tensor -s4s4_linear_cutlass( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { -#if defined(BUILD_S4S4_LINEAR_CUTLASS) - // Check inputs. - check_inputs(xq, x_scale, wq, w_scale, bias); - - // Squash the input tensors as appropriate. - const auto xq_sizes = xq.sizes().vec(); - const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); - const auto x_scale_sizes = x_scale.sizes().vec(); - const auto x_scale_1d = x_scale.reshape({-1}); - const auto w_scale_1d = w_scale.reshape({-1}); - - // Introduce alias names for arguments, according to the CUTLASS - // naming conventions. - const auto& tensor_a = xq_2d; - const auto& tensor_a_scale = x_scale_1d; - const auto& tensor_b = wq; - const auto& tensor_b_scale = w_scale_1d; - const auto& tensor_c = bias; - - // Create output tensor. - at::Tensor tensor_d = - tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); - - // Dispatch to appropriate kernel template. - dispatch_on_tensor_a_scale_and_tensor_b_scale( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); - - // Reshape and return output tensor. - auto tensor_d_sizes = xq_sizes; - tensor_d_sizes.back() = wq.size(0); - return tensor_d.reshape(tensor_d_sizes); -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); - return at::Tensor{}; -#endif -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::s4s4_linear_cutlass", &s4s4_linear_cutlass); -} - -} // namespace torchao diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu deleted file mode 100644 index 53eaf53961..0000000000 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ /dev/null @@ -1,315 +0,0 @@ -#include - -#include -#include -#include -#include - -#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ - defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) -#define BUILD_S8S4_LINEAR_CUTLASS -#endif - -#if defined(BUILD_S8S4_LINEAR_CUTLASS) -#include "scaled_linear.h" -#include -#include -#include -#endif - -namespace torchao { - -#if defined(BUILD_S8S4_LINEAR_CUTLASS) - -template -static void select_config( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major == 8; - - if (is_sm8x) { - if constexpr (std::is_same::value && - std::is_same::value) { - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - - // A minimal heuristic to improve performance for small number - // of inputs cases. - if (tensor_a.size(0) <= 16) { - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; - constexpr auto NumStages = 6; - scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, - ThreadblockSwizzle, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; - constexpr auto NumStages = 5; - scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, - ThreadblockSwizzle, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; - constexpr auto NumStages = 4; - scaled_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, - ThreadblockSwizzle, ElementA, ElementB, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - return; - } - } - - TORCH_CHECK(false, - __func__, " : Operator not supported on SM", dprops->major, ".", - dprops->minor, " for given operands"); -} - -template -static void -dispatch_on_tensor_a_and_tensor_b( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - if (tensor_a.scalar_type() == at::ScalarType::Char) { - if (tensor_b.scalar_type() == at::ScalarType::Char) { - if (tensor_a.size(1) == 2 * tensor_b.size(1)) { - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; - select_config< - ElementA, ElementB, ElementAccumulator, Operator, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - return; - } - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for combination of data ", - "types ", tensor_a.scalar_type(), " for first operand and ", - tensor_b.scalar_type(), " for second operand"); -} - - -template -static void -dispatch_on_tensor_c( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - if (tensor_c.numel() == 0) { - using ElementC = ElementOutput; - using UseTensorC = std::false_type; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } - - using UseTensorC = std::true_type; - if (tensor_c.scalar_type() == at::ScalarType::Half) { - using ElementC = cutlass::half_t; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { - using ElementC = cutlass::bfloat16_t; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - return; - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for datatype ", - tensor_c.scalar_type(), " for addend"); -} - -static void -dispatch_on_tensor_a_scale_and_tensor_b_scale( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), - __func__, " : Operator not supported for output datatype ", - tensor_d.scalar_type(), " as it's different from the first ", - " operand scale datatype ", tensor_a_scale.scalar_type()); - - if (tensor_a_scale.scalar_type() == at::ScalarType::Half && - tensor_b_scale.scalar_type() == at::ScalarType::Half) { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementOutput = cutlass::half_t; - dispatch_on_tensor_c( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); - return; - } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && - tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementOutput = cutlass::bfloat16_t; - dispatch_on_tensor_c( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); - return; - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for combination of data ", - "types ", tensor_a_scale.scalar_type(), - " for first operand scale and ", tensor_b_scale.scalar_type(), - " for second operand scale"); -} - -static void -check_inputs( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { - // Validate layouts of arguments. - TORCH_CHECK(xq.dim() >= 2, - __func__, " : Expected xq argument to be 2D or " - "higher-dimensional tensor, got ", xq.dim(), " dims"); - TORCH_CHECK(xq.layout() == at::Layout::Strided, - __func__, " : Expected xq argument to be strided, got layout ", - xq.layout()); - TORCH_CHECK(x_scale.dim() == xq.dim() - 1, - __func__, " : Expected xq scale argument to be ", xq.dim() - 1, - "D tensor, got ", x_scale.dim(), " dims"); - TORCH_CHECK(x_scale.layout() == at::Layout::Strided, - __func__, " : Expected xq scale argument to be strided, got " - "layout ", x_scale.layout()); - TORCH_CHECK(wq.dim() == 2, - __func__, " : Expected wq argument to be 2D tensor, got ", - wq.dim(), " dims"); - TORCH_CHECK(wq.layout() == at::Layout::Strided, - __func__, " : Expected wq argument to be strided, got layout ", - wq.layout()); - TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, - __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", - "got ", w_scale.dim(), " dims"); - TORCH_CHECK(w_scale.layout() == at::Layout::Strided, - __func__, " : Expected wq scale argument to be strided, got " - "layout ", w_scale.layout()); - if (bias.numel() > 0) { - TORCH_CHECK(bias.dim() == 1, - __func__, " : Expected bias argument to be 1D tensor, got ", - bias.dim(), " dims"); - TORCH_CHECK(bias.layout() == at::Layout::Strided, - __func__, " : Expected bias argument to be strided, got ", - "layout ", bias.layout()); - } - - // Validate sizes of arguments. - const auto xq_sizes = xq.sizes().vec(); - TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), - __func__, " : Expected xq argument to have ", 2 * wq.size(1), - " columns, but got ", xq_sizes.back()); - const auto x_scale_sizes = x_scale.sizes().vec(); - for (auto i = 0; i < x_scale_sizes.size(); ++i) - TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], - __func__, " : Expected xq scale argument size at position ", - i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); - TORCH_CHECK(w_scale.numel() == wq.size(0), - __func__, " : Expected wq scale argument to have ", wq.size(0), - " elements, got ", w_scale.numel(), " elements"); - if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == wq.size(0), - __func__, " : Expected bias argument to have ", wq.size(0), - " elements, got ", bias.numel(), " elements"); - } - - // Validate strides of arguments. - const auto xq_strides = xq.strides(); - TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, - __func__, " : Expected xq argument in row-major layout"); - auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; - for (int i = xq_strides.size() - 3; i >= 0; --i) { - xq_stride_expected *= xq_sizes[i + 1]; - TORCH_CHECK(xq_strides[i] == xq_stride_expected, - __func__, " : Expected xq argument in row-major layout"); - } - TORCH_CHECK(x_scale.is_contiguous(), - __func__, " : Expected xq scale argument to be contiguous"); - const auto wq_strides = wq.strides(); - TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, - __func__, " : Expected wq argument in row-major layout"); - TORCH_CHECK(w_scale.is_contiguous(), - __func__, " : Expected wq scale argument to be contiguous"); - if (bias.numel() > 0) { - const auto bias_strides = bias.strides(); - TORCH_CHECK(bias_strides[0] == 1, - __func__, " : Expected bias argument to be contiguous"); - } -} -#endif - -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (xq * x_scale) @ (wq * w_scale).T + bias -// Notes: The "x_scale" tensor is expected to be a vector, of size -// equal to number of rows of "xq" tensor. The "w_scale" tensor is -// expected to be a vector, of size equal to number of rows of "wq" -// tensor. The "bias" tensor is expected to be a vector, of size equal -// to number of rows of "wq" tensor. -at::Tensor -s8s4_linear_cutlass( - const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) - // Check inputs. - check_inputs(xq, x_scale, wq, w_scale, bias); - - // Squash the input tensors as appropriate. - const auto xq_sizes = xq.sizes().vec(); - const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); - const auto x_scale_sizes = x_scale.sizes().vec(); - const auto x_scale_1d = x_scale.reshape({-1}); - const auto w_scale_1d = w_scale.reshape({-1}); - - // Introduce alias names for arguments, according to the CUTLASS - // naming conventions. - const auto& tensor_a = xq_2d; - const auto& tensor_a_scale = x_scale_1d; - const auto& tensor_b = wq; - const auto& tensor_b_scale = w_scale_1d; - const auto& tensor_c = bias; - - // Create output tensor. - at::Tensor tensor_d = - tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); - - // Dispatch to appropriate kernel template. - dispatch_on_tensor_a_scale_and_tensor_b_scale( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); - - // Reshape and return output tensor. - auto tensor_d_sizes = xq_sizes; - tensor_d_sizes.back() = wq.size(0); - return tensor_d.reshape(tensor_d_sizes); -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); - return at::Tensor{}; -#endif -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass); -} - -} // namespace torchao diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h b/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h deleted file mode 100644 index 991384b572..0000000000 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/scaled_linear.h +++ /dev/null @@ -1,288 +0,0 @@ -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include "cutlass/gemm/device/gemm_universal_adapter.h" - -#define CUTLASS_STATUS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - __func__, " : Got CUTLASS error: ", \ - cutlassGetStatusString(status)); \ - } - -namespace torchao { - -template< - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - int NumStages, - typename ThreadblockSwizzle, - typename ElementA, - typename ElementB, - typename ElementAccumulator, - typename Operator, - typename ElementAScale, - typename ElementBScale, - typename ElementC, - typename UseTensorC, - typename ElementOutput> -void scaled_linear_kernel_cutlass_sm8x( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using SmArch = cutlass::arch::Sm80; - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - - using ElementEpilogue = float; - constexpr auto NumEVTEpilogueStages = 1; - - const int m = tensor_a.size(0); - const int n = tensor_b.size(0); - const int k = std::is_same::value ? - tensor_a.size(1) * 2 : - tensor_a.size(1); - - constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; - - // Check for current CUTLASS limitations w.r.t. alignments. - TORCH_CHECK(k % AlignmentA == 0, - __func__, " : Number of columns of tensor A must be divisible ", - "by ", AlignmentA); - TORCH_CHECK(k % AlignmentB == 0, - __func__, " : Number of columns of tensor B must be divisible ", - "by ", AlignmentB); - TORCH_CHECK(n % AlignmentC == 0, - __func__, " : Number of columns of tensor C must be divisible ", - "by ", AlignmentC); - - using TensorAScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementAScale, - AlignmentAScale, - NumEVTEpilogueStages>; - using TensorBScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementBScale, - AlignmentBScale, - NumEVTEpilogueStages>; - using TensorCTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementC, - AlignmentC, - NumEVTEpilogueStages>; - using OutputTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementOutput, - AlignmentOutput, - NumEVTEpilogueStages>; - - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - using TensorAScale = - cutlass::epilogue::threadblock::VisitorColBroadcast< - TensorAScaleTileThreadMap, - ElementAScale, - cute::Stride>; - using TensorAScaleArguments = typename TensorAScale::Arguments; - - using TensorBScale = - cutlass::epilogue::threadblock::VisitorRowBroadcast< - TensorBScaleTileThreadMap, - ElementBScale, - cute::Stride>; - using TensorBScaleArguments = typename TensorBScale::Arguments; - - using TensorCScalar = - cutlass::epilogue::threadblock::VisitorScalarBroadcast; - using TensorCTensor = - cutlass::epilogue::threadblock::VisitorRowBroadcast< - TensorCTileThreadMap, - ElementC, - cute::Stride>; - using TensorC = - std::conditional_t; - using TensorCArguments = typename TensorC::Arguments; - - using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, - cutlass::FloatRoundStyle::round_to_nearest - >; - using EVTApplyAScale = cutlass::epilogue::threadblock::Sm80EVT< - ApplyAScale, - Accum, - TensorAScale>; - - using ApplyBScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, - cutlass::FloatRoundStyle::round_to_nearest - >; - using EVTApplyBScale = cutlass::epilogue::threadblock::Sm80EVT< - ApplyBScale, - EVTApplyAScale, - TensorBScale>; - - using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::plus, ElementEpilogue, ElementEpilogue, - cutlass::FloatRoundStyle::round_to_nearest - >; - using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< - ApplySum, - EVTApplyBScale, - TensorC>; - - using Output = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementOutput, - cutlass::FloatRoundStyle::round_to_nearest, - cute::Stride // StrideMNL - >; - - using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< - Output, - EVTApplySum>; - - using EVTKernel = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, - ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementOutput, LayoutOutput, AlignmentOutput, - ElementAccumulator, - ElementEpilogue, - cutlass::arch::OpClassTensorOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EVTOutput, - ThreadblockSwizzle, - NumStages, - Operator, - NumEVTEpilogueStages - >::GemmKernel; - - // GemmUniversalBase doesn't work with W4A4 - // using Gemm = cutlass::gemm::device::GemmUniversalBase; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - cutlass::gemm::GemmCoord problem_size(m, n, k); - constexpr auto SplitKFactor = 1; - - TensorAScaleArguments tensor_a_scale_arguments{ - (ElementAScale*)tensor_a_scale.data_ptr(), - ElementAScale(1), - {cute::_1{}, cute::_0{}, problem_size.m()} - }; - TensorBScaleArguments tensor_b_scale_arguments{ - (ElementBScale*)tensor_b_scale.data_ptr(), - ElementBScale(1), - {cute::_0{}, cute::_1{}, problem_size.n()} - }; - TensorCArguments tensor_c_arguments{ - [&]() -> TensorCArguments { - if constexpr (UseTensorC::value) { - return {(ElementC*)tensor_c.data_ptr(), - ElementC(0), - {cute::_0{}, cute::_1{}, problem_size.n()}}; - } else { - return {ElementC(0)}; - } - }() - }; - typename Output::Arguments output_arguments{ - (ElementOutput*)tensor_d.data_ptr(), - {problem_size.n(), cute::_1{}, problem_size.mn().product()} - }; - typename EVTOutput::Arguments callback_arguments{ - { - { - { - {}, // Accum - tensor_a_scale_arguments, // TensorAScale - {} // ApplyAScale - }, // EVTApplyAScale - tensor_b_scale_arguments, // TensorBScale - {}, // ApplyBScale - }, // EVTApplyBScale - tensor_c_arguments, // TensorC - {} // ApplySum - }, // EVTApplySum - output_arguments // Output - }; // EVTOutput - // constexpr auto AvailSms = -1; - - typename Gemm::Arguments arguments( - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - SplitKFactor, - callback_arguments, // arguments of EVT callbacks - (ElementA*)tensor_a.data_ptr(), - (ElementB*)tensor_b.data_ptr(), - nullptr, // ptr C (unused) - nullptr, // ptr D (unused) - problem_size.mk().product(), // batch stride A - problem_size.nk().product(), // batch stride B - 0, // batch stride C (unused) - 0, // batch stride D (unused) - problem_size.k(), // stride A - problem_size.k(), // stride B - 0, // stride C (unused) - 0 - // , // stride D (unused) - // AvailSms // GemmUniversalBase requires passing AvailSms, but GemmUniversalAdapter doesn't - ); - - Gemm gemm_op; - - cutlass::Status status; - - // Verify that GEMM operation with given arguments can be performed - // by CUTLASS. - status = gemm_op.can_implement(arguments); - CUTLASS_STATUS_CHECK(status); - - // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. - const auto workspace_size = Gemm::get_workspace_size(arguments); - auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, - at::TensorOptions().dtype(at::kByte)); - - // Initialize CUTLASS mixed datatypes GEMM object. - status = gemm_op.initialize(arguments, workspace.data_ptr(), - at::cuda::getCurrentCUDAStream()); - CUTLASS_STATUS_CHECK(status); - - // Perform mixed datatypes GEMM operation. - status = gemm_op.run(at::cuda::getCurrentCUDAStream()); - CUTLASS_STATUS_CHECK(status); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -} // namespace torchao diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index d7374c8d50..0287745b53 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -144,14 +144,16 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import s8s4_linear_cutlass + from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 weight = weight_tensor.tensor_impl.int_data weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale - out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias) + out = rowwise_scaled_linear_cutlass_s8s4( + input, input_scale, weight, weight_scale, bias + ) return out diff --git a/torchao/ops.py b/torchao/ops.py index 840dbc0e97..d74d6ac3f2 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -20,11 +20,10 @@ "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" ) lib.define( - "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) -lib.define("int4_mm_cutlass(Tensor A, Tensor B) -> Tensor") lib.define( - "scaled_int4_mm_cutlass(Tensor A, Tensor B, Tensor row_scale, Tensor col_scale) -> Tensor" + "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) @@ -518,7 +517,7 @@ def _( return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) -def s8s4_linear_cutlass( +def rowwise_scaled_linear_cutlass_s8s4( input: Tensor, input_scale: Tensor, weight: Tensor, @@ -526,23 +525,23 @@ def s8s4_linear_cutlass( bias: Tensor, ) -> Tensor: """ - CUTLASS-based W4A8 linear operator. + CUTLASS-based row-wise scaled W4A8 linear operator. Args: - input: input tensor, quantized to 8-bit integer values. + input: quantized input tensor, in row-major layout. input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. - weight: weight matrix, quantized to 4-bit integer values, in row-major layout. + weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). bias: a vector of size equal to number of rows of weight tensor, or None. Returns: output: result tensor, in row-major layout. """ - return torch.ops.torchao.s8s4_linear_cutlass.default( + return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default( input, input_scale, weight, weight_scale, bias ) -@register_custom_op("torchao::s8s4_linear_cutlass") +@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4") def _( input: Tensor, input_scale: Tensor, @@ -550,6 +549,8 @@ def _( weight_scale: Tensor, bias: Tensor, ) -> Tensor: + # FIXME: update this!!! + # Validate dtypes. torch._check( input.dtype == torch.int8, @@ -621,50 +622,36 @@ def _( ) -def int4_mm_cutlass(A: Tensor, B: Tensor) -> Tensor: +def rowwise_scaled_linear_cutlass_s4s4( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: """ - CUTLASS-based W4A4 matmul. + CUTLASS-based row-wise scaled W4A4 linear operator. Args: - A: first INT4 tensor, packed in INT8 dtype, row-major layout. - B: second INT4 tensor, packed in INT8 dtype, column-major layout. + input: quantized input tensor, in row-major layout. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: quantized weight matrix, in row-major layout. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: a vector of size equal to number of rows of weight tensor, or None. Returns: output: result tensor, in row-major layout. """ - assert A.dtype == B.dtype == torch.int8 - assert A.ndim == B.ndim == 2 - assert A.shape[1] == B.shape[0] - assert A.is_contiguous() and B.T.is_contiguous() - return torch.ops.torchao.int4_mm_cutlass.default(A, B) - -@register_custom_op("torchao::int4_mm_cutlass") -def _(A: Tensor, B: Tensor) -> Tensor: - return A.new_empty(A.shape[0], B.shape[1], dtype=torch.int32) + return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default( + input, input_scale, weight, weight_scale, bias + ) -def scaled_int4_mm_cutlass( - A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor +@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, ) -> Tensor: - """ - CUTLASS-based W4A4 scaled-matmul. - Args: - A: first INT4 tensor, packed in INT8 dtype, row-major layout. - B: second INT4 tensor, packed in INT8 dtype, column-major layout. - row_scale: scaling for each output row. - col_scale: scaling for each output column. - Returns: - output: result tensor, in row-major layout. - """ - assert A.dtype == B.dtype == torch.int8 - assert A.ndim == B.ndim == 2 - assert A.shape[1] == B.shape[0] - assert A.is_contiguous() and B.T.is_contiguous() - assert row_scale.ndim == col_scale.ndim == 1 - assert row_scale.dtype == col_scale.dtype - assert row_scale.dtype in (torch.float16, torch.bfloat16) - return torch.ops.torchao.scaled_int4_mm_cutlass.default(A, B, row_scale, col_scale) - - -@register_custom_op("torchao::scaled_int4_mm_cutlass") -def _(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tensor) -> Tensor: - return row_scale.new_empty(A.shape[0], B.shape[1]) + return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) From 9a1ce25dd8c83be80fdff65914ad557fff439d15 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 23 Jan 2025 18:17:46 +0800 Subject: [PATCH 09/10] fix aqt test --- torchao/dtypes/uintx/cutlass_int4_packed_layout.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index 0287745b53..037ae1f3ad 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -176,19 +176,15 @@ def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import scaled_int4_mm_cutlass + from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 weight = weight_tensor.tensor_impl.int_data weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale - batch_dims = input_tensor.shape[:-1] - input = input.view(-1, input.shape[-1]) - input_scale = input_scale.view(-1) - out = scaled_int4_mm_cutlass(input, weight.T, input_scale, weight_scale) - if bias is not None: - out = out + bias - out = out.view(*batch_dims, out.shape[-1]) + out = rowwise_scaled_linear_cutlass_s4s4( + input, input_scale, weight, weight_scale, bias + ) return out From b9db0f11497c21420f38ebc4b4cefc86ef47b02d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 23 Jan 2025 23:45:40 +0800 Subject: [PATCH 10/10] remove int4_cutlass.cu --- torchao/csrc/cuda/int4_cutlass.cu | 231 ------------------------------ 1 file changed, 231 deletions(-) delete mode 100644 torchao/csrc/cuda/int4_cutlass.cu diff --git a/torchao/csrc/cuda/int4_cutlass.cu b/torchao/csrc/cuda/int4_cutlass.cu deleted file mode 100644 index 452abcceaa..0000000000 --- a/torchao/csrc/cuda/int4_cutlass.cu +++ /dev/null @@ -1,231 +0,0 @@ -#include -#include - -// copied from s8s4_linear_cutlass.cu -#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ - defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) -#define BUILD_INT4_MM_CUTLASS -#endif - -#if defined(BUILD_INT4_MM_CUTLASS) -#include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_universal.h" -#include "cutlass/gemm/device/gemm.h" -#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" -#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" - -#define CUTLASS_STATUS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - __func__, " : Got CUTLASS error: ", \ - cutlassGetStatusString(status)); \ - } -#endif - -namespace torchao { - -#if defined(BUILD_INT4_MM_CUTLASS) -// define common params -using ElementA = cutlass::int4b_t; -using ElementB = cutlass::int4b_t; -using ElementAccumulator = int32_t; -using OpClass = cutlass::arch::OpClassTensorOp; -using ArchTag = cutlass::arch::Sm80; - -// how many elements to load at a time -> load 128-bit = 32 x 4-bit -constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; -constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; -#endif - -// we will do input checks in python. A and B are stored as int8 -torch::Tensor int4_mm_cutlass(torch::Tensor A, torch::Tensor B) { -#if defined(BUILD_INT4_MM_CUTLASS) - int M = A.size(0); - int K = A.size(1) * 2; - int N = B.size(1); - torch::Tensor C = torch::empty({M, N}, A.options().dtype(torch::kInt32)); - - // some configs for int4 mma - // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu - // using default config. this can be tuned. - using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; - // static int const kStages = 3; - using ElementC = int32_t; - using Gemm = cutlass::gemm::device::Gemm< - ElementA, cutlass::layout::RowMajor, // A matrix - ElementB, cutlass::layout::ColumnMajor, // B matrix - ElementC, cutlass::layout::RowMajor, // C matrix - ElementAccumulator, OpClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape - >; - Gemm::Arguments args { - {M, N, K}, - {reinterpret_cast(A.data_ptr()), K}, - {reinterpret_cast(B.data_ptr()), K}, - {C.data_ptr(), N}, - {C.data_ptr(), N}, - {1, 0} // epilogue - }; - Gemm gemm_op; - CUTLASS_STATUS_CHECK(gemm_op(args)); - return C; -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); - return at::Tensor{}; -#endif -} - -template< - typename ElementC, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - int numStages> -void scaled_int4_mm_cutlass_dispatch(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale, torch::Tensor C) { - // problem shape - int M = A.size(0); - int K = A.size(1) * 2; - int N = B.size(1); - - constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // 8 for BF16/FP16 - using ElementEpilogue = float; - constexpr int numEpilogueStages = 1; - - // build epilogue visitor tree - using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages - >; - - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest; - using Multiply = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode - >; - - // (1, N) - using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementC, - cute::Stride // MNL - >; - using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; - - // (M, 1) - using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, ElementC, - cute::Stride // MNL - >; - using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT; - - using Output = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementC, RoundMode, - cute::Stride // MNL - >; - using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT; - - using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA, - ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB, - ElementC, cutlass::layout::RowMajor, AlignmentC, - ElementAccumulator, ElementEpilogue, OpClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, - EVTOutput, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, - numStages, - cutlass::arch::OpMultiplyAddSaturate, // OpMultiplyAdd does not work - numEpilogueStages - >::GemmKernel; - using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; - - // col_scale, row_scale, and C must have the same dtype - const ElementA *A_ptr = reinterpret_cast(A.data_ptr()); - const ElementB *B_ptr = reinterpret_cast(B.data_ptr()); - const ElementC *col_scale_ptr = reinterpret_cast(col_scale.data_ptr()); - const ElementC *row_scale_ptr = reinterpret_cast(row_scale.data_ptr()); - ElementC *C_ptr = reinterpret_cast(C.data_ptr()); - - typename EVTOutput::Arguments callback_args{ - { - { - {}, // Accum - {col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale - {} // Multiply - }, // EVTCompute0 - {row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale - {} // Multiply - }, // EVTCompute1 - {C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // EVTOutput - }; - - typename DeviceGemm::Arguments args( - cutlass::gemm::GemmUniversalMode::kGemm, - cutlass::gemm::GemmCoord{M, N, K}, - 1, // batch_split - callback_args, - A_ptr, B_ptr, nullptr, nullptr, // unsued C_ptr and D_ptr - M * K, N * K, 0, 0, // batch_stride A, B, C, D - K, K, 0, 0 // stride A, B, C, D - ); - - DeviceGemm gemm_op; - auto stream = at::cuda::getCurrentCUDAStream(); - CUTLASS_STATUS_CHECK(gemm_op.can_implement(args)); - CUTLASS_STATUS_CHECK(gemm_op(args, nullptr, stream)); -} - -// we will do input checks in python. A and B are stored as int8 -// this function is based on the following cutlass example -// https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu -// also with the help of emitted code from cutlass Python -torch::Tensor scaled_int4_mm_cutlass(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) { -#if defined(BUILD_INT4_MM_CUTLASS) - int M = A.size(0); - int N = B.size(1); - torch::Tensor C = torch::empty({M, N}, row_scale.options()); - - // some configs for int4 mma - // https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu - // using default config. this can be tuned. - using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; - constexpr int numStages = 3; - - AT_DISPATCH_SWITCH( - row_scale.scalar_type(), - "scaled_int4_mm_cutlass", - AT_DISPATCH_CASE( - torch::ScalarType::Half, - [&]() { - using ElementC = cutlass::half_t; - scaled_int4_mm_cutlass_dispatch< - ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( - A, B, row_scale, col_scale, C); - } - ) - AT_DISPATCH_CASE( - torch::ScalarType::BFloat16, - [&]() { - using ElementC = cutlass::bfloat16_t; - scaled_int4_mm_cutlass_dispatch< - ElementC, ThreadblockShape, WarpShape, InstructionShape, numStages>( - A, B, row_scale, col_scale, C); - } - ) - ); - - return C; -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); - return at::Tensor{}; -#endif -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::int4_mm_cutlass", &int4_mm_cutlass); - m.impl("torchao::scaled_int4_mm_cutlass", &scaled_int4_mm_cutlass); -} - -} // namespace torchao