From 8db0bc9b8e06396d82924b5b0597c74cbf2c2b6d Mon Sep 17 00:00:00 2001 From: ANIKET SHIVAM <3268307+ANIKET-SHIVAM@users.noreply.github.com> Date: Fri, 14 Apr 2023 20:19:34 -0700 Subject: [PATCH] CUTLASS 3.1 (#915) Co-authored-by: Aniket Shivam --- CHANGELOG.md | 15 + CMakeLists.txt | 25 +- README.md | 57 +- cmake/googletest.cmake | 2 +- .../turing_tensorop_conv2dfprop.cu | 1 + .../ampere_tensorop_conv2dfprop.cu | 358 +- .../22_quaternion_conv/quaternion_conv.cu | 3 +- examples/24_gemm_grouped/gemm_grouped.cu | 68 +- ...plex_gemm.cu => 29_3xtf32_complex_gemm.cu} | 2 +- .../CMakeLists.txt | 6 +- examples/39_gemm_permute/gemm_permute.cu | 1422 ++++--- examples/39_gemm_permute/layouts.h | 510 +++ examples/39_gemm_permute/permute_info.h | 344 ++ examples/40_cutlass_py/README.md | 11 +- examples/40_cutlass_py/conv2d.py | 47 +- examples/40_cutlass_py/customizable/conv2d.py | 66 +- examples/40_cutlass_py/customizable/gemm.py | 36 +- .../customizable/gemm_grouped.py | 32 +- examples/40_cutlass_py/gemm.py | 24 +- examples/40_cutlass_py/gemm_grouped.py | 26 +- examples/45_dual_gemm/device/dual_gemm.h | 4 +- examples/45_dual_gemm/dual_gemm.cu | 54 +- examples/45_dual_gemm/dual_gemm_run.h | 71 +- examples/45_dual_gemm/kernel/dual_gemm.h | 16 +- .../depthwise_simt_conv2dfprop.cu | 30 +- .../48_hopper_warp_specialized_gemm.cu | 26 +- .../49_collective_builder.cu} | 131 +- .../CMakeLists.txt | 7 +- .../50_hopper_gemm_with_epilogue_swizzle.cu | 22 +- examples/51_hopper_gett/gett_kernel.cuh | 15 +- examples/CMakeLists.txt | 2 +- examples/common/helper.h | 1 + examples/cute/tutorial/CMakeLists.txt | 1 - examples/python/00_basic_gemm.ipynb | 340 ++ examples/python/01_epilogue.ipynb | 202 + .../02_pytorch_extension_grouped_gemm.ipynb | 264 ++ examples/python/README.md | 14 + include/cute/algorithm/copy.hpp | 49 +- include/cute/algorithm/functional.hpp | 4 +- include/cute/algorithm/gemm.hpp | 68 +- include/cute/algorithm/prefer.hpp | 2 +- include/cute/algorithm/tensor_algorithms.hpp | 21 + include/cute/algorithm/tuple_algorithms.hpp | 59 +- include/cute/arch/cluster_sm90.hpp | 56 +- include/cute/arch/copy_sm75.hpp | 2 +- include/cute/arch/copy_sm90_desc.hpp | 27 +- include/cute/arch/copy_sm90_tma.hpp | 333 +- include/cute/arch/mma_sm90.hpp | 190 +- include/cute/arch/mma_sm90_desc.hpp | 12 +- include/cute/arch/mma_sm90_gmma.hpp | 3733 ++++++++++------- include/cute/arch/util.hpp | 44 +- include/cute/atom/copy_atom.hpp | 351 +- include/cute/atom/copy_traits.hpp | 57 +- include/cute/atom/copy_traits_sm90_tma.hpp | 632 ++- include/cute/atom/mma_atom.hpp | 240 +- include/cute/atom/mma_traits.hpp | 140 +- include/cute/atom/mma_traits_sm90_gmma.hpp | 932 ++-- include/cute/config.hpp | 43 +- include/cute/container/alignment.hpp | 4 +- include/cute/container/array.hpp | 92 +- include/cute/container/array_aligned.hpp | 240 +- include/cute/container/array_subbyte.hpp | 154 +- include/cute/container/array_view.hpp | 274 -- include/cute/container/bit_field.hpp | 2 +- include/cute/container/cuda_types.hpp | 175 + include/cute/container/tuple.hpp | 263 +- include/cute/container/type_list.hpp | 62 +- include/cute/int_tuple.hpp | 159 +- include/cute/layout.hpp | 175 +- include/cute/numeric/arithmetic_tuple.hpp | 58 +- include/cute/numeric/bfloat.hpp | 2 + include/cute/numeric/complex.hpp | 38 +- include/cute/numeric/int.hpp | 30 +- include/cute/numeric/integer_sequence.hpp | 86 +- include/cute/numeric/integer_subbyte.hpp | 6 +- include/cute/numeric/integral_constant.hpp | 142 +- include/cute/numeric/math.hpp | 56 +- include/cute/numeric/tfloat.hpp | 2 + include/cute/pointer.hpp | 9 +- include/cute/stride.hpp | 137 +- include/cute/swizzle.hpp | 2 + include/cute/swizzle_layout.hpp | 10 +- include/cute/swizzle_ptr.hpp | 8 +- include/cute/tensor.hpp | 138 +- include/cute/underscore.hpp | 10 +- include/cute/util/debug.hpp | 30 +- include/cute/util/print.hpp | 7 +- include/cute/util/type_traits.hpp | 146 +- include/cutlass/arch/arch.h | 1 - include/cutlass/arch/barrier.h | 102 +- include/cutlass/arch/memory.h | 4 +- include/cutlass/arch/mma.h | 6 + include/cutlass/barrier.h | 12 +- include/cutlass/cluster_launch.hpp | 93 +- include/cutlass/complex.h | 10 - include/cutlass/conv/conv2d_problem_size.h | 7 - include/cutlass/conv/convolution.h | 59 +- .../default_conv2d_fprop_with_broadcast.h | 2 +- .../default_conv2d_fprop_with_reduction.h | 2 +- .../conv/kernel/default_conv2d_group_fprop.h | 132 + ...activation_tile_access_iterator_analytic.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...rop_filter_tile_access_iterator_analytic.h | 2 +- ...op_filter_tile_access_iterator_optimized.h | 2 +- ...activation_tile_access_iterator_analytic.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...t_gradient_tile_access_iterator_analytic.h | 2 +- ..._gradient_tile_access_iterator_optimized.h | 2 +- ...activation_tile_access_iterator_analytic.h | 2 +- ...ctivation_tile_access_iterator_optimized.h | 2 +- ...t_gradient_tile_access_iterator_analytic.h | 2 +- ..._gradient_tile_access_iterator_optimized.h | 2 +- .../threadblock/implicit_gemm_multistage.h | 21 +- .../conv/threadblock/threadblock_swizzle.h | 4 +- include/cutlass/core_io.h | 2 +- include/cutlass/cutlass.h | 31 +- include/cutlass/detail/dependent_false.hpp | 86 + include/cutlass/device_kernel.h | 6 +- .../collective/builders/sm90_builder.inl | 536 +++ .../collective/collective_builder.hpp | 77 + .../collective/collective_epilogue.hpp | 12 +- .../epilogue/collective/default_epilogue.hpp | 48 +- .../cutlass/epilogue/collective/detail.hpp | 211 + ...ogue.hpp => epilogue_tensor_broadcast.hpp} | 164 +- ...logue.hpp => sm70_epilogue_vectorized.hpp} | 43 +- .../sm90_epilogue_tma_warpspecialized.hpp | 582 +++ ...e_tma_warpspecialized_bias_elementwise.hpp | 558 +++ include/cutlass/epilogue/dispatch_policy.hpp | 97 + .../cutlass/epilogue/thread/detail.hpp | 34 +- .../epilogue/thread/linear_combination.h | 102 +- .../linear_combination_bias_elementwise.h | 15 +- .../thread/linear_combination_clamp.h | 1 + .../thread/linear_combination_dgelu.h | 2 +- .../thread/linear_combination_drelu.h | 2 +- .../thread/linear_combination_generic.h | 1 + .../thread/linear_combination_leaky_relu.h | 1 + .../thread/linear_combination_params.h | 6 +- .../epilogue/thread/linear_combination_relu.h | 4 +- .../thread/linear_combination_relu0.h | 4 +- .../linear_combination_residual_block.h | 9 +- .../linear_combination_tensor_broadcast.hpp | 251 ++ .../linear_combination_with_elementwise.h | 2 +- .../threadblock/predicated_tile_iterator.h | 34 +- .../threadblock/shared_load_iterator_mixed.h | 24 +- .../warp/tile_iterator_tensor_op_mixed.h | 1 - include/cutlass/float8.h | 102 +- include/cutlass/functional.h | 11 + .../collective/builders/sm90_gmma_builder.inl | 670 ++- .../gemm/collective/collective_builder.hpp | 6 +- .../gemm/collective/collective_mma.hpp | 5 +- .../gemm/collective/sm70_mma_twostage.hpp | 24 +- .../gemm/collective/sm80_mma_multistage.hpp | 24 +- .../sm90_mma_multistage_gmma_ss.hpp | 46 +- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 608 +++ .../gemm/collective/sm90_mma_tma_gmma_ss.hpp | 87 +- .../sm90_mma_tma_gmma_ss_warpspecialized.hpp | 155 +- include/cutlass/gemm/device/base_grouped.h | 2 +- .../gemm/device/default_gemm_configuration.h | 3 - include/cutlass/gemm/device/gemm_universal.h | 33 +- .../gemm/device/gemm_universal_adapter.h | 25 +- .../cutlass/gemm/device/gemm_universal_base.h | 4 + include/cutlass/gemm/dispatch_policy.hpp | 30 +- include/cutlass/gemm/gemm.h | 10 +- include/cutlass/gemm/kernel/default_gemm.h | 91 +- .../gemm/kernel/default_gemm_universal.h | 18 +- .../gemm/kernel/default_gemm_with_broadcast.h | 4 +- .../gemm/kernel/default_gemm_with_reduction.h | 4 +- include/cutlass/gemm/kernel/gemm_universal.h | 38 +- .../cutlass/gemm/kernel/gemm_universal.hpp | 3 +- .../gemm/kernel/gemm_with_fused_epilogue.h | 8 +- .../kernel/rank_2k_grouped_problem_visitor.h | 2 +- include/cutlass/gemm/kernel/sm70_gemm.hpp | 19 +- include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 48 +- .../kernel/sm90_gemm_tma_warpspecialized.hpp | 224 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 496 +++ ...m90_gemm_tma_warpspecialized_pingpong.hpp} | 323 +- .../gemm/kernel/sm90_tile_scheduler.hpp | 99 +- .../cutlass/gemm/threadblock/default_mma.h | 84 +- .../default_mma_core_with_access_size.h | 2 +- .../cutlass/gemm/threadblock/mma_multistage.h | 5 +- .../threadblock/threadblock_swizzle_streamk.h | 5 +- .../cutlass/gemm/warp/default_mma_tensor_op.h | 2 +- .../cutlass/gemm/warp/mma_complex_tensor_op.h | 5 + include/cutlass/layout/permute.h | 747 +++- include/cutlass/numeric_conversion.h | 26 +- include/cutlass/pipeline.hpp | 529 --- include/cutlass/pipeline/pipeline.hpp | 36 + include/cutlass/pipeline/sm90_pipeline.hpp | 989 +++++ include/cutlass/platform/platform.h | 2 + include/cutlass/relatively_equal.h | 16 + include/cutlass/semaphore.h | 4 +- .../collective/sm90_wgmma_transpose.hpp | 336 ++ .../predicated_tile_access_iterator.h | 125 +- .../threadblock/predicated_tile_iterator.h | 31 +- .../regular_tile_iterator_pitch_linear.h | 2 +- include/cutlass/uint128.h | 7 +- media/docs/cute/01_layout.md | 4 +- media/docs/cute/02_layout_operations.md | 8 +- media/docs/cute/0t_mma_atom.md | 4 +- media/docs/efficient_gemm.md | 17 +- media/docs/gemm_api_3x.md | 9 +- media/docs/implicit_gemm_convolution.md | 65 +- media/docs/pipeline.md | 2 +- media/docs/profiler.md | 9 +- media/docs/programming_guidelines.md | 68 +- media/docs/quickstart.md | 3 +- python/README.md | 180 + python/cutlass/__init__.py | 117 + python/cutlass/backend/__init__.py | 27 + .../cutlass/backend}/arguments.py | 55 +- .../cutlass/backend}/c_types.py | 174 +- .../cutlass/backend}/compiler.py | 245 +- .../cutlass/backend}/conv2d_operation.py | 367 +- .../cutlass/backend}/epilogue.py | 385 +- .../cutlass/backend}/frontend.py | 28 +- .../cutlass/backend}/gemm_operation.py | 1334 ++++-- python/cutlass/backend/library.py | 714 ++++ .../cutlass/backend}/memory_manager.py | 2 +- .../cutlass/backend}/operation.py | 48 +- .../cutlass/backend}/parser.py | 545 ++- .../cutlass/backend}/reduction_operation.py | 264 +- .../cutlass/backend}/tensor_ref.py | 37 +- .../cutlass/backend/test/__init__.py | 12 +- python/cutlass/backend/test/conv2d_testbed.py | 783 ++++ .../backend}/test/gemm_grouped_testbed.py | 153 +- .../cutlass/backend}/test/gemm_testbed.py | 480 ++- .../cutlass/backend}/test/profiler.py | 13 +- .../cutlass/backend}/test/utils.py | 84 +- .../cutlass/backend}/type_hint.py | 8 +- python/cutlass/backend/utils/__init__.py | 41 + .../cutlass/backend}/utils/datatypes.py | 100 +- .../cutlass/backend}/utils/device.py | 0 .../cutlass/backend}/utils/reference_model.py | 202 +- python/cutlass/backend/utils/software.py | 111 + .../src => python/cutlass}/cpp/compiler.h | 0 .../cutlass/cpp/cutlass_bindings.cpp | 4 +- .../src => python/cutlass}/cpp/include/arch.h | 0 .../cpp/include/conv/conv_problem_size.h | 0 .../cutlass}/cpp/include/conv/convolution.h | 0 .../cutlass}/cpp/include/conv/host.h | 0 .../epilogue/epilogue_visitor_generic.h | 0 .../epilogue/epilogue_visitor_op/binary_ops.h | 0 .../epilogue/epilogue_visitor_op/unary_ops.h | 0 .../visitor_op_accumulator.h | 0 .../epilogue_visitor_op/visitor_op_binary.h | 0 .../visitor_op_column_broadcast.h | 0 .../visitor_op_column_reduction.h | 0 .../visitor_op_linear_combination.h | 0 .../visitor_op_row_broadcast.h | 0 .../visitor_op_row_reduction.h | 0 .../visitor_op_tensor_input.h | 0 .../visitor_op_tensor_output.h | 0 .../epilogue_visitor_op/visitor_op_unary.h | 0 .../epilogue_visitor_with_layernorm.h | 0 .../cutlass}/cpp/include/gemm/gemm.h | 0 .../gemm/gemm_universal_with_visitor.h | 10 + .../cutlass}/cpp/include/gemm/host.h | 0 .../cutlass}/cpp/include/layout/layout.h | 0 .../cutlass}/cpp/include/layout/matrix.h | 0 .../cutlass}/cpp/include/layout/tensor.h | 0 .../cutlass}/cpp/include/swizzling.h | 25 +- .../cutlass}/cpp/include/tensor_coord.h | 0 .../cutlass}/cpp/include/tensor_ref_view.h | 0 .../cutlass}/cpp/include/types.h | 0 .../src => python/cutlass}/cpp/library.h | 0 .../cutlass}/cpp/test/conv/conv_problems.h | 0 .../cutlass}/cpp/test/conv/convolution.h | 0 .../cutlass}/cpp/test/conv/host.h | 2 +- .../cutlass}/cpp/test/gemm/gemm.h | 0 .../cutlass}/cpp/test/gemm/host.h | 0 .../cutlass/emit/__init__.py | 4 +- python/cutlass/emit/common.py | 182 + python/cutlass/emit/pytorch.py | 639 +++ python/cutlass/epilogue.py | 107 + python/cutlass/library_defaults.py | 445 ++ .../cutlass/op/__init__.py | 9 +- python/cutlass/op/gemm.py | 696 +++ python/cutlass/op/gemm_grouped.py | 270 ++ python/cutlass/op/op.py | 116 + .../Makefile => python/cutlass/swizzle.py | 48 +- python/cutlass/utils/__init__.py | 40 + python/cutlass/utils/check.py | 192 + python/cutlass/utils/datatypes.py | 339 ++ .../docker/Dockerfile-cuda11.8-pytorch | 2 +- python/docker/Dockerfile-cuda12.0-pytorch | 38 + python/docs_src/Makefile | 20 + .../docs => python/docs_src}/make.bat | 8 +- .../source/_static/cutlass-logo-small.png | Bin 0 -> 1488 bytes .../source/_static/logo-dark-mode.png | Bin 0 -> 50546 bytes .../source/_static/logo-light-mode.png | Bin 0 -> 48816 bytes python/docs_src/source/_templates/layout.html | 94 + python/docs_src/source/conf.py | 100 + python/docs_src/source/contribute.md | 9 + python/docs_src/source/cutlass.emit.rst | 18 + python/docs_src/source/cutlass.op.rst | 26 + python/docs_src/source/cutlass.rst | 36 + python/docs_src/source/cutlass.utils.rst | 18 + python/docs_src/source/examples.rst | 9 + .../source/externals/00_basic_gemm.nblink | 3 + .../source/externals/01_epilogue.nblink | 3 + .../02_pytorch_extension_grouped_gemm.nblink | 3 + python/docs_src/source/index.rst | 55 + python/docs_src/source/install.md | 37 + python/docs_src/source/modules.rst | 7 + python/setup.py | 106 + .../python/backend}/conv/__init__.py | 0 ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py | 110 +- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 82 +- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py | 48 +- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py | 46 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py | 143 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py | 133 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py | 226 +- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 28 +- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py | 48 +- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py | 60 +- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 142 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py | 46 +- ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py | 128 +- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py | 48 +- ...hwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py | 60 +- .../python/backend}/conv/run_all_tests.py | 6 +- .../python/backend}/gemm/__init__.py | 0 .../python/backend}/gemm/gemm_bf16_sm80.py | 40 +- .../python/backend}/gemm/gemm_bf16_sm90.py | 44 +- .../python/backend}/gemm/gemm_f16_sm80.py | 152 +- .../python/backend}/gemm/gemm_f16_sm90.py | 114 +- .../python/backend}/gemm/gemm_f32_sm80.py | 58 +- .../python/backend}/gemm/gemm_f64_sm80.py | 40 +- .../python/backend}/gemm/gemm_f64_sm90.py | 34 +- .../python/backend}/gemm/gemm_grouped_sm80.py | 76 +- .../python/backend}/gemm/gemm_s8_sm80.py | 78 +- .../python/backend}/gemm/gemm_s8_sm90.py | 58 +- .../python/backend}/gemm/run_all_tests.py | 4 +- test/python/emit/pytorch.py | 161 + test/python/gemm/gemm_f16_sm80.py | 167 + test/python/gemm/gemm_f16_sm90.py | 173 + test/python/gemm/gemm_f32_sm80.py | 155 + test/python/gemm/gemm_f64_sm80.py | 156 + test/python/gemm/gemm_f64_sm90.py | 142 + test/python/gemm/gemm_s8_sm80.py | 156 + test/python/gemm/gemm_s8_sm90.py | 155 + test/python/gemm/run_all_tests.py | 42 + test/python/interface/gemm_interface.py | 354 ++ test/unit/CMakeLists.txt | 2 + test/unit/cluster_launch/CMakeLists.txt | 32 + test/unit/cluster_launch/cluster_launch.cu | 370 ++ test/unit/conv/device/CMakeLists.txt | 1 + test/unit/conv/device/conv2d_problems.h | 2 - test/unit/conv/device/conv2d_testbed.h | 2 +- .../conv/device/conv2d_testbed_interleaved.h | 3 +- .../device/conv2d_with_broadcast_testbed.h | 2 +- .../device/conv2d_with_reduction_testbed.h | 2 +- test/unit/conv/device/conv3d_testbed.h | 2 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu | 104 +- test/unit/cute/CMakeLists.txt | 3 + test/unit/cute/core/CMakeLists.txt | 2 + test/unit/cute/core/array_subbyte.cpp | 114 + test/unit/cute/core/compact_xmajor.cpp | 231 + test/unit/cute/hopper/CMakeLists.txt | 15 + test/unit/cute/hopper/bulk_load.cu | 196 + test/unit/cute/hopper/bulk_store.cu | 178 + test/unit/cute/hopper/stsm.cu | 4 +- test/unit/cute/hopper/tma_load.cu | 580 ++- test/unit/cute/hopper/tma_store.cu | 524 +-- .../unit/cute/msvc_compilation/CMakeLists.txt | 33 + test/unit/cute/msvc_compilation/tuple.cpp | 161 + test/unit/gemm/device/CMakeLists.txt | 27 +- .../device/default_gemm_configuration.hpp | 63 +- test/unit/gemm/device/gemm_grouped_sm80.cu | 124 +- .../gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu | 32 + .../gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu | 20 + .../gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu | 32 + .../gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu | 20 + .../gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu | 77 + .../gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu | 18 + .../gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu | 20 + .../gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu | 18 + .../gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu | 20 + test/unit/gemm/device/gemm_testbed_3x.hpp | 599 ++- .../gemm_testbed_3x_tensor_broadcast.hpp | 488 +++ ...emm_bf16_bf16_bf16_alignx_tensor_op_f32.cu | 61 +- .../sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu | 61 +- .../sm90_gemm_f16_f16_f16_alignx_tensor_op.cu | 181 +- .../device/sm90_gemm_f16_f16_f16_tensor_op.cu | 640 ++- ...f16_tensor_op_f32_cluster_unspecialized.cu | 209 +- ...6_tensor_op_f32_cluster_warpspecialized.cu | 209 +- ...f32_cluster_warpspecialized_cooperative.cu | 850 ++++ ...pecialized_cooperative_bias_elementwise.cu | 366 ++ ...p_f32_cluster_warpspecialized_pingpong.cu} | 447 +- ...rpspecialized_pingpong_bias_elementwise.cu | 365 ++ ..._f16_f16_tensor_op_f32_tensor_broadcast.cu | 298 ++ .../sm90_gemm_f32_f32_f32_tensor_op_f32.cu | 15 +- ..._f32_f32_tensor_op_f32_tensor_broadcast.cu | 102 + ...sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu | 46 +- .../sm90_gemm_s8_s8_s8_tensor_op_s32.cu | 91 +- ...s8_s8_s8_tensor_op_s32_tensor_broadcast.cu | 102 + ...gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu | 46 +- .../sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu | 61 +- ..._op_f32_gmma_rs_cluster_warpspecialized.cu | 566 +++ test/unit/gemm/device/testing_elementwise.hpp | 81 + test/unit/gemm/warp/wmma_sm72.cu | 2 +- test/unit/pipeline/pipeline_async.cu | 12 +- test/unit/pipeline/pipeline_tma_async.cu | 30 +- .../pipeline_tma_async_warp_specialized.cu | 6 +- ...e_tma_async_warp_specialized_persistent.cu | 12 +- test/unit/pipeline/sequence_barrier.cu | 6 +- test/unit/substrate/CMakeLists.txt | 33 + test/unit/substrate/dependent_false.cpp | 88 + tools/CMakeLists.txt | 2 +- tools/library/CMakeLists.txt | 7 +- .../include/cutlass/library/arch_mappings.h | 6 + .../library/include/cutlass/library/handle.h | 22 +- .../library/include/cutlass/library/library.h | 15 +- .../include/cutlass/library/operation_table.h | 46 +- tools/library/scripts/gemm_operation.py | 67 +- tools/library/scripts/generator.py | 558 ++- tools/library/scripts/library.py | 52 + tools/library/scripts/pycutlass/README.md | 143 - .../pycutlass/docker/Dockerfile-cuda12.0 | 46 - .../scripts/pycutlass/docs/source/conf.py | 96 - .../pycutlass/docs/source/conv2d_op.rst | 13 - .../scripts/pycutlass/docs/source/cutlass.rst | 100 - .../scripts/pycutlass/docs/source/gemm_op.rst | 18 - .../scripts/pycutlass/docs/source/index.rst | 31 - .../docs/source/md/EpilogueVisitorTree.md | 225 - .../pycutlass/docs/source/md/basic_idea.md | 283 -- .../pycutlass/docs/source/user_guide.rst | 4 - .../pycutlass/docs/source/visitor_tree.rst | 4 - .../pycutlass/profile/conv/conv2d_f16_sm80.py | 106 - .../pycutlass/profile/gemm/gemm_f32_sm80.py | 91 - .../library/scripts/pycutlass/pyproject.toml | 9 - tools/library/scripts/pycutlass/setup.py | 116 - .../pycutlass/src/pycutlass/__init__.py | 55 - .../builder/collective_op_builder.py | 395 -- .../pycutlass/src/pycutlass/library.py | 870 ---- .../pycutlass/src/pycutlass/test/__init__.py | 4 - .../src/pycutlass/test/conv2d_testbed.py | 632 --- .../pycutlass/src/pycutlass/utils/__init__.py | 1 - .../test/conv/cached_results_SM80.txt | 274 -- .../pycutlass/test/example/run_all_example.sh | 112 - .../pycutlass/test/frontend/test_frontend.py | 154 - .../test/unit/cached_results_SM80_2080.txt | 363 -- .../scripts/pycutlass/test/unit/test_sm80.py | 464 -- tools/library/src/gemm_operation.h | 21 +- tools/library/src/gemm_operation_3x.hpp | 41 +- tools/library/src/handle.cu | 43 +- tools/library/src/library_internal.h | 8 + tools/library/src/manifest.cpp | 2 +- tools/library/src/operation_table.cu | 5 +- .../reduction/init_reduction_operations.cu | 1 - .../library/src/reduction/reduction_device.cu | 21 +- .../src/reference/conv_reference_operation.h | 1 + tools/library/src/reference/gemm.cu | 367 ++ .../src/reference/gemm_reference_operation.h | 85 +- tools/library/src/util.cu | 64 + tools/profiler/CMakeLists.txt | 2 +- .../profiler/src/conv2d_operation_profiler.cu | 14 +- .../profiler/src/conv3d_operation_profiler.cu | 11 +- tools/profiler/src/cublas_helpers.cu | 13 +- tools/profiler/src/cudnn_helpers.cpp | 2 +- tools/profiler/src/cudnn_helpers.h | 2 +- tools/profiler/src/device_allocation.cu | 76 + tools/profiler/src/device_context.cu | 20 +- tools/profiler/src/device_context.h | 10 +- tools/profiler/src/gemm_operation_profiler.cu | 32 +- tools/profiler/src/gemm_operation_profiler.h | 3 +- tools/profiler/src/options.cu | 3 +- .../src/rank_2k_operation_profiler.cu | 13 +- .../profiler/src/rank_k_operation_profiler.cu | 9 +- .../src/sparse_gemm_operation_profiler.cu | 20 +- tools/profiler/src/symm_operation_profiler.cu | 15 +- tools/profiler/src/trmm_operation_profiler.cu | 12 +- tools/util/CMakeLists.txt | 2 +- .../include/cutlass/util/gett_commandline.hpp | 10 +- .../util/reference/device/gemm_complex.h | 17 +- .../util/reference/device/tensor_fill.h | 11 + .../cutlass/util/reference/host/convolution.h | 34 +- .../util/reference/host/gemm_complex.h | 10 +- .../cutlass/util/reference/host/gett.hpp | 96 +- .../util/reference/host/tensor_compare.h | 136 +- .../include/cutlass/util/tensor_view_io.h | 16 +- 482 files changed, 37001 insertions(+), 16236 deletions(-) rename examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/{29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu => 29_3xtf32_complex_gemm.cu} (99%) create mode 100644 examples/39_gemm_permute/layouts.h create mode 100644 examples/39_gemm_permute/permute_info.h rename examples/{49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu => 49_hopper_gemm_with_collective_builder/49_collective_builder.cu} (78%) rename examples/{49_hopper_gemm_schedules_with_collective_builder => 49_hopper_gemm_with_collective_builder}/CMakeLists.txt (93%) create mode 100644 examples/python/00_basic_gemm.ipynb create mode 100644 examples/python/01_epilogue.ipynb create mode 100644 examples/python/02_pytorch_extension_grouped_gemm.ipynb create mode 100644 examples/python/README.md delete mode 100644 include/cute/container/array_view.hpp create mode 100644 include/cute/container/cuda_types.hpp create mode 100644 include/cutlass/detail/dependent_false.hpp create mode 100644 include/cutlass/epilogue/collective/builders/sm90_builder.inl create mode 100644 include/cutlass/epilogue/collective/collective_builder.hpp create mode 100644 include/cutlass/epilogue/collective/detail.hpp rename include/cutlass/epilogue/collective/{default_transposed_epilogue.hpp => epilogue_tensor_broadcast.hpp} (50%) rename include/cutlass/epilogue/collective/{epilogue.hpp => sm70_epilogue_vectorized.hpp} (92%) create mode 100644 include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp create mode 100644 include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp rename tools/library/scripts/pycutlass/src/cpp/cute.cpp => include/cutlass/epilogue/thread/detail.hpp (72%) create mode 100644 include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp create mode 100644 include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp create mode 100644 include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp rename include/cutlass/gemm/kernel/{sm90_gemm_tma_warpspecialized_persistent.hpp => sm90_gemm_tma_warpspecialized_pingpong.hpp} (55%) delete mode 100644 include/cutlass/pipeline.hpp create mode 100644 include/cutlass/pipeline/pipeline.hpp create mode 100644 include/cutlass/pipeline/sm90_pipeline.hpp create mode 100644 include/cutlass/transform/collective/sm90_wgmma_transpose.hpp create mode 100644 python/README.md create mode 100644 python/cutlass/__init__.py create mode 100644 python/cutlass/backend/__init__.py rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/arguments.py (79%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/c_types.py (63%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/compiler.py (66%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/conv2d_operation.py (63%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/epilogue.py (87%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/frontend.py (85%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/gemm_operation.py (52%) create mode 100644 python/cutlass/backend/library.py rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/memory_manager.py (100%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/operation.py (80%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/parser.py (62%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/reduction_operation.py (62%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/tensor_ref.py (83%) rename tools/library/scripts/pycutlass/build.sh => python/cutlass/backend/test/__init__.py (88%) create mode 100644 python/cutlass/backend/test/conv2d_testbed.py rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/test/gemm_grouped_testbed.py (66%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/test/gemm_testbed.py (56%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/test/profiler.py (91%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/test/utils.py (64%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/type_hint.py (89%) create mode 100644 python/cutlass/backend/utils/__init__.py rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/utils/datatypes.py (58%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/utils/device.py (100%) rename {tools/library/scripts/pycutlass/src/pycutlass => python/cutlass/backend}/utils/reference_model.py (64%) create mode 100644 python/cutlass/backend/utils/software.py rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/compiler.h (100%) rename tools/library/scripts/pycutlass/src/cpp/cutlass.cpp => python/cutlass/cpp/cutlass_bindings.cpp (98%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/arch.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/conv/conv_problem_size.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/conv/convolution.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/conv/host.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_generic.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/epilogue/epilogue_visitor_with_layernorm.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/gemm/gemm.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/gemm/gemm_universal_with_visitor.h (99%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/gemm/host.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/layout/layout.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/layout/matrix.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/layout/tensor.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/swizzling.h (93%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/tensor_coord.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/tensor_ref_view.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/include/types.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/library.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/test/conv/conv_problems.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/test/conv/convolution.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/test/conv/host.h (99%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/test/gemm/gemm.h (100%) rename {tools/library/scripts/pycutlass/src => python/cutlass}/cpp/test/gemm/host.h (100%) rename tools/library/scripts/pycutlass/test/frontend/run_test.sh => python/cutlass/emit/__init__.py (94%) create mode 100644 python/cutlass/emit/common.py create mode 100644 python/cutlass/emit/pytorch.py create mode 100644 python/cutlass/epilogue.py create mode 100644 python/cutlass/library_defaults.py rename tools/library/scripts/pycutlass/build_doc.sh => python/cutlass/op/__init__.py (90%) create mode 100644 python/cutlass/op/gemm.py create mode 100644 python/cutlass/op/gemm_grouped.py create mode 100644 python/cutlass/op/op.py rename tools/library/scripts/pycutlass/docs/Makefile => python/cutlass/swizzle.py (60%) create mode 100644 python/cutlass/utils/__init__.py create mode 100644 python/cutlass/utils/check.py create mode 100644 python/cutlass/utils/datatypes.py rename {tools/library/scripts/pycutlass => python}/docker/Dockerfile-cuda11.8-pytorch (96%) create mode 100644 python/docker/Dockerfile-cuda12.0-pytorch create mode 100644 python/docs_src/Makefile rename {tools/library/scripts/pycutlass/docs => python/docs_src}/make.bat (94%) create mode 100644 python/docs_src/source/_static/cutlass-logo-small.png create mode 100644 python/docs_src/source/_static/logo-dark-mode.png create mode 100644 python/docs_src/source/_static/logo-light-mode.png create mode 100644 python/docs_src/source/_templates/layout.html create mode 100644 python/docs_src/source/conf.py create mode 100644 python/docs_src/source/contribute.md create mode 100644 python/docs_src/source/cutlass.emit.rst create mode 100644 python/docs_src/source/cutlass.op.rst create mode 100644 python/docs_src/source/cutlass.rst create mode 100644 python/docs_src/source/cutlass.utils.rst create mode 100644 python/docs_src/source/examples.rst create mode 100644 python/docs_src/source/externals/00_basic_gemm.nblink create mode 100644 python/docs_src/source/externals/01_epilogue.nblink create mode 100644 python/docs_src/source/externals/02_pytorch_extension_grouped_gemm.nblink create mode 100644 python/docs_src/source/index.rst create mode 100644 python/docs_src/source/install.md create mode 100644 python/docs_src/source/modules.rst create mode 100644 python/setup.py rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/__init__.py (100%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py (64%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py (69%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py (73%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py (74%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py (52%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py (58%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py (53%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py (79%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py (73%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py (69%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py (62%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py (75%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py (64%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py (73%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py (69%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/conv/run_all_tests.py (94%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/__init__.py (100%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_bf16_sm80.py (74%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_bf16_sm90.py (77%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_f16_sm80.py (68%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_f16_sm90.py (58%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_f32_sm80.py (72%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_f64_sm80.py (75%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_f64_sm90.py (82%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_grouped_sm80.py (71%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_s8_sm80.py (71%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/gemm_s8_sm90.py (73%) rename {tools/library/scripts/pycutlass/test => test/python/backend}/gemm/run_all_tests.py (96%) create mode 100644 test/python/emit/pytorch.py create mode 100644 test/python/gemm/gemm_f16_sm80.py create mode 100644 test/python/gemm/gemm_f16_sm90.py create mode 100644 test/python/gemm/gemm_f32_sm80.py create mode 100644 test/python/gemm/gemm_f64_sm80.py create mode 100644 test/python/gemm/gemm_f64_sm90.py create mode 100644 test/python/gemm/gemm_s8_sm80.py create mode 100644 test/python/gemm/gemm_s8_sm90.py create mode 100644 test/python/gemm/run_all_tests.py create mode 100644 test/python/interface/gemm_interface.py create mode 100644 test/unit/cluster_launch/CMakeLists.txt create mode 100644 test/unit/cluster_launch/cluster_launch.cu create mode 100644 test/unit/cute/core/array_subbyte.cpp create mode 100644 test/unit/cute/core/compact_xmajor.cpp create mode 100644 test/unit/cute/hopper/bulk_load.cu create mode 100644 test/unit/cute/hopper/bulk_store.cu create mode 100644 test/unit/cute/msvc_compilation/CMakeLists.txt create mode 100644 test/unit/cute/msvc_compilation/tuple.cpp create mode 100644 test/unit/gemm/device/gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp create mode 100644 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu create mode 100644 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu rename test/unit/gemm/device/{sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu => sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu} (71%) create mode 100644 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu create mode 100644 test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu create mode 100644 test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu create mode 100644 test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu create mode 100644 test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu create mode 100644 test/unit/gemm/device/testing_elementwise.hpp create mode 100644 test/unit/substrate/CMakeLists.txt create mode 100644 test/unit/substrate/dependent_false.cpp delete mode 100644 tools/library/scripts/pycutlass/README.md delete mode 100644 tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 delete mode 100644 tools/library/scripts/pycutlass/docs/source/conf.py delete mode 100644 tools/library/scripts/pycutlass/docs/source/conv2d_op.rst delete mode 100644 tools/library/scripts/pycutlass/docs/source/cutlass.rst delete mode 100644 tools/library/scripts/pycutlass/docs/source/gemm_op.rst delete mode 100644 tools/library/scripts/pycutlass/docs/source/index.rst delete mode 100644 tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md delete mode 100644 tools/library/scripts/pycutlass/docs/source/md/basic_idea.md delete mode 100644 tools/library/scripts/pycutlass/docs/source/user_guide.rst delete mode 100644 tools/library/scripts/pycutlass/docs/source/visitor_tree.rst delete mode 100644 tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py delete mode 100644 tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py delete mode 100644 tools/library/scripts/pycutlass/pyproject.toml delete mode 100644 tools/library/scripts/pycutlass/setup.py delete mode 100644 tools/library/scripts/pycutlass/src/pycutlass/__init__.py delete mode 100644 tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py delete mode 100644 tools/library/scripts/pycutlass/src/pycutlass/library.py delete mode 100644 tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py delete mode 100644 tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py delete mode 100644 tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py delete mode 100644 tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt delete mode 100755 tools/library/scripts/pycutlass/test/example/run_all_example.sh delete mode 100644 tools/library/scripts/pycutlass/test/frontend/test_frontend.py delete mode 100644 tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt delete mode 100644 tools/library/scripts/pycutlass/test/unit/test_sm80.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8744e0a6..9eeb34c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,21 @@ # NVIDIA CUTLASS Changelog +## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14) +* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python). +* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper. +* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues. +* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. +* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper. +* An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper. +* Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization. +* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler. +* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. +* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. +* The GitHub branch is renamed from `master` to `main` in this release. +* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads) +* Updates and bugfixes from the community (thanks!) + ## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23) * [CuTe](/media/docs/cute/00_quickstart.md), a [new core library and backend](/include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. * [A new conceptual operation hierarchy](media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](/media/docs/gemm_api_3x.md). diff --git a/CMakeLists.txt b/CMakeLists.txt index e879f780..1136d095 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) +cmake_policy(SET CMP0112 NEW) if(cutlass_LOADED) # If CUTLASS has been previously fetched and loaded, don't do it again. @@ -39,7 +40,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 3.0.0 LANGUAGES CXX) +project(CUTLASS VERSION 3.1.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -124,6 +125,17 @@ endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") +# Find unsupported and deprecated compute capabilities +if (CUTLASS_NVCC_ARCHS_SUPPORTED) + set(CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS}) + list(REMOVE_ITEM CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS_SUPPORTED}) + if (CUTLASS_NVCC_ARCHS_UNSUPPORTED) + message(WARNING "Using unsupported or deprecated compute capabilities ${CUTLASS_NVCC_ARCHS_UNSUPPORTED}. Support may be removed in future versions.") + endif() +else() + message(WARNING "No supported compute capabilities for CUDA ${CUDA_VERSION}.") +endif() + # Special policy introduced in CMake 3.13 if (POLICY CMP0076) cmake_policy(SET CMP0076 NEW) @@ -287,9 +299,10 @@ if (CUTLASS_ENABLE_OPENMP_TESTS) message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.") endif() endif() - -list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-Wconversion>) -list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-fno-strict-aliasing>) +if(UNIX) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing) +endif() # Don't leak lineinfo in release builds if (NOT CMAKE_BUILD_TYPE MATCHES "Release") @@ -838,3 +851,5 @@ install( ################################################################################ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake) + + diff --git a/README.md b/README.md index 79c11b3e..f7bad67d 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.0 +# CUTLASS 3.1 -_CUTLASS 3.0 - January 2023_ +_CUTLASS 3.1 - April 2023_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -31,33 +31,37 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](/media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. -CUTLASS 3.0 introduces a new core library, CuTe, to describe and manipulate tensors of threads and data. +CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data. CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. -CUTLASS 3.0 adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design +CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](/media/docs/cute/00_quickstart.md). In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -# What's New in CUTLASS 3.0 +# What's New in CUTLASS 3.1 -CUTLASS 3.0, as the next major version of the CUTLASS API, brings with it CuTe, a new programming model and backend designed for massively parallel heterogenous agents. Using CuTe, CUTLASS 3.0 provides implementations of GEMM kernels for the NVIDIA Hopper architecture. +CUTLASS 3.1 is an update to CUTLASS adding: -- [CuTe-based layouts and layout algebra](/media/docs/cute/00_quickstart.md) -- [A new GEMM template API](/media/docs/gemm_api_3x.md) that eschews the architecture-centric hierarchy of 2.x in favour of a new conceptual framing. Read more in the [3.0 design documentation](/media/docs/cutlass_3x_design.md). -- Support for 4th generation Hopper Tensor Core instructions (WGMMA) through CuTe. -- Support for Hopper asynchronous Tensor Memory Accelerator (TMA) instructions and associated transaction barriers through CuTe. -- New warp-specialized GEMM kernels targeting Hopper TMA + WGMMA for speed-of-light GEMMs. -- New warp-specialized persistent GEMM kernels targeting Hopper TMA + WGMMA. -- Support for CUDA Threadblock Clusters and programmatic TMA multicast for greater execution and data locality. -- A new way to instantiate default GEMM kernels using `CollectiveBuilder`s that supersede the 2.x `DefaultXConfiguration` types in favour a metaprogramming based kernel generator functionality. See [example 49](/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu). -- Extensions to the CUTLASS library and profiler to support CUTLASS 3.0 Hopper kernels, and a new format -for kernel procedural names. -- *Announcement*: CUTLASS plans to rename the GitHub branch `master` to `main` with a future release. +- New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python). +- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper. +- Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues. +- New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. +- New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that improves performance on Hopper. +- An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper. +- New Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization. +- Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler. +- Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. +- *Announcement*: + - The GitHub branch is renamed from `master` to `main` in this release. + - A slight modification has been made to the ordering of arguments passed in to epilogues in 3.x kernels. + Existing CUTLASS 3.x kernel invocations will need to be modified to reflect this change. 2.x kernels + remain unaffected. See [#890](https://github.com/NVIDIA/cutlass/issues/890) for additional information. + - The CUTLASS Python interface supersedes PyCUTLASS. PyCUTLASS has been moved to [/python/cutlass/backend](/python/cutlass/backend). + Backward compatibility between the Python interface and PyCUTLASS will not be maintained moving forward. -## New architecture, compiler, and CUDA Toolkit requirements Minimum requirements: @@ -65,7 +69,7 @@ Minimum requirements: - Compiler: Must support at least C++17 - CUDA Toolkit version: 11.4 -CUTLASS 3.0 *removes support* for the following: +Starting from CUTLASS 3.0, CUTLASS removed support for the following: - Maxwell and Pascal GPU architectures - Ubuntu 16.04 @@ -87,20 +91,21 @@ an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) (NVIDIA Ampere and an [NVIDIA A40](https://www.nvidia.com/en-us/data-center/a40/) (NVIDIA Ampere architecture). CUTLASS 3.0 was compiled with the [CUDA 12.0 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's -[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). +[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and +[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions.

When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad) kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) -as shown in the above figure. Tensor Core operations are still implemented using CUDA's +as shown in the above figure. Tensor Core operations are implemented using CUDA's [mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). # Compatibility CUTLASS requires a C++17 host compiler and -performs best when built with the [**CUDA 12.0 Toolkit**](https://developer.nvidia.com/cuda-toolkit). -It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, and CUDA 11.8. +performs best when built with the [**CUDA 12.1 Toolkit**](https://developer.nvidia.com/cuda-toolkit). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and CUDA 12.0. ## Operating Systems We have tested the following environments. @@ -112,6 +117,7 @@ We have tested the following environments. | Ubuntu 22.04 | GCC 11.2.0 | Note: We plan to add Windows (MSVC) & Clang compiler support soon. +Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. ## Hardware CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs. @@ -131,9 +137,9 @@ CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be ## Target Architecture -In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduces the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). +In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). -The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12.0 or 11.8, the kernel is expected to fail with a runtime error. +The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CTK 12 or 11.8, the kernel is expected to fail with a runtime error. ``` cmake .. -DCUTLASS_NVCC_ARCHS="90a" @@ -558,4 +564,3 @@ SPDX-License-Identifier: BSD-3-Clause OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ``` - diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 85edc807..a16231a1 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -9,7 +9,7 @@ endif() FetchContent_Declare( googletest GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG 0fe9660 + GIT_TAG v1.13.0 ) FetchContent_GetProperties(googletest) diff --git a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu index bfa4f8f3..ade0b979 100644 --- a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu +++ b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu @@ -555,6 +555,7 @@ Result profile_convolution(Options const &options) { LayoutOutput, ElementComputeEpilogue, ElementAccumulator, + ElementOutput, cutlass::NumericConverterClamp >( problem_size, diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index b30d9086..11ece8a6 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -31,83 +31,181 @@ /** -This example shows how to run convolution kernels using functions and data structures -provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. - -Writing a single high performance convolution kernel is hard but do-able. Whereas writing -high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions to compose -multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance -of GPU easily. - -CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -and thread-block level, they compute on their own tile-size with higher level of tile sizes being -composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -threadblock-tile (tile size computed by a threadblock). - -In thie example, we split variable initialization into -1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel -can view them (logical to physical mapping) -2. Setting up computation properties : describes how the above set tensors will be used to compute -output of convolution. - -First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along -with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS, -the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as -alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as -epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to -ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as -cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float), -ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t), -ElementOutput (float). Communicating just the data type is not enough. As the data is laid out -linearly in memory, we have to convey the layout of tensors. We do that by initializing template -variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template -variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of -elements per vector memory access (8), data type of accumulator (float) and data type of -computation of linear combination (alpha * X + beta * C). - -Now that we setup the properties of data, we have to setup properties of computation. - -Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64, -64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it -internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -data in bank-conflict free manner, and ton of other variables required to compose, initialize and -launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer -from understanding and coding complicated hardware optimizations which can easily go wrong. - -CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -constitute the whole process of loading input data from global memory to shared memory, loading data -from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -sequence shows a typical mma multistage pipeline. -(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h) - -tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers ---mma--> registers --global stores--> output to global memory - -NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies. - - -There are few more template variables initialized such as, which threadblock tile of output matrix -is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. - -These are all put together to create a template variable which describes CUTLASS Implicit GEMM -kernel using cutlass::conv::device::ImplicitGemm template. - -The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. -We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come -in the way of learning CUTLASS. - -Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS -kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64, -R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the -important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -memory required by the kernel we instantiated. If yes, we create it and pass it along with other -arguments created to initialize CUTLASS kernel then, the kernel is launched. - -In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to -compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. +This example shows how to run CUTLASS's convolution kernels +based on the Implicit GEMM algorithm, that use the Tensor Cores +on an NVIDIA Ampere GPU. + +Writing a single high-performance convolution kernel is hard enough, +let alone writing kernels that perform well for multiple problem sizes +and use good software abstractions. +CUTLASS provides simplified abstractions +to compose multiple sections of a convolution kernel. +When used properly, the kernels can reach peak GPU performance. + +CUTLASS divides a kernel into hierarchical composable sections +for each level of the GPU hardware hierarchy: +thread, warp, and threadblock. +Each section computes on its own tile shape, +with each higher level's tile shape +being composed from lower-level tile shapes. +Multiple thread tiles (the tile shape each thread computes) +can be used to form warp tiles (the tile shape each warp computes), +and multiple warp tiles can be used to compute threadblock tiles +(the tile shape computed by a threadblock). + +In thie example, we split variable initialization into two parts. + +1. Setting up data properties: describes how tensors are laid out in the memory + and how the kernel can view them (logical to physical mapping) + +2. Setting up computation properties: describes how the above tensors + will be used to compute the output of convolution + +We begin by setting up the data types +of all the input and output elements of a convolution. +A convolution computes +C = alpha * Conv2dFprop(A, B) + beta * C, +so we set up data types for the input tensor A, +weights tensor B, output tensor C, +and the scaling factors alpha and beta. +CUTLASS divides the convolution into two parts: +the "mainloop" that computes X = Conv2dFprop(A, B), +and the "epilogue" that computes C = alpha * X + beta * C. +The epilogue is an element-wise operation on X and C. +In this case, it is a linear combination, +but other epilogues are possible. + +In this example, we want + +* the scaling factors alpha and beta to be float, + +* the elements of A and B to be cutlass::half_t + (a 16-bit floating-point type), + +* the elements of C to be float, and + +* intermediate sums to be accumulated in float. + +We convey this to the CUTLASS kernel +by setting the following template parameters. + +* alpha and beta: ElementComputeEpilogue = float + +* Elements of input tensor A: ElementInputA = cutlass::half_t + +* Elements of input tensor B: ElementInputB = cutlass::half_t + +* Elements of output tensor C: ElementOutput = float + +* Accumulation type: ElementAccumulator = float + +Next, we describe the layout of the input and output tensors. +We convey this to the CUTLASS kernel +by setting the following template parameters. + +* Layout of input tensor A: LayoutInputA = TensorNHWC + +* Layout of input tensor B: LayoutInputB = TensorNHWC + +* Layout of output tensor C: LayoutOutput = TensorNHWC + +After that, we set up rules to compute the epilogue. +The epilogue in this case is a simple linear combination +C = alpha * X + beta * C. +Thus, we set the kernel's template parameter EpilogueOp +to LinearCombination. LinearCombination itself +has template parameters: + +* the element type of the output tensor (ElementOutput), + +* the number of elements per vector memory access (8), + +* the data type of the accumulator (ElementAccumulator), + +* and the data type used to compute the linear combination + (ElementComputeEpilogue). + +We then define the tile shapes +that each level of the computation uses. +We define these as types that encode the tile shapes +as compile-time integer values. +Each shape expresses the dimensions M x N x K. +Here, the letters refer to the dimensions +of a matrix-matrix multiply. + +* ThreadblockShape defines the threadblock tile shape + as 128 x 128 x 64. + +* WarpShape defines the warp tile shape as 64 x 64 x 64. + +* InstructionShape defines the MMA + (matrix multiply-accumulate) operation shape + as 16 x 8 x 16. + +These types become template arguments +of the kernel properties type +cutlass::conv::kernel::DefaultConv2dFprop. +The kernel uses these shapes to deduce +the number of threads needed per threadblock, +the required amount of shared memory, +the internal layouts needed to access +shared memory without bank conflicts, +and many other properties that the kernel needs +for good performance. +CUTLASS deduces all these properties automatically, +so that users don't have to. +DefaultConv2dFprop accepts other template parameters +that describe things like the target CUDA SM architecture. + +CUTLASS also supports multiple MMA pipelines in a threadblock. +An MMA pipeline constitutes the whole process +of loading input data from global memory to shared memory, +loading data from shared memory to registers, +doing matrix multiplication, +and storing the result to global memory. +The below flow sequence shows a typical MMA multistage pipeline +(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h). + +tensor in global memory +--cp_async--> +tile in shared memory +--smem loads--> +registers +--mma--> +registers +--global stores--> +output to global memory + +On NVIDIA Ampere, the kernel uses `cp_async` +to build a multistage software pipeline. +This helps it better hide latency. + +At this point, we can define the actual CUTLASS kernel type +as the alias ImplicitGemm, a specialization of +cutlass::conv::device::ImplicitGemmConvolution. +The latter accepts the kernel properties type alias +Conv2dFpropKernel as its one template argument. + +This example then sets up a test problem +and arguments to the kernel. +We use CUTLASS utilities to allocate +the input and output tensors +and fill them with sample input data. +We then create the kernel arguments +as an instance of ImplicitGemm::Arguments. +The arguments include +the problem size (N = 1, H = 64, W = 64, C = 128), +filter size (K = 64, R = 3, S = 3, C = 128), +padding, strides, dilation, tensors, alpha, beta, +and the split k-dimension factor. +We also query CUTLASS if the kernel we instantiated +requires any memory for scratch space. +If yes, we reserve scratch space and pass it along +with other arguments to initialize the CUTLASS kernel. + +After lauching the CUTLASS kernel, this example runs +a reference convolution kernel (from CUTLASS utilities) +to check correctness. */ #include @@ -131,8 +229,8 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM #include "helper.h" -// The code section below describes datatype for input, output tensors and computation between -// elements +// Data types for input and output tensors +// and computation between elements using ElementAccumulator = float; // Data type of accumulator using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) using ElementInputA = cutlass::half_t; // Data type of elements in input tensor @@ -143,39 +241,40 @@ using LayoutInputA = cutlass::layout::TensorNHWC; using LayoutInputB = cutlass::layout::TensorNHWC; using LayoutOutput = cutlass::layout::TensorNHWC; -// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +// Whether to use tensor cores or regular SIMT cores on GPU SM using MMAOp = cutlass::arch::OpClassTensorOp; -// This code section describes CUDA SM architecture number +// SM architecture number using SmArch = cutlass::arch::Sm80; -// This code section describes the tile size a thread block will compute -using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape +// Threadblock tile shape +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -// This code section describes tile size a warp will compute -using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape +// Warp tile shape +using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -// This code section describes the size of MMA op -using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape +// MMA (Tensor Core instruction, in this case) tile shape +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -// This code section describes how threadblocks are scheduled on GPU +// How the kernel schedules threadblocks using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -// Number of pipelines you want to use +// Number of pipeline stages to use constexpr int NumStages = 3; -// This code section describe iterator algorithm selected is Analytic or Optimized +// Which iterator algorithm to use: Analytic or Optimized static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -// This code section describes the epilogue part of the kernel, we use default value +// The epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // Data type of output matrix. - 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized // memory access. This becomes the vector width of // math instructions in the epilogue too. ElementAccumulator, // Data type of accumulator ElementComputeEpilogue>; // Data type for alpha/beta in linear combination +// Kernel properties type using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, @@ -193,6 +292,7 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< IteratorAlgorithm >::Kernel; +// Type of the actual kernel using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -230,7 +330,7 @@ struct Options { beta(0), benchmark(false) { } - // Verify the problem size is compatible with the CUTLASS Convolution implementation. + // Verify that the problem size is compatible with CUTLASS's convolution implementation bool valid() { // @@ -256,7 +356,7 @@ struct Options { return true; } - /// Updates input and filter sizes + /// Update input and filter sizes void update( cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) { @@ -270,7 +370,7 @@ struct Options { padding.c() = filter_size.w() / 2; } - // Parses the command line + // Parse command-line arguments void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); @@ -302,11 +402,11 @@ struct Options { cmd.get_cmd_line_argument("k", filter_size.n()); cmd.get_cmd_line_argument("r", filter_size.h()); cmd.get_cmd_line_argument("s", filter_size.w()); - filter_size.c() = input_size.c(); + filter_size.c() = input_size.c(); cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("tag", tag); @@ -320,12 +420,12 @@ struct Options { } } - /// Prints the usage statement. + /// Print an explanation of the command-line arguments std::ostream & print_usage(std::ostream &out) const { out << "16_ampere_tensorop_conv2dfprop example\n\n" - << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" - << " forward convolution on tensors of layout NHWC.\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types\n" + << " to compute forward convolution on tensors of layout NHWC.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --n= Input tensor extent N\n" @@ -350,7 +450,7 @@ struct Options { return out; } - + /// Computes the output tensor size (NPQK) cutlass::Tensor4DCoord output_size() const { return cutlass::Tensor4DCoord( @@ -360,19 +460,20 @@ struct Options { filter_size.n()); } - /// Compute performance in GFLOP/s + /// Compute performance in Gflop/s + /// + /// Gflop/s stands for billions (10^9) of + /// floating-point operations per second (Gflop/s). double gflops(double runtime_s) const { // Number of multiply-adds = NPQK * CRS int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } }; -///////////////////////////////////////////////////////////////////////////////////////////////// - struct Result { double runtime_ms; double gflops; @@ -380,14 +481,14 @@ struct Result { cutlass::Status reference_check; cudaError_t error; - Result(): - runtime_ms(0), + Result(): + runtime_ms(0), gflops(0), status(cutlass::Status::kSuccess), reference_check(cutlass::Status::kInvalid), error(cudaSuccess) { } - static std::ostream & print_header(std::ostream &out, Options const &options) { + static std::ostream& print_header(std::ostream &out, Options const &options) { if (!options.tag.empty()) { out << "Name,"; @@ -404,7 +505,7 @@ struct Result { out << options.tag << ","; } - out + out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << "," @@ -420,8 +521,6 @@ struct Result { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Runs one benchmark Result profile_convolution(Options const &options) { @@ -441,7 +540,7 @@ Result profile_convolution(Options const &options) { // Initialize tensors // - // Fill tensor A on host with uniform-distribution random data + // Fill tensor A on host with uniformly distributed random data cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, @@ -449,7 +548,7 @@ Result profile_convolution(Options const &options) { ElementInputA(-8), 0); - // Fill tensor B on host with uniform-distribution random data + // Fill tensor B on host with uniformly distributed random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, @@ -457,7 +556,7 @@ Result profile_convolution(Options const &options) { ElementInputB(-8), 0); - // Fill tensor C on host with uniform-distribution random data + // Fill tensor C on host with uniformly distributed random data cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), 1, @@ -490,7 +589,7 @@ Result profile_convolution(Options const &options) { int split_k_slices = 1; // Construct Conv2dProblemSize with user defined output size - cutlass::conv::Conv2dProblemSize problem_size( + cutlass::conv::Conv2dProblemSize problem_size( options.input_size, options.filter_size, options.padding, @@ -501,7 +600,7 @@ Result profile_convolution(Options const &options) { split_k_slices ); - // Construct ImplicitGemm::Argument structure with conv2d + // Construct ImplicitGemm::Argument structure with conv2d // problem size, data pointers, and epilogue values typename ImplicitGemm::Arguments arguments{ problem_size, @@ -539,7 +638,7 @@ Result profile_convolution(Options const &options) { // // Optional reference check // - + if (options.reference_check) { std::cout << "Verification on host...\n"; @@ -552,8 +651,7 @@ Result profile_convolution(Options const &options) { ElementOutput, LayoutOutput, ElementComputeEpilogue, - ElementAccumulator, - cutlass::NumericConverter + ElementAccumulator >( problem_size, tensor_a.host_ref(), @@ -564,7 +662,7 @@ Result profile_convolution(Options const &options) { options.beta ); - // Check if output from CUTLASS kernel and reference kernel are equal or not + // Check if CUTLASS kernel and reference kernel produced the same output tensor_d.sync_host(); bool passed = cutlass::reference::host::TensorEquals( @@ -589,14 +687,14 @@ Result profile_convolution(Options const &options) { std::stringstream ss; ss << "16_ampere_workspace_conv2dfprop_" - << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() + << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" - << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() + << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; std::ofstream output_workspace(ss.str()); - output_workspace + output_workspace << "Input = \n" << tensor_a.host_view() << "\n\n" << "Filters = \n" << tensor_b.host_view() << "\n\n"; @@ -616,7 +714,7 @@ Result profile_convolution(Options const &options) { if (options.measure_performance) { cudaEvent_t events[2]; - + for (auto & event : events) { result.error = cudaEventCreate(&event); if (result.error != cudaSuccess) { @@ -632,7 +730,7 @@ Result profile_convolution(Options const &options) { return result; } - // Launch a sequence of implicit GEMM operations on the device + // Launch a sequence of implicit GEMM operations on the device. for (int iteration = 0; iteration < options.iterations; ++iteration) { result.status = implicit_gemm_op(); CUTLASS_CHECK(result.status); @@ -652,7 +750,7 @@ Result profile_convolution(Options const &options) { return result; } - // Measure elapsed runtime + // Measure elapsed runtime. float runtime_ms = 0; result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); if (result.error != cudaSuccess) { @@ -660,7 +758,7 @@ Result profile_convolution(Options const &options) { return result; } - // Print average runtime and GFLOPs. + // Print average run time and floating-point throughput (Gflop/s). result.runtime_ms = double(runtime_ms) / double(options.iterations); result.gflops = options.gflops(result.runtime_ms / 1000.0); @@ -673,8 +771,6 @@ Result profile_convolution(Options const &options) { return result; } -///////////////////////////////////////////////////////////////////////////////////////////////// - int main(int argc, char const **args) { bool notSupported = false; @@ -701,7 +797,7 @@ int main(int argc, char const **args) { } Options options; - + options.parse(argc, args); if (options.help) { @@ -768,5 +864,3 @@ int main(int argc, char const **args) { return 0; } - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/22_quaternion_conv/quaternion_conv.cu b/examples/22_quaternion_conv/quaternion_conv.cu index 57df73f1..2c6a3502 100644 --- a/examples/22_quaternion_conv/quaternion_conv.cu +++ b/examples/22_quaternion_conv/quaternion_conv.cu @@ -470,8 +470,7 @@ Result profile_convolution(Options const &options) { ElementOutput, LayoutOutput, ElementComputeEpilogue, - ElementAccumulator, - cutlass::NumericConverter + ElementAccumulator >( problem_size, tensor_a.host_ref(), diff --git a/examples/24_gemm_grouped/gemm_grouped.cu b/examples/24_gemm_grouped/gemm_grouped.cu index 71b77b04..66ff1432 100644 --- a/examples/24_gemm_grouped/gemm_grouped.cu +++ b/examples/24_gemm_grouped/gemm_grouped.cu @@ -37,7 +37,7 @@ leading dimensions and problem sizes are stored in arrays in GMEM. This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM - concept may be distinct. + concept may be distinct. This benchmark program initializes a workspace with random problem sizes for a given number of groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to @@ -186,7 +186,7 @@ struct Options { // // Methods - // + // Options(): help(false), @@ -216,7 +216,7 @@ struct Options { cmd.get_cmd_line_argument("alignment", alignment, 8); cmd.get_cmd_line_argument("groups", problem_count, 15); cmd.get_cmd_line_argument("alpha", alpha, 1.0f); - cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); cmd.get_cmd_line_argument("iterations", iterations, 20); cmd.get_cmd_line_argument("streams", cuda_streams, 0); cmd.get_cmd_line_argument("verbose", verbose, false); @@ -455,13 +455,13 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = int64_t(); for (auto const & problem : problem_sizes) { fmas += problem.product(); } - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -546,7 +546,7 @@ public: template void initialize_tensor( Element *ptr, - size_t capacity, + size_t capacity, cutlass::Distribution::Kind dist_kind, uint32_t seed) { @@ -578,7 +578,7 @@ public: cutlass::reference::device::BlockFillRandomUniform( ptr, capacity, seed, scope_max, scope_min, 0); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::device::BlockFillRandomGaussian( @@ -589,7 +589,7 @@ public: // Fill with increasing elements cutlass::reference::device::BlockFillSequential( ptr, capacity, Element(1), Element()); - } + } else { // Fill with all 1s @@ -674,13 +674,13 @@ public: ptr_A.reset(problem_count()); ptr_A.copy_from_host(ptr_A_host.data()); - + ptr_B.reset(problem_count()); ptr_B.copy_from_host(ptr_B_host.data()); - + ptr_C.reset(problem_count()); ptr_C.copy_from_host(ptr_C_host.data()); - + ptr_D.reset(problem_count()); ptr_D.copy_from_host(ptr_D_host.data()); @@ -712,7 +712,7 @@ public: MatrixCoord extent_A{problem.m(), problem.k()}; MatrixCoord extent_B{problem.k(), problem.n()}; MatrixCoord extent_C{problem.m(), problem.n()}; - + cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); @@ -724,18 +724,18 @@ public: cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, - ElementC, LayoutC, + ElementC, LayoutC, ElementCompute, ElementAccumulator >( problem, - options.alpha, + options.alpha, view_A, Gemm::kTransformA, view_B, Gemm::kTransformB, - options.beta, - view_C, - view_Ref_device, + options.beta, + view_C, + view_Ref_device, ElementAccumulator(0) ); @@ -781,8 +781,8 @@ public: std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n"; for (auto const & bin : this->options.problem_bins) { - std::cout << " [" << bin_idx << "]: " - << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() + std::cout << " [" << bin_idx << "]: " + << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() << ", batch count: " << bin.second.size() << "\n"; ++bin_idx; @@ -832,7 +832,7 @@ public: for (auto const & bin : this->options.problem_bins) { int first_idx = bin.second.front(); - + bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx)); bin_count.push_back(int32_t(bin.second.size())); @@ -974,7 +974,7 @@ public: std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; return result; } - + } // @@ -1027,7 +1027,7 @@ public: int last_stream_idx = 0; for (int iter = 0; iter < this->options.iterations; ++iter) { - + for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx]; @@ -1098,7 +1098,7 @@ public: std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } - + // // Wait for work to be completed // @@ -1129,10 +1129,10 @@ public: for (auto event : events) { (void)cudaEventDestroy(event); } - + for (auto stream : cuda_streams) { if (stream) { - (void)cudaStreamDestroy(stream); + (void)cudaStreamDestroy(stream); } } @@ -1203,8 +1203,8 @@ public: int tiles = Gemm::problem_tile_count(problem); total_tiles += tiles; - std::cout << " [" << idx << "]: " - << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << " (" << tiles << " threadblock tiles)" << "\n"; ++idx; @@ -1442,12 +1442,12 @@ int main(int argc, char const **args) { } if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { - + // // This example requires an NVIDIA Ampere-architecture GPU. // - std::cout + std::cout << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " << "later (compute capability 80 or greater).\n"; @@ -1497,9 +1497,9 @@ int main(int argc, char const **args) { cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< - ElementOutput, + ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, + ElementAccumulator, ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, @@ -1519,8 +1519,8 @@ int main(int argc, char const **args) { cutlass::ComplexTransform::kNone, 8, ElementOutput, LayoutC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, @@ -1531,7 +1531,7 @@ int main(int argc, char const **args) { // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. // This parameter is passed in at present to match the APIs of other kernels. The parameter // is unused within the kernel. - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 4>::GemmKernel; using GemmGrouped = cutlass::gemm::device::GemmGrouped; diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu similarity index 99% rename from examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu rename to examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu index fc6f6af8..adc9b407 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu @@ -181,7 +181,7 @@ struct Options { << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n" + << "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm --m=1024 --n=512 \\\n" << " --alpha=2 --beta=0.707 \n\n"; return out; diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt index 679ada2a..1efc3056 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt @@ -27,9 +27,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +# Both filenames are shorter to avoid MAX_PATH issues on Windows. cutlass_example_add_executable( - 29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm - 29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu + 29_3xtf32_complex_gemm + 29_3xtf32_complex_gemm.cu ) diff --git a/examples/39_gemm_permute/gemm_permute.cu b/examples/39_gemm_permute/gemm_permute.cu index ed3e3998..84e9052c 100644 --- a/examples/39_gemm_permute/gemm_permute.cu +++ b/examples/39_gemm_permute/gemm_permute.cu @@ -34,10 +34,12 @@ This example computes batched GEMM operations with output results permuted as reshaped tensors. - We provide layout plugin as a flexible tool for users to add any customized output tensor permute operation, + We provide layout plugin as a flexible tool for users to add any customized input/output tensor permute operation, or any other generalized global memory writeout address computation. To add a customized layout, add new class in include/cutlass/layout/permute.h + In this example we use several permute operations (permute([0, 2, 1, 3])) + In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on output matrix. The address computations are performed in compute(col_init, row_init, stride_init, @@ -46,12 +48,13 @@ Tips: - 1) Make sure to set batch_stride_D to zero for BMM permute; Also the BMM GEMM should be in mode - cutlass::gemm::GemmUniversalMode::kBatched instead of kArray + 1) Make sure to set batch_stride to zero for BMM permute; also the BMM GEMM should be in mode + cutlass::gemm::GemmUniversalMode::kBatched instead of kArray. - 2) When the last dimension is touched in permute op (for example permute([0, 2, 3, 1])), AlignmentC should - be set to 1. If the last dimension is untouched, one can set AlignmentC to be larger like 8 in our example. - As a result, permute op without touching the last dimension is recommended to obtain the best performance gain. + 2) When the contiguous dimension is touched in permute op (for example [0, 2, 3, 1] for row-major matrix + or [1, 0, 2, 3] for column-major), Alignment should be set to 1 for the corresponding matrix. + If the last dimension is untouched, one can set Alignment to be larger like 8 in our example. + As a result, permute op without touching the unit stride dimension is recommended to obtain the best performance. Examples: @@ -87,50 +90,65 @@ #include "cutlass/util/reference/host/gemm_complex.h" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/layout/permute.h" +#include "layouts.h" +#include "permute_info.h" + /// Tensor4DPermuteBMM0213 ---> -/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped -/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. -const int D1 = 12; +/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimension as [B, M, N]) reshaped +/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor. +int constexpr D1 = 12; /// Tensor5DPermute20314 ---> -/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped -/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. -const int T1 = 16; -const int T2 = 3; -const int T3 = 8; - -// Alignment C -const int AlignmentC = 8; +/// Permute layout function for 5-D permuted tensors with matrix (dimension as [M, N]) reshaped +/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding tensor. +int constexpr T1 = 16; +int constexpr T2 = 3; +int constexpr T3 = 8; + +/// Tensor4DPermute0213 ---> +/// Permute layout function for 4-D permuted tensors with matrix (dimension as [M, N]) reshaped +/// as [M/S1, S1, S2, N/S2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. +int constexpr S1 = 8; +int constexpr S2 = 4; + +// // // Alignments +int constexpr AlignmentA = 8; +int constexpr AlignmentB = 8; +int constexpr AlignmentC = 8; + +/// GEMM element types +using ElementInput = cutlass::half_t; +using ElementOutput = cutlass::half_t; +using ElementAccumulator = float; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Result structure -struct Result { - - double runtime_ms; - double gflops; - cutlass::Status status; - cudaError_t error; - bool passed; - - // - // Methods - // - - Result( - double runtime_ms = 0, - double gflops = 0, - cutlass::Status status = cutlass::Status::kSuccess, - cudaError_t error = cudaSuccess - ): - runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } -}; +/// Useful macros + +#define CHECK_CUDA_CALL(call, handler) \ +do { \ + cudaError_t __err = (call); \ + if (__err != cudaSuccess) { \ + std::cerr << #call " failed: " << cudaGetErrorString(__err) << std::endl; \ + handler; \ + } \ +} while(0) + +#define CHECK_CUTLASS_CALL(call, handler) \ +do { \ + cutlass::Status __status = (call); \ + if (__status != cutlass::Status::kSuccess) { \ + std::cerr << #call " failed: " << cutlass::cutlassGetStatusString(__status) << std::endl; \ + handler; \ + } \ +} while(0) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -184,81 +202,79 @@ struct Options { int m, n, k; - cmd.get_cmd_line_argument("m", m, 128); + cmd.get_cmd_line_argument("m", m, 384); cmd.get_cmd_line_argument("n", n, 192); - cmd.get_cmd_line_argument("k", k, 128); - cmd.get_cmd_line_argument("batch-count", batch_count, 768); + cmd.get_cmd_line_argument("k", k, 384); + cmd.get_cmd_line_argument("batch-count", batch_count, 96); - cutlass::gemm::GemmCoord problem(m, n, k); - problem_each = problem; - - if (batch_count % D1 != 0){ - std::cerr << "\nProblem count error (problem-count = " << batch_count << "). " - << "problem-count needs to be divided with no remain by " << D1 << " (D1)." - << " (Required by the Batched GEMM permute Tensor4DPermuteBMM0213)\n\n"; - error = true; - } - - if (m % (AlignmentC * T1) != 0){ - std::cerr << "\nProblem m size error (m = " << m << "). " - << "m needs to be divided with no remain by " << (AlignmentC * T1) << " (AlignmentC * T1)." - << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; - error = true; - } - - if (n % (AlignmentC * (T2 * T3)) != 0){ - std::cerr << "\nProblem n size error (n = " << n << "). " - << "n needs to be divided with no remain by " << (AlignmentC * (T2 * T3)) << " (AlignmentC * T2 * T3)." - << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; - error = true; - } + problem_each = cutlass::gemm::GemmCoord(m, n, k); } /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "39_gemm_permute\n\n" - << " 1) This example firstly profiles the performance of a batched GEMM kernel with BMM whole output" - << " (including output matrices for each batch) as permuted 4D Tensor." - << " The BMM tensor output in shape of [B, M, N] is reshaped as [B/D1, D1, M, N] and then permuted with" - << " permute([0, 2, 1, 3]) to be in shape of [B/D1, M, D1, N].\n\n" - << " 2) This example also profiles the performance of a normal GEMM kernel with output as permuted 5D Tensor." - << " The GEMM matrix output in shape of [M, N] is reshaped as [M/T1, T1, T2, T3, N/T2/T3] and then permuted" - << " with permute([2, 0, 3, 1, 4]) to be in shape of [T2, M/T1, T3, T1, N/T2/T3].\n\n" - << " Note: D1, T1, T2, T3 are compile-time constants defined in gemm_permute.cu\n\n" - << "Options:\n\n" - << " --help If specified, displays this usage statement.\n\n" - << " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" - << " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" - << " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" - << " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=128)\n" - << " --alpha= Epilogue scalar alpha (real part)\n" - << " --beta= Epilogue scalar beta (real part)\n\n" - << " --iterations= Number of profiling iterations to perform.\n" - << " --reference-check= If true, performs reference check.\n" - << " --verbose= If true, prints problem sizes and batching structure.\n"; - - out << "\n\nExamples:\n\n" - - << "# Runs a batched GEMM with 96 batches\n" - << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96\n\n" - - << "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" - << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true\n\n" - - << "# Execute batched GEMM and profile with NSight\n" - << "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + out << + "39_gemm_permute\n" + "\n" + " This example tests and profiles the performance of normal GEMM and batched GEMM with different" + " combinations of fused permutations of input and output tensors." + "\n" + " Permutations considered in this example:\n" + "\n" + " Normal GEMM:\n" + " 1) Tensor4DPermute0213: matrix of shape [X, Y] is reshaped as [X/S1, S1, S2, Y/S2] and has its dimensions" + " permuted as [0, 2, 1, 3], resulting in shape [X/S1, S2, S1, Y/S2] viewed as matrix of shape [X*S2/S1, Y*S1/S2].\n" + " 2) Tensor5DPermute20314: matrix of shape [X, Y] is reshaped as [X/T1, T1, T2, T3, Y/T2/T3] and has its dimensions" + " permuted as [2, 0, 3, 1, 4], resulting in shape [T2, X/T1, T3, T1, Y/T2/T3] viewed as matrix of shape [X*T2/T1, Y*T1/T2].\n" + "\n" + " Batched GEMM:\n" + " 3) Tensor4DPermuteBMM0213: batched tensor of 3D shape [B, X, Y] is reshaped as 4D shape [B/D1, D1, X, Y]" + " and has its dimensions permuted as [0, 2, 1, 3], resulting in shape [B/D1, X, D1, Y] viewed as" + " a matrix of shape [B/D1, X, Y*D1] for batched GEMM purposes.\n" + "\n" + " Note: S1, S2, D1, D2, T1, T2, T3 are compile-time constants defined in gemm_permute.cu." + " Runtime specification of these values is not supported." + " These values along with alignment requirements place constraints on supported matrix sizes.\n" + "\n" + " Note: X, Y above may refer to M, N or K dimensions of GEMM problem, depending on the tensor considered (A, B or D)." + " For the output tensor D the values correspond directly to dimensions of D, whereas for A and B the original dimensions" + " X', Y' are inferred from the ones supplied to the GEMM, taking into account the permute operation.\n" + "\n" + "Options:\n" + "\n" + " --help If specified, displays this usage statement.\n\n" + " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" + " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" + " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" + " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=384)\n" + " --alpha= Epilogue scalar alpha (real part)\n" + " --beta= Epilogue scalar beta (real part)\n\n" + " --iterations= Number of profiling iterations to perform.\n" + " --reference-check= If true, performs reference check.\n" + " --verbose= If true, prints problem sizes and batching structure.\n" + "\n" + "Examples:\n" + "\n" + "# Runs a batched GEMM with 96 batches\n" + "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96\n" + "\n" + "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" + "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96 --k=1024 --verbose=true\n" + "\n" + "# Execute batched GEMM and profile with NSight\n" + "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n" + "\n"; return out; } /// Compute performance in GFLOP/s - double gflops(double runtime_s) const { + double gflops(double runtime_s, bool batched) const { // Number of real-valued multiply-adds int64_t fmas = int64_t(); - fmas += problem_each.product() * batch_count; + fmas += problem_each.product() * (batched ? batch_count : 1); // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; @@ -267,28 +283,77 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template -class Testbed { -public: +namespace detail +{ + +/// Dimension-generic permutation loop +template +void permute_host_impl( + cutlass::TensorView const & input, + cutlass::TensorView const & output, + PermuteOp && permute, + Coord & coord +) { + static_assert(Layout::kRank == Coord::kRank, "Incompatible Layout and Coord types"); + if constexpr (I == Coord::kRank) { + output.at(permute(coord)) = input.at(coord); + } + else { + for (coord[I] = 0; coord[I] < input.extent(I); ++coord[I]) { + permute_host_impl(input, output, std::forward(permute), coord); + } + } +} - // - // Type definitions - // +} // namespace detail + +/// Perform a reference (host-based) permutation of an input tensor +template +void permute_host( + cutlass::TensorView const &input, + cutlass::TensorView const &output, + int batch_count) { + Layout layout = input.layout(); + cutlass::MatrixCoord extent = input.extent(); + + std::size_t num_elems = layout.capacity(extent) * batch_count; + std::vector h_input(num_elems); + cutlass::device_memory::copy_to_host(h_input.data(), input.data(), num_elems); + + std::vector h_output(num_elems); + + using Info = PermuteInfo; + using TensorLayout = typename Info::Layout; + + auto shape_orig = Info::original_shape(extent, batch_count); + auto shape_perm = Info::permute(shape_orig); + + cutlass::TensorView view_input(h_input.data(), TensorLayout::packed(shape_orig), shape_orig); + cutlass::TensorView view_output(h_output.data(), TensorLayout::packed(shape_perm), shape_perm); - using ElementA = typename GemmBatched::ElementA; - using ElementB = typename GemmBatched::ElementB; - using ElementC = typename GemmBatched::ElementC; - using ElementAccumulator = typename GemmBatched::ElementAccumulator; + decltype(shape_orig) coord; + detail::permute_host_impl<0>(view_input, view_output, Info::permute, coord); - using EpilogueOutputOp = typename GemmBatched::GemmKernel::Epilogue::OutputOp; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; + cutlass::device_memory::copy_to_device(output.data(), h_output.data(), num_elems); +} + +template +struct LayoutInfo; - using LayoutA = typename GemmBatched::LayoutA; - using LayoutB = typename GemmBatched::LayoutB; - using LayoutC = typename GemmBatched::LayoutC; +template<> +struct LayoutInfo { + static std::string name() { return "RowMajor"; } +}; + +template<> +struct LayoutInfo { + static std::string name() { return "ColumnMajor"; } +}; - using MatrixCoord = typename LayoutC::TensorCoord; +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Testbed { private: // @@ -323,63 +388,90 @@ public: ): options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } - /// Verbose BMM info - void print_BMM_info_() { - - // Print batched GEMM - std::cout << "Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor:\n"; +private: - auto problem = options.problem_each; - std::cout - << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() - << ", batch count: " << options.batch_count << "\n"; + /// Print permutation info for one tensor + template + void print_tensor_info( + std::ostream & os, + std::string const &tensor_name, + int row_dim, + int col_dim) { + + cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim)); + using Info = PermuteInfo; + + os << "tensor " << tensor_name << ": " << Info::desc() << "\n"; + os << " extent: [" << extent.row() << ", " << extent.column() << "]"; + if (Info::kBatched) { + os << ", batch count: " << options.batch_count; + } + os << "\n"; + if (!cutlass::layout::is_trivial_permute) { + auto shape_orig = Info::original_shape(extent, options.batch_count); + auto shape_perm = Info::permute(shape_orig); + os << " original: [" << shape_orig << "]\n"; + os << " permuted: [" << shape_perm << "]\n"; + } + } - std::cout << "output tensor shape: [" << options.batch_count << ", " << problem.m() << ", " - << problem.n() <<"]\n"; - std::cout << "reshaped as: [" << options.batch_count / D1 << ", " << D1 << ", " - << problem.m() << ", " << problem.n() <<"]\n"; - std::cout << "finally permuted as: [" << options.batch_count / D1 << ", " << problem.m() << ", " - << D1 << ", " << problem.n() <<"]\n"; + /// Check shape compatibility for one tensor + template + bool check_tensor_shape( + std::string const &tensor_name, + int row_dim, + int col_dim) { - std::cout << "----------------------------------------------------\n"; + cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim)); - } + using Info = PermuteInfo; - /// Verbose normal GEMM info - void print_GEMM_info_() { + auto rowAlign = cutlass::platform::is_same::value ? Alignment : 1; + auto colAlign = cutlass::platform::is_same::value ? Alignment : 1; - // Print batched GEMM - std::cout << "Normal GEMM with permute([2, 0, 3, 1, 4]):\n"; + auto rowFactor = Info::kRowFactor * rowAlign; + auto colFactor = Info::kColumnFactor * colAlign; - auto problem = options.problem_each; - std::cout - << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << "\n"; + // Assumes row-major layout + bool const valid_row = extent.row() % rowFactor == 0; + if (!valid_row) { + std::cerr << "Invalid tensor " << tensor_name << " row size = " << extent.row() << ", " + "must be divisible by " << rowFactor << ", " + "required by " << Info::name() << + (rowAlign > 1 ? (" and alignment of " + std::to_string(rowAlign)) : "") << std::endl; + } - std::cout << "output tensor shape: [" << problem.m() << ", " << problem.n() <<"]" << std::endl; - std::cout << "reshaped as: [" << problem.m() / T1 << ", " << T1 << ", " - << T2 << ", " << T3 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; - std::cout << "finally permuted as: [" << T2 << ", " << problem.m() / T1 << ", " - << T3 << ", " << T1 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; + bool const valid_col = extent.column() % colFactor == 0; + if (!valid_col) { + std::cerr << "Invalid tensor " << tensor_name << " column size = " << extent.column() << ", " + "must be divisible by " << colFactor << ", " + "required by " << Info::name() << + (colAlign > 1 ? (" and alignment of " + std::to_string(colAlign)) : "") << std::endl; + } - std::cout << "----------------------------------------------------\n"; + bool const valid_bsz = options.batch_count % Info::kBatchFactor == 0; + if (!valid_bsz) { + std::cerr << "Invalid batch count = " << options.batch_count << ", " + "must be divisible by " << Info::kBatchFactor << ", " + "required by " << Info::name() << std::endl; + } + return valid_row && valid_col && valid_bsz; } -private: - /// Helper to initialize a tensor view template void initialize_tensor_( - Element *ptr, - size_t capacity, - cutlass::Distribution::Kind dist_kind, - uint32_t seed) { + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { if (dist_kind == cutlass::Distribution::Uniform) { Element scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; if (bits_input == 1) { scope_max = 2; @@ -424,13 +516,8 @@ private: } /// Initializes data structures - void initialize_(int batch_count) { + void initialize(int batch_count) { - // - // Choose random problem sizes - // - - // construct a few problems of random sizes srand(seed); int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; @@ -438,19 +525,13 @@ private: int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; - // - // Assign space - // - + // Allocate space block_A.reset(total_elements_A); block_B.reset(total_elements_B); block_C.reset(total_elements_C); block_D.reset(total_elements_D); - // - // Initialize the problems of the workspace - // - + // Initialize input tensors initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); @@ -459,668 +540,685 @@ private: block_D.get(), total_elements_D, ElementC(), ElementC()); } - /// Verifies the BMM GEMM result - bool verify_BMM_() { - bool passed = true; + /// Check device GEMM results against a reference implementation with separate host-based permutation + template + bool validate(Gemm const &gemm) { + + bool constexpr kBatched = PermuteInfo::kBatched + || PermuteInfo::kBatched + || PermuteInfo::kBatched; + + int const batch_count = kBatched ? options.batch_count : 1; cutlass::gemm::GemmCoord problem = options.problem_each; - LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); - LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); - LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); - LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + cutlass::MatrixCoord extent_A{problem.m(), problem.k()}; + cutlass::MatrixCoord extent_B{problem.k(), problem.n()}; + cutlass::MatrixCoord extent_C{problem.m(), problem.n()}; - MatrixCoord extent_A{problem.m(), problem.k()}; - MatrixCoord extent_B{problem.k(), problem.n()}; - MatrixCoord extent_C{problem.m(), problem.n()}; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + LayoutA layout_A(LayoutA::packed(extent_A)); + LayoutB layout_B(LayoutB::packed(extent_B)); + LayoutC layout_C(LayoutC::packed(extent_C)); + + auto size_A = layout_A.capacity(extent_A) * batch_count; + auto size_B = layout_B.capacity(extent_B) * batch_count; + auto size_C = layout_C.capacity(extent_C) * batch_count; cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); + cutlass::TensorView view_D(block_D.get(), layout_C, extent_C); + + cutlass::DeviceAllocation block_A_perm(size_A); + cutlass::DeviceAllocation block_B_perm(size_B); + + cutlass::TensorView view_A_perm(block_A_perm.get(), layout_A, extent_A); + cutlass::TensorView view_B_perm(block_B_perm.get(), layout_B, extent_B); + + permute_host(view_A.const_view(), view_A_perm, batch_count); + permute_host(view_B.const_view(), view_B_perm, batch_count); - cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C) * options.batch_count); - cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); + cutlass::DeviceAllocation block_D_ref(size_C); + cutlass::TensorView view_D_ref(block_D_ref.get(), layout_C, extent_C); + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; // Reference GEMM cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - ElementCompute, ElementAccumulator + typename EpilogueOutputOp::ElementCompute, + typename Gemm::ElementAccumulator >( problem, options.alpha, - view_A, - GemmBatched::kTransformA, - view_B, - GemmBatched::kTransformB, + view_A_perm, + Gemm::kTransformA, + view_B_perm, + Gemm::kTransformB, options.beta, view_C, - view_Ref_device, + view_D_ref, ElementAccumulator(0), - options.batch_count, + batch_count, options.problem_each.m() * options.problem_each.k(), options.problem_each.n() * options.problem_each.k(), options.problem_each.m() * options.problem_each.n(), options.problem_each.m() * options.problem_each.n() ); - // Copy to host memory - std::vector matrix_D(layout_D.capacity(extent_C) * options.batch_count); - std::vector matrix_Ref(layout_D.capacity(extent_C) * options.batch_count); - - cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); - cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); - - // Print out the results and reference in 4D Tensor - // [options.batch_count, options.problem_each.m() * options.problem_each.n()] -> [D0, D1, D2, D3]. - // After permute Op, -> [D0, D2, D1, D3]. - int D0 = options.batch_count / D1; - int D2 = options.problem_each.m(); - int D3 = options.problem_each.n(); - - cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently - cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D2, D1, D3})), cutlass::Tensor4DCoord({D0, D2, D1, D3})); - - cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), - cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D1, D2, D3})), cutlass::Tensor4DCoord({D0, D1, D2, D3})); - - // Tensor Permute Op on reference tensor - cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor4DCoord({D0, D2, D1, D3})); - for (int n = 0; n < D0; ++n) { - for (int h = 0; h < D1; ++h) { - for (int w = 0; w < D2; ++w) { - for (int c = 0; c < D3; ++c) { - view_Ref_Permute_Tensor.at({n, w, h, c}) = view_Ref_Tensor.at({n, h, w, c}); - } - } - } - } + cutlass::DeviceAllocation block_D_perm(size_C); + cutlass::TensorView view_D_perm(block_D_perm.get(), layout_C, extent_C); + permute_host(view_D_ref.const_view(), view_D_perm, batch_count); // Reference check - passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); - - if (!passed) { - std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; - return passed; - } - - std::cout << "Passed verification" << std::endl; - return passed; - } - - bool verify_GEMM_normal_() { - - bool passed = true; - - cutlass::gemm::GemmCoord problem = options.problem_each; - - LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); - LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); - LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); - LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); - - MatrixCoord extent_A{problem.m(), problem.k()}; - MatrixCoord extent_B{problem.k(), problem.n()}; - MatrixCoord extent_C{problem.m(), problem.n()}; - - cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); - cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); - cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); + return cutlass::reference::device::BlockCompareEqual(view_D_perm.data(), view_D.data(), size_C); +} - cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); - cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); +public: - // Reference GEMM - cutlass::reference::device::GemmComplex< - ElementA, LayoutA, - ElementB, LayoutB, - ElementC, LayoutC, - ElementCompute, ElementAccumulator - >( - problem, - options.alpha, - view_A, - GemmBatched::kTransformA, - view_B, - GemmBatched::kTransformB, - options.beta, - view_C, - view_Ref_device, - ElementAccumulator(0) - ); + template + bool profile_GEMM_permute() { - // Copy to host memory - std::vector matrix_D(layout_D.capacity(extent_C)); - std::vector matrix_Ref(layout_D.capacity(extent_C)); - - cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); - cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); - - // Print out the results and reference in 5D Tensor - // [options.problem_each.m(), options.problem_each.n()] -> [T0, T1, T2, T3, T4]. - // options.problem_each.m() == T0 * T1 - // options.problem_each.n() == T2 * T3 * T4 - // After permute Op, -> [T2, T0, T3, T1, T4]. - int T0 = options.problem_each.m() / T1; - int T4 = options.problem_each.n() / T2 / T3; - - cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently - cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})), cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); - cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), - cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})), cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})); - - // Tensor Permute Op on reference tensor - cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); - for (int n = 0; n < T0; ++n) { - for (int d = 0; d < T1; ++d) { - for (int h = 0; h < T2; ++h) { - for (int w = 0; w < T3; ++w) { - for (int c = 0; c < T4; ++c) { - view_Ref_Permute_Tensor.at({h, n, w, d, c}) = view_Ref_Tensor.at({n, d, h, w, c}); // permute([2,0,3,1,4]) - } - } - } - } - } + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; - // Reference check - passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); + using PermuteALayout = typename Gemm::PermuteALayout; + using PermuteBLayout = typename Gemm::PermuteBLayout; + using PermuteDLayout = typename Gemm::PermuteDLayout; - if (!passed) { - std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; - return passed; - } + bool constexpr kBatched = PermuteInfo::kBatched + || PermuteInfo::kBatched + || PermuteInfo::kBatched; - std::cout << "Passed verification" << std::endl; - return passed; -} - -public: - /// Executes a conventional batched GEMM kernel. - Result profile_batched_kBatched() { + std::cout << "\n" + "====================================================\n" + << (kBatched ? "Batched" : "Normal") << " GEMM:" + << "\n A=" << LayoutInfo::name() << "," << PermuteInfo::name() + << "\n B=" << LayoutInfo::name() << "," << PermuteInfo::name() + << "\n D=" << LayoutInfo::name() << "," << PermuteInfo::name() + << "\n" + "====================================================\n"; - std::cout << "\n====================================================" << std::endl; - std::cout << "Batched GEMM (CUTLASS):\n" - << "====================================================" << std::endl; - if (options.verbose) { - print_BMM_info_(); + print_tensor_info(std::cout, "A", 0, 2); + print_tensor_info(std::cout, "B", 2, 1); + print_tensor_info(std::cout, "D", 0, 1); } + std::cout << std::endl; - Result result; + bool valid = true; + valid &= check_tensor_shape("A", 0, 2); + valid &= check_tensor_shape("B", 2, 1); + valid &= check_tensor_shape("D", 0, 1); + if (!valid) + { + std::cout << "Skipped test" << std::endl; + return true; + } - result.passed = false; + int const batch_count = kBatched ? options.batch_count : 1; // Initialize the problem - initialize_(options.batch_count); + initialize(batch_count); // Configure the GEMM arguments + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); // Please make sure all problem_sizes are the same for kBatched mode auto problem = options.problem_each; - // For regular BMM - int64_t batch_stride_C = problem.m() * problem.n(); - // For BMM permute output ---> make sure to set batch_stride_D to zero for BMM permute op - int64_t batch_stride_D = 0; + cutlass::MatrixCoord extent_A{problem.m(), problem.k()}; + cutlass::MatrixCoord extent_B{problem.k(), problem.n()}; + cutlass::MatrixCoord extent_C{problem.m(), problem.n()}; + + LayoutA layout_A(LayoutA::packed(extent_A)); + LayoutB layout_B(LayoutB::packed(extent_B)); + LayoutC layout_C(LayoutC::packed(extent_C)); // Configure GEMM arguments - typename GemmBatched::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kBatched, - options.problem_each, - options.batch_count, + typename Gemm::Arguments arguments{ + kBatched ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, epilogue_op, (void*)block_A.get(), (void*)block_B.get(), (void*)block_C.get(), (void*)block_D.get(), - problem.m() * problem.k(), - problem.n() * problem.k(), - batch_stride_C, - batch_stride_D, - problem.k(), - problem.n(), - problem.n(), - problem.n() + // For any non-trivial permute the batch stride must be set to 0 + cutlass::layout::is_trivial_permute ? layout_A.capacity(extent_A) : 0, + cutlass::layout::is_trivial_permute ? layout_B.capacity(extent_B) : 0, + layout_C.capacity(extent_C), + cutlass::layout::is_trivial_permute ? layout_C.capacity(extent_C) : 0, + layout_A.stride(0), + layout_B.stride(0), + layout_C.stride(0), + layout_C.stride(0), }; // Initialize the GEMM object - GemmBatched gemm; - - result.status = gemm.initialize(arguments, nullptr); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; - return result; - } + Gemm gemm_normal; - // Run the batched GEMM object - result.status = gemm.run(); + CHECK_CUTLASS_CALL(gemm_normal.initialize(arguments, nullptr), return false); - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; - return result; - } + // Run the normal GEMM object + CHECK_CUTLASS_CALL(gemm_normal.run(), return false); // Wait for completion - result.error = cudaDeviceSynchronize(); - - if (result.error != cudaSuccess) { - std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); - return result; - } + CHECK_CUDA_CALL(cudaDeviceSynchronize(), return false); // // Verify correctness // - result.passed = true; - if (options.reference_check) { - result.passed = verify_BMM_(); + if (validate(gemm_normal)) { + std::cout << "\nPassed verification\n" << std::endl; + } + else { + std::cerr << "\n*** Error - problem failed the QA check ***\n" << std::endl; + return false; + } } - // - // Warm-up run of the batched GEMM object - // - result.status = gemm.run(); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; - return result; - } + // Warm-up run of the normal GEMM object + CHECK_CUTLASS_CALL(gemm_normal.run(), return false); - // // Construct events - // - cudaEvent_t events[2]; - for (auto & event : events) { - result.error = cudaEventCreate(&event); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; - return -1; - } + CHECK_CUDA_CALL(cudaEventCreate(&event), return false); } // Record an event at the start of a series of GEMM operations - result.error = cudaEventRecord(events[0]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + CHECK_CUDA_CALL(cudaEventRecord(events[0]), return false); - // // Run profiling loop - // - for (int iter = 0; iter < options.iterations; ++iter) { - gemm(); + gemm_normal(); } - // - // Stop profiling loop - // - // Record an event when the GEMM operations have been launched. - result.error = cudaEventRecord(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + CHECK_CUDA_CALL(cudaEventRecord(events[1]), return false); // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + CHECK_CUDA_CALL(cudaEventSynchronize(events[1]), return false); // Measure elapsed runtime - float runtime_ms = 0; - result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + float runtime_total_ms = 0; + CHECK_CUDA_CALL(cudaEventElapsedTime(&runtime_total_ms, events[0], events[1]), return false); // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); + double runtime_avg_ms = double(runtime_total_ms) / double(options.iterations); + double gflops = options.gflops(runtime_avg_ms / 1000.0, kBatched); - // // Cleanup - // - for (auto event : events) { - (void)cudaEventDestroy(event); + CHECK_CUDA_CALL(cudaEventDestroy(event), return false); } - std::cout << " " << 1 << " batched GEMMs launched\n"; + std::cout << " Runtime: " << runtime_avg_ms << " ms\n" + " GFLOPs: " << gflops << std::endl; - std::cout << std::endl; - std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms\n"; - std::cout << " " << "Batched GFLOPs: " << result.gflops << "\n"; - - return result; + return true; } +}; - Result profile_GEMM_permute() { +/// Shorthand alist for GEMM instantiations +template +using GemmPermute = cutlass::gemm::device::GemmUniversal< + ElementInput, LayoutA, + ElementInput, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + AlignmentC, //128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, /*kStages*/ + AlignmentA, /*AlignmentA*/ + AlignmentB, /*AlignmentB*/ + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + PermuteDLayout, /*PermuteDLayout*/ + typename cutlass::layout::InversePermute::type, /*PermuteALayout*/ + typename cutlass::layout::InversePermute::type /*PermuteBLayout*/ +>; - std::cout << "\n====================================================" << std::endl; - std::cout << "Normal GEMM (CUTLASS):\n" - << "====================================================" << std::endl; +/////////////////////////////////////////////////////////////////////////////////////////////////// - if (options.verbose) { - print_GEMM_info_(); - } +int main(int argc, char const **args) { - Result result; + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // - result.passed = false; + cudaDeviceProp props; - // Initialize the problem - initialize_(1); + CHECK_CUDA_CALL(cudaGetDeviceProperties(&props, 0), return EXIT_FAILURE); - // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // - // Please make sure all problem_sizes are the same for kBatched mode - auto problem = options.problem_each; + std::cout << "CUTLASS's GEMM+Permute example requires a GPU of NVIDIA's Ampere Architecture " + "or later (compute capability 80 or greater).\n"; - // Configure GEMM arguments - typename GemmPermute::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - options.problem_each, - 1, - epilogue_op, - (void*)block_A.get(), - (void*)block_B.get(), - (void*)block_C.get(), - (void*)block_D.get(), - 0, - 0, - 0, - 0, - problem.k(), - problem.n(), - problem.n(), - problem.n() - }; + return EXIT_SUCCESS; + } - // Initialize the GEMM object - GemmPermute gemm_normal; + // + // Parse options + // - result.status = gemm_normal.initialize(arguments, nullptr); + Options options; + + options.parse(argc, args); - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; - return result; - } + if (options.help) { + options.print_usage(std::cout) << std::endl; + return EXIT_SUCCESS; + } - // Run the normal GEMM object - result.status = gemm_normal.run(); + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return EXIT_FAILURE; + } - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; - return result; - } + // + // Define GEMM types to test + // - // Wait for completion - result.error = cudaDeviceSynchronize(); + // + // TTT (Row-major) GEMMs + // - if (result.error != cudaSuccess) { - std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); - return result; - } + using TTTGemmNormalPermuteNone = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - // - // Verify correctness - // - result.passed = true; + using TTTGemmNormalPermuteA = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - if (options.reference_check) { - result.passed = verify_GEMM_normal_(); - } + using TTTGemmNormalPermuteAD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; - // - // Warm-up run of the normal GEMM object - // - result.status = gemm_normal.run(); + using TTTGemmNormalPermuteB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; - return result; - } + using TTTGemmNormalPermuteBD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; - // - // Construct events - // + using TTTGemmNormalPermuteD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; - cudaEvent_t events[2]; + using TTTGemmNormalPermuteAB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - for (auto & event : events) { - result.error = cudaEventCreate(&event); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; - return -1; - } - } + using TTTGemmNormalPermuteABD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; - // Record an event at the start of a series of GEMM operations - result.error = cudaEventRecord(events[0]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + // + // NNN (Col-major) GEMMs + // - // - // Run profiling loop - // + using NNNGemmNormalPermuteNone = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - for (int iter = 0; iter < options.iterations; ++iter) { - gemm_normal(); - } + using NNNGemmNormalPermuteA = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - // - // Stop profiling loop - // + using NNNGemmNormalPermuteAD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; - // Record an event when the GEMM operations have been launched. - result.error = cudaEventRecord(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + using NNNGemmNormalPermuteB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + using NNNGemmNormalPermuteBD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; - // Measure elapsed runtime - float runtime_ms = 0; - result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } + using NNNGemmNormalPermuteD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; - // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); + using NNNGemmNormalPermuteAB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - // - // Cleanup - // + using NNNGemmNormalPermuteABD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; - for (auto event : events) { - (void)cudaEventDestroy(event); - } + // + // NNT (Col-major inputs, row-major output) GEMMs + // - std::cout << std::endl; - std::cout << " " << "Normal Runtime: " << result.runtime_ms << " ms" << std::endl; - std::cout << " " << "Normal GFLOPs: " << result.gflops << "\n"; + using NNTGemmNormalPermuteNone = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - return result; - } -}; + using NNTGemmNormalPermuteA = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; -/////////////////////////////////////////////////////////////////////////////////////////////////// + using NNTGemmNormalPermuteAD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; -int main(int argc, char const **args) { + using NNTGemmNormalPermuteB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using NNTGemmNormalPermuteBD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using NNTGemmNormalPermuteD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using NNTGemmNormalPermuteAB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using NNTGemmNormalPermuteABD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; // - // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // TTN (Row-major inputs, col-major output) GEMMs // - cudaDeviceProp props; + using TTNGemmNormalPermuteNone = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (error != cudaSuccess) { - std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return -1; - } + using TTNGemmNormalPermuteA = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { - - // - // This example requires an NVIDIA Ampere-architecture GPU. - // + using TTNGemmNormalPermuteAD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; - std::cout - << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " - << "later (compute capability 80 or greater).\n"; + using TTNGemmNormalPermuteB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - return 0; - } + using TTNGemmNormalPermuteBD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using TTNGemmNormalPermuteD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using TTNGemmNormalPermuteAB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using TTNGemmNormalPermuteABD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; // - // Parse options + // TTT (Row-major) BMMs // - Options options; - - options.parse(argc, args); + using TTTGemmBatchedPermuteA = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - if (options.help) { - options.print_usage(std::cout) << std::endl; - return 0; - } + using TTTGemmBatchedPermuteAD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; - if (options.error) { - std::cerr << "Aborting execution." << std::endl; - return -1; - } + using TTTGemmBatchedPermuteB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; - // - // Define the GEMM types - // + using TTTGemmBatchedPermuteBD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; - using ElementOutput = cutlass::half_t; - using ElementAccumulator = float; + using TTTGemmBatchedPermuteD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; + + using TTTGemmBatchedPermuteAB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::RowMajor; + using TTTGemmBatchedPermuteABD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; // - // Define a conventional batched GEMM type + // NNN (Col-major) BMMs // - // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 - using GemmBatched = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, LayoutA, - cutlass::half_t, LayoutB, - ElementOutput, LayoutC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - AlignmentC, //128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 4, - 8, /*alignmentA*/ - 8, /*alignmengB*/ - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - false, /*GatherA*/ - false, /*GatherB*/ - false, /*ScatterD*/ - cutlass::layout::Tensor4DPermuteBMM0213 /*PermuteDLayout*/ + using NNNGemmBatchedPermuteA = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; - // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 - using GemmPermute = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, LayoutA, - cutlass::half_t, LayoutB, - ElementOutput, LayoutC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - AlignmentC, //128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 4, - 8, /*alignmentA*/ - 8, /*alignmengB*/ - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - false, /*GatherA*/ - false, /*GatherB*/ - false, /*ScatterD*/ - cutlass::layout::Tensor5DPermute20314 /*PermuteDLayout*/ + using NNNGemmBatchedPermuteAD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor >; - // - // Profile it - // + using NNNGemmBatchedPermuteB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - Testbed testbed(options); + using NNNGemmBatchedPermuteBD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; - Result result; - result = testbed.profile_batched_kBatched(); - if (!result.passed) { - std::cout << "Profiling batched GEMM has failed.\n"; - std::cout << "\nFailed\n"; - } else { - std::cout << "\nPassed CUTLASS batched GEMM\n"; - } + using NNNGemmBatchedPermuteD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; - result = testbed.profile_GEMM_permute(); - if (!result.passed) { - std::cout << "Profiling normal GEMM has failed.\n"; - std::cout << "\nFailed\n"; - } else { - std::cout << "\nPassed CUTLASS normal GEMM\n"; - } + using NNNGemmBatchedPermuteAB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; - std::cout << "\n====================================================" << std::endl; - std::cout << "Finished\n"; - std::cout << "====================================================" << std::endl; + using NNNGemmBatchedPermuteABD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; + + // + // Profile it + // - return 0; + Testbed testbed(options); + + bool result = true; + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + std::cout << "\n" + "====================================================\n" + "Finished (" << (result ? "PASS" : "FAIL") << ")\n" + "====================================================" << std::endl; + + return result ? EXIT_SUCCESS : EXIT_FAILURE; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/39_gemm_permute/layouts.h b/examples/39_gemm_permute/layouts.h new file mode 100644 index 00000000..0a111137 --- /dev/null +++ b/examples/39_gemm_permute/layouts.h @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines additional layout functions used in Permute GEMM example to simplify + computing reference permutations of 4/5D tensors when source data is column-major. +*/ +#pragma once +#if defined(__CUDACC_RTC__) +#include +#else +#include "assert.h" +#endif +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/coord.h" +#include "cutlass/tensor_coord.h" + +namespace cutlass { +namespace layout { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D CWHN tensors. +class TensorCWHN { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, hn, whn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN( + typename Stride::Index stride_h, ///< number of elements between adjacent N coordinates + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c ///< number of elements between adjacent W coordinates + ): + stride_(make_Coord(stride_h, stride_w, stride_c)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorCWHN packed(TensorCoord const &extent) { + return TensorCWHN( + make_Coord( + extent.n(), + extent.h() * extent.n(), + extent.w() * extent.h() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.h()) + + LongIndex(stride_[1] * coord.w()) + + LongIndex(stride_[2] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.h() * stride_[0] > stride_[1]) + || (extent.w() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.c() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNHCW { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [w, cw, hcw] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW( + typename Stride::Index stride_c, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_c, stride_h, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNHCW(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNHCW packed(TensorCoord const &extent) { + return TensorNHCW( + make_Coord( + extent.w(), + extent.c() * extent.w(), + extent.h() * extent.c() * extent.w() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.w() + + LongIndex(stride_[0] * coord.c()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.w() > stride_[0]) + || (extent.c() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNCWH { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [h, wh, cwh] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH( + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_w, stride_c, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNCWH(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNCWH packed(TensorCoord const &extent) { + return TensorNCWH( + make_Coord( + extent.h(), + extent.w() * extent.h(), + extent.c() * extent.w() * extent.h() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.h() + + LongIndex(stride_[0] * coord.w()) + + LongIndex(stride_[1] * coord.c()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.h() > stride_[0]) + || (extent.w() * stride_[0] > stride_[1]) + || (extent.c() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 5-D CWHDN tensors. +class TensorCWHDN { +public: + /// Logical rank of tensor + static int const kRank = 5; + + /// Rank of stride vector + static int const kStrideRank = 4; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, d, h, w, c) + using TensorCoord = Tensor5DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, dn, hdn, whdn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN( + typename Stride::Index n, + typename Stride::Index dn, + typename Stride::Index hdn, + typename Stride::Index whdn): + stride_(make_Coord(n, dn, hdn, whdn)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHDN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2]), + static_cast(stride[3])) + ) { } + + /// Helper returns a layout to a tightly packed CWHDN tensor. + CUTLASS_HOST_DEVICE + static TensorCWHDN packed(TensorCoord const &extent) { + return TensorCWHDN( + make_Coord( + extent.n(), + extent.d() * extent.n(), + extent.h() * extent.d() * extent.n(), + extent.w() * extent.h() * extent.d() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.d()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.w()) + + LongIndex(stride_[3] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.d() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2]) + || (extent.w() * stride_[2] > stride_[3])) { + assert(0); + } + return extent.c() * stride_[3]; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layout +} // namespace cutlass diff --git a/examples/39_gemm_permute/permute_info.h b/examples/39_gemm_permute/permute_info.h new file mode 100644 index 00000000..99d21fb5 --- /dev/null +++ b/examples/39_gemm_permute/permute_info.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Contains additional metadata about layout permute functions used in the example. +*/ + +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/permute.h" + +/// Additional permutation metadata to facilitate testing/printing +template +struct PermuteInfo; + +/// Specialization for default case (no permute). Other specializations must follow this template. +template<> +struct PermuteInfo { + + /// Whether this is a BMM or GEMM permutation (NoPermute can actually be either) + static bool constexpr kBatched = false; + + /// Minimal divisor for row extent + static int constexpr kRowFactor = 1; + + /// Minimum divisor for column extent + static int constexpr kColumnFactor = 1; + + /// Minimum divisor for batch size dimension + static int constexpr kBatchFactor = 1; + + /// Tensor layout used in permutation operation + using Layout = cutlass::layout::PackedVectorLayout; + + static std::string name() { + return "NoPermute"; + } + + /// User-friendly description of the permute operation + static std::string desc() { + return "no permutation"; + } + + /// Infer original higher-rank tensor shape from GEMM/BMM matrix extents. + /// For direct (output) permutations, must be a simple reshape of extent. + /// For inverse (input) permutations, must return shape *before* permute operation. + /// In case of NoPermute, simply use a linear (rank 1) view of the memory + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + return Layout::TensorCoord(extent.row() * extent.column() * batch_count); + } + + /// Compute the permuted higher-rank tensor shape from the original shape. + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return s; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermuteBMM0213<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row(); + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHCW; + + static std::string name() { + return "Tensor4DPermuteBMM0321<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 3, 2, 1]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[3], s[2], s[1]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row() / D1; + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = D2; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermute0213<" + std::to_string(D1) + "," + std::to_string(D2) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D1; + int D3 = extent.column() / D2; + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D2; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D2; + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNDHWC; + + static std::string name() { + return "Tensor5DPermute20314<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [2, 0, 3, 1, 4]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[2], s[0], s[3], s[1], s[4]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorCWHDN; + + static std::string name() { + return "Tensor5DPermute02413<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 4, 1, 3]"; + } + + using Coord = cutlass::Tensor5DCoord; + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[0], s[2], s[4], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md index 4d13ea3c..d33a6d53 100644 --- a/examples/40_cutlass_py/README.md +++ b/examples/40_cutlass_py/README.md @@ -1,10 +1,15 @@ -# CUTLASS Python Interface Examples -This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples: +# PyCUTLASS Examples + +**NOTE:** This directory contains examples for PyCUTLASS, a Python library providing low-level +building blocks for emitting CUTLASS C++ kernels. For examples using CUTLASS's Pythonic interface, +see the [examples/python](/examples/python) directory. + +Two types of examples are provided: * _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations * [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel ## Setting up the Python interface -Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API. +Please follow the instructions [here](/python/README.md#installation) to set up the PyCUTLASS. ## Running examples Each of the basic examples can be run as follows: diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py index 89565bb5..a21f9769 100644 --- a/examples/40_cutlass_py/conv2d.py +++ b/examples/40_cutlass_py/conv2d.py @@ -38,10 +38,11 @@ import numpy as np import sys -import cutlass -import pycutlass -from pycutlass import * -from pycutlass.utils.device import device_cc +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +from cutlass.backend.utils.device import device_cc parser = argparse.ArgumentParser( @@ -76,11 +77,11 @@ pycutlass.compiler.nvcc() # Set up A, B, C and accumulator -A = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment) -B = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment) -C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment) -element_acc = cutlass.float32 -element_epilogue = cutlass.float32 +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.TensorNHWC, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 # Select instruction shape based on the Tensor Core instructions supported # by the device on which we are running @@ -89,12 +90,14 @@ elif cc == 75: instruction_shape = [16, 8, 8] else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 instruction_shape = [16, 8, 16] math_inst = MathInstruction( instruction_shape, A.element, B.element, element_acc, - cutlass.OpClass.TensorOp, + cutlass_bindings.OpClass.TensorOp, MathOperation.multiply_add ) @@ -108,8 +111,8 @@ epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, - iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.fprop, + iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=cc, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor @@ -125,20 +128,20 @@ # Randomly initialize tensors -problem_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(args.n, args.h, args.c, args.w), - cutlass.Tensor4DCoord(args.k, args.r, args.s, args.c), - cutlass.Tensor4DCoord(0, 0, 0, 0), # Padding - cutlass.MatrixCoord(1, 1), # Strides - cutlass.MatrixCoord(1, 1), # Dilation - cutlass.conv.Mode.cross_correlation, +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.n, args.h, args.c, args.w), + cutlass_bindings.Tensor4DCoord(args.k, args.r, args.s, args.c), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), # Padding + cutlass_bindings.MatrixCoord(1, 1), # Strides + cutlass_bindings.MatrixCoord(1, 1), # Dilation + cutlass_bindings.conv.Mode.cross_correlation, 1, # Split k slices 1 # Groups ) -tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size) -tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size) -tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size) +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size) +tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size) tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) diff --git a/examples/40_cutlass_py/customizable/conv2d.py b/examples/40_cutlass_py/customizable/conv2d.py index 4affff9a..6fb24944 100644 --- a/examples/40_cutlass_py/customizable/conv2d.py +++ b/examples/40_cutlass_py/customizable/conv2d.py @@ -30,11 +30,11 @@ # ################################################################################ import numpy as np -import pycutlass -from pycutlass import * -from pycutlass.conv2d_operation import * -from pycutlass.utils import reference_model -from pycutlass.utils.device import device_cc +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +from cutlass.backend.conv2d_operation import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule import sys import torch.nn.functional as F @@ -62,7 +62,7 @@ help='Data type of accumulator') parser.add_argument('-m', "--math", default="multiply_add", type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") -parser.add_argument('-op', "--opcode", default="simt", type=str, +parser.add_argument('-op', "--opcode", default="Simt", type=str, choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') @@ -156,12 +156,12 @@ np.random.seed(0) -element_a = getattr(cutlass, args.element_a) -element_b = getattr(cutlass, args.element_b) -element_c = getattr(cutlass, args.element_c) -element_acc = getattr(cutlass, args.element_acc) +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) math_operation = getattr(MathOperation, args.math) -opclass = getattr(cutlass.OpClass, args.opcode) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) math_inst = MathInstruction( args.instruction_shape, element_a, element_b, @@ -173,9 +173,9 @@ math_inst ) -layout_a = getattr(cutlass, args.layout_a) -layout_b = getattr(cutlass, args.layout_b) -layout_c = getattr(cutlass, args.layout_c) +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) A = TensorDescription( element_a, layout_a, args.alignment_a @@ -189,7 +189,7 @@ element_c, layout_c, args.alignment_c ) -element_epilogue = getattr(cutlass, args.element_epilogue) +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) if (args.activation_function == "identity" or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): # @@ -200,10 +200,10 @@ getattr(pycutlass, args.activation_function)(element_epilogue), C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm) -swizzling_functor = getattr(cutlass, args.swizzling_functor) +iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, args.iterator_algorithm) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) stride_support = getattr(StrideSupport, args.stride_support) -conv_kind = getattr(cutlass.conv.Operator, args.conv_kind) +conv_kind = getattr(cutlass_bindings.conv.Operator, args.conv_kind) operation = Conv2dOperation( conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, @@ -226,7 +226,7 @@ getattr(pycutlass, args.activation_function)(element_epilogue), C.element, C.alignment, math_inst.element_accumulator, element_epilogue) reduction_operation = ReductionOperation( - shape=cutlass.MatrixCoord(4, 32 * C.alignment), + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), C=C, element_accumulator=element_acc, element_compute=element_epilogue, epilogue_functor=epilogue_functor_reduction, @@ -236,34 +236,34 @@ pycutlass.compiler.add_module(operations) -problem_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), - cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), - cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), - cutlass.MatrixCoord(args.stride[0], args.stride[1]), - cutlass.MatrixCoord(args.dilation[0], args.dilation[1]), - cutlass.conv.Mode.cross_correlation, +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), + cutlass_bindings.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), + cutlass_bindings.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), + cutlass_bindings.MatrixCoord(args.stride[0], args.stride[1]), + cutlass_bindings.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass_bindings.conv.Mode.cross_correlation, args.split_k_slices, 1 ) # User-provide inputs -tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size( +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size( conv_kind, problem_size ) -tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size( +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size( conv_kind, problem_size ) if args.bias: - tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent( + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( conv_kind, problem_size ).at(3) else: - tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size( + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( conv_kind, problem_size ) -tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size( +tensor_D_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( conv_kind, problem_size ) @@ -288,12 +288,12 @@ operation=operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), - split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode), + split_k_mode=getattr(cutlass_bindings.conv.SplitKMode, args.split_k_mode), split_k_slices=problem_size.split_k_slices ) if args.split_k_mode == "Parallel" and args.split_k_slices > 1: - implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) + implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) reduction_arguments = ReductionArguments( reduction_operation, problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], diff --git a/examples/40_cutlass_py/customizable/gemm.py b/examples/40_cutlass_py/customizable/gemm.py index 039c871a..745f6aac 100644 --- a/examples/40_cutlass_py/customizable/gemm.py +++ b/examples/40_cutlass_py/customizable/gemm.py @@ -30,10 +30,10 @@ # ################################################################################ import numpy as np -import pycutlass -from pycutlass import * -from pycutlass.utils.device import device_cc -import cutlass +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import cutlass_bindings from bfloat16 import bfloat16 import sys @@ -62,7 +62,7 @@ help='Data type of accumulator') parser.add_argument('-m', "--math", default="multiply_add", type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") -parser.add_argument('-op', "--opcode", default="simt", type=str, +parser.add_argument('-op', "--opcode", default="Simt", type=str, choices=["Simt", 'TensorOp'], help="This option describes whether you want to use tensor \ cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") @@ -147,12 +147,12 @@ np.random.seed(0) -element_a = getattr(cutlass, args.element_a) -element_b = getattr(cutlass, args.element_b) -element_c = getattr(cutlass, args.element_c) -element_acc = getattr(cutlass, args.element_acc) +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) math_operation = getattr(MathOperation, args.math) -opclass = getattr(cutlass.OpClass, args.opcode) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) math_inst = MathInstruction( args.instruction_shape, element_a, element_b, @@ -164,9 +164,9 @@ math_inst ) -layout_a = getattr(cutlass, args.layout_a) -layout_b = getattr(cutlass, args.layout_b) -layout_c = getattr(cutlass, args.layout_c) +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) A = TensorDescription( element_a, layout_a, args.alignment_a @@ -180,7 +180,7 @@ element_c, layout_c, args.alignment_c ) -element_epilogue = getattr(cutlass, args.element_epilogue) +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) if (args.activation_function == "identity" or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): # @@ -191,7 +191,7 @@ getattr(pycutlass, args.activation_function)(element_epilogue), C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -swizzling_functor = getattr(cutlass, args.swizzling_functor) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) visitor = args.epilogue_visitor is not None @@ -275,7 +275,7 @@ def __call__( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) reduction_operation = ReductionOperation( - shape=cutlass.MatrixCoord(4, 32 * C.alignment), + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), C=C, element_accumulator=element_acc, element_compute=element_epilogue, epilogue_functor=epilogue_functor_reduction, @@ -287,7 +287,7 @@ def __call__( # User-provide inputs -problem_size = cutlass.gemm.GemmCoord( +problem_size = cutlass_bindings.gemm.GemmCoord( args.problem_size[0], args.problem_size[1], args.problem_size[2]) tensor_a_size = args.batch * problem_size.m() * problem_size.k() @@ -384,7 +384,7 @@ def __call__( operation=operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, output_op=output_op, - gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode), + gemm_mode=getattr(cutlass_bindings.gemm.Mode, args.gemm_mode), split_k_slices=args.split_k_slices, batch=args.batch ) diff --git a/examples/40_cutlass_py/customizable/gemm_grouped.py b/examples/40_cutlass_py/customizable/gemm_grouped.py index 5a7551dc..0cecb328 100644 --- a/examples/40_cutlass_py/customizable/gemm_grouped.py +++ b/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -30,9 +30,9 @@ # ################################################################################ import numpy as np -import pycutlass -from pycutlass import * -from pycutlass.utils.device import device_cc +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc import csv import sys @@ -61,7 +61,7 @@ help='Data type of accumulator') parser.add_argument('-m', "--math", default="multiply_add", type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") -parser.add_argument('-op', "--opcode", default="simt", type=str, +parser.add_argument('-op', "--opcode", default="Simt", type=str, choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') # tile description @@ -111,7 +111,7 @@ default="Device", type=str, choices=["Host", "Device"], help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") # arguments -parser.add_argument("-p", "--problem_size_dir", type=str, +parser.add_argument("-p", "--problem_size_dir", type=str, default="grouped_gemm_problem_size.csv", help="path to the csv file contains the problem sizes") parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") @@ -139,12 +139,12 @@ np.random.seed(0) -element_a = getattr(cutlass, args.element_a) -element_b = getattr(cutlass, args.element_b) -element_c = getattr(cutlass, args.element_c) -element_acc = getattr(cutlass, args.element_acc) +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) math_operation = getattr(MathOperation, args.math) -opclass = getattr(cutlass.OpClass, args.opcode) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) math_inst = MathInstruction( args.instruction_shape, element_a, element_b, @@ -156,9 +156,9 @@ math_inst ) -layout_a = getattr(cutlass, args.layout_a) -layout_b = getattr(cutlass, args.layout_b) -layout_c = getattr(cutlass, args.layout_c) +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) A = TensorDescription( element_a, layout_a, args.alignment_a @@ -172,7 +172,7 @@ element_c, layout_c, args.alignment_c ) -element_epilogue = getattr(cutlass, args.element_epilogue) +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) if args.activation_function == "identity": epilogue_functor = getattr(pycutlass, args.epilogue_functor)( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) @@ -180,7 +180,7 @@ epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( getattr(pycutlass, args.activation_function)(element_epilogue), C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -swizzling_functor = getattr(cutlass, args.swizzling_functor) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) precompute_mode = getattr(SchedulerMode, args.precompute_mode) operation = GemmOperationGrouped( @@ -203,7 +203,7 @@ reader = csv.reader(csv_file) for row in reader: problem_sizes.append( - cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) + cutlass_bindings.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) ) problem_count = len(problem_sizes) diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py index db72d264..17b5d389 100644 --- a/examples/40_cutlass_py/gemm.py +++ b/examples/40_cutlass_py/gemm.py @@ -37,10 +37,10 @@ import numpy as np import sys -import cutlass -import pycutlass -from pycutlass import * -from pycutlass.utils.device import device_cc +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'") @@ -72,11 +72,11 @@ pycutlass.compiler.nvcc() # Set up A, B, C and accumulator -A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment) -B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment) -C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment) -element_acc = cutlass.float32 -element_epilogue = cutlass.float32 +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 # Select instruction shape based on the Tensor Core instructions supported # by the device on which we are running @@ -85,12 +85,14 @@ elif cc == 75: instruction_shape = [16, 8, 8] else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 instruction_shape = [16, 8, 16] math_inst = MathInstruction( instruction_shape, A.element, B.element, element_acc, - cutlass.OpClass.TensorOp, + cutlass_bindings.OpClass.TensorOp, MathOperation.multiply_add ) @@ -122,7 +124,7 @@ tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32) tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32) -problem_size = cutlass.gemm.GemmCoord(args.m, args.n, args.k) +problem_size = cutlass_bindings.gemm.GemmCoord(args.m, args.n, args.k) alpha = 1. beta = 0. diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py index df23454f..16e25d0c 100644 --- a/examples/40_cutlass_py/gemm_grouped.py +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -37,10 +37,10 @@ import numpy as np import sys -import cutlass -import pycutlass -from pycutlass import * -from pycutlass.utils.device import device_cc +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python") @@ -65,11 +65,11 @@ # Set up A, B, C and accumulator alignment = 1 -A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment) -B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment) -C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment) -element_acc = cutlass.float32 -element_epilogue = cutlass.float32 +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 # Select instruction shape based on the Tensor Core instructions supported # by the device on which we are running @@ -78,12 +78,14 @@ elif cc == 75: instruction_shape = [16, 8, 8] else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 instruction_shape = [16, 8, 16] math_inst = MathInstruction( instruction_shape, A.element, B.element, element_acc, - cutlass.OpClass.TensorOp, + cutlass_bindings.OpClass.TensorOp, MathOperation.multiply_add ) @@ -112,8 +114,8 @@ # Initialize tensors for each problem in the group problem_sizes = [ - cutlass.gemm.GemmCoord(128, 128, 64), - cutlass.gemm.GemmCoord(512, 256, 128) + cutlass_bindings.gemm.GemmCoord(128, 128, 64), + cutlass_bindings.gemm.GemmCoord(512, 256, 128) ] problem_count = len(problem_sizes) diff --git a/examples/45_dual_gemm/device/dual_gemm.h b/examples/45_dual_gemm/device/dual_gemm.h index 71f7973e..4939edc3 100644 --- a/examples/45_dual_gemm/device/dual_gemm.h +++ b/examples/45_dual_gemm/device/dual_gemm.h @@ -159,7 +159,7 @@ class DualGemm { using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, - ThreadblockShape, WarpShape, + ThreadblockShape, WarpShape, InstructionShape, Stages, Operator>::ThreadblockMma; using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, @@ -348,7 +348,7 @@ class DualGemm { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); diff --git a/examples/45_dual_gemm/dual_gemm.cu b/examples/45_dual_gemm/dual_gemm.cu index 75ef1502..ce7db2a1 100644 --- a/examples/45_dual_gemm/dual_gemm.cu +++ b/examples/45_dual_gemm/dual_gemm.cu @@ -167,10 +167,10 @@ bool run_nonfused_gemm_f16_sm80() { std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; bool pass = nonFusedGemm.run( - problem_size, - alpha0, - beta0, - alpha1, + problem_size, + alpha0, + beta0, + alpha1, beta1, true /* is_profiling */ ); @@ -248,10 +248,10 @@ bool run_fused_gemm_f16_sm80_shmem() { std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; bool passed = fusedGemm.run( - problem_size, - alpha0, - beta0, - alpha1, + problem_size, + alpha0, + beta0, + alpha1, beta1 ); @@ -301,11 +301,11 @@ bool run_batched_fused_gemm_f16_sm80_shmem() { std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n"; bool passed = fusedGemm.run( - batch_problem_size, - alpha0, - beta0, - alpha1, - beta1, + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, kBatchCount, false, /* broadcast_b1 */ false /* is_profiling */ @@ -358,11 +358,11 @@ bool run_broadcast_fused_gemm_f16_sm80_shmem() { std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; bool passed = fusedGemm.run( - problem_size, - alpha0, - beta0, - alpha1, - beta1, + problem_size, + alpha0, + beta0, + alpha1, + beta1, 1, /* batch_count */ true, /* broadcast_b1 */ true /* is_profiling */ @@ -415,11 +415,11 @@ bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() { std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; bool passed = fusedGemm.run( - batch_problem_size, - alpha0, - beta0, - alpha1, - beta1, + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, kBatchCount, true, /* broadcast_b1 */ false /* is_profiling */ @@ -444,11 +444,11 @@ int main() { }; std::string test_name = ( - "dual-gemm f16 bias=" + - std::to_string(kUseBias) + - " split_k_serial=" + + "dual-gemm f16 bias=" + + std::to_string(kUseBias) + + " split_k_serial=" + std::to_string(kSplitKSerial) + - " batch_count=" + + " batch_count=" + std::to_string(kBatchCount) ); diff --git a/examples/45_dual_gemm/dual_gemm_run.h b/examples/45_dual_gemm/dual_gemm_run.h index d6a52d58..bdaa6983 100644 --- a/examples/45_dual_gemm/dual_gemm_run.h +++ b/examples/45_dual_gemm/dual_gemm_run.h @@ -45,6 +45,7 @@ #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/platform/platform.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" @@ -356,13 +357,13 @@ struct NonFusedDualGemmRun for(int i = 0; i < runs; i++) { status = gemm_op_0(); - + CUTLASS_CHECK(status); } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); - + CUTLASS_CHECK(status); } @@ -564,22 +565,22 @@ struct DualFusedGemmRun cutlass::HostTensor< typename DualGemm::ElementA, typename DualGemm::LayoutA> tensor_A0( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); cutlass::HostTensor< typename DualGemm::ElementB, typename DualGemm::LayoutB0> tensor_B0( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_C0( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< @@ -589,22 +590,22 @@ struct DualFusedGemmRun cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_D0( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> reference_D0( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementB, typename DualGemm::LayoutB1> tensor_B1( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); if (broadcast_b1) { tensor_B1.resize({problem_size.k(), batch_count}); @@ -613,8 +614,8 @@ struct DualFusedGemmRun cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_C1( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< @@ -624,29 +625,29 @@ struct DualFusedGemmRun cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_D1( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_D2( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> reference_D1( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> reference_D2( - std::is_same::value ? - cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); @@ -712,16 +713,16 @@ struct DualFusedGemmRun ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; } typename DualGemm::Arguments arguments{ - (batch_count > 1 ? - cutlass::gemm::DualGemmMode::kBatched : + (batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : cutlass::gemm::DualGemmMode::kGemm), problem_size, tensor_A0.device_ref(), tensor_B0.device_ref(), ref_B0, DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, - (broadcast_b1 ? - typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : + (broadcast_b1 ? + typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : tensor_B1.device_ref()), ref_B1, DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, @@ -793,15 +794,15 @@ struct DualFusedGemmRun using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< typename DualGemm::ElementA, typename DualGemm::LayoutA, typename DualGemm::ElementB, typename DualGemm::LayoutB0, - typename DualGemm::ElementC, typename DualGemm::LayoutC, + typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementAccumulator >; GemmUniversal0 reference_gemm0; typename GemmUniversal0::Arguments args0 { - (batch_count > 1 ? - cutlass::gemm::GemmUniversalMode::kBatched : + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), problem_size, batch_count, @@ -828,15 +829,15 @@ struct DualFusedGemmRun using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< typename DualGemm::ElementA, typename DualGemm::LayoutA, typename DualGemm::ElementB, typename DualGemm::LayoutB1, - typename DualGemm::ElementC, typename DualGemm::LayoutC, + typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementAccumulator >; GemmUniversal1 reference_gemm1; typename GemmUniversal1::Arguments args1 { - (batch_count > 1 ? - cutlass::gemm::GemmUniversalMode::kBatched : + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), problem_size, batch_count, @@ -861,7 +862,7 @@ struct DualFusedGemmRun CUTLASS_CHECK(status); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view()); } diff --git a/examples/45_dual_gemm/kernel/dual_gemm.h b/examples/45_dual_gemm/kernel/dual_gemm.h index 56ed9e7e..f0ad97db 100644 --- a/examples/45_dual_gemm/kernel/dual_gemm.h +++ b/examples/45_dual_gemm/kernel/dual_gemm.h @@ -300,7 +300,7 @@ struct DualGemm { int offset_k = 0; int problem_size_k = params.problem_size.k(); - ElementA *ptr_A0 = static_cast(params.ref_A0.data()); + ElementA *ptr_A0 = static_cast(params.ref_A0.data()); ElementB *ptr_B0 = static_cast(params.ref_B0.data()); ElementB *ptr_B1 = static_cast(params.ref_B1.data()); @@ -309,7 +309,7 @@ struct DualGemm { // if (params.mode == DualGemmMode::kGemm) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; @@ -413,11 +413,11 @@ struct DualGemm { int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C0 = static_cast(params.ref_C0.data()); - ElementC *ptr_C1 = static_cast(params.ref_C1.data()); - ElementC *ptr_D0 = static_cast(params.ref_D0.data()); - ElementC *ptr_D1 = static_cast(params.ref_D1.data()); - ElementC *ptr_D2 = static_cast(params.ref_D2.data()); + ElementC *ptr_C0 = static_cast(params.ref_C0.data()); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D0 = static_cast(params.ref_D0.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + ElementC *ptr_D2 = static_cast(params.ref_D2.data()); // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); @@ -425,7 +425,7 @@ struct DualGemm { if (params.mode == DualGemmMode::kGemm) { // If performing a reduction via split-K, fetch the initial synchronization if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - + // Fetch the synchronization lock initially but do not block. semaphore.fetch(); diff --git a/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu index 9a26e896..36bf775c 100644 --- a/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu +++ b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu @@ -233,6 +233,17 @@ struct Options { return false; } + // Filter size passed through command line does not match filter size template parameter + if (filter_size.h() != FilterShape::kRow || filter_size.w() != FilterShape::kColumn) { + std::cerr << "Filter size passed in (" << filter_size.h() << "x" << filter_size.w() << ") " + << "must match the FilterShape template parameter of the convolution " + << "(" << FilterShape::kRow << "x" << FilterShape::kColumn << "). " + << "To use the filter shape passed in, change the FilterShape template " + << "parameter and recompile this example." + << std::endl; + return false; + } + return true; } @@ -319,9 +330,9 @@ struct Options { "table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 " + << "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=32 " "--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n" - << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 " + << "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=1 " "--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n"; return out; @@ -515,14 +526,13 @@ Result profile_convolution(Options const &options) { ElementOutput, LayoutOutput, ElementComputeEpilogue, - ElementAccumulator, - cutlass::NumericConverter >(problem_size, - tensor_a.host_ref(), - tensor_b.host_ref(), - tensor_c.host_ref(), - tensor_ref_d.host_ref(), - options.alpha, - options.beta); + ElementAccumulator >(problem_size, + tensor_a.host_ref(), + tensor_b.host_ref(), + tensor_c.host_ref(), + tensor_ref_d.host_ref(), + options.alpha, + options.beta); // Check if output from CUTLASS kernel and reference kernel are equal or not tensor_d.sync_host(); diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index 599d1d50..2c7d0ba9 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -60,6 +60,7 @@ #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" @@ -95,12 +96,13 @@ constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // M // C/D matrix configuration using ElementC = float; // Element type for C and D matrix operands using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TilesShape = Shape<_128,_128,_32>; // Threadblock-level tile size +using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder @@ -110,15 +112,20 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, - TilesShape, ClusterShape, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; -using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, // Indicates ProblemShape @@ -308,11 +315,8 @@ typename Gemm::Arguments args_from_options(const Options &options) typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k}, - block_A.get(), - stride_A, - block_B.get(), - stride_B, - {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}} + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; return arguments; diff --git a/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu similarity index 78% rename from examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu rename to examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index ccf74a65..001e8329 100644 --- a/examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -77,10 +77,27 @@ will fit in shared memory given the types of operands and the thread block shape, rather than simply using a single default value. - Note that one does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide - every template parameter to the gemm::collective::CollectiveMma. Specifying every template parameter in this - manner remains the primary API for using CUTLASS 3 kernels. The CollectiveBuilder is simply meant to be - a convenience interface. + CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of + the collective is specified via the schedule tags that corresond to the underlying collective's + dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto` + are special cases of these schedules that allow the builder to also decide the dispatch policy for you, + therefore letting the builder pick the collective specialization. + + CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the + assembled collctives have the best performance, but this is not a guarantee. A user relying on `Auto` + may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized + implementations that the builder can transparently assemble for `Auto`. + + If a user decides to let the builders pick the collective specialization via `Auto` schedules, + they must be used for both mainloop and epilogue alike to ensure compatibility between the + chosen collectives. Additionally, if a user chooses to opt in to a specific schedule, non-`Auto` + schedules must be used for both mainloop and epilogue builder schedules, and these schedules + must be compatible. + + One does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide + every template parameter to the `gemm::collective::CollectiveMma`. Specifying every template parameter + in this manner remains the primary API for using CUTLASS 3 kernels. `CollectiveBuilder`s are + simply meant to be a convenience interface. Note also that, while the selections made by CollectiveBuilder attempt to maximize performance, this is not a guarantee. Furthermore, the behavior of the CollectiveBuilder when `Auto` parameters are provided is subject @@ -94,7 +111,7 @@ extending the problem size with an additional tensor rank. Example usage: - $ ./examples/49_hopper_gemm_schedules_with_collective_builder/49_hopper_gemm_schedules_with_collective_builder \ + $ ./examples/49_hopper_with_collective_builder/49_collective_builder \ --m=2048 --n=2048 --k=2048 --l=2 */ @@ -108,6 +125,7 @@ #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" @@ -160,7 +178,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "49_hopper_gemm_schedules_with_collective_builder\n\n" + out << "49_hopper_with_collective_builder\n\n" << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" << " performant kernels targeting NVIDIA's Hopper architecture.\n\n" << "Options:\n\n" @@ -212,14 +230,24 @@ bool initialize_block( // operation builders by specializing the GEMM only on the kernel schedule it will use and the // number of pipeline stages. // -// For either option, one can use a special `Auto` type that tells the CollectiveBuilder +// One can use a special `Auto` type that tells the CollectiveBuilder // to select an appropriate value on its own. The CollectiveBuilder will attempt to select -// values that will result in the most-performant kernel, but this is not a guarantee. Furthermore, -// the behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases +// configurations that will result in the most-performant kernel, but this is not a guarantee. +// +// If relying on 'Auto' schedules, all builders must use the 'Auto' schedule to ensure compatiblity. +// For example, if `KernelScheduleAuto` is used for the mainloop builder, `EpilogueScheduleAuto` must +// be used for the epilogue builder. +// +// Furthermore, if an override schedule is selected, both epilgoue and mainloop schedules must +// be specifically opt into a compatible selection. +// +// Behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases // -- do not rely on `Auto` if you require a specific scheduling policy. template < // Type of kernel schedule to generate - class KernelScheduleType = cutlass::gemm::collective::KernelScheduleAuto, + class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto, + // Type of epilogue schedule to generate + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, // Number of pipeline stages to use class StageCountType = cutlass::gemm::collective::StageCountAuto > @@ -230,23 +258,33 @@ struct ExampleRunner { using LayoutC = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::ColumnMajor; - static constexpr int kAlignmentA = 8; - static constexpr int kAlignmentB = 8; + static constexpr int AlignmentA = 8; + static constexpr int AlignmentB = 8; + static constexpr int AlignmentC = 8; + static constexpr int AlignmentD = 8; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, AlignmentC, + cutlass::half_t, LayoutD, AlignmentD, + EpilogueScheduleType + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, kAlignmentA, - cutlass::half_t, LayoutB, kAlignmentB, + cutlass::half_t, LayoutA, AlignmentA, + cutlass::half_t, LayoutB, AlignmentB, float, Shape<_128,_128,_64>, Shape<_2,_1,_1>, - StageCountType, - KernelScheduleType + std::conditional_t, + cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>, + StageCountType>, + MainloopScheduleType >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, @@ -262,10 +300,10 @@ struct ExampleRunner { using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; - using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); - using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); - using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); - using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t; + using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t; // // Data members @@ -356,11 +394,8 @@ struct ExampleRunner { typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - block_A.get(), - stride_A, - block_B.get(), - stride_B, - {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; @@ -477,42 +512,48 @@ int main(int argc, char const **args) { // selected and the maximum number of stages that can fit in shared memory will be selected. // // This example is equivalent to declaring - // ExampleRunner + // ExampleRunner< + // cutlass::gemm::collective::KernelScheduleAuto, + // cutlass::epilogue::collective::EpilogueScheduleAuto, + // cutlass::gemm::collective::StageCountAuto> // Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and // stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change // -- do not rely on `Auto` if you require a specific scheduling policy. + // If you opt in to a non-'Auto' schedule, make sure all collectives are built using specific, compatible schedules. ExampleRunner<> auto_schedule_auto_stage_runner; passed = auto_schedule_auto_stage_runner.run(options, hw_info); print_result("Automatically-selected schedule and stage count", passed); // One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto // with the number of stages to use (5 in this case). - ExampleRunner auto_schedule_5_stage_runner; + ExampleRunner< + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto, + _5> auto_schedule_5_stage_runner; + passed = auto_schedule_5_stage_runner.run(options, hw_info); print_result("Automatically-selected schedule with 5 stages", passed); // One can also override the scheduling policy to use. In this case, use the KernelTma scheduling - // policy, which specifies that the Hopper TMA feature should be used. - ExampleRunner tma_schedule_auto_stage_runner; + // policy, which specifies that the Hopper TMA feature should be used, and we also use an epilgoue + // that does not use any shared memory. + ExampleRunner tma_schedule_auto_stage_runner; passed = tma_schedule_auto_stage_runner.run(options, hw_info); print_result("TMA schedule with automatically-selected stage count", passed); // Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized - // scheduling policy. - // - // Note that, as of the CUTLASS 3.0 release, this is the default scheduling policy - // used by the CollectiveBuilder, so this declaration is equivalent to ExampleRunner<> and - // ExampleRunner. However, this default is subject to - // change in future releases -- do not rely on `Auto` if you require a specific scheduling policy. - ExampleRunner ws_schedule_auto_stage_runner; + // scheduling policy, and an epilgoue that does not use any shared memory. + ExampleRunner ws_schedule_auto_stage_runner; passed = ws_schedule_auto_stage_runner.run(options, hw_info); print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed); // Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized - // scheduling policy, leveraging persistent thread blocks. - ExampleRunner ws_persistent_schedule_auto_stage_runner; - passed = ws_persistent_schedule_auto_stage_runner.run(options, hw_info); - print_result("Persistent warp-specialized TMA schedule with automatically-selected stage count", passed); + // scheduling policy, TMA-based epilogue, leveraging persistent thread blocks. + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecializedPingpong, + cutlass::epilogue::TmaWarpSpecialized> ws_pingpong_schedule_auto_stage_runner; + passed = ws_pingpong_schedule_auto_stage_runner.run(options, hw_info); + print_result("Ping-pong warp-specialized TMA schedule with automatically-selected stage count", passed); #endif diff --git a/examples/49_hopper_gemm_schedules_with_collective_builder/CMakeLists.txt b/examples/49_hopper_gemm_with_collective_builder/CMakeLists.txt similarity index 93% rename from examples/49_hopper_gemm_schedules_with_collective_builder/CMakeLists.txt rename to examples/49_hopper_gemm_with_collective_builder/CMakeLists.txt index 30c6e5ea..53518d62 100644 --- a/examples/49_hopper_gemm_schedules_with_collective_builder/CMakeLists.txt +++ b/examples/49_hopper_gemm_with_collective_builder/CMakeLists.txt @@ -27,9 +27,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - +# Both filenames are shorter to avoid MAX_PATH issues on Windows. cutlass_example_add_executable( - 49_hopper_gemm_schedules_with_collective_builder - 49_hopper_gemm_schedules_with_collective_builder.cu + 49_collective_builder + 49_collective_builder.cu ) diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu index 7323cc39..f1595e5d 100644 --- a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -34,7 +34,7 @@ The following example shows how to assemble a custom GEMM kernel that spells out the Collectives directly instead of using a builder and, in the process, instance a more efficient Epilogue - (from `cutlass/epilogue/collective/epilogue.hpp`) instead of using the default epilogue. + (from `cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp`) instead of using the default epilogue. The GemmUniversal API takes 3 main template arguments: (1) the problem shape / extents @@ -65,7 +65,7 @@ #include "cute/tensor.hpp" #include "cutlass/util/command_line.h" #include "cutlass/tensor_ref.h" -#include "cutlass/epilogue/collective/epilogue.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" @@ -122,7 +122,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "50_hopper_gemm_with_vectorized_epilogue\n\n" + out << "50_hopper_gemm_with_epilogue_swizzle\n\n" << "Hopper GEMM Example with Epilogue Swizzle.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement\n\n" @@ -286,11 +286,8 @@ struct ExampleRunner { typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - block_A.get(), - stride_A, - block_B.get(), - stride_B, - {block_C.get(), stride_C, block_D.get(), stride_D, {options.alpha, options.beta}}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; @@ -443,11 +440,11 @@ int main(int argc, char const **args) { cute::SM90_TMA_LOAD, cute::SM90_TMA_LOAD_MULTICAST>::type; - using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{})) >()); - using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{})) >()); @@ -494,14 +491,15 @@ int main(int argc, char const **args) { Stride<_16,_1>>, TileShapeS2R>; - using Epilogue = cutlass::epilogue::collective::Epilogue< + using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, SmemLayout, Copy_Atom, TiledCopyS2R, - Copy_Atom>; + Copy_Atom>>; // // Assembling the GemmKernel diff --git a/examples/51_hopper_gett/gett_kernel.cuh b/examples/51_hopper_gett/gett_kernel.cuh index aa6b8357..8256a771 100644 --- a/examples/51_hopper_gett/gett_kernel.cuh +++ b/examples/51_hopper_gett/gett_kernel.cuh @@ -37,7 +37,7 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" namespace example { @@ -88,10 +88,12 @@ gett_kernel( cutlass::FloatRoundStyle::round_to_nearest, ElementC>; // No changes are required to the default epilogue - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< StrideC, StrideD, - EpilogueThreadOp>; + EpilogueThreadOp, + cutlass::gemm::EpilogueDefault>>; // CollectiveMma for GETTs can be built using the CollectiveBuilders using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -100,7 +102,7 @@ gett_kernel( ElementB, StrideB, 128 / cutlass::sizeof_bits::value, ElementAccumulator, TileShape, Shape<_1,_2,_1>, - cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::StageCountAutoCarveout, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; @@ -115,9 +117,8 @@ gett_kernel( typename GettOperator::Arguments args { cutlass::gemm::GemmUniversalMode::kBatched, problem_shape_mnkl, - ptr_A, stride_a_mkl, - ptr_B, stride_b_nkl, - { ptr_C, stride_c_mnl, ptr_D, stride_d_mnl, {alpha, beta} } + { ptr_A, stride_a_mkl, ptr_B, stride_b_nkl }, + { {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl } }; #if CUTLASS_DEBUG_TRACE_LEVEL > 0 diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index fe884a5b..aea1a89f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -129,7 +129,7 @@ foreach(EXAMPLE 46_depthwise_simt_conv2dfprop 47_ampere_gemm_universal_streamk 48_hopper_warp_specialized_gemm - 49_hopper_gemm_schedules_with_collective_builder + 49_hopper_gemm_with_collective_builder 50_hopper_gemm_with_epilogue_swizzle 51_hopper_gett ) diff --git a/examples/common/helper.h b/examples/common/helper.h index b6357b24..61c63dc2 100644 --- a/examples/common/helper.h +++ b/examples/common/helper.h @@ -31,6 +31,7 @@ #pragma once #include "cuda_runtime.h" +#include /** * Panic wrapper for unwinding CUTLASS errors diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index 97867ded..d27035fd 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -31,4 +31,3 @@ cutlass_example_add_executable( sgemm_nt_1 sgemm_nt_1.cu ) - diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb new file mode 100644 index 00000000..f69f4d6e --- /dev/null +++ b/examples/python/00_basic_gemm.ipynb @@ -0,0 +1,340 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1ef96b3f", + "metadata": {}, + "source": [ + "# Basic example of using the CUTLASS Python interface\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n" + ] + }, + { + "cell_type": "markdown", + "id": "962324fd", + "metadata": {}, + "source": [ + "We first import various packages needed for the example and construct the input and output tensors that will be used in our example.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e324219", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import random\n", + "\n", + "import cutlass\n", + "\n", + "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", + "# omit this information.\n", + "print_module = True\n", + "\n", + "m = 128\n", + "n = m\n", + "k = m\n", + "\n", + "dtype = np.float16\n", + "type_A = np.float16\n", + "type_B = np.float16\n", + "type_C = np.float16\n", + "type_D = np.float16\n", + "\n", + "np.random.seed(1234)\n", + "random.seed(1234)\n", + "scope_min = -4\n", + "scope_max = 4\n", + "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", + "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", + "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", + "\n", + "alpha = np.float16(1.)\n", + "beta = np.float16(0.)\n", + "\n", + "tensor_D = np.zeros(tensor_C.shape).astype(type_D)" + ] + }, + { + "cell_type": "markdown", + "id": "f2c7bf48", + "metadata": {}, + "source": [ + "## Declaring and running a GEMM\n", + "To get started, one only needs to provide the tensors declared above to the `cutlass.op.Gemm` call.\n", + "This sets up a default GEMM operation for the given device on which you are running.\n", + "\n", + "Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.\n", + "\n", + "Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dfd8975", + "metadata": {}, + "outputs": [], + "source": [ + "# We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,\n", + "# specifying `element_accumulator` is not required if it is the same as `element`\n", + "plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor, element_accumulator=np.float32)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "4a5856de", + "metadata": {}, + "source": [ + "There are many other ways to construct a plan from `cutlass.op.Gemm` (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the `cutlass.op.Gemm` constructor." + ] + }, + { + "cell_type": "markdown", + "id": "945478ef", + "metadata": {}, + "source": [ + "We then compare the output to running the GEMM using NumPy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b669de6", + "metadata": {}, + "outputs": [], + "source": [ + "tensor_D_numpy = (alpha * (tensor_A @ tensor_B)) + (beta * tensor_C)\n", + "np.testing.assert_array_equal(tensor_D, tensor_D_numpy)" + ] + }, + { + "cell_type": "markdown", + "id": "ee5cbbbe", + "metadata": {}, + "source": [ + "Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy." + ] + }, + { + "cell_type": "markdown", + "id": "b6c86493", + "metadata": {}, + "source": [ + "## Changing operation modes\n", + "By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to `cutlass.op.Gemm` is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.\n", + "\n", + "The operation mode currently in use can be returned via the `plan.opclass` property. In this case Tensor Core operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "529fda93", + "metadata": {}, + "outputs": [], + "source": [ + "print(plan.opclass)" + ] + }, + { + "cell_type": "markdown", + "id": "6d27c575", + "metadata": {}, + "source": [ + "Suppose that we don't want to use Tensor Cores for this GEMM. One can change to using CUTLASS's SIMT GEMMs by setting the plan's `opclass` field.\n", + "\n", + "As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASS's SIMT GEMMs.\n", + "\n", + "Also notice that, this time around, we provided tensor parameters to `plan.run()`. One is free to provide different parameters to `plan.run()` than were passed in at the initial call to `cutlass.op.Gemm`, provided that the passed-in tensors have the same data type and layout as those passed in on intialization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a44d35b", + "metadata": {}, + "outputs": [], + "source": [ + "tensor_D_simt = np.zeros(tensor_C.shape).astype(type_D)\n", + "plan.opclass = cutlass.OpcodeClass.Simt\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_simt, alpha, beta, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "639dcb59", + "metadata": {}, + "source": [ + "If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b480853", + "metadata": {}, + "outputs": [], + "source": [ + "np.testing.assert_array_equal(tensor_D, tensor_D_simt)" + ] + }, + { + "cell_type": "markdown", + "id": "0cce1eae", + "metadata": {}, + "source": [ + "## Running cached kernels\n", + "You may have noticed that the `plan.run()` calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.\n", + "\n", + "CUTLASS caches compiled binaries so that recompilation isn't necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call `plan.run()` again (with a different set of tensor parameters), you'll find the call to return much faster." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8051e5e", + "metadata": {}, + "outputs": [], + "source": [ + "m = 2400\n", + "n = 3232\n", + "k = 4096\n", + "\n", + "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", + "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", + "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", + "tensor_D = np.zeros(tensor_C.shape).astype(type_D)\n", + "\n", + "alpha = np.float16(1.)\n", + "beta = np.float16(2.)\n", + "\n", + "plan.opclass = cutlass.OpcodeClass.TensorOp\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "52a4e318", + "metadata": {}, + "source": [ + "## Running non-default GEMMs\n", + "The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?\n", + "\n", + "Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c593be1", + "metadata": {}, + "outputs": [], + "source": [ + "tiles = plan.tile_descriptions()\n", + "print('{} tile descriptions returned'.format(len(tiles)))\n", + "num_print = 10\n", + "print('First {} tile descriptions are:'.format(num_print))\n", + "for td in tiles[:num_print]:\n", + " print(td)" + ] + }, + { + "cell_type": "markdown", + "id": "dc3ad875", + "metadata": {}, + "source": [ + "Next, we'll pick one of these configurations at random and compile and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8dc5287", + "metadata": {}, + "outputs": [], + "source": [ + "idx = random.randint(0, len(tiles)-1)\n", + "td = tiles[idx]\n", + "print('Tile description {} is: {}'.format(idx, td))\n", + "plan.compile(td)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "c5a8b534", + "metadata": {}, + "source": [ + "One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5e88d17", + "metadata": {}, + "outputs": [], + "source": [ + "# Stream K is only supported pre-SM90 (at least when this example was written)\n", + "if plan.cc != 90:\n", + " plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n", + " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "5a8ba2ba", + "metadata": {}, + "source": [ + "## Handling errors\n", + "The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n", + "\n", + "Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe7d0e42", + "metadata": {}, + "outputs": [], + "source": [ + "# td = tiles[0]\n", + "# td.stages = 8\n", + "# plan.compile(td)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "vscode": { + "interpreter": { + "hash": "0466d96796c9cd8f7a1cad264ff326ececc950ba2420e0256d5105fc1a3c6e70" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/01_epilogue.ipynb b/examples/python/01_epilogue.ipynb new file mode 100644 index 00000000..05ab60d6 --- /dev/null +++ b/examples/python/01_epilogue.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "5d24a692", + "metadata": {}, + "source": [ + "# Example of using elementwise activation functions in the CUTLASS Python interface\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "3ca993fe", + "metadata": {}, + "source": [ + "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63a70a3c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import cutlass\n", + "\n", + "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", + "# omit this information.\n", + "print_module = True\n", + "\n", + "m = 256\n", + "n = m\n", + "k = m\n", + "\n", + "type_A = np.float16\n", + "type_B = np.float16\n", + "type_C = np.float16\n", + "type_D = np.float16\n", + "\n", + "np.random.seed(1234)\n", + "scope_min = -4\n", + "scope_max = 4\n", + "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", + "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", + "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", + "\n", + "alpha = np.float16(1.)\n", + "beta = np.float16(0.)\n", + "\n", + "tensor_D = np.zeros(tensor_C.shape).astype(type_D)" + ] + }, + { + "cell_type": "markdown", + "id": "1eb0d95b", + "metadata": {}, + "source": [ + "## Run a GEMM with an identity activation function\n", + "To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d257833", + "metadata": {}, + "outputs": [], + "source": [ + "plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "54961694", + "metadata": {}, + "source": [ + "## Run a GEMM with a ReLU element-wise activation function\n", + "CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n", + "```\n", + "D = alpha * (A @ B) + beta * C\n", + "D = act(D)\n", + "```\n", + "\n", + "Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n", + "\n", + "This is easy to do in CUTLASS. One only needs to set the plan's `activation` field." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe49443", + "metadata": {}, + "outputs": [], + "source": [ + "tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n", + "plan.activation = cutlass.epilogue.relu\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "455d0a37", + "metadata": {}, + "source": [ + "We can now verify that the result of the GEMM that used a ReLU activation function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e32e7798", + "metadata": {}, + "outputs": [], + "source": [ + "relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n", + "np.testing.assert_array_equal(relu_ref, tensor_D_relu)" + ] + }, + { + "cell_type": "markdown", + "id": "cf959171", + "metadata": {}, + "source": [ + "## Other element-wise activation functions\n", + "CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e17d730", + "metadata": {}, + "outputs": [], + "source": [ + "activations = plan.activations()\n", + "for activation in activations:\n", + " print(activation)" + ] + }, + { + "cell_type": "markdown", + "id": "0e4599fa", + "metadata": {}, + "source": [ + "We can then run each of them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c3598c9", + "metadata": {}, + "outputs": [], + "source": [ + "for activation in activations:\n", + " print('=============================================================================================')\n", + " print(f'Compiling and running activation {activation}')\n", + " print('=============================================================================================')\n", + " plan.activation = activation\n", + " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "751f8d92", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/02_pytorch_extension_grouped_gemm.ipynb new file mode 100644 index 00000000..567a583a --- /dev/null +++ b/examples/python/02_pytorch_extension_grouped_gemm.ipynb @@ -0,0 +1,264 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "6acbea5d", + "metadata": {}, + "source": [ + "# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n", + "a grouped GEMM kernel and export it as a PyTorch CUDA extension.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n", + "\n", + "## Background on grouped GEMM\n", + "Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n", + "in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n", + "without the requirement that the sizes and strides of each GEMM be the same.\n", + "\n", + "For example, if one has `p` GEMMs with sizes:\n", + "```text\n", + "M_1 x N_1 x K_1\n", + "M_2 x N_2 x K_2\n", + "...\n", + "M_p x N_p x K_p\n", + "```\n", + "CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n", + "\n", + "Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n", + "insufficiently utilize the device in isolation.\n", + "\n", + "## Declaring a grouped GEMM via the CUTLASS Python interface\n", + "A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n", + "simply calls `cutlass.op.GroupedGemm`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdcf21d8", + "metadata": {}, + "outputs": [], + "source": [ + "import cutlass\n", + "import torch\n", + "\n", + "dtype = torch.float16\n", + "plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)" + ] + }, + { + "cell_type": "markdown", + "id": "514f40a4", + "metadata": {}, + "source": [ + "We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2a7371e", + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "random.seed(2023)\n", + "\n", + "# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n", + "def initialize(dtype, M, N, K):\n", + " sizes = [(M, K), (K, N), (M, N), (M, N)]\n", + " return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n", + "\n", + "# Utility function to generate `problems` GEMMs of random sizes\n", + "def generate_problems(problems):\n", + " valid_sizes = [128, 256, 512, 1024]\n", + " As, Bs, Cs, Ds = [], [], [], []\n", + " for _ in range(problems):\n", + " M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n", + " A, B, C, D = initialize(dtype, M, N, K)\n", + " As.append(A)\n", + " Bs.append(B)\n", + " Cs.append(C)\n", + " Ds.append(D)\n", + " return As, Bs, Cs, Ds" + ] + }, + { + "cell_type": "markdown", + "id": "590a3bc5", + "metadata": {}, + "source": [ + "We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "776c9233", + "metadata": {}, + "outputs": [], + "source": [ + "As, Bs, Cs, Ds, = generate_problems(50)\n", + "\n", + "plan.run(As, Bs, Cs, Ds, print_module=True)\n", + "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + "\n", + "for d, d_torch in zip(Ds, Ds_torch):\n", + " assert torch.allclose(d, d_torch)" + ] + }, + { + "cell_type": "markdown", + "id": "766e4f03", + "metadata": {}, + "source": [ + "## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n", + "The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n", + "\n", + "The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n", + "\n", + "To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a98dee6", + "metadata": {}, + "outputs": [], + "source": [ + "op = plan.construct()\n", + "grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)" + ] + }, + { + "cell_type": "markdown", + "id": "c8ca3991", + "metadata": {}, + "source": [ + "The `cutlass.emit.pytorch` function emits:\n", + "* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n", + "* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n", + "* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n", + "\n", + "The extension can be build from within the `module_output` directory by running:\n", + "```bash\n", + "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n", + "```\n", + "Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n", + "\n", + "See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n", + "\n", + "The PyTorch CUDA extension could be built for this module by running:\n", + "```bash\n", + "cd out\n", + "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n", + "```\n", + "(assuming that one is building for SM80)\n", + "\n", + "One could then use the kernel in a later PyTorch module by running:\n", + "\n", + "```python\n", + "import torch\n", + "import grouped_gemm\n", + "\n", + "grouped_gemm.run(As, Bs)\n", + "```\n", + "\n", + "In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n", + "Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n", + "and returns back the loaded extension.\n", + "\n", + "We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cecb26a4", + "metadata": {}, + "outputs": [], + "source": [ + "Ds = grouped_gemm.run(As, Bs)\n", + "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + "for d, d_torch in zip(Ds, Ds_torch):\n", + " assert torch.allclose(d, d_torch)" + ] + }, + { + "cell_type": "markdown", + "id": "50db80e4", + "metadata": {}, + "source": [ + "Finally, we can profile our grouped GEMM extension:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b76805d3", + "metadata": {}, + "outputs": [], + "source": [ + "num_warmup = 20\n", + "num_profile = 100\n", + "\n", + "# Warmup iterations\n", + "for _ in range(num_warmup):\n", + " Ds = grouped_gemm.run(As, Bs)\n", + " Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + " torch.cuda.synchronize()\n", + "\n", + "# Timing iterations\n", + "import time\n", + "grouped = 0\n", + "nongrouped = 0\n", + "for _ in range(num_profile):\n", + " start = time.time()\n", + " Ds = grouped_gemm.run(As, Bs)\n", + " torch.cuda.synchronize()\n", + " grouped += time.time() - start\n", + "\n", + " start = time.time()\n", + " Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + " torch.cuda.synchronize()\n", + " nongrouped += time.time() - start\n", + "\n", + "print('Grouped: {:.3f} us'.format(grouped * 1e6/num_profile))\n", + "print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n", + "print('Speedup: {:.3f}'.format(nongrouped / grouped))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f22fc696", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/README.md b/examples/python/README.md new file mode 100644 index 00000000..fab167c9 --- /dev/null +++ b/examples/python/README.md @@ -0,0 +1,14 @@ +# Examples of using the CUTLASS Python interface + +* [00_basic_gemm](/examples/python/00_basic_gemm.ipynb) + + Shows how declare, configure, compile, and run a CUTLASS GEMM using the Python interface + +* [01_epilogue](/examples/python/01_epilogue.ipynb) + + Shows how to fuse elementwise activation functions to GEMMs via the Python interface + +* [02_pytorch_extension_grouped_gemm](/examples/python/02_pytorch_extension_grouped_gemm.ipynb) + + Shows how to declare, compile, and run a grouped GEMM operation via the Python interface, + along with how the emitted kernel can be easily exported to a PyTorch CUDA extension. diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 04ceb051..65f80af8 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -171,7 +171,7 @@ copy_vec(Tensor const& src, { using SrcType = typename SrcEngine::value_type; using DstType = typename DstEngine::value_type; - if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType)) + if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType)) { /* @pre is_aligned(src.data()) && * is_aligned(dst.data()) @@ -259,4 +259,51 @@ copy(Copy_Atom const&, return copy(src, dst); } +////////////////////////////////////////// +// Special Auto-Vectorizing Overloads +////////////////////////////////////////// + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, CA_Args...> const& atom, + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + static_assert(sizeof_bits::value == sizeof_bits::value); + static_assert((is_gmem::value && is_smem::value) || + (is_smem::value && is_gmem::value), + "Bulk Copy only supports gmem -> smem or smem -> gmem movement."); + // Do BulkCopy dispatch + using BULK_COPY_OP = conditional_t::value, + SM90_BULK_COPY_G2S, + SM90_BULK_COPY_S2G>; + + constexpr int N = decltype(max_common_vector(src, dst))::value; + + // Construct a new concrete Atom of the vector size + using N_BITS = Int::value>; + using COPY_ATOM = Copy_Atom, SrcType>; + auto bulk_atom = apply(atom.opargs_, [&](auto const&... args) { return COPY_ATOM{args...}; }); + + // Tile the src and dst to the Atom + auto tiler = right_inverse(dst.layout()).compose(Int{}); + +#if 0 + if (thread0()) { + print("copy -- found a max_common_vector of %d\n", N); + print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n"); + print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n"); + } +#endif + + return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler)); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + } // end namespace cute diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp index e66cd975..ea17ecb9 100644 --- a/include/cute/algorithm/functional.hpp +++ b/include/cute/algorithm/functional.hpp @@ -30,10 +30,10 @@ **************************************************************************************************/ #pragma once -#include - #include +#include + /** C++14 extensions */ namespace cute { diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp index 6e2ce612..329a1fe7 100644 --- a/include/cute/algorithm/gemm.hpp +++ b/include/cute/algorithm/gemm.hpp @@ -32,10 +32,12 @@ #include -#include +#include #include + +#include + #include -#include /** The gemm algorithm takes four (or three) tensors and computes * D += A * B + C @@ -281,17 +283,15 @@ gemm(MMA_Atom const& mma, CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); - // REGISTER .reuse OPTIMIZATIONS - auto M = size<1>(A); auto N = size<1>(B); - + // REGISTER .reuse OPTIMIZATIONS // 64-bit traversal specialization -- serpentine path - if (size<0>(A) * sizeof(typename Tensor::value_type) == 8 && - size<0>(B) * sizeof(typename Tensor::value_type) == 8) + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) { -#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) - // Row-major iteration +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major serpentine iteration CUTE_UNROLL for (int m = 0; m < M; ++m) { CUTE_UNROLL @@ -301,7 +301,7 @@ gemm(MMA_Atom const& mma, } } #else - // Col-major iteration + // Col-major serpentine iteration CUTE_UNROLL for (int n = 0; n < N; ++n) { CUTE_UNROLL @@ -312,13 +312,12 @@ gemm(MMA_Atom const& mma, } #endif } else - // 32-bit traversal specialization -- kinked serpentine path - if (size<0>(A) * sizeof(typename Tensor::value_type) == 4 && - size<0>(B) * sizeof(typename Tensor::value_type) == 4) + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) { -#if 1 // NOTE: Must depend on the C-matrix order... (which we can test) - // Row-major iteration +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major kinked serpentine iteration CUTE_UNROLL for (int m = 0; m < M; m += 2) { CUTE_UNROLL @@ -332,7 +331,7 @@ gemm(MMA_Atom const& mma, } } #else - // Col-major iteration + // Col-major kinked serpentine iteration CUTE_UNROLL for (int n = 0; n < N; n += 2) { CUTE_UNROLL @@ -347,9 +346,36 @@ gemm(MMA_Atom const& mma, } } #endif - } else { - // Fallback to serpentine loop - // Col-major iteration + } else + // 64-bit + 32-bit traversal order -- keep A (64-bit) in the outer loop and serpentine B + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) { + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } + } else + // 32-bit + 64-bit traversal order -- keep B (64-bit) in the outer loop and serpentine A + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } else + // Fallback to serpentine loop + { + // Col-major serpentine iteration CUTE_UNROLL for (int n = 0; n < N; ++n) { CUTE_UNROLL @@ -504,9 +530,9 @@ gemm(ThrMMA const& thr_mma, using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; - static_assert(std::is_same_v>, TypeA>, + static_assert(is_same_v>, TypeA>, "ALoadTransformOp functor must accept and return value of type TA::value_type"); - static_assert(std::is_same_v>, TypeB>, + static_assert(is_same_v>, TypeB>, "BLoadTransformOp functor must accept and return value of type TB::value_type"); // Original, static size of the problem diff --git a/include/cute/algorithm/prefer.hpp b/include/cute/algorithm/prefer.hpp index 700edff0..804896ce 100644 --- a/include/cute/algorithm/prefer.hpp +++ b/include/cute/algorithm/prefer.hpp @@ -34,7 +34,7 @@ namespace cute { // Infinite types that inherit from each other -template +template struct prefer : prefer {}; template <> diff --git a/include/cute/algorithm/tensor_algorithms.hpp b/include/cute/algorithm/tensor_algorithms.hpp index 258ddec6..5fac8f92 100644 --- a/include/cute/algorithm/tensor_algorithms.hpp +++ b/include/cute/algorithm/tensor_algorithms.hpp @@ -99,4 +99,25 @@ transform(Tensor&& tensor, UnaryOp&& op) return transform(tensor, std::forward(op)); } +// Similar to std::transform transforms one tensors and assigns it to another +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor_in, Tensor& tensor_out, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in); ++i) { + tensor_out(i) = static_cast(op)(tensor_in(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor_in, Tensor&& tensor_out, UnaryOp&& op) +{ + return transform(tensor_in, tensor_out, std::forward(op)); +} + } // end namespace cute diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index 35b19f96..393a93af 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -32,11 +32,11 @@ #include +#include #include #include #include #include -#include /** Common algorithms on (hierarchical) tuples */ /** Style choice: @@ -150,7 +150,7 @@ CUTE_HOST_DEVICE constexpr auto for_each_leaf(T&& t, F&& f) { - if constexpr (is_tuple>::value) { + if constexpr (is_tuple>::value) { return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); } else { return f(static_cast(t)); @@ -205,6 +205,20 @@ transform_leaf(T const& t, F&& f) CUTE_GCC_UNREACHABLE; } +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); }); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + // // find and find_if // @@ -258,25 +272,40 @@ find(T const& t, X const& x) } template +CUTE_HOST_DEVICE constexpr auto none_of(T const& t, F&& f) { - return cute::integral_constant::value>{}; + if constexpr (is_tuple::value) { + return cute::integral_constant::value>{}; + } else { + return not f(t); + } + + CUTE_GCC_UNREACHABLE; } template +CUTE_HOST_DEVICE constexpr auto all_of(T const& t, F&& f) { - auto not_f = [&](auto const& a) { return !f(a); }; - return cute::integral_constant::value>{}; + if constexpr (is_tuple::value) { + auto not_f = [&](auto const& a) { return not f(a); }; + return cute::integral_constant::value>{}; + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; } template +CUTE_HOST_DEVICE constexpr auto any_of(T const& t, F&& f) { - return cute::integral_constant{}; + return not none_of(t, f); } // @@ -340,7 +369,7 @@ CUTE_HOST_DEVICE constexpr auto fold(T&& t, V&& v, F&& f) { - if constexpr (is_tuple>::value) { + if constexpr (is_tuple>::value) { return detail::fold(static_cast(t), static_cast(v), f, @@ -357,11 +386,11 @@ CUTE_HOST_DEVICE constexpr decltype(auto) fold_first(T&& t, F&& f) { - if constexpr (is_tuple>::value) { + if constexpr (is_tuple>::value) { return detail::fold(static_cast(t), get<0>(static_cast(t)), f, - make_range<1,std::tuple_size>::value>{}); + make_range<1,tuple_size>::value>{}); } else { return static_cast(t); } @@ -753,12 +782,12 @@ escan(T const& t, V const& v, F&& f) namespace detail { -template +template CUTE_HOST_DEVICE constexpr auto -zip_(T const& t, seq) +zip_(Ts const&... ts) { - return cute::make_tuple(get(get(t))...); + return cute::make_tuple(get(ts)...); } template @@ -767,7 +796,7 @@ auto zip(T const& t, seq, seq) { static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); - return cute::make_tuple(detail::zip_(t, seq{})...); + return cute::make_tuple(zip_(get(t)...)...); } } // end namespace detail @@ -817,8 +846,8 @@ zip2_by(T const& t, TG const& guide, seq, seq) auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) - return cute::make_tuple(cute::make_tuple(get(split)...), - cute::make_tuple(get(split)..., get(t)...)); + return cute::make_tuple(cute::make_tuple(get<0>(get(split))...), + cute::make_tuple(get<1>(get(split))..., get(t)...)); } } // end namespace detail diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index 6fd9edd3..e9c858d5 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -49,7 +49,7 @@ CUTE_DEVICE void cluster_arrive_relaxed() #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); #else - asm volatile ("brkpt;\n" ::); + CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); #endif } @@ -58,7 +58,7 @@ CUTE_DEVICE void cluster_arrive() #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) asm volatile("barrier.cluster.arrive.aligned;\n" : : ); #else - asm volatile ("brkpt;\n" ::); + CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); #endif } @@ -67,7 +67,7 @@ CUTE_DEVICE void cluster_wait() #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) asm volatile("barrier.cluster.wait.aligned;\n" : : ); #else - asm volatile ("brkpt;\n" ::); + CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); #endif } @@ -77,7 +77,7 @@ CUTE_DEVICE void cluster_sync() cluster_arrive(); cluster_wait(); #else - asm volatile ("brkpt;\n" ::); + CUTE_RUNTIME_ASSERT("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); #endif } @@ -90,8 +90,13 @@ CUTE_DEVICE dim3 cluster_grid_dims() asm volatile("mov.u32 %0, %nclusterid.y;\n" : "=r"(y) : ); asm volatile("mov.u32 %0, %nclusterid.z;\n" : "=r"(z) : ); return {x, y, z}; -#else +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of gridDim with __CUDA_ARCH__. return gridDim; +#elif defined(_MSC_VER) + CUTE_RUNTIME_ASSERT("cluster_grid_dims() can only be called on device"); +#else + return {0, 0, 0}; #endif } @@ -104,8 +109,13 @@ CUTE_DEVICE dim3 cluster_id_in_grid() asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(y) : ); asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(z) : ); return {x, y, z}; -#else +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. return blockIdx; +#elif defined(_MSC_VER) + CUTE_RUNTIME_ASSERT("cluster_id_in_grid() can only be called on device"); +#else + return {0, 0, 0}; #endif } @@ -154,8 +164,8 @@ CUTLASS_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { #if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) uint32_t result; - asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" - : "=r"(result) + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(result) : "r"(smemAddr), "r"(rank)); return result; #else @@ -187,4 +197,34 @@ CUTE_HOST_DEVICE uint32_t elect_one_sync() #endif } +struct ElectOneLaneIdReturnType { + uint32_t is_leader; + uint32_t leader_lane_id; +}; + +CUTE_HOST_DEVICE +ElectOneLaneIdReturnType +elect_one_leader_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %rx;\n" + ".reg .pred %px;\n" + " elect.sync %rx|%px, %2;\n" + "@%px mov.s32 %1, 1;\n" + " mov.s32 %0, %rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return {pred, laneid}; +#elif defined(__CUDA_ARCH__) + return {(threadIdx.x % 32) == 0, 0}; +#else + return {true, 0}; +#endif +} + } // end namespace cute diff --git a/include/cute/arch/copy_sm75.hpp b/include/cute/arch/copy_sm75.hpp index 58b42864..0929b327 100644 --- a/include/cute/arch/copy_sm75.hpp +++ b/include/cute/arch/copy_sm75.hpp @@ -37,7 +37,7 @@ // Config #if defined(__clang__) && defined(__CUDA__) // ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046 - // ... but broken until Clang 15: + // ... but will not work until Clang 15: // * https://reviews.llvm.org/D121666 // * https://reviews.llvm.org/D126846 #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index ca8320f6..f69c2bd5 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -30,7 +30,10 @@ **************************************************************************************************/ #pragma once +#if !defined(__CUDACC_RTC__) #include +#include +#endif #include @@ -135,18 +138,18 @@ enum class SmemSwizzleBits : uint8_t { template inline CUtensorMapDataType to_CUtensorMapDataType() { - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else - if constexpr (std::is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } } diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index d6025e4a..412754c7 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -169,7 +169,7 @@ struct SM90_TMA_LOAD void const* const smem_ptr, int32_t const& crd0) { - return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0); + return SM90_TMA_LOAD_1D::copy(desc_ptr, smem_mbar, smem_ptr, crd0); } CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, @@ -201,11 +201,138 @@ struct SM90_TMA_LOAD } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2], {%6};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::copy(desc_ptr, smem_mbar, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::copy(desc_ptr, smem_mbar, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::copy(desc_ptr, smem_mbar, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// /// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory //////////////////////////////////////////////////////////////////////////////////////////////////// -struct SM90_TMA_LOAD_1D_MULTICAST +struct SM90_TMA_LOAD_MULTICAST_1D { CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, @@ -230,7 +357,7 @@ struct SM90_TMA_LOAD_1D_MULTICAST } }; -struct SM90_TMA_LOAD_2D_MULTICAST +struct SM90_TMA_LOAD_MULTICAST_2D { CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, @@ -255,7 +382,7 @@ struct SM90_TMA_LOAD_2D_MULTICAST } }; -struct SM90_TMA_LOAD_3D_MULTICAST +struct SM90_TMA_LOAD_MULTICAST_3D { CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, @@ -280,7 +407,7 @@ struct SM90_TMA_LOAD_3D_MULTICAST } }; -struct SM90_TMA_LOAD_4D_MULTICAST +struct SM90_TMA_LOAD_MULTICAST_4D { CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, @@ -305,7 +432,7 @@ struct SM90_TMA_LOAD_4D_MULTICAST } }; -struct SM90_TMA_LOAD_5D_MULTICAST +struct SM90_TMA_LOAD_MULTICAST_5D { CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, @@ -337,35 +464,174 @@ struct SM90_TMA_LOAD_MULTICAST void const* const smem_ptr, int32_t const& crd0) { - return SM90_TMA_LOAD_1D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0); + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0); } CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1) { - return SM90_TMA_LOAD_2D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1); + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1); } CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { - return SM90_TMA_LOAD_3D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2); + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2); } CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) { - return SM90_TMA_LOAD_4D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3); } CUTE_HOST_DEVICE static void copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) { - return SM90_TMA_LOAD_5D_MULTICAST::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, smem_mbar, multicast_mask, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + uint16_t const& multicast_mask, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + uint16_t const& multicast_mask, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + uint16_t const& multicast_mask, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + uint16_t const& multicast_mask, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, smem_mbar, + multicast_mask, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + uint16_t const& multicast_mask, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, smem_mbar, + multicast_mask, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, uint64_t& smem_mbar, + uint16_t const& multicast_mask, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, smem_mbar, + multicast_mask, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); } }; @@ -533,7 +799,7 @@ tma_store_arrive() { } // Wait on prior N (Count) TMA_STORE instructions to complete -template +template CUTE_HOST_DEVICE static void tma_store_wait() { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) @@ -547,6 +813,49 @@ tma_store_wait() { #endif } +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// BULK_COPY : Copy a bulk of memory between shared memory and global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_COPY_G2S +{ + CUTE_HOST_DEVICE static void + copy(void const* const gmem_ptr, uint64_t& smem_mbar, + void const* const smem_ptr, int32_t load_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(&smem_mbar); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(smem_int_ptr), "l"(gmem_ptr), "r"(load_bytes), "r"(smem_int_mbar) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_BULK_COPY_S2G +{ + CUTE_HOST_DEVICE static void + copy(void const* const smem_ptr, + void const* const gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_BULK_COPY_AUTO {}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace cute diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index 08fe2b28..42778c80 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -36,7 +36,7 @@ #include // Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) # define CUTE_ARCH_MMA_SM90_ENABLED #endif @@ -342,7 +342,7 @@ struct SM90_16x8x16_C64C64C64C64_TN namespace cute { namespace GMMA { -template< +template < class ElementA, class ElementB, class ElementC, @@ -362,9 +362,9 @@ ss_op_selector() auto Tile_N = size<1>(TileShape_MNK{}); // FP16 accumulator - if constexpr (std::is_same_v) { - static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); - static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); + if constexpr (is_same_v) { + static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); // Dispatch against the Tile N mode size @@ -398,11 +398,11 @@ ss_op_selector() } // FP32 accumulator - else if constexpr (std::is_same_v) { + else if constexpr (is_same_v) { // FP16 inputs - if constexpr (std::is_same_v) { - static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32F16F16_SS{}; @@ -434,8 +434,8 @@ ss_op_selector() } // BF16 inputs - else if constexpr (std::is_same_v) { - static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + else if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { @@ -468,8 +468,8 @@ ss_op_selector() } // TF32 inputs - else if constexpr (std::is_same_v) { - static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + else if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); @@ -508,36 +508,36 @@ ss_op_selector() } // S32 accumulator - else if constexpr (std::is_same_v) { + else if constexpr (is_same_v) { static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); // ElementA == int8_t && ElementB == int8_t - if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (is_same_v && is_same_v) { if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8S8_SS_TN{}; + return SM90_64x256x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8S8_SS_TN{}; + return SM90_64x192x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8S8_SS_TN{}; + return SM90_64x128x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8S8_SS_TN{}; + return SM90_64x96x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8S8_SS_TN{}; + return SM90_64x64x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8S8_SS_TN{}; + return SM90_64x32x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8S8_SS_TN{}; + return SM90_64x16x32_S32S8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8S8_SS_TN{}; + return SM90_64x8x32_S32S8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -545,32 +545,32 @@ ss_op_selector() } // ElementA == int8_t && ElementB == uint8_t - else if constexpr (std::is_same_v && std::is_same_v) { + else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8U8_SS_TN{}; + return SM90_64x256x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8U8_SS_TN{}; + return SM90_64x192x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8U8_SS_TN{}; + return SM90_64x128x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8U8_SS_TN{}; + return SM90_64x96x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8U8_SS_TN{}; + return SM90_64x64x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8U8_SS_TN{}; + return SM90_64x32x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8U8_SS_TN{}; + return SM90_64x16x32_S32S8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8U8_SS_TN{}; + return SM90_64x8x32_S32S8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -578,32 +578,32 @@ ss_op_selector() } // ElementA == uint8_t && ElementB == int8_t - else if constexpr (std::is_same_v && std::is_same_v) { + else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8S8_SS_TN{}; + return SM90_64x256x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8S8_SS_TN{}; + return SM90_64x192x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8S8_SS_TN{}; + return SM90_64x128x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8S8_SS_TN{}; + return SM90_64x96x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8S8_SS_TN{}; + return SM90_64x64x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8S8_SS_TN{}; + return SM90_64x32x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8S8_SS_TN{}; + return SM90_64x16x32_S32U8S8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8S8_SS_TN{}; + return SM90_64x8x32_S32U8S8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -611,32 +611,32 @@ ss_op_selector() } // ElementA == uint8_t && ElementB == uint8_t - else if constexpr (std::is_same_v && std::is_same_v) { + else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8U8_SS_TN{}; + return SM90_64x256x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8U8_SS_TN{}; + return SM90_64x192x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8U8_SS_TN{}; + return SM90_64x128x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8U8_SS_TN{}; + return SM90_64x96x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8U8_SS_TN{}; + return SM90_64x64x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8U8_SS_TN{}; + return SM90_64x32x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8U8_SS_TN{}; + return SM90_64x16x32_S32U8U8_SS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8U8_SS_TN{}; + return SM90_64x8x32_S32U8U8_SS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -650,7 +650,7 @@ ss_op_selector() } } -template< +template < class ElementA, class ElementB, class ElementC, @@ -671,9 +671,9 @@ rs_op_selector() auto Tile_N = size<1>(TileShape_MNK{}); // FP16 accumulator - if constexpr (std::is_same_v) { - static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); - static_assert(std::is_same_v, "Element types for AB must be half if ElementC is half."); + if constexpr (is_same_v) { + static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); + static_assert(is_same_v, "Element types for AB must be half if ElementC is half."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); // Dispatch against the Tile N mode size @@ -707,12 +707,12 @@ rs_op_selector() } // FP32 accumulator - else if constexpr (std::is_same_v) { - static_assert(std::is_same_v, "ElementA and ElementB must be the same type for this config."); + else if constexpr (is_same_v) { + static_assert(is_same_v, "ElementA and ElementB must be the same type for this config."); static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); // FP16 inputs - if constexpr (std::is_same_v) { + if constexpr (is_same_v) { if constexpr (Tile_N % 256 == 0) { return SM90_64x256x16_F32F16F16_RS{}; } @@ -743,7 +743,7 @@ rs_op_selector() } // BF16 inputs - else if constexpr (std::is_same_v) { + else if constexpr (is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); if constexpr (Tile_N % 256 == 0) { @@ -776,7 +776,7 @@ rs_op_selector() } // TF32 inputs - else if constexpr (std::is_same_v) { + else if constexpr (is_same_v) { static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); @@ -815,35 +815,35 @@ rs_op_selector() } // S32 accumulator - else if constexpr (std::is_same_v) { + else if constexpr (is_same_v) { static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); // ElementA == int8_t && ElementB == int8_t - if constexpr (std::is_same_v && std::is_same_v) { + if constexpr (is_same_v && is_same_v) { if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8S8_RS_TN{}; + return SM90_64x256x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8S8_RS_TN{}; + return SM90_64x192x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8S8_RS_TN{}; + return SM90_64x128x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8S8_RS_TN{}; + return SM90_64x96x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8S8_RS_TN{}; + return SM90_64x64x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8S8_RS_TN{}; + return SM90_64x32x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8S8_RS_TN{}; + return SM90_64x16x32_S32S8S8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8S8_RS_TN{}; + return SM90_64x8x32_S32S8S8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -851,32 +851,32 @@ rs_op_selector() } // ElementA == int8_t && ElementB == uint8_t - else if constexpr (std::is_same_v && std::is_same_v) { + else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32S8U8_RS_TN{}; + return SM90_64x256x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32S8U8_RS_TN{}; + return SM90_64x192x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32S8U8_RS_TN{}; + return SM90_64x128x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32S8U8_RS_TN{}; + return SM90_64x96x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32S8U8_RS_TN{}; + return SM90_64x64x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32S8U8_RS_TN{}; + return SM90_64x32x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32S8U8_RS_TN{}; + return SM90_64x16x32_S32S8U8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32S8U8_RS_TN{}; + return SM90_64x8x32_S32S8U8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -884,32 +884,32 @@ rs_op_selector() } // ElementA == uint8_t && ElementB == int8_t - else if constexpr (std::is_same_v && std::is_same_v) { + else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8S8_RS_TN{}; + return SM90_64x256x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8S8_RS_TN{}; + return SM90_64x192x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8S8_RS_TN{}; + return SM90_64x128x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8S8_RS_TN{}; + return SM90_64x96x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8S8_RS_TN{}; + return SM90_64x64x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8S8_RS_TN{}; + return SM90_64x32x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8S8_RS_TN{}; + return SM90_64x16x32_S32U8S8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8S8_RS_TN{}; + return SM90_64x8x32_S32U8S8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); @@ -917,32 +917,32 @@ rs_op_selector() } // ElementA == uint8_t && ElementB == uint8_t - else if constexpr (std::is_same_v && std::is_same_v) { + else if constexpr (is_same_v && is_same_v) { static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); if constexpr (Tile_N % 256 == 0) { - return SM90_64x256x32_S32U8U8_RS_TN{}; + return SM90_64x256x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 192 == 0) { - return SM90_64x192x32_S32U8U8_RS_TN{}; + return SM90_64x192x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 128 == 0) { - return SM90_64x128x32_S32U8U8_RS_TN{}; + return SM90_64x128x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 96 == 0) { - return SM90_64x96x32_S32U8U8_RS_TN{}; + return SM90_64x96x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 64 == 0) { - return SM90_64x64x32_S32U8U8_RS_TN{}; + return SM90_64x64x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 32 == 0) { - return SM90_64x32x32_S32U8U8_RS_TN{}; + return SM90_64x32x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 16 == 0) { - return SM90_64x16x32_S32U8U8_RS_TN{}; + return SM90_64x16x32_S32U8U8_RS_TN{}; } else if constexpr (Tile_N % 8 == 0) { - return SM90_64x8x32_S32U8U8_RS_TN{}; + return SM90_64x8x32_S32U8U8_RS_TN{}; } else { static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index abac5170..dd4e1fb5 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -31,13 +31,17 @@ #pragma once +#if !defined(__CUDACC_RTC__) +#include +#endif + #include #include // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -# define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_SM90A_ENABLED #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -68,6 +72,7 @@ CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { return nullptr; } +#if !defined(__CUDACC_RTC__) // Output operator for all enums in this namespace CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { char const* s = to_string(t); @@ -78,6 +83,7 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { } return os; } +#endif // !defined(__CUDACC_RTC__) } // end namespace GMMA @@ -115,12 +121,14 @@ union GmmaDescriptor // Printer CUTE_HOST_DEVICE friend void print(GmmaDescriptor const& t) { - printf("GmmaDescriptor: 0x%016lx\n", t.desc_); + #if !defined(__CUDACC_RTC__) + printf("GmmaDescriptor: 0x%016" PRIx64 "\n", t.desc_); printf(" start_addr : 0x%04x\n", t.start_address_); printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_); printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_); printf(" base_offset: 0x%01x\n", t.base_offset_); printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast(t.layout_type_))); + #endif } }; diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp index 25a1d171..db4083ee 100644 --- a/include/cute/arch/mma_sm90_gmma.hpp +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -35,7 +35,7 @@ // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) -# define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_SM90A_ENABLED #endif namespace cute { @@ -47,10 +47,10 @@ CUTE_HOST_DEVICE void warpgroup_arrive() { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); #else - CUTE_RUNTIME_ASSERT("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } @@ -60,10 +60,10 @@ void warpgroup_wait() { static_assert(N >= 0 && N <= 7, "_warpgroup.wait {N}; must be in range [0, 7]"); -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); #else - CUTE_RUNTIME_ASSERT("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } @@ -72,23 +72,30 @@ CUTE_HOST_DEVICE void warpgroup_commit_batch() { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); #else - CUTE_RUNTIME_ASSERT("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } CUTE_HOST_DEVICE void warpgroup_fence_operand(uint32_t& reg) { + // MSVC emits a build error for 'asm volatile' + // even if it only occurs in a __device__ function. + // This prevents the error. +#if defined(__CUDA_ARCH__) asm volatile("" : "+r"(reg) :: "memory"); +#endif } CUTE_HOST_DEVICE void warpgroup_fence_operand(float& reg) { +#if defined(__CUDA_ARCH__) asm volatile("" : "+f"(reg) :: "memory"); +#endif } namespace GMMA { @@ -115,10 +122,9 @@ enum class ScaleIn { //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -132,21 +138,26 @@ struct SM90_64x8x16_F16F16F16_SS CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1) + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " "{%0, %1}," " %2," " %3," - " %4, %5, %6, %7, %8;\n" + " p, %5, %6, %7, %8;\n" + "}\n" : "+r"(d0), "+r"(d1) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -154,10 +165,9 @@ struct SM90_64x8x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -174,21 +184,26 @@ struct SM90_64x8x16_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1) + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " "{%0, %1}," "{%2, %3, %4, %5}," " %6," - " %7, %8, %9, %10;\n" + " p, %8, %9, %10;\n" + "}\n" : "+r"(d0), "+r"(d1) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -196,10 +211,9 @@ struct SM90_64x8x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -213,21 +227,26 @@ struct SM90_64x16x16_F16F16F16_SS CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " "{%0, %1, %2, %3}," " %4," " %5," - " %6, %7, %8, %9, %10;\n" + " p, %7, %8, %9, %10;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -235,10 +254,9 @@ struct SM90_64x16x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -255,21 +273,26 @@ struct SM90_64x16x16_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9, %10, %11, %12;\n" + " p, %10, %11, %12;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -277,10 +300,9 @@ struct SM90_64x16x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -295,22 +317,27 @@ struct SM90_64x32x16_F16F16F16_SS fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10, %11, %12, %13, %14;\n" + " p, %11, %12, %13, %14;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -318,10 +345,9 @@ struct SM90_64x32x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -339,22 +365,27 @@ struct SM90_64x32x16_F16F16F16_RS fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13, %14, %15, %16;\n" + " p, %14, %15, %16;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -362,10 +393,9 @@ struct SM90_64x32x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -382,25 +412,30 @@ struct SM90_64x64x16_F16F16F16_SS uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18, %19, %20, %21, %22;\n" + " p, %19, %20, %21, %22;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -408,10 +443,9 @@ struct SM90_64x64x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -431,25 +465,30 @@ struct SM90_64x64x16_F16F16F16_RS uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21, %22, %23, %24;\n" + " p, %22, %23, %24;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -457,10 +496,9 @@ struct SM90_64x64x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -479,17 +517,22 @@ struct SM90_64x96x16_F16F16F16_SS uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23}," " %24," " %25," - " %26, %27, %28, %29, %30;\n" + " p, %27, %28, %29, %30;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -498,9 +541,9 @@ struct SM90_64x96x16_F16F16F16_SS "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -508,10 +551,9 @@ struct SM90_64x96x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -533,17 +575,22 @@ struct SM90_64x96x16_F16F16F16_RS uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23) + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23}," "{%24, %25, %26, %27}," " %28," - " %29, %30, %31, %32;\n" + " p, %30, %31, %32;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -552,9 +599,9 @@ struct SM90_64x96x16_F16F16F16_RS "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -562,10 +609,9 @@ struct SM90_64x96x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -586,10 +632,14 @@ struct SM90_64x128x16_F16F16F16_SS uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -597,7 +647,8 @@ struct SM90_64x128x16_F16F16F16_SS " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34, %35, %36, %37, %38;\n" + " p, %35, %36, %37, %38;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -608,9 +659,9 @@ struct SM90_64x128x16_F16F16F16_SS "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -618,10 +669,9 @@ struct SM90_64x128x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -645,10 +695,14 @@ struct SM90_64x128x16_F16F16F16_RS uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -656,7 +710,8 @@ struct SM90_64x128x16_F16F16F16_RS " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37, %38, %39, %40;\n" + " p, %38, %39, %40;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -667,9 +722,9 @@ struct SM90_64x128x16_F16F16F16_RS "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -677,10 +732,9 @@ struct SM90_64x128x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -705,10 +759,14 @@ struct SM90_64x192x16_F16F16F16_SS uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -718,7 +776,8 @@ struct SM90_64x192x16_F16F16F16_SS " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50, %51, %52, %53, %54;\n" + " p, %51, %52, %53, %54;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -733,9 +792,9 @@ struct SM90_64x192x16_F16F16F16_SS "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -743,10 +802,9 @@ struct SM90_64x192x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -774,10 +832,14 @@ struct SM90_64x192x16_F16F16F16_RS uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -787,7 +849,8 @@ struct SM90_64x192x16_F16F16F16_RS " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53, %54, %55, %56;\n" + " p, %54, %55, %56;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -802,9 +865,9 @@ struct SM90_64x192x16_F16F16F16_RS "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -812,10 +875,9 @@ struct SM90_64x192x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -844,10 +906,14 @@ struct SM90_64x256x16_F16F16F16_SS uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -859,7 +925,8 @@ struct SM90_64x256x16_F16F16F16_SS " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66, %67, %68, %69, %70;\n" + " p, %67, %68, %69, %70;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -878,9 +945,9 @@ struct SM90_64x256x16_F16F16F16_SS "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -888,10 +955,9 @@ struct SM90_64x256x16_F16F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x16 F16+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -923,10 +989,14 @@ struct SM90_64x256x16_F16F16F16_RS uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -938,7 +1008,8 @@ struct SM90_64x256x16_F16F16F16_RS " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69, %70, %71, %72;\n" + " p, %70, %71, %72;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -957,9 +1028,9 @@ struct SM90_64x256x16_F16F16F16_RS "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -967,10 +1038,9 @@ struct SM90_64x256x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -984,21 +1054,26 @@ struct SM90_64x8x16_F32F16F16_SS CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3) + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " "{%0, %1, %2, %3}," " %4," " %5," - " %6, %7, %8, %9, %10;\n" + " p, %7, %8, %9, %10;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1006,10 +1081,9 @@ struct SM90_64x8x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1026,21 +1100,26 @@ struct SM90_64x8x16_F32F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3) + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9, %10, %11, %12;\n" + " p, %10, %11, %12;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1048,10 +1127,9 @@ struct SM90_64x8x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1066,22 +1144,27 @@ struct SM90_64x16x16_F32F16F16_SS fma(uint64_t const& desc_a, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7) + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10, %11, %12, %13, %14;\n" + " p, %11, %12, %13, %14;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1089,10 +1172,9 @@ struct SM90_64x16x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1110,22 +1192,27 @@ struct SM90_64x16x16_F32F16F16_RS fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7) + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13, %14, %15, %16;\n" + " p, %14, %15, %16;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1133,10 +1220,9 @@ struct SM90_64x16x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1153,25 +1239,30 @@ struct SM90_64x32x16_F32F16F16_SS float & d00, float & d01, float & d02, float & d03, float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15) + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18, %19, %20, %21, %22;\n" + " p, %19, %20, %21, %22;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1179,10 +1270,9 @@ struct SM90_64x32x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1202,25 +1292,30 @@ struct SM90_64x32x16_F32F16F16_RS float & d00, float & d01, float & d02, float & d03, float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15) + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21, %22, %23, %24;\n" + " p, %22, %23, %24;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1228,10 +1323,9 @@ struct SM90_64x32x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1252,10 +1346,14 @@ struct SM90_64x64x16_F32F16F16_SS float & d16, float & d17, float & d18, float & d19, float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31) + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1263,7 +1361,8 @@ struct SM90_64x64x16_F32F16F16_SS " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34, %35, %36, %37, %38;\n" + " p, %35, %36, %37, %38;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1274,9 +1373,9 @@ struct SM90_64x64x16_F32F16F16_SS "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1284,10 +1383,9 @@ struct SM90_64x64x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1311,10 +1409,14 @@ struct SM90_64x64x16_F32F16F16_RS float & d16, float & d17, float & d18, float & d19, float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31) + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1322,7 +1424,8 @@ struct SM90_64x64x16_F32F16F16_RS " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37, %38, %39, %40;\n" + " p, %38, %39, %40;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1333,9 +1436,9 @@ struct SM90_64x64x16_F32F16F16_RS "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1343,10 +1446,9 @@ struct SM90_64x64x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1371,10 +1473,14 @@ struct SM90_64x96x16_F32F16F16_SS float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47) + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1384,7 +1490,8 @@ struct SM90_64x96x16_F32F16F16_SS " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50, %51, %52, %53, %54;\n" + " p, %51, %52, %53, %54;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1399,9 +1506,9 @@ struct SM90_64x96x16_F32F16F16_SS "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1409,10 +1516,9 @@ struct SM90_64x96x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1440,10 +1546,14 @@ struct SM90_64x96x16_F32F16F16_RS float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47) + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1453,7 +1563,8 @@ struct SM90_64x96x16_F32F16F16_RS " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53, %54, %55, %56;\n" + " p, %54, %55, %56;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1468,9 +1579,9 @@ struct SM90_64x96x16_F32F16F16_RS "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1478,10 +1589,9 @@ struct SM90_64x96x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1510,10 +1620,14 @@ struct SM90_64x128x16_F32F16F16_SS float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63) + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1525,7 +1639,8 @@ struct SM90_64x128x16_F32F16F16_SS " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66, %67, %68, %69, %70;\n" + " p, %67, %68, %69, %70;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1544,9 +1659,9 @@ struct SM90_64x128x16_F32F16F16_SS "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1554,10 +1669,9 @@ struct SM90_64x128x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1589,10 +1703,14 @@ struct SM90_64x128x16_F32F16F16_RS float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63) + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1604,7 +1722,8 @@ struct SM90_64x128x16_F32F16F16_RS " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69, %70, %71, %72;\n" + " p, %70, %71, %72;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1623,9 +1742,9 @@ struct SM90_64x128x16_F32F16F16_RS "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1633,10 +1752,9 @@ struct SM90_64x128x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1673,10 +1791,14 @@ struct SM90_64x192x16_F32F16F16_SS float & d80, float & d81, float & d82, float & d83, float & d84, float & d85, float & d86, float & d87, float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95) + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1692,7 +1814,8 @@ struct SM90_64x192x16_F32F16F16_SS " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98, %99, %100, %101, %102;\n" + " p, %99, %100, %101, %102;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1719,9 +1842,9 @@ struct SM90_64x192x16_F32F16F16_SS "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1729,10 +1852,9 @@ struct SM90_64x192x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1772,10 +1894,14 @@ struct SM90_64x192x16_F32F16F16_RS float & d80, float & d81, float & d82, float & d83, float & d84, float & d85, float & d86, float & d87, float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95) + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1791,7 +1917,8 @@ struct SM90_64x192x16_F32F16F16_RS " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101, %102, %103, %104;\n" + " p, %102, %103, %104;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -1818,9 +1945,9 @@ struct SM90_64x192x16_F32F16F16_RS "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1828,10 +1955,9 @@ struct SM90_64x192x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1876,10 +2002,14 @@ struct SM90_64x256x16_F32F16F16_SS float & d112, float & d113, float & d114, float & d115, float & d116, float & d117, float & d118, float & d119, float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127) + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -1899,7 +2029,8 @@ struct SM90_64x256x16_F32F16F16_SS " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130, %131, %132, %133, %134;\n" + " p, %131, %132, %133, %134;\n" + "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), @@ -1934,9 +2065,9 @@ struct SM90_64x256x16_F32F16F16_SS "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -1944,10 +2075,9 @@ struct SM90_64x256x16_F32F16F16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x16 F32+=F16*F16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -1995,10 +2125,14 @@ struct SM90_64x256x16_F32F16F16_RS float & d112, float & d113, float & d114, float & d115, float & d116, float & d117, float & d118, float & d119, float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127) + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2018,7 +2152,8 @@ struct SM90_64x256x16_F32F16F16_RS " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133, %134, %135, %136;\n" + " p, %134, %135, %136;\n" + "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), @@ -2053,9 +2188,9 @@ struct SM90_64x256x16_F32F16F16_RS "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2063,10 +2198,9 @@ struct SM90_64x256x16_F32F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2080,21 +2214,26 @@ struct SM90_64x8x16_F32BF16BF16_SS CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3) + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " "{%0, %1, %2, %3}," " %4," " %5," - " %6, %7, %8, %9, %10;\n" + " p, %7, %8, %9, %10;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2102,10 +2241,9 @@ struct SM90_64x8x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2122,21 +2260,26 @@ struct SM90_64x8x16_F32BF16BF16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3) + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9, %10, %11, %12;\n" + " p, %10, %11, %12;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2144,10 +2287,9 @@ struct SM90_64x8x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2162,22 +2304,27 @@ struct SM90_64x16x16_F32BF16BF16_SS fma(uint64_t const& desc_a, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7) + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10, %11, %12, %13, %14;\n" + " p, %11, %12, %13, %14;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2185,10 +2332,9 @@ struct SM90_64x16x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2206,22 +2352,27 @@ struct SM90_64x16x16_F32BF16BF16_RS fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7) + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13, %14, %15, %16;\n" + " p, %14, %15, %16;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2229,10 +2380,9 @@ struct SM90_64x16x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2249,25 +2399,30 @@ struct SM90_64x32x16_F32BF16BF16_SS float & d00, float & d01, float & d02, float & d03, float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15) + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18, %19, %20, %21, %22;\n" + " p, %19, %20, %21, %22;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2275,10 +2430,9 @@ struct SM90_64x32x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2298,25 +2452,30 @@ struct SM90_64x32x16_F32BF16BF16_RS float & d00, float & d01, float & d02, float & d03, float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15) + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21, %22, %23, %24;\n" + " p, %22, %23, %24;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2324,10 +2483,9 @@ struct SM90_64x32x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2348,10 +2506,14 @@ struct SM90_64x64x16_F32BF16BF16_SS float & d16, float & d17, float & d18, float & d19, float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31) + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2359,7 +2521,8 @@ struct SM90_64x64x16_F32BF16BF16_SS " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34, %35, %36, %37, %38;\n" + " p, %35, %36, %37, %38;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2370,9 +2533,9 @@ struct SM90_64x64x16_F32BF16BF16_SS "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2380,10 +2543,9 @@ struct SM90_64x64x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2407,10 +2569,14 @@ struct SM90_64x64x16_F32BF16BF16_RS float & d16, float & d17, float & d18, float & d19, float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31) + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2418,7 +2584,8 @@ struct SM90_64x64x16_F32BF16BF16_RS " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37, %38, %39, %40;\n" + " p, %38, %39, %40;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2429,9 +2596,9 @@ struct SM90_64x64x16_F32BF16BF16_RS "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2439,10 +2606,9 @@ struct SM90_64x64x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2467,10 +2633,14 @@ struct SM90_64x96x16_F32BF16BF16_SS float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47) + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2480,7 +2650,8 @@ struct SM90_64x96x16_F32BF16BF16_SS " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50, %51, %52, %53, %54;\n" + " p, %51, %52, %53, %54;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2495,9 +2666,9 @@ struct SM90_64x96x16_F32BF16BF16_SS "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2505,10 +2676,9 @@ struct SM90_64x96x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2536,10 +2706,14 @@ struct SM90_64x96x16_F32BF16BF16_RS float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47) + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2549,7 +2723,8 @@ struct SM90_64x96x16_F32BF16BF16_RS " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53, %54, %55, %56;\n" + " p, %54, %55, %56;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2564,9 +2739,9 @@ struct SM90_64x96x16_F32BF16BF16_RS "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2574,10 +2749,9 @@ struct SM90_64x96x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2606,10 +2780,14 @@ struct SM90_64x128x16_F32BF16BF16_SS float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63) + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2621,7 +2799,8 @@ struct SM90_64x128x16_F32BF16BF16_SS " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66, %67, %68, %69, %70;\n" + " p, %67, %68, %69, %70;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2640,9 +2819,9 @@ struct SM90_64x128x16_F32BF16BF16_SS "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2650,10 +2829,9 @@ struct SM90_64x128x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2685,10 +2863,14 @@ struct SM90_64x128x16_F32BF16BF16_RS float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63) + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2700,7 +2882,8 @@ struct SM90_64x128x16_F32BF16BF16_RS " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69, %70, %71, %72;\n" + " p, %70, %71, %72;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2719,9 +2902,9 @@ struct SM90_64x128x16_F32BF16BF16_RS "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2729,10 +2912,9 @@ struct SM90_64x128x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2769,10 +2951,14 @@ struct SM90_64x192x16_F32BF16BF16_SS float & d80, float & d81, float & d82, float & d83, float & d84, float & d85, float & d86, float & d87, float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95) + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2788,7 +2974,8 @@ struct SM90_64x192x16_F32BF16BF16_SS " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98, %99, %100, %101, %102;\n" + " p, %99, %100, %101, %102;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2815,9 +3002,9 @@ struct SM90_64x192x16_F32BF16BF16_SS "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2825,10 +3012,9 @@ struct SM90_64x192x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2868,10 +3054,14 @@ struct SM90_64x192x16_F32BF16BF16_RS float & d80, float & d81, float & d82, float & d83, float & d84, float & d85, float & d86, float & d87, float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95) + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2887,7 +3077,8 @@ struct SM90_64x192x16_F32BF16BF16_RS " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101, %102, %103, %104;\n" + " p, %102, %103, %104;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -2914,9 +3105,9 @@ struct SM90_64x192x16_F32BF16BF16_RS "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -2924,10 +3115,9 @@ struct SM90_64x192x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -2972,10 +3162,14 @@ struct SM90_64x256x16_F32BF16BF16_SS float & d112, float & d113, float & d114, float & d115, float & d116, float & d117, float & d118, float & d119, float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127) + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -2995,7 +3189,8 @@ struct SM90_64x256x16_F32BF16BF16_SS " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130, %131, %132, %133, %134;\n" + " p, %131, %132, %133, %134;\n" + "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), @@ -3030,9 +3225,9 @@ struct SM90_64x256x16_F32BF16BF16_SS "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3040,10 +3235,9 @@ struct SM90_64x256x16_F32BF16BF16_SS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x16 F32+=BF16*BF16 -template< +template < GMMA::Major tnspA, GMMA::Major tnspB, - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3091,10 +3285,14 @@ struct SM90_64x256x16_F32BF16BF16_RS float & d112, float & d113, float & d114, float & d115, float & d116, float & d117, float & d118, float & d119, float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127) + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3114,7 +3312,8 @@ struct SM90_64x256x16_F32BF16BF16_RS " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133, %134, %135, %136;\n" + " p, %134, %135, %136;\n" + "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), @@ -3149,9 +3348,9 @@ struct SM90_64x256x16_F32BF16BF16_RS "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3159,8 +3358,7 @@ struct SM90_64x256x16_F32BF16BF16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3174,21 +3372,26 @@ struct SM90_64x8x8_F32TF32TF32_SS_TN CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3) + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " "{%0, %1, %2, %3}," " %4," " %5," - " %6, %7, %8;\n" + " p, %7, %8;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3196,8 +3399,7 @@ struct SM90_64x8x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x8x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3211,21 +3413,26 @@ struct SM90_64x8x8_F32TF32TF32_RS_TN CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3) + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9, %10, %11;\n" + " p, %10, %11;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3233,8 +3440,7 @@ struct SM90_64x8x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3249,22 +3455,27 @@ struct SM90_64x16x8_F32TF32TF32_SS_TN fma(uint64_t const& desc_a, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7) + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10, %11, %12;\n" + " p, %11, %12;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3272,8 +3483,7 @@ struct SM90_64x16x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x16x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3288,22 +3498,27 @@ struct SM90_64x16x8_F32TF32TF32_RS_TN fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7) + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13, %14, %15;\n" + " p, %14, %15;\n" + "}\n" : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3311,8 +3526,7 @@ struct SM90_64x16x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3329,25 +3543,30 @@ struct SM90_64x32x8_F32TF32TF32_SS_TN float & d00, float & d01, float & d02, float & d03, float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15) + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18, %19, %20;\n" + " p, %19, %20;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3355,8 +3574,7 @@ struct SM90_64x32x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x32x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3373,25 +3591,30 @@ struct SM90_64x32x8_F32TF32TF32_RS_TN float & d00, float & d01, float & d02, float & d03, float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15) + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21, %22, %23;\n" + " p, %22, %23;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3399,8 +3622,7 @@ struct SM90_64x32x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3421,10 +3643,14 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN float & d16, float & d17, float & d18, float & d19, float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31) + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3432,7 +3658,8 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34, %35, %36;\n" + " p, %35, %36;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3443,9 +3670,9 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3453,8 +3680,7 @@ struct SM90_64x64x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x64x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3475,10 +3701,14 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN float & d16, float & d17, float & d18, float & d19, float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31) + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3486,7 +3716,8 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37, %38, %39;\n" + " p, %38, %39;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3497,9 +3728,9 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3507,8 +3738,7 @@ struct SM90_64x64x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3533,10 +3763,14 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47) + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3546,7 +3780,8 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50, %51, %52;\n" + " p, %51, %52;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3561,9 +3796,9 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3571,8 +3806,7 @@ struct SM90_64x96x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x96x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3597,10 +3831,14 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47) + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3610,7 +3848,8 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53, %54, %55;\n" + " p, %54, %55;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3625,9 +3864,9 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3635,8 +3874,7 @@ struct SM90_64x96x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3665,10 +3903,14 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63) + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3680,7 +3922,8 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66, %67, %68;\n" + " p, %67, %68;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3699,9 +3942,9 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3709,8 +3952,7 @@ struct SM90_64x128x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x128x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3739,10 +3981,14 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63) + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3754,7 +4000,8 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69, %70, %71;\n" + " p, %70, %71;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3773,9 +4020,9 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3783,8 +4030,7 @@ struct SM90_64x128x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3821,10 +4067,14 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN float & d80, float & d81, float & d82, float & d83, float & d84, float & d85, float & d86, float & d87, float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95) + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3840,7 +4090,8 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98, %99, %100;\n" + " p, %99, %100;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3867,9 +4118,9 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3877,8 +4128,7 @@ struct SM90_64x192x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x192x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -3915,10 +4165,14 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN float & d80, float & d81, float & d82, float & d83, float & d84, float & d85, float & d86, float & d87, float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95) + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -3934,7 +4188,8 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101, %102, %103;\n" + " p, %102, %103;\n" + "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), @@ -3961,9 +4216,9 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -3971,8 +4226,7 @@ struct SM90_64x192x8_F32TF32TF32_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -4017,10 +4271,14 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN float & d112, float & d113, float & d114, float & d115, float & d116, float & d117, float & d118, float & d119, float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127) + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4040,7 +4298,8 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130, %131, %132;\n" + " p, %131, %132;\n" + "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), @@ -4075,9 +4334,9 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; @@ -4085,8 +4344,7 @@ struct SM90_64x256x8_F32TF32TF32_SS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// // GMMA 64x256x8 TN F32+=TF32*TF32 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One, +template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > @@ -4131,10 +4389,14 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN float & d112, float & d113, float & d114, float & d115, float & d116, float & d117, float & d118, float & d119, float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127) + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4154,7 +4416,8 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133, %134, %135;\n" + " p, %134, %135;\n" + "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), @@ -4189,19 +4452,16 @@ struct SM90_64x256x8_F32TF32TF32_RS_TN "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*S8 struct SM90_64x8x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4212,31 +4472,33 @@ struct SM90_64x8x32_S32S8S8_SS_TN CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*S8 struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4247,31 +4509,33 @@ struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*S8 struct SM90_64x16x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4283,32 +4547,34 @@ struct SM90_64x16x32_S32S8S8_SS_TN fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*S8 struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4320,32 +4586,34 @@ struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*S8 struct SM90_64x32x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4359,35 +4627,37 @@ struct SM90_64x32x32_S32S8S8_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*S8 struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4401,35 +4671,37 @@ struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*S8 struct SM90_64x64x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4447,10 +4719,14 @@ struct SM90_64x64x32_S32S8S8_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4458,7 +4734,8 @@ struct SM90_64x64x32_S32S8S8_SS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4469,19 +4746,16 @@ struct SM90_64x64x32_S32S8S8_SS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*S8 struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4499,10 +4773,14 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4510,7 +4788,8 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4521,19 +4800,16 @@ struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*S8 struct SM90_64x96x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4555,10 +4831,14 @@ struct SM90_64x96x32_S32S8S8_SS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4568,7 +4848,8 @@ struct SM90_64x96x32_S32S8S8_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4583,19 +4864,16 @@ struct SM90_64x96x32_S32S8S8_SS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*S8 struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4617,10 +4895,14 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4630,7 +4912,8 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4645,19 +4928,16 @@ struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*S8 struct SM90_64x128x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4683,10 +4963,14 @@ struct SM90_64x128x32_S32S8S8_SS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4698,7 +4982,8 @@ struct SM90_64x128x32_S32S8S8_SS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4717,19 +5002,16 @@ struct SM90_64x128x32_S32S8S8_SS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*S8 struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4755,10 +5037,14 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4770,7 +5056,8 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4789,19 +5076,16 @@ struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*S8 struct SM90_64x192x32_S32S8S8_SS_TN { using DRegisters = void; @@ -4835,10 +5119,14 @@ struct SM90_64x192x32_S32S8S8_SS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4854,7 +5142,8 @@ struct SM90_64x192x32_S32S8S8_SS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4881,19 +5170,16 @@ struct SM90_64x192x32_S32S8S8_SS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*S8 struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -4927,10 +5213,14 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -4946,7 +5236,8 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -4973,19 +5264,16 @@ struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*S8 struct SM90_64x256x32_S32S8S8_SS_TN { using DRegisters = void; @@ -5027,10 +5315,14 @@ struct SM90_64x256x32_S32S8S8_SS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5050,7 +5342,8 @@ struct SM90_64x256x32_S32S8S8_SS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -5085,19 +5378,16 @@ struct SM90_64x256x32_S32S8S8_SS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*S8 struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; @@ -5139,10 +5429,14 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5162,7 +5456,8 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -5197,19 +5492,16 @@ struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*S8 struct SM90_64x8x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5220,31 +5512,33 @@ struct SM90_64x8x32_S32S8S8_RS_TN CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*S8 struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5255,31 +5549,33 @@ struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*S8 struct SM90_64x16x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5291,32 +5587,34 @@ struct SM90_64x16x32_S32S8S8_RS_TN fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*S8 struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5328,32 +5626,34 @@ struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*S8 struct SM90_64x32x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5367,35 +5667,37 @@ struct SM90_64x32x32_S32S8S8_RS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*S8 struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5409,35 +5711,37 @@ struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*S8 struct SM90_64x64x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5455,10 +5759,14 @@ struct SM90_64x64x32_S32S8S8_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5466,7 +5774,8 @@ struct SM90_64x64x32_S32S8S8_RS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5477,19 +5786,16 @@ struct SM90_64x64x32_S32S8S8_RS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*S8 struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5507,10 +5813,14 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5518,7 +5828,8 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5529,19 +5840,16 @@ struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*S8 struct SM90_64x96x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5563,10 +5871,14 @@ struct SM90_64x96x32_S32S8S8_RS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5576,7 +5888,8 @@ struct SM90_64x96x32_S32S8S8_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5591,19 +5904,16 @@ struct SM90_64x96x32_S32S8S8_RS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*S8 struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5625,10 +5935,14 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5638,7 +5952,8 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5653,19 +5968,16 @@ struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*S8 struct SM90_64x128x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5691,10 +6003,14 @@ struct SM90_64x128x32_S32S8S8_RS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5706,7 +6022,8 @@ struct SM90_64x128x32_S32S8S8_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5725,19 +6042,16 @@ struct SM90_64x128x32_S32S8S8_RS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*S8 struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5763,10 +6077,14 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5778,7 +6096,8 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5797,19 +6116,16 @@ struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*S8 struct SM90_64x192x32_S32S8S8_RS_TN { using DRegisters = void; @@ -5843,10 +6159,14 @@ struct SM90_64x192x32_S32S8S8_RS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5862,7 +6182,8 @@ struct SM90_64x192x32_S32S8S8_RS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5889,19 +6210,16 @@ struct SM90_64x192x32_S32S8S8_RS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*S8 struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -5935,10 +6253,14 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -5954,7 +6276,8 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -5981,19 +6304,16 @@ struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*S8 struct SM90_64x256x32_S32S8S8_RS_TN { using DRegisters = void; @@ -6035,10 +6355,14 @@ struct SM90_64x256x32_S32S8S8_RS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6058,7 +6382,8 @@ struct SM90_64x256x32_S32S8S8_RS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -6093,19 +6418,16 @@ struct SM90_64x256x32_S32S8S8_RS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*S8 struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; @@ -6147,10 +6469,14 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6170,7 +6496,8 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -6205,19 +6532,16 @@ struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*U8 struct SM90_64x8x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6228,31 +6552,33 @@ struct SM90_64x8x32_S32S8U8_SS_TN CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*U8 struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6263,31 +6589,33 @@ struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*U8 struct SM90_64x16x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6299,32 +6627,34 @@ struct SM90_64x16x32_S32S8U8_SS_TN fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*U8 struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6336,32 +6666,34 @@ struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*U8 struct SM90_64x32x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6375,35 +6707,37 @@ struct SM90_64x32x32_S32S8U8_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*U8 struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6417,35 +6751,37 @@ struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*U8 struct SM90_64x64x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6463,10 +6799,14 @@ struct SM90_64x64x32_S32S8U8_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6474,7 +6814,8 @@ struct SM90_64x64x32_S32S8U8_SS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6485,19 +6826,16 @@ struct SM90_64x64x32_S32S8U8_SS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*U8 struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6515,10 +6853,14 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6526,7 +6868,8 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6537,19 +6880,16 @@ struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*U8 struct SM90_64x96x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6571,10 +6911,14 @@ struct SM90_64x96x32_S32S8U8_SS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6584,7 +6928,8 @@ struct SM90_64x96x32_S32S8U8_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6599,19 +6944,16 @@ struct SM90_64x96x32_S32S8U8_SS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*U8 struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6633,10 +6975,14 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6646,7 +6992,8 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6661,19 +7008,16 @@ struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*U8 struct SM90_64x128x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6699,10 +7043,14 @@ struct SM90_64x128x32_S32S8U8_SS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6714,7 +7062,8 @@ struct SM90_64x128x32_S32S8U8_SS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6733,19 +7082,16 @@ struct SM90_64x128x32_S32S8U8_SS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*U8 struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6771,10 +7117,14 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6786,7 +7136,8 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6805,19 +7156,16 @@ struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*U8 struct SM90_64x192x32_S32S8U8_SS_TN { using DRegisters = void; @@ -6851,10 +7199,14 @@ struct SM90_64x192x32_S32S8U8_SS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6870,7 +7222,8 @@ struct SM90_64x192x32_S32S8U8_SS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6897,19 +7250,16 @@ struct SM90_64x192x32_S32S8U8_SS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*U8 struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -6943,10 +7293,14 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -6962,7 +7316,8 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -6989,19 +7344,16 @@ struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*U8 struct SM90_64x256x32_S32S8U8_SS_TN { using DRegisters = void; @@ -7043,10 +7395,14 @@ struct SM90_64x256x32_S32S8U8_SS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7066,7 +7422,8 @@ struct SM90_64x256x32_S32S8U8_SS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -7101,19 +7458,16 @@ struct SM90_64x256x32_S32S8U8_SS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*U8 struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; @@ -7155,10 +7509,14 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7178,7 +7536,8 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -7213,19 +7572,16 @@ struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*U8 struct SM90_64x8x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7236,31 +7592,33 @@ struct SM90_64x8x32_S32S8U8_RS_TN CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=S8*U8 struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7271,31 +7629,33 @@ struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*U8 struct SM90_64x16x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7307,32 +7667,34 @@ struct SM90_64x16x32_S32S8U8_RS_TN fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=S8*U8 struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7344,32 +7706,34 @@ struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*U8 struct SM90_64x32x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7383,35 +7747,37 @@ struct SM90_64x32x32_S32S8U8_RS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=S8*U8 struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7425,35 +7791,37 @@ struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*U8 struct SM90_64x64x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7471,10 +7839,14 @@ struct SM90_64x64x32_S32S8U8_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7482,7 +7854,8 @@ struct SM90_64x64x32_S32S8U8_RS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7493,19 +7866,16 @@ struct SM90_64x64x32_S32S8U8_RS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=S8*U8 struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7523,10 +7893,14 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7534,7 +7908,8 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7545,19 +7920,16 @@ struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*U8 struct SM90_64x96x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7579,10 +7951,14 @@ struct SM90_64x96x32_S32S8U8_RS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7592,7 +7968,8 @@ struct SM90_64x96x32_S32S8U8_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7607,19 +7984,16 @@ struct SM90_64x96x32_S32S8U8_RS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=S8*U8 struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7641,10 +8015,14 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7654,7 +8032,8 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7669,19 +8048,16 @@ struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*U8 struct SM90_64x128x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7707,10 +8083,14 @@ struct SM90_64x128x32_S32S8U8_RS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7722,7 +8102,8 @@ struct SM90_64x128x32_S32S8U8_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7741,19 +8122,16 @@ struct SM90_64x128x32_S32S8U8_RS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=S8*U8 struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7779,10 +8157,14 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7794,7 +8176,8 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7813,19 +8196,16 @@ struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*U8 struct SM90_64x192x32_S32S8U8_RS_TN { using DRegisters = void; @@ -7859,10 +8239,14 @@ struct SM90_64x192x32_S32S8U8_RS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7878,7 +8262,8 @@ struct SM90_64x192x32_S32S8U8_RS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7905,19 +8290,16 @@ struct SM90_64x192x32_S32S8U8_RS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=S8*U8 struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -7951,10 +8333,14 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -7970,7 +8356,8 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -7997,19 +8384,16 @@ struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*U8 struct SM90_64x256x32_S32S8U8_RS_TN { using DRegisters = void; @@ -8051,10 +8435,14 @@ struct SM90_64x256x32_S32S8U8_RS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8074,7 +8462,8 @@ struct SM90_64x256x32_S32S8U8_RS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -8109,19 +8498,16 @@ struct SM90_64x256x32_S32S8U8_RS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=S8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=S8*U8 struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; @@ -8163,10 +8549,14 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8186,7 +8576,8 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -8221,19 +8612,16 @@ struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*S8 struct SM90_64x8x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8244,31 +8632,33 @@ struct SM90_64x8x32_S32U8S8_SS_TN CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*S8 struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8279,31 +8669,33 @@ struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*S8 struct SM90_64x16x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8315,32 +8707,34 @@ struct SM90_64x16x32_S32U8S8_SS_TN fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*S8 struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8352,32 +8746,34 @@ struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*S8 struct SM90_64x32x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8391,35 +8787,37 @@ struct SM90_64x32x32_S32U8S8_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*S8 struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8433,35 +8831,37 @@ struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*S8 struct SM90_64x64x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8479,10 +8879,14 @@ struct SM90_64x64x32_S32U8S8_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8490,7 +8894,8 @@ struct SM90_64x64x32_S32U8S8_SS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8501,19 +8906,16 @@ struct SM90_64x64x32_S32U8S8_SS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*S8 struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8531,10 +8933,14 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8542,7 +8948,8 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8553,19 +8960,16 @@ struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*S8 struct SM90_64x96x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8587,10 +8991,14 @@ struct SM90_64x96x32_S32U8S8_SS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8600,7 +9008,8 @@ struct SM90_64x96x32_S32U8S8_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8615,19 +9024,16 @@ struct SM90_64x96x32_S32U8S8_SS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*S8 struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8649,10 +9055,14 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8662,7 +9072,8 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8677,19 +9088,16 @@ struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*S8 struct SM90_64x128x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8715,10 +9123,14 @@ struct SM90_64x128x32_S32U8S8_SS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8730,7 +9142,8 @@ struct SM90_64x128x32_S32U8S8_SS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8749,19 +9162,16 @@ struct SM90_64x128x32_S32U8S8_SS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*S8 struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8787,10 +9197,14 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8802,7 +9216,8 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8821,19 +9236,16 @@ struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*S8 struct SM90_64x192x32_S32U8S8_SS_TN { using DRegisters = void; @@ -8867,10 +9279,14 @@ struct SM90_64x192x32_S32U8S8_SS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8886,7 +9302,8 @@ struct SM90_64x192x32_S32U8S8_SS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -8913,19 +9330,16 @@ struct SM90_64x192x32_S32U8S8_SS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*S8 struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -8959,10 +9373,14 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -8978,7 +9396,8 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9005,19 +9424,16 @@ struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*S8 struct SM90_64x256x32_S32U8S8_SS_TN { using DRegisters = void; @@ -9059,10 +9475,14 @@ struct SM90_64x256x32_S32U8S8_SS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9082,7 +9502,8 @@ struct SM90_64x256x32_S32U8S8_SS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -9117,19 +9538,16 @@ struct SM90_64x256x32_S32U8S8_SS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*S8 struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; @@ -9171,10 +9589,14 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9194,7 +9616,8 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -9229,19 +9652,16 @@ struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*S8 struct SM90_64x8x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9252,31 +9672,33 @@ struct SM90_64x8x32_S32U8S8_RS_TN CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*S8 struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9287,31 +9709,33 @@ struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*S8 struct SM90_64x16x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9323,32 +9747,34 @@ struct SM90_64x16x32_S32U8S8_RS_TN fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*S8 struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9360,32 +9786,34 @@ struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*S8 struct SM90_64x32x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9399,35 +9827,37 @@ struct SM90_64x32x32_S32U8S8_RS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*S8 struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9441,35 +9871,37 @@ struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*S8 struct SM90_64x64x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9487,10 +9919,14 @@ struct SM90_64x64x32_S32U8S8_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9498,7 +9934,8 @@ struct SM90_64x64x32_S32U8S8_RS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9509,19 +9946,16 @@ struct SM90_64x64x32_S32U8S8_RS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*S8 struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9539,10 +9973,14 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9550,7 +9988,8 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9561,19 +10000,16 @@ struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*S8 struct SM90_64x96x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9595,10 +10031,14 @@ struct SM90_64x96x32_S32U8S8_RS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9608,7 +10048,8 @@ struct SM90_64x96x32_S32U8S8_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9623,19 +10064,16 @@ struct SM90_64x96x32_S32U8S8_RS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*S8 struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9657,10 +10095,14 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9670,7 +10112,8 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9685,19 +10128,16 @@ struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*S8 struct SM90_64x128x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9723,10 +10163,14 @@ struct SM90_64x128x32_S32U8S8_RS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9738,7 +10182,8 @@ struct SM90_64x128x32_S32U8S8_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9757,19 +10202,16 @@ struct SM90_64x128x32_S32U8S8_RS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*S8 struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9795,10 +10237,14 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9810,7 +10256,8 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9829,19 +10276,16 @@ struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*S8 struct SM90_64x192x32_S32U8S8_RS_TN { using DRegisters = void; @@ -9875,10 +10319,14 @@ struct SM90_64x192x32_S32U8S8_RS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9894,7 +10342,8 @@ struct SM90_64x192x32_S32U8S8_RS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -9921,19 +10370,16 @@ struct SM90_64x192x32_S32U8S8_RS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*S8 struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -9967,10 +10413,14 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -9986,7 +10436,8 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10013,19 +10464,16 @@ struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*S8 struct SM90_64x256x32_S32U8S8_RS_TN { using DRegisters = void; @@ -10067,10 +10515,14 @@ struct SM90_64x256x32_S32U8S8_RS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10090,7 +10542,8 @@ struct SM90_64x256x32_S32U8S8_RS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -10125,19 +10578,16 @@ struct SM90_64x256x32_S32U8S8_RS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*S8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*S8 struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; @@ -10179,10 +10629,14 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10202,7 +10656,8 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -10237,19 +10692,16 @@ struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*U8 struct SM90_64x8x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10260,31 +10712,33 @@ struct SM90_64x8x32_S32U8U8_SS_TN CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*U8 struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10295,31 +10749,33 @@ struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3}," " %4," " %5," - " %6;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*U8 struct SM90_64x16x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10331,32 +10787,34 @@ struct SM90_64x16x32_S32U8U8_SS_TN fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*U8 struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10368,32 +10826,34 @@ struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," - " %10;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*U8 struct SM90_64x32x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10407,35 +10867,37 @@ struct SM90_64x32x32_S32U8U8_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*U8 struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10449,35 +10911,37 @@ struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," - " %18;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*U8 struct SM90_64x64x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10495,10 +10959,14 @@ struct SM90_64x64x32_S32U8U8_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10506,7 +10974,8 @@ struct SM90_64x64x32_S32U8U8_SS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10517,19 +10986,16 @@ struct SM90_64x64x32_S32U8U8_SS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*U8 struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10547,10 +11013,14 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10558,7 +11028,8 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," " %32," " %33," - " %34;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10569,19 +11040,16 @@ struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*U8 struct SM90_64x96x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10603,10 +11071,14 @@ struct SM90_64x96x32_S32U8U8_SS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10616,7 +11088,8 @@ struct SM90_64x96x32_S32U8U8_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10631,19 +11104,16 @@ struct SM90_64x96x32_S32U8U8_SS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*U8 struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10665,10 +11135,14 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10678,7 +11152,8 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " %50;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10693,19 +11168,16 @@ struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*U8 struct SM90_64x128x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10731,10 +11203,14 @@ struct SM90_64x128x32_S32U8U8_SS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10746,7 +11222,8 @@ struct SM90_64x128x32_S32U8U8_SS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10765,19 +11242,16 @@ struct SM90_64x128x32_S32U8U8_SS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*U8 struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10803,10 +11277,14 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10818,7 +11296,8 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," " %64," " %65," - " %66;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10837,19 +11316,16 @@ struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*U8 struct SM90_64x192x32_S32U8U8_SS_TN { using DRegisters = void; @@ -10883,10 +11359,14 @@ struct SM90_64x192x32_S32U8U8_SS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10902,7 +11382,8 @@ struct SM90_64x192x32_S32U8U8_SS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -10929,19 +11410,16 @@ struct SM90_64x192x32_S32U8U8_SS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*U8 struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -10975,10 +11453,14 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -10994,7 +11476,8 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," " %96," " %97," - " %98;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11021,19 +11504,16 @@ struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*U8 struct SM90_64x256x32_S32U8U8_SS_TN { using DRegisters = void; @@ -11075,10 +11555,14 @@ struct SM90_64x256x32_S32U8U8_SS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11098,7 +11582,8 @@ struct SM90_64x256x32_S32U8U8_SS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -11133,19 +11618,16 @@ struct SM90_64x256x32_S32U8U8_SS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*U8 struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; @@ -11187,10 +11669,14 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11210,7 +11696,8 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," " %128," " %129," - " %130;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -11245,19 +11732,16 @@ struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*U8 struct SM90_64x8x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11268,31 +11752,33 @@ struct SM90_64x8x32_S32U8U8_RS_TN CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x8x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x8x32 TN S32+=U8*U8 struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11303,31 +11789,33 @@ struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE CUTE_HOST_DEVICE static void fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3) + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," - " %9;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*U8 struct SM90_64x16x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11339,32 +11827,34 @@ struct SM90_64x16x32_S32U8U8_RS_TN fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x16x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x16x32 TN S32+=U8*U8 struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11376,32 +11866,34 @@ struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7) + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," - " %13;\n" + " p;\n" + "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*U8 struct SM90_64x32x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11415,35 +11907,37 @@ struct SM90_64x32x32_S32U8U8_RS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x32x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x32x32 TN S32+=U8*U8 struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11457,35 +11951,37 @@ struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15) + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," " %20," - " %21;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*U8 struct SM90_64x64x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11503,10 +11999,14 @@ struct SM90_64x64x32_S32U8U8_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11514,7 +12014,8 @@ struct SM90_64x64x32_S32U8U8_RS_TN " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11525,19 +12026,16 @@ struct SM90_64x64x32_S32U8U8_RS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x64x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x64x32 TN S32+=U8*U8 struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11555,10 +12053,14 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31) + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11566,7 +12068,8 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE " %24, %25, %26, %27, %28, %29, %30, %31}," "{%32, %33, %34, %35}," " %36," - " %37;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11577,19 +12080,16 @@ struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*U8 struct SM90_64x96x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11611,10 +12111,14 @@ struct SM90_64x96x32_S32U8U8_RS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11624,7 +12128,8 @@ struct SM90_64x96x32_S32U8U8_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11639,19 +12144,16 @@ struct SM90_64x96x32_S32U8U8_RS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x96x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x96x32 TN S32+=U8*U8 struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11673,10 +12175,14 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47) + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11686,7 +12192,8 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE " %40, %41, %42, %43, %44, %45, %46, %47}," "{%48, %49, %50, %51}," " %52," - " %53;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11701,19 +12208,16 @@ struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*U8 struct SM90_64x128x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11739,10 +12243,14 @@ struct SM90_64x128x32_S32U8U8_RS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11754,7 +12262,8 @@ struct SM90_64x128x32_S32U8U8_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11773,19 +12282,16 @@ struct SM90_64x128x32_S32U8U8_RS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x128x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x128x32 TN S32+=U8*U8 struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11811,10 +12317,14 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63) + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11826,7 +12336,8 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " %69;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11845,19 +12356,16 @@ struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*U8 struct SM90_64x192x32_S32U8U8_RS_TN { using DRegisters = void; @@ -11891,10 +12399,14 @@ struct SM90_64x192x32_S32U8U8_RS_TN uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -11910,7 +12422,8 @@ struct SM90_64x192x32_S32U8U8_RS_TN " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -11937,19 +12450,16 @@ struct SM90_64x192x32_S32U8U8_RS_TN "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x192x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x192x32 TN S32+=U8*U8 struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -11983,10 +12493,14 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95) + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -12002,7 +12516,8 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE " %88, %89, %90, %91, %92, %93, %94, %95}," "{%96, %97, %98, %99}," " %100," - " %101;\n" + " p;\n" + "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), @@ -12029,19 +12544,16 @@ struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*U8 struct SM90_64x256x32_S32U8U8_RS_TN { using DRegisters = void; @@ -12083,10 +12595,14 @@ struct SM90_64x256x32_S32U8U8_RS_TN uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -12106,7 +12622,8 @@ struct SM90_64x256x32_S32U8U8_RS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -12141,19 +12658,16 @@ struct SM90_64x256x32_S32U8U8_RS_TN "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 64x256x32 TN S32+=U8*U8 -template< - GMMA::ScaleOut scaleD = GMMA::ScaleOut::One -> +// GMMA 64x256x32 TN S32+=U8*U8 struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; @@ -12195,10 +12709,14 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127) + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90_ENABLED) +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " @@ -12218,7 +12736,8 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " %133;\n" + " p;\n" + "}\n" : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), @@ -12253,9 +12772,9 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "n"(int32_t(scaleD))); + "r"(int32_t(scale_D))); #else - CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90_ENABLED"); + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 2e71db11..205951fe 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -42,17 +42,21 @@ // __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 14) #else - // ... but broken on Windows until Clang 15: https://reviews.llvm.org/D122897 + // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER (__clang_major__ >= 15) #endif #endif #if defined(__NVCC__) || defined(__CUDACC_RTC__) // __cvta_generic_to_shared added in CUDA 11+ - #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11)) + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) + #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif // __nvvm_get_smem_pointer added in CUDA 10.2 - #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) + #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 + #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif #endif #define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED (CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED) @@ -172,6 +176,40 @@ explode(Fn fn, return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); } +template +CUTE_HOST_DEVICE constexpr +void +explode_with_d_scaling(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + ParamType&& p0) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]..., p0); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_with_d_scaling(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + ParamType&& p0) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., p0); +} + } // end namespace detail template - #include - #include -#include #include -namespace cute { +#include -// Generic copy_unpack for any Copy_Traits -template -CUTE_HOST_DEVICE constexpr -void -copy_unpack(Copy_Traits const&, - Tensor const& src, - Tensor & dst) -{ - // Specializations can generalize on these checks - //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); - //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); - - using RegistersSrc = typename Operation::SRegisters; - using RegistersDst = typename Operation::DRegisters; - using RegTypeSrc = typename std::remove_extent::type; - using RegTypeDst = typename std::remove_extent::type; - constexpr int RegNumSrc = std::extent::value; - constexpr int RegNumDst = std::extent::value; - - Tensor rS = recast(src); - Tensor rD = recast(dst); - - CUTE_STATIC_ASSERT_V(size(rS) == Int{}, - "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); - CUTE_STATIC_ASSERT_V(size(rD) == Int{}, - "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); - - detail::explode(Operation::copy, - rS, make_int_sequence{}, - rD, make_int_sequence{}); -} +#include +namespace cute +{ template struct Copy_Atom; @@ -110,33 +76,18 @@ struct Copy_Atom, T> // Additional Trait parameters/transformations template - CUTE_HOST_DEVICE + CUTE_HOST_DEVICE auto with(TraitsArgs&&... args) const { auto traits = Traits::with(std::forward(args)...); return Copy_Atom{traits}; } - // Print thread and data layouts for debugging - CUTE_HOST_DEVICE static - void - print_all() - { - print("ThrID: "); print(ThrID{}); print("\n"); - print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n"); - print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n"); - print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n"); - print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n"); - print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n"); - print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n"); - print("ValueType: %db", sizeof_bits::value); print("\n"); - } - // // Tensor call interfaces // - // Cast, check, and call + // Check and call instruction, or recurse template CUTE_HOST_DEVICE @@ -147,12 +98,19 @@ struct Copy_Atom, T> static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); - if constexpr (is_constant::value || is_constant::value) { + if constexpr (is_constant::value || + is_constant::value) { // Dispatch to unpack for instruction return copy_unpack(*this, src, dst); - } else { - // Recurse if needed by peeling the tensor mode + } else + if constexpr (is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(sizeof(TS) < 0, "No instruction match and no recursion possible."); } } @@ -172,6 +130,9 @@ struct Copy_Atom, T> // A tiling of copy atoms // +template +struct ThrCopy; + template coord [Need not be 2D...] class ShapeTile_MN> // coord space @@ -211,7 +172,13 @@ struct TiledCopy : Copy_Atom auto tidfrg_S(STensor&& stensor) { - return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + constexpr int R = remove_cvref_t::rank; + static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); + // Generalize the dimension checks for arbitrary rank + //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + + return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); } // Tile a tensor or a layout from shape @@ -229,20 +196,24 @@ struct TiledCopy : Copy_Atom auto tidfrg_D(DTensor&& dtensor) { - return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + constexpr int R = remove_cvref_t::rank; + static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); + // Generalize the dimension checks for arbitrary rank + //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + + return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); } + // Tile a tensor or a layout from shape + // (Tile,(RestM,RestN,...)) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) template CUTE_HOST_DEVICE constexpr static auto - thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) + tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) { - constexpr int R = remove_cvref_t::rank; - static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); - // Generalize the dimension checks for arbitrary rank - //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); - // Take the thrs/vals that the atom is interested in // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); @@ -259,12 +230,8 @@ struct TiledCopy : Copy_Atom /// ================== - // Tile the tensor for TiledLayout - auto t_tensor = zipped_divide(tensor, Tiler_MN{}); - // ((TileM,TileN,...),(RestM,RestN,...)) - // Transform the tile mode - auto tv_tensor = t_tensor.compose(thrval2mn, _); + auto tv_tensor = tensor.compose(thrval2mn, _); // ((thrid,val),(RM,RN,...)) // Unfold and return @@ -308,14 +275,22 @@ struct TiledCopy : Copy_Atom CUTE_HOST_DEVICE constexpr static auto - get_layoutS_MN() + get_layoutS_TV() { // (M,N) -> (M,N) - auto ref_S = make_layout(TiledShape_MN{}); + auto ref_S = make_layout(make_shape(TiledShape_MN{}, Int<1>{})); // (thr_idx,val_idx) -> (M,N) - auto layoutS_TV = tidfrg_S(ref_S); + return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_MN() + { + // (thr_idx,val_idx) -> (M,N) + auto layoutS_TV = get_layoutS_TV(); // (M,K) -> (thr_idx,val_idx) - auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S)); + auto layoutS_MK = right_inverse(layoutS_TV).with_shape(TiledShape_MN{}); // athrid = (v,m,k) -> thr_idx auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); @@ -325,24 +300,22 @@ struct TiledCopy : Copy_Atom CUTE_HOST_DEVICE constexpr static auto - get_layoutS_TV() + get_layoutD_TV() { // (M,N) -> (M,N) - auto ref_S = make_layout(TiledShape_MN{}); + auto ref_D = make_layout(make_shape(TiledShape_MN{}, Int<1>{})); // (thr_idx,val_idx) -> (M,N) - return tidfrg_S(ref_S)(_,_,Int<0>{}); + return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); } CUTE_HOST_DEVICE constexpr static auto get_layoutD_MN() { - // (M,N) -> (M,N) - auto ref_D = make_layout(TiledShape_MN{}); // (thr_idx,val_idx) -> (M,N) - auto layoutD_TV = tidfrg_D(ref_D); + auto layoutD_TV = get_layoutD_TV(); // (M,K) -> (thr_idx,val_idx) - auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D)); + auto layoutD_MK = right_inverse(layoutD_TV).with_shape(TiledShape_MN{}); // athrid = (v,m,k) -> thr_idx auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); @@ -350,70 +323,13 @@ struct TiledCopy : Copy_Atom return cute::make_tuple(layoutD_MK, thrID_D); } - CUTE_HOST_DEVICE constexpr static - auto - get_layoutD_TV() - { - // (M,N) -> (M,N) - auto ref_D = make_layout(TiledShape_MN{}); - // (thr_idx,val_idx) -> (M,N) - return tidfrg_D(ref_D)(_,_,Int<0>{}); - } - - template - struct ThrCopy : Copy_Atom - { - ThrIdx thr_idx_; - - CUTE_HOST_DEVICE - ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} - - template - CUTE_HOST_DEVICE - auto - partition_S(STensor&& stensor) { - //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), - // "Expected ValType for tiling SrcTensor."); - auto thr_tensor = make_tensor(std::forward(stensor).data(), tidfrg_S(stensor.layout())); - return thr_tensor(thr_idx_, _, repeat>(_)); - } - - template - CUTE_HOST_DEVICE - auto - partition_D(DTensor&& dtensor) { - //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), - // "Expected ValType for tiling DstTensor."); - auto thr_tensor = make_tensor(std::forward(dtensor).data(), tidfrg_D(dtensor.layout())); - return thr_tensor(thr_idx_, _, repeat>(_)); - } - - template - CUTE_HOST_DEVICE static - auto - retile_S(STensor&& stensor) { - static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), - "Expected ValType for tiling SrcTensor."); - return make_tensor(std::forward(stensor).data(), TiledCopy::retile(stensor.layout())); - } - - template - CUTE_HOST_DEVICE static - auto - retile_D(DTensor&& dtensor) { - static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), - "Expected ValType for tiling DstTensor."); - return make_tensor(std::forward(dtensor).data(), TiledCopy::retile(dtensor.layout())); - } - }; - template ::value)> - CUTE_HOST_DEVICE static + CUTE_HOST_DEVICE static auto get_slice(ThrIdx const& thr_idx) { - return ThrCopy(thr_idx); + return ThrCopy(thr_idx); } template +struct ThrCopy +{ + ThrIdx thr_idx_; + + CUTE_HOST_DEVICE + ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + partition_S(STensor&& stensor) { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(std::forward(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE + auto + partition_D(DTensor&& dtensor) { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(std::forward(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE static + auto + retile_S(STensor&& stensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + return make_tensor(std::forward(stensor).data(), TiledCopy::retile(stensor.layout())); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D(DTensor&& dtensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + return make_tensor(std::forward(dtensor).data(), TiledCopy::retile(dtensor.layout())); + } +}; + template -CUTE_HOST_DEVICE + class Tiler> +CUTE_HOST_DEVICE auto make_tiled_copy_impl(Copy_Atom const& atom, LayoutCopy_TV const&, - Tile const&) + Tiler const&) { - return TiledCopy, LayoutCopy_TV, Tile>{atom}; + return TiledCopy, LayoutCopy_TV, Tiler>{atom}; } // @@ -445,7 +408,7 @@ make_tiled_copy_impl(Copy_Atom const& atom, template -CUTE_HOST_DEVICE +CUTE_HOST_DEVICE auto make_tiled_copy_A(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) @@ -456,7 +419,7 @@ make_tiled_copy_A(Copy_Atom const& copy_atom, template -CUTE_HOST_DEVICE +CUTE_HOST_DEVICE auto make_tiled_copy_B(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) @@ -467,7 +430,7 @@ make_tiled_copy_B(Copy_Atom const& copy_atom, template -CUTE_HOST_DEVICE +CUTE_HOST_DEVICE auto make_tiled_copy_C(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) @@ -476,10 +439,50 @@ make_tiled_copy_C(Copy_Atom const& copy_atom, return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{}))); } +// returns the smallest tiled copy that can retile LayoutC_TV +// for use with pipelined epilogues with subtiled stores +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_atom(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) +{ + // Truncate the V-layout to just the Copy_Atom, keep the V-order + auto layoutC_TV = tiled_mma.get_layoutC_TV(); + auto copy_V = Int::NumValSrc>{}; + CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); + auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); + + // Recompute tiler and restride the TV layout for the new tiler + + // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + using MNK = typename TiledMMA::TiledShape_MNK; + auto mma_tiler = make_shape(size<0>(MNK{}),size<1>(MNK{})); + auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); + }); + + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // Apply the tiler to a reference and transform the codomain + // tile_coord -> mma_coord + auto tile2mma = composition(make_layout(mma_tiler), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2mma), layout_TV); + + + using MNK = typename TiledMMA::TiledShape_MNK; + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + template > -CUTE_HOST_DEVICE +CUTE_HOST_DEVICE auto make_tiled_copy(Copy_Atom const& copy_atom, ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx @@ -493,11 +496,10 @@ make_tiled_copy(Copy_Atom const& copy_atom, // Take the raked_products to compute the Layout_MN auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); - - //print("thr_layout: "); print(thr_layout_mn); print("\n"); - //print("val_layout: "); print(val_layout_mn); print("\n"); - //print("layout_mn : "); print(layout_mn); print("\n"); - //print("layout_tv : "); print(layout_tv); print("\n"); + // print("thr_layout: "); print(thr_layout_mn); print("\n"); + // print("val_layout: "); print(val_layout_mn); print("\n"); + // print("layout_mn : "); print(layout_mn); print("\n"); + // print("layout_tv : "); print(layout_tv); print("\n"); return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); } @@ -505,7 +507,7 @@ make_tiled_copy(Copy_Atom const& copy_atom, // Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy template -CUTE_HOST_DEVICE +CUTE_HOST_DEVICE auto make_tiled_copy_S(Copy_Atom const& copy_atom, TiledCopy const& tiled_copy) @@ -516,7 +518,7 @@ make_tiled_copy_S(Copy_Atom const& copy_atom, // Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy template -CUTE_HOST_DEVICE +CUTE_HOST_DEVICE auto make_tiled_copy_D(Copy_Atom const& copy_atom, TiledCopy const& tiled_copy) @@ -550,6 +552,40 @@ size(TiledCopy const&) // Display utilities // +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, T> const&) +{ + using Atom = Copy_Atom, T>; + print("Copy_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); + print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); + print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); + print(" ValueType: %db\n", int(sizeof_bits::value)); +} + +template +CUTE_HOST_DEVICE +void +print(TiledCopy const& copy, char const* pad = "") +{ + using Copy = TiledCopy; + print("TiledCopy\n"); + print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); + print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); + print(static_cast(copy)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrCopy const&) +{ + print(TiledCopy{}); +} + template CUTE_HOST_DEVICE auto @@ -655,7 +691,6 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and //////////////////////////////////////////////////////////////////////////////////////////////////// -#include #include #include #include diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index 83cb0565..cea03c0f 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -32,11 +32,30 @@ #include -#include +#include namespace cute { +/** + * concept Copy_Traits + * { + * using ThrID = // Logical thread id (tid) -> tidx + * + * using SrcLayout = // (Logical src thread id (tid), Logical src value id (vid)) -> bit + * using DstLayout = // (Logical dst thread id (tid), Logical dst value id (vid)) -> bit + * using RefLayout = // (Logical ref thread id (tid), Logical ref value id (vid)) -> bit + * }; + * + * The abstract bit ordering of the Copy_Traits (the codomain of SrcLayout, DstLayout, and RefLayout) + * is arbitrary and only used to construct maps + * (ref-tid,ref-vid) -> (src-tid,src-vid) + * (ref-tid,ref-vid) -> (dst-tid,dst-vid) + * in TiledCopy. The Layout_TV in TiledCopy is in accordance with the RefLayout of a Traits, then mapped to + * the Src or Dst (tid,vid) representation on demand. + * + */ + template struct Copy_Traits { @@ -73,4 +92,40 @@ struct Copy_Traits using RefLayout = SrcLayout; }; +// +// Generic copy_unpack for any Copy_Traits +// +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + // Specializations can generalize on these checks + //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); + //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); + + using RegistersSrc = typename Operation::SRegisters; + using RegistersDst = typename Operation::DRegisters; + using RegTypeSrc = typename remove_extent::type; + using RegTypeDst = typename remove_extent::type; + constexpr int RegNumSrc = extent::value; + constexpr int RegNumDst = extent::value; + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + detail::explode(Operation::copy, + rS, make_int_sequence{}, + rD, make_int_sequence{}); +} + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 4e98ea32..6f3f9d4d 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -30,14 +30,18 @@ **************************************************************************************************/ #pragma once +#if !defined(__CUDACC_RTC__) #include +#endif #include #include -#include #include +#include +#include + namespace cute { @@ -142,7 +146,7 @@ struct Copy_Traits return {tma_desc_, tma_mbar}; } - // Generate the TMA coord tensor + // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr auto @@ -257,7 +261,7 @@ struct Copy_Traits return {tma_desc_, tma_load_mbar, multicast_mask}; } - // Generate the TMA coord tensor + // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr auto @@ -300,6 +304,13 @@ struct Copy_Traits TmaDescriptor tma_desc_; GmemStrides g_stride_; + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr @@ -352,240 +363,199 @@ struct Copy_Traits } }; +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// BULK COPY ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBits::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_BULK_COPY_G2S arguments + // 0: uint64_t* bulk_load_memory_barrier + cute::tuple bulk_load_mbar_; + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_same, cute::tuple>::value, + "Extra arguments not set. Set .with() before use."); + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); + static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); + SM90_BULK_COPY_G2S::copy(src.data().get(), *get<0>(traits.bulk_load_mbar_), + dst.data().get(), int32_t(NumBits::value / 8)); + } + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {{&bulk_mbar}}; + } +}; + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBits::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); + static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); + SM90_BULK_COPY_S2G::copy(src.data().get(), dst.data().get(), int32_t(NumBits::value / 8)); + } +}; + +// +// Placeholder for the bulk copy algorithm's default, auto-vectorizing behavior +// + +template +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_UBULK_COPY arguments + // 0: uint64_t* bulk_load_memory_barrier [if this is a BULK_LOAD_G2S] + cute::tuple opargs_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {{&bulk_mbar}}; + } +}; + // // MAKE_TMA_COPY and related // +namespace detail +{ + template -TMA::SmemSwizzleBits -get_tma_swizzle_bits(ComposedLayout,Offset,SLayout>) +auto +get_swizzle_portion(ComposedLayout,Offset,SLayout>) { - static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); - static_assert(S == 3, "Unsupported layout swizzle"); - - switch (B) { - default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3. Unsupported layout swizzle."); - case 3: return TMA::SmemSwizzleBits::B128; - case 2: return TMA::SmemSwizzleBits::B64; - case 1: return TMA::SmemSwizzleBits::B32; - case 0: return TMA::SmemSwizzleBits::DISABLE; - } + return Swizzle{}; } template -TMA::SmemSwizzleBits -get_tma_swizzle_bits(Layout) +auto +get_swizzle_portion(Layout) { - return TMA::SmemSwizzleBits::DISABLE; + return Swizzle<0,4,3>{}; } template auto -get_nonswizzle_layout(ComposedLayout,Offset,SLayout> const& slayout) +get_nonswizzle_portion(ComposedLayout,Offset,SLayout> const& slayout) { return slayout.layout_fn(); } template auto -get_nonswizzle_layout(Layout const& slayout) +get_nonswizzle_portion(Layout const& slayout) { return slayout; } -/** Make a CuTe CTA-collective TiledCopy for a TMA operation. - * - * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE - * @param gtensor The GMEM Tensor to be involved in the TMA. - * @param slayout The SMEM Layout to be involved in the TMA. - * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. - * This is often the blk_shape that is used to tile the GMEM for CTAs: - * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor - * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 - * defining the multicast size (used to further partition the SMEM) - * Else, static-1 - * - * This code attempts to maximize the TMA box size. It does this by tracing - * the SMEM "vector" -- the inverse of the smem layout -- to find the largest - * contiguous array of smem that can be written to/from global memory given - * the constraints that the TMA instruction imposes. - * - * This is accomplished by assigning "basis" strides to the GMEM to track which - * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according - * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. - * - * Examples: - using T = float; - T* gptr = nullptr; - - { - // Simple 2D - Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM - auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); - } - - { - // GMMA 2D - Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM - auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); - } - - { - // 3D - Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM - auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); - } - - { - // cuTENSOR 4D - auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM - auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: - // Take 128-elem from m: m0 must divide 128, - // m-last may be predicated - // Take 32-elem from k0, 2-elem from k1 - auto slayout = make_layout(cta_tile); // Col-Major SMEM - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Swizzle) +{ + if constexpr (M == 4) { + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 3: return TMA::SmemSwizzleBits::B128; + case 2: return TMA::SmemSwizzleBits::B64; + case 1: return TMA::SmemSwizzleBits::B32; + case 0: return TMA::SmemSwizzleBits::DISABLE; } - * - * Check the TMA box size and desc: - print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - print("TMA desc : "); print(tma.tma_desc_); print("\n"); - * - * Usage: - Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor - Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA - Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + } else + { + static_assert(M < 0, "Unsupported layout swizzle."); + } +} - auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning - Tensor tAgA = cta_tma.partition_S(gA); // Partition for src - Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout const& layout) +{ + return get_tma_swizzle_bits(get_swizzle_portion(layout)); +} - copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params - */ -template +#if !defined(__CUDACC_RTC__) +// Use a smem2gmode map to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +template CUTE_HOST auto -make_tma_copy(CopyOp, - Tensor const& gtensor, - SLayout const& slayout, - CTA_Tile const& cta_tile, - Cluster_Size const& cluster_size) +make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor + Layout const& smem_inv, // smem_idx to flat gmode + Swizzle const& swizzle) // Swizzle fn on smem_idx { - static_assert((std::is_same::value && is_constant<1, Cluster_Size>::value) || - (std::is_same::value) || - (std::is_same::value && is_constant<1, Cluster_Size>::value)); - - using T = typename Tensor::value_type; - - // - // TMA parameter checking - // + using T = typename GEngine::value_type; auto flat_glayout = flatten(gtensor.layout()); - - CUTE_STATIC_ASSERT_V(rank(flatten(cta_tile)) <= Int<5>{}, - "CTA_Tile cannot have more than five modes, TMA arch restriction."); - CUTE_STATIC_ASSERT_V(rank(flat_glayout) <= Int<5>{} || rank(flatten(cta_tile)) <= Int<4>{}, - "If GTensor has more than five modes, then CTA_Tile cannot have more than four modes. TMA multimode."); - CUTE_STATIC_ASSERT_V(compatible(product_each(shape(slayout)), shape(cta_tile)), - "CTA_Tile must be compatible with SLayout."); - CUTE_STATIC_ASSERT_V(is_integral{} && has_single_bit(cluster_size) && cluster_size <= Int<16>{}, - "Expecting a pow2 integral Cluster_Size leq 16."); - CUTE_STATIC_ASSERT_V(size(slayout) % cluster_size == Int<0>{}, - "ClusterShape must divide domain size of slayout."); - - // - // TMA slayout manipulation - // + CUTE_STATIC_ASSERT_V(rank(flat_glayout) == rank(smem_inv)); + constexpr int rank_smem_inv = decltype(rank(smem_inv))::value; auto tma_multimode = rank(flat_glayout) > Int<5>{}; - - // Invert the smem to get the largest contiguous vector in the smem layout - auto inv_smem_layout = right_inverse(get_nonswizzle_layout(slayout)); - // trunc_smem_idx -> trunc_smem_coord - - // Map from smem idx to a gmem mode - auto sidx_to_gmode = flatten(composition(make_identity_layout(cta_tile), inv_smem_layout)); - - // Truncate any incompatibilities - auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e){ - [[maybe_unused]] auto v = basis_value(e); - return not is_constant<1,decltype(v)>{}; - }); - static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA."); - constexpr int smem_tma_rank = cute::min(int(smem_rank), (tma_multimode ? 4 : 5)); - - // Keep only the static-1 basis modes into gmem - auto sidx_to_gmode_cluster_trunc = take<0,smem_tma_rank>(sidx_to_gmode); - // Keep only the portion each multicast CTA will be responsible for - auto sidx_to_gmode_cta_trunc = composition(sidx_to_gmode_cluster_trunc, shape_div(size(sidx_to_gmode_cluster_trunc), cluster_size)); - - // - // TMA gtensor manipulation - // - - // Generate a TupleBasis for the gtensor - auto flat_gbasis = make_basis_like(shape(flat_glayout)); - - // Fold the flat_gbasis into the glayout - auto glayout_basis = make_layout(shape(gtensor), - stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis), - make_layout(repeat_like(shape(gtensor), Int<2>{}))))); - - // Tile the modes of gtensor with cta_tile - auto cta_glayout_basis = composition(glayout_basis, cta_tile); - - // Check that the cta_tile selects modes from gtensor properly - for_each(flatten(stride(cta_glayout_basis)), [](auto d) { - static_assert(is_constant<1, decltype(d.value())>::value, - "CTA_Tile does not faithfully partition the GMEM, it should select the number of elements from each mode of glayout."); - }); - - // Tile the modes of gtensor again with the truncated cta_tile o inv_smem_layout - auto tma_layout_cta_trunc = flatten(composition(glayout_basis, sidx_to_gmode_cta_trunc)); - - // Append any missing basis on the end as size-1 modes b/c they got truncated - auto missing_basis = fold(stride(tma_layout_cta_trunc), flat_gbasis, [](auto init, auto e){ - auto k = find(init, e); - return remove(init); - }); - - // The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode - auto tma_layout_cta = flatten(make_layout(tma_layout_cta_trunc, - make_layout(repeat(Int<1>{}), missing_basis))); - -#if 0 - print("g_layout : "); print(gtensor.layout()); print("\n"); - print("s_layout : "); print(slayout); print("\n"); - print("cta_tile : "); print(cta_tile); print("\n"); - print("cluster_size : "); print(cluster_size); print("\n"); - print("flat_gbasis : "); print(flat_gbasis); print("\n"); - print("cta_glayout : "); print(cta_glayout_basis); print("\n"); - print("inv_smem : "); print(inv_smem_layout); print("\n"); - print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); - print("missing_b : "); print(missing_basis); print("\n"); - print("tma_layout_cta: "); print(tma_layout_cta); print("\n"); -#endif + constexpr uint32_t tma_dim = cute::min(rank(flat_glayout), 5);; // // TMA gmem desc info // - constexpr int TmaRANK = cute::min(rank(flat_glayout), 5); void* gmem_address = (void*) gtensor.data(); - cute::array gmem_prob_shape = {1,1,1,1,1}; - cute::array gmem_prob_stride = {0,0,0,0,0}; - for_each(make_seq{}, [&](auto i) { - // NOTE : WAR g++-7.3.5, let it deduce e rather than fuse with below - auto e = stride(tma_layout_cta); + cute::array gmem_prob_shape = {1,1,1,1,1}; + cute::array gmem_prob_stride = {0,0,0,0,0}; + for_each(make_seq{}, [&](auto i) { + auto e = stride(smem_inv); // For g++-7.5, let it deduce e rather than fuse with below constexpr int j = decltype(e.mode())::value; constexpr int tma_i = i < 5 ? i : 4; @@ -634,10 +604,10 @@ make_tma_copy(CopyOp, // TMA smem desc info // - // TMA smem box size - cute::array smem_box_shape = {1,1,1,1,1}; - for_each(make_seq{}, [&](auto i) { - uint32_t shape_i = shape(tma_layout_cta); + cute::array smem_box_shape = {1,1,1,1,1}; + cute::array smem_box_stride = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + uint32_t shape_i = shape(smem_inv); constexpr int tma_i = i < 5 ? i : 4; if (tma_multimode && tma_i == 4) { // We're "reusing" this TMA mode and using it as a "multimode" @@ -647,9 +617,6 @@ make_tma_copy(CopyOp, } }); - // TMA smem mode strides - [[maybe_unused]] cute::array smem_box_stride = {1,1,1,1,1}; - assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 @@ -676,21 +643,19 @@ make_tma_copy(CopyOp, TmaDescriptor tma_desc = {0}; -#if (__CUDACC_VER_MAJOR__ >= 12) - // // TMA general info // - cuuint32_t tma_dim = TmaRANK; +#if (__CUDACC_VER_MAJOR__ >= 12) + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; // TMA smem swizzle type - CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(slayout)); - + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); CUresult result = cuTensorMapEncodeTiled( &tma_desc, tma_format, @@ -721,16 +686,12 @@ make_tma_copy(CopyOp, std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; assert(false); } -#endif // (__CUDACC_VER_MAJOR__ >= 12) - - // - // Construct the Copy_Traits - // +#endif // (__CUDACC_VER_MAJOR__ >= 12) // Finally, get the inverse permutation of the E bases for the mocked gmem stride - auto gmem_stride_bases_flat = transform(make_seq{}, [&](auto i) { - auto k = find(stride(tma_layout_cta), E{}); - // NOTE: gcc 7.3.5 WAR -- avoid if constexpr + auto gmem_stride_bases_flat = transform(make_seq{}, [&](auto i) { + auto k = find(stride(smem_inv), E{}); + // For gcc 7.5 -- avoid 'if constexpr' int32_t tma_coord_stride = int32_t(stride(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16)); return conditional_return(tma_multimode && (k >= Int<4>{}), E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride @@ -738,10 +699,120 @@ make_tma_copy(CopyOp, }); // Give that the profile of gtensor and fold it + // NOTE: This is the only reason we want the original gtensor shape rather than the more intuitive flattened shape auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat), make_layout(repeat_like(shape(gtensor), Int<2>{})))); - constexpr int num_bits = size(sidx_to_gmode_cta_trunc) * sizeof(T) * 8; + return make_tuple(tma_desc, gmem_stride_bases); +} + +template +CUTE_HOST +auto +make_tma_copy_tiled(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM + Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid + Layout const& cta_v_map) // V: CTA val idx -> gmem coord +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, + "Number of active CTAs in TMA must divide domain size of slayout."); + + // + // TMA slayout manipulation + // + + auto flat_glayout = flatten(gtensor.layout()); + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); + + // Truncate any incompatibilities + auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e) { + auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA. Do they have a common majorness?"); + // TMA uses a maximum of 5 modes + // If the gtensor has more than 5 modes, we need to reserve the last TMA-mode as a "multimode" + constexpr int smem_tma_rank = cute::min(int(smem_rank), (rank(flat_glayout) > Int<5>{} ? 4 : 5)); + + // Keep only the static-1 basis modes into gmem + auto sidx_to_gmode_trunc = take<0,smem_tma_rank>(sidx_to_gmode); + + // Split according to the portion each multicast CTA will be responsible for + auto sidx_to_gmode_vt = logical_divide(sidx_to_gmode_trunc, shape_div(size(sidx_to_gmode_trunc), cosize(cta_t_map))); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); + + print("sidx_to_gmode_trunc : "); print(sidx_to_gmode_trunc); print("\n"); + print("sidx_to_gmode_vt : "); print(sidx_to_gmode_vt); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto flat_gbasis = make_basis_like(shape(flat_glayout)); + + // Fold the flat_gbasis into the glayout + auto glayout_basis = make_layout(shape(gtensor), + stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis), + make_layout(repeat_like(shape(gtensor), Int<2>{}))))); + + // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + auto tma_layout_v_trunc = flatten(composition(glayout_basis, layout<0>(sidx_to_gmode_vt))); + + // Append any missing basis on the end as size-1 modes b/c they got truncated + // NOTE This is essentially ArithmeticTuple complement... + auto missing_basis = fold(stride(tma_layout_v_trunc), flat_gbasis, [](auto init, auto e) { + auto k = find(init, e); + return remove(init); + }); + + // The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode + auto tma_layout_v = make_layout(flatten(cute::make_tuple(tma_layout_v_trunc.shape(), repeat(Int<1>{}))), + flatten(cute::make_tuple(tma_layout_v_trunc.stride(), missing_basis))); + +#if 0 + print("flat_gbasis : "); print(flat_gbasis); print("\n"); + print("missing_b : "); print(missing_basis); print("\n"); + print("tma_layout_v : "); print(tma_layout_v); print("\n"); +#endif + + // + // Construct the TMA Desc and GMEM mode ordering + // + + auto [tma_desc, gmem_stride_bases] = detail::make_tma_copy_desc(gtensor, tma_layout_v, get_swizzle_portion(slayout)); + + // + // Construct the Copy_Traits + // + + using T = typename GEngine::value_type; + constexpr int num_bits = decltype(size<0>(sidx_to_gmode_vt))::value * sizeof(T) * 8; using Traits = Copy_Traits, decltype(gmem_stride_bases)>; #if 0 @@ -749,30 +820,126 @@ make_tma_copy(CopyOp, print("g_stride_bases: "); print(gmem_stride_bases); print("\n"); #endif + Traits tma_traits{tma_desc, gmem_stride_bases}; + // // Construct the TiledCopy // - // The ThrVal layout for 1 TMA instruction within cta_tile - auto layout_tv_1 = composition(inv_smem_layout, make_layout(make_shape(cluster_size, size(sidx_to_gmode_cta_trunc)), GenRowMajor{})); - // The ThrVal layout for N TMA instructions within cta_tile - auto layout_tv = tile_to_shape(layout_tv_1, make_shape(cluster_size, size(cta_tile)/cluster_size)); + auto cta_tiler = product_each(shape(cta_v_map)); + + // (CTA V, CTA T) -> smem_coord + auto layout_vt = composition(inv_smem_layout, make_layout(shape(sidx_to_gmode_vt))); + // Scale that up to cover all of the smem_coords + auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt))); + // Flip it and change the domain of the T from logical thr to thr_idx + auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT)); #if 0 - print("layout_tv : "); print(layout_tv); print("\n"); + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_VT : "); print(layout_VT); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); #endif - // If CTA_Tile and SLayout are incompatible, product_each makes sure - // that the TiledCopy generates consistent accesses. - auto cta_tile_tiled = [&]() { - if constexpr (compatible(shape(CTA_Tile{}), shape(SLayout{}))) { - return cta_tile; - } else { - return product_each(cta_tile); + using T = typename GEngine::value_type; + return TiledCopy, decltype(layout_TV), decltype(cta_tiler)>{tma_traits}; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace detail + +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. + * This is often the blk_shape that is used to tile the GMEM for CTAs: + * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor + * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 + * defining the multicast size (used to further partition the SMEM) + * Else, static-1 + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + using T = float; + T* gptr = nullptr; + + { + // Simple 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM + auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); } - }(); - return TiledCopy, decltype(layout_tv), decltype(cta_tile_tiled)>{tma_desc, gmem_stride_bases}; + { + // GMMA 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM + auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // 3D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM + auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // cuTENSOR 4D + auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto slayout = make_layout(cta_tile); // Col-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); + } + * + * Check the TMA box size and desc: + print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + print("TMA desc : "); print(tma.tma_desc_); print("\n"); + * + * Usage: + Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor + Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA + Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + + auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning + Tensor tAgA = cta_tma.partition_S(gA); // Partition for src + Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst + + copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params + */ +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tile const& cta_tile, + Cluster_Size const& cluster_size) +{ + + return detail::make_tma_copy_tiled(copy_op, + gtensor, + slayout, + make_layout(cluster_size), + make_identity_layout(cta_tile)); } // Explicit defaulting @@ -797,9 +964,10 @@ auto make_tma_copy(CopyOp const& copy_op, Tensor const& gtensor, SLayout const& slayout, - Cluster_Size const& cluster_size) + Cluster_Size const& cluster_size) { return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); } +#endif // !defined(__CUDACC_RTC__) } // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index c3025f50..109117ce 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -30,101 +30,16 @@ **************************************************************************************************/ #pragma once -#include - #include - #include -#include -#include -#include - -namespace cute { - -// Generic mma_unpack for any MMA_Traits -template -CUTE_HOST_DEVICE constexpr -void -mma_unpack(MMA_Traits const&, - Tensor & D, - Tensor const& A, - Tensor const& B, - Tensor const& C) -{ - static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); - - // Register value types from the MMA_Operation register arrays - using RegTypeD = typename std::remove_extent::type; - using RegTypeA = typename std::remove_extent::type; - using RegTypeB = typename std::remove_extent::type; - using RegTypeC = typename std::remove_extent::type; - constexpr int RegNumD = std::extent::value; - constexpr int RegNumA = std::extent::value; - constexpr int RegNumB = std::extent::value; - constexpr int RegNumC = std::extent::value; - - Tensor rA = recast(A); - Tensor rB = recast(B); - - CUTE_STATIC_ASSERT_V(size(rA) == Int{}); - CUTE_STATIC_ASSERT_V(size(rB) == Int{}); - - if constexpr (std::is_same::value) - { - static_assert(std::is_same::value, "GMMA C and D value_type must match."); - static_assert(std::is_same::value, "GMMA C and D layouts must match."); - // assert((void*)&C == (void*)&D); - - Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D - - //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); - - detail::explode(Operation::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); - } else - { - Tensor rD = recast(D); - Tensor rC = recast(C); - - CUTE_STATIC_ASSERT_V(size(rD) == Int{}); - CUTE_STATIC_ASSERT_V(size(rC) == Int{}); - - detail::explode(Operation::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); - } -} - - -namespace detail { -template -struct FrgTypeA_or_Default { using type = typename X::ElementAVal; }; -template -struct FrgTypeA_or_Default> { using type = typename X::ElementAFrg; }; +#include -template -struct FrgTypeB_or_Default { using type = typename X::ElementBVal; }; -template -struct FrgTypeB_or_Default> { using type = typename X::ElementBFrg; }; +#include -template -struct FrgTypeC_or_Default { using type = typename X::ElementCVal; }; -template -struct FrgTypeC_or_Default> { using type = typename X::ElementCFrg; }; +#include -} // end namespace detail +namespace cute { template struct MMA_Atom; @@ -167,17 +82,6 @@ struct MMA_Atom> return MMA_Atom{traits}; } - // Print thread and data layouts for debugging - CUTE_HOST_DEVICE static - void - print_all() - { - print("ThrID: "); print(ThrID{}); print("\n"); - print("LayoutA_TV: "); print(LayoutA_TV{}); print("\n"); - print("LayoutB_TV: "); print(LayoutB_TV{}); print("\n"); - print("LayoutC_TV: "); print(LayoutC_TV{}); print("\n"); - } - // // Tensor call interfaces // @@ -232,7 +136,6 @@ struct MMA_Atom> // Check that this tensor is likely already partitioned CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); - // C is a bit special because we are after accumulators here // The input/output type doesn't have to match the accumulator type //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); @@ -249,12 +152,14 @@ struct MMA_Atom> // Check that this tensor is likely already partitioned CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); - static_assert(std::is_same::value_type>::value, "Expecting ValTypeA type"); if constexpr (has_dereference::value) { - return recast(std::forward(atensor)); + // If the intended FrgTypeA is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value, "Expecting ValTypeA type"); + return make_tensor(std::forward(atensor)); } else { - return make_tensor(make_fragment_like(atensor.layout())); + // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout + return make_fragment_like(atensor); } CUTE_GCC_UNREACHABLE; @@ -268,12 +173,14 @@ struct MMA_Atom> // Check that this tensor is likely already partitioned CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); - static_assert(std::is_same::value_type>::value, "Expecting ValTypeB type"); if constexpr (has_dereference::value) { - return recast(std::forward(btensor)); + // If the intended FrgTypeB is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value, "Expecting ValTypeB type"); + return make_tensor(std::forward(btensor)); } else { - return make_tensor(make_fragment_like(btensor.layout())); + // Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout + return make_fragment_like(btensor); } CUTE_GCC_UNREACHABLE; @@ -607,7 +514,7 @@ struct ThrMMA : TiledMMA auto partition_C(CTensor&& ctensor) const { - auto thr_tensor = make_tensor(std::forward(ctensor).data(), thrfrg_C(ctensor.layout())); + auto thr_tensor = make_tensor(std::forward(ctensor).data(), TiledMMA::thrfrg_C(ctensor.layout())); auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); @@ -618,7 +525,7 @@ struct ThrMMA : TiledMMA auto partition_A(ATensor&& atensor) const { - auto thr_tensor = make_tensor(std::forward(atensor).data(), thrfrg_A(atensor.layout())); + auto thr_tensor = make_tensor(std::forward(atensor).data(), TiledMMA::thrfrg_A(atensor.layout())); auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); @@ -629,7 +536,7 @@ struct ThrMMA : TiledMMA auto partition_B(BTensor&& btensor) const { - auto thr_tensor = make_tensor(std::forward(btensor).data(), thrfrg_B(btensor.layout())); + auto thr_tensor = make_tensor(std::forward(btensor).data(), TiledMMA::thrfrg_B(btensor.layout())); auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); @@ -675,8 +582,8 @@ make_tiled_mma(MMA_Atom const&, MMAValLayout const& val_layout = {}, Permutations const& permutations = {}) { - auto thr_layout_mnk = append<3>(thr_layout, Layout<_1>{}); - auto val_layout_mnk = append<3>(val_layout, Layout<_1>{}); + auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{}); + auto val_layout_mnk = append<3>(val_layout, Layout<_1,_0>{}); auto permutation_mnk = append<3>(permutations, _); return TiledMMA, @@ -707,19 +614,24 @@ make_tiled_mma(MMA_Op const&, template CUTE_HOST_DEVICE constexpr auto -partition_fragment_C(TiledMMA, Shape_MN shapeMN) +partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) { constexpr int R = rank_v; static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - - auto V = size<1>(typename TiledMMA::AtomLayoutC_TV{}); - auto M = shape_div(size<0>(shapeMN), size<0>(atomMNK) * size<1>(thrVMNK)); - auto N = shape_div(size<1>(shapeMN), size<1>(atomMNK) * size<2>(thrVMNK)); - auto frg_shape = tuple_cat(make_shape(V,M,N), take<2,R>(shapeMN)); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + auto V = shape<1>(typename TiledMMA::AtomLayoutC_TV{}); + auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK)); + auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK)); + return tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN)); +} - return make_tensor::FrgTypeC>(frg_shape); +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_C(TiledMMA const& mma, Shape_MN const& shapeMN) +{ + return make_tensor::FrgTypeC>(partition_shape_C(mma, shapeMN)); } // partition_fragment_A and partition_fragment_B often depend on the @@ -727,6 +639,36 @@ partition_fragment_C(TiledMMA, Shape_MN shapeMN) // For these reasons, they should not be used in a static context. // See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) +{ + constexpr int R = rank_v; + static_assert(R >= 2, "Must have at least rank-2"); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + auto V = shape<1>(typename TiledMMA::AtomLayoutA_TV{}); + auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK)); + auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK)); + return tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) +{ + constexpr int R = rank_v; + static_assert(R >= 2, "Must have at least rank-2"); + auto atomMNK = typename TiledMMA::AtomShape_MNK{}; + auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; + auto V = shape<1>(typename TiledMMA::AtomLayoutB_TV{}); + auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK)); + auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK)); + return tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK)); +} + // // Size // @@ -739,18 +681,62 @@ tile_size(TiledMMA const& mma) return size(typename TiledMMA::TiledShape_MNK{}); } -template +template +CUTE_HOST_DEVICE constexpr +auto +tile_shape(TiledMMA const& mma) +{ + return shape(typename TiledMMA::TiledShape_MNK{}); +} + +template CUTE_HOST_DEVICE constexpr auto size(TiledMMA const& mma) { - return size(typename TiledMMA::ThrLayoutVMNK{}); + return size(typename TiledMMA::ThrLayoutVMNK{}); } // // Display utilities // +template +CUTE_HOST_DEVICE +void +print(MMA_Atom> const&) +{ + using Atom = MMA_Atom>; + print("MMA_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); + print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); + print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledMMA const& mma) +{ + using MMA = TiledMMA; + print("TiledMMA\n"); + print(" TiledThr: "); print(TiledThr{}); print("\n"); + print(" TiledVal: "); print(TiledVal{}); print("\n"); + print(" TiledPerm: "); print(TiledPerm{}); print("\n"); + print(" TiledShape_MNK: "); print(typename MMA::TiledShape_MNK{}); print("\n"); + print(" ThrLayoutVMNK: "); print(typename MMA::ThrLayoutVMNK{}); print("\n"); + print(static_cast(mma)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrMMA const&) +{ + print(TiledMMA{}); +} + template CUTE_HOST_DEVICE auto @@ -992,9 +978,9 @@ print_latex_mma(Shape_MNK const& shape_mnk, printf(latex_header); - int M = size<0>(shape_mnk); - int N = size<1>(shape_mnk); - int K = size<2>(shape_mnk); + constexpr int M = size<0>(shape_mnk); + constexpr int N = size<1>(shape_mnk); + constexpr int K = size<2>(shape_mnk); // C starting at 0,0 bool c_filled[M][N] = {}; @@ -1070,12 +1056,10 @@ print_latex_mma(Shape_MNK const& shape_mnk, //////////////////////////////////////////////////////////////////////////////////////////////////// -#include #include #include #include #include #include #include - //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index a8c3323a..7242e2d4 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -32,11 +32,43 @@ #include -#include +#include namespace cute { +namespace detail { + +template +struct supports_output_scaling { static constexpr bool value = false; }; + +template +struct supports_output_scaling().accumulate_)>> { static constexpr bool value = true; }; + +} // end namespace detail + +/** + * concept MMA_Traits + * { + * using ElementDVal = // Logical A-value type + * using ElementAVal = // Logical B-value type + * using ElementBVal = // Logical C-value type + * using ElementCVal = // Logical D-value type (NOTE: Not used? Assumed == ElementDVal) + * + * using ElementAFrg = // A-type consumed by MMA (if ommitted, same as ElementAVal) + * using ElementBFrg = // B_type consumed by MMA (if ommitted, same as ElementBVal) + * using ElementCFrg = // C_type consumed by MMA (if ommitted, same as ElementCVal) + * + * using Shape_MNK = // Logical MxNxK shape of the MMA + * + * using ThrID = // Logical thread id (tid) -> tidx + * + * using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord + * using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord + * using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord + * }; + */ + template struct MMA_Traits { @@ -67,4 +99,110 @@ struct MMA_Traits> using CLayout = Layout>; }; +// +// Generic mma_unpack for any MMA_Traits +// +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using MMATraits = MMA_Traits; + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + Tensor rA = recast(A); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + if constexpr (is_same::value) + { + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + if constexpr (detail::supports_output_scaling::value) { + detail::explode_with_d_scaling(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + traits.accumulate_); + } + else { + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); + } + } + else { + Tensor rD = recast(D); + Tensor rC = recast(C); + + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + if constexpr (detail::supports_output_scaling::value) { + detail::explode_with_d_scaling(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + traits.accumulate_); + } + else { + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); + } + } +} + +namespace detail { + +template +struct FrgTypeA_or_Default { using type = typename X::ElementAVal; }; +template +struct FrgTypeA_or_Default> { using type = typename X::ElementAFrg; }; + +template +struct FrgTypeB_or_Default { using type = typename X::ElementBVal; }; +template +struct FrgTypeB_or_Default> { using type = typename X::ElementBFrg; }; + +template +struct FrgTypeC_or_Default { using type = typename X::ElementCVal; }; +template +struct FrgTypeC_or_Default> { using type = typename X::ElementCFrg; }; + +} // end namespace detail + } // namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index d390dafc..752023c2 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -37,6 +37,29 @@ namespace cute { +// Fence between the async destination accumulators of GMMA & source for their dependent use +template +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(Tensor& frg) { + CUTE_STATIC_ASSERT(is_static::value); + if constexpr (is_same_v) { + auto f32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(f32_frg); ++i) { + warpgroup_fence_operand(f32_frg(i)); + } + } + else { + CUTE_STATIC_ASSERT(is_rmem::value); + auto u32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(u32_frg); ++i) { + warpgroup_fence_operand(u32_frg(i)); + } + } +} + namespace GMMA { /////////////////////////////////////////// @@ -77,68 +100,22 @@ using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_S // With GMMA::Major param template -using Layout_INTER_Atom = typename std::conditional, Layout_K_INTER_Atom>::type; template -using Layout_SW32_Atom = typename std::conditional, Layout_K_SW32_Atom>::type; template -using Layout_SW64_Atom = typename std::conditional, Layout_K_SW64_Atom>::type; template -using Layout_SW128_Atom = typename std::conditional, Layout_K_SW128_Atom>::type; -// Helper for GMMA smem selection that considers a tensor TileShape: -// (BLK_MN, BLK_K) -// or hierarchically -// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) -// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 -template -CUTE_HOST_DEVICE constexpr -auto -smem_selector() -{ - auto BLK_MN0 = size<0>(BLK_MN{}); - auto BLK_K0 = size<0>(BLK_K{}); - - static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); - static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); - - - if constexpr (major == GMMA::Major::MN) { - if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { - return GMMA::Layout_MN_SW128_Atom{}; - } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { - return GMMA::Layout_MN_SW64_Atom{}; - } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { - return GMMA::Layout_MN_SW32_Atom{}; - } else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { - return GMMA::Layout_MN_INTER_Atom{}; - } else { - static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, - "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); - } - } else if constexpr (major == GMMA::Major::K) { - if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { - return GMMA::Layout_K_SW64_Atom{}; - } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { - return GMMA::Layout_K_SW32_Atom{}; - } else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { - return GMMA::Layout_K_INTER_Atom{}; - } else { - static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, - "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); - } - } -} - // // Tensor to LayoutType utility // @@ -147,13 +124,13 @@ smem_selector() template CUTE_HOST_DEVICE constexpr LayoutType -layout_type(Tensor>>, +layout_type(Tensor>>, Layout> const&) { static_assert(M == 4, "Unsupported layout swizzle"); static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); static_assert(S == 3, "Unsupported layout swizzle"); - + switch (B) { case 0: return LayoutType::INTERLEAVE; case 1: return LayoutType::B32; @@ -167,7 +144,7 @@ layout_type(Tensor>> template CUTE_HOST_DEVICE constexpr LayoutType -layout_type(Tensor>, +layout_type(Tensor>, Layout> const&) { return LayoutType::INTERLEAVE; @@ -177,7 +154,7 @@ layout_type(Tensor>, // Construction method for GMMA Descriptors /////////////////////////////////////////////////////////////////////////////// -/** +/** * /////////////////////////////// * // make_gmma_desc // * /////////////////////////////// @@ -188,14 +165,14 @@ layout_type(Tensor>, * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) * -* where +* where * T : sizeof(uint128_t) / sizeof(value_type) * m : integer in [1,16] corresponding to GMMA shape * k : integer in [1,32] corresponding to GMMA shape * SBO: stride byte offset * LBO: leading byte offset * -* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. +* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. * For example, * auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); * is guaranteed to be accepted by make_gmma_desc for appropriate value_type. @@ -210,7 +187,7 @@ layout_type(Tensor>, * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) * -* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. +* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. * For example, * auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); * is guaranteed to be accepted by make_gmma_desc for appropriate value_type. @@ -279,7 +256,7 @@ make_gmma_desc(Tensor const& tensor) desc.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; desc.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; } - else if constexpr (MajorMode == GMMA::Major::K) + else if constexpr (MajorMode == GMMA::Major::K) { /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form * @@ -335,7 +312,7 @@ make_gmma_desc(Tensor const& tensor) // Higher level GMMA Descriptor utilities /////////////////////////////////////////////////////////////////////////////// -struct gmma_descriptor_iterator +struct DescriptorIterator { GmmaDescriptor desc_; @@ -351,7 +328,7 @@ struct gmma_descriptor_iterator // Return an advanced iterator template CUTE_HOST_DEVICE constexpr - gmma_descriptor_iterator operator+(Index const& offset) const + DescriptorIterator operator+(Index const& offset) const { // offset is in the units of uint128_t (4LSB of start_address not included) @@ -366,66 +343,43 @@ struct gmma_descriptor_iterator //return {desc}; // The above seems to not work for some reason... - return {desc_ + uint64_t(offset)}; + return { GmmaDescriptor {desc_ + uint64_t(offset)} }; } + + CUTE_HOST_DEVICE friend void + print(DescriptorIterator const&) { printf("GMMA::DescriptorIterator"); } }; +// The GMMA Traits below have custom fragment type flags for their smem desc tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. template -struct smem_desc : gmma_descriptor_iterator {}; - -template -CUTE_HOST_DEVICE constexpr -auto -make_gmma_desc_fragment(Tensor const& t) -{ - // Cast to a uint128_t tensor for GMMA Desc iteration - return make_tensor(gmma_descriptor_iterator{make_gmma_desc(tensor<0>(t))}, - recast(t).layout()); -} - -// Recast a tensor to a tensor of gmma_descriptor_iterator -template -CUTE_HOST_DEVICE constexpr -auto -recast(Tensor&& tensor, type_list>) -{ - return make_gmma_desc_fragment(tensor); -} +struct smem_desc : DescriptorIterator {}; -// Recast a gmma_descriptor_iterator Tensor to uint64_t, it's RegType +// Recast a DescriptorIterator Tensor to uint64_t, it's RegType template CUTE_HOST_DEVICE constexpr auto -recast(Tensor,TLayout> const& tensor, type_list) +recast(Tensor,TLayout> const& tensor, type_list) { - static_assert(std::is_same::value, "Can only cast descriptors to uint64_t."); + static_assert(is_same::value, "Can only cast descriptors to uint64_t."); return make_tensor(tensor.data(), Layout<_1,_0>{}); } } // end namespace GMMA -// Fence between the async destination accumulators of GMMA & source for their dependent use -template -CUTE_HOST_DEVICE -void -warpgroup_fence_operand(Tensor& frg) { - CUTE_STATIC_ASSERT(is_static::value); - if constexpr (std::is_same_v) { - auto f32_frg = recast(frg); - CUTE_UNROLL - for (int i = 0; i < size(f32_frg); ++i) { - warpgroup_fence_operand(f32_frg(i)); - } - } - else { - CUTE_STATIC_ASSERT(is_rmem::value); - auto u32_frg = recast(frg); - CUTE_UNROLL - for (int i = 0; i < size(u32_frg); ++i) { - warpgroup_fence_operand(u32_frg(i)); - } +// Customization point for creating a GMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + return make_tensor(GMMA::DescriptorIterator{GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + recast(smem_tensor).layout()); } -} +}; /////////////////////////////////////////////////////////////////////////////// //////////////////////////// MMA_TRAITS /////////////////////////////////////// @@ -476,8 +430,8 @@ using ABLayout = Layout,Int>>, } // namespace GMMA -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -492,12 +446,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 8, 16>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -511,12 +467,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -531,12 +489,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 16, 16>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -550,12 +510,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -570,12 +532,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 32, 16>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -589,12 +553,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -609,12 +575,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 64, 16>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -628,12 +596,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -648,12 +618,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 96, 16>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -667,12 +639,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -687,12 +661,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<128, 16>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -706,12 +682,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -726,12 +704,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<192, 16>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -745,12 +725,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -765,12 +747,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<256, 16>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = half_t; using ElementAVal = half_t; @@ -784,12 +768,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -804,12 +790,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 8, 16>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -823,12 +811,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -843,12 +833,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 16, 16>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -862,12 +854,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -882,12 +876,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 32, 16>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -901,12 +897,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -921,12 +919,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 64, 16>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -940,12 +940,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -960,12 +962,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 96, 16>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -979,12 +983,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -999,12 +1005,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<128, 16>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -1018,12 +1026,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -1038,12 +1048,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<192, 16>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -1057,12 +1069,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -1077,12 +1091,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<256, 16>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = half_t; @@ -1096,12 +1112,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1116,12 +1134,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 8, 16>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1135,12 +1155,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1155,12 +1177,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 16, 16>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1174,12 +1198,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1194,12 +1220,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 32, 16>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1213,12 +1241,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1233,12 +1263,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 64, 16>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1252,12 +1284,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1272,12 +1306,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout< 96, 16>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1291,12 +1327,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1311,12 +1349,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<128, 16>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1330,12 +1370,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1350,12 +1392,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<192, 16>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1369,12 +1413,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1389,12 +1435,14 @@ struct MMA_Traits; using BLayout = GMMA::ABLayout<256, 16>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = bfloat16_t; @@ -1408,12 +1456,14 @@ struct MMA_Traits; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1428,12 +1478,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout< 8, 8>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1447,12 +1499,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout< 8, 8>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1467,12 +1521,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout< 16, 8>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1486,12 +1542,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout< 16, 8>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1506,12 +1564,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout< 32, 8>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1525,12 +1585,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout< 32, 8>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1545,12 +1607,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout< 64, 8>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1564,12 +1628,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout< 64, 8>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1584,12 +1650,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout< 96, 8>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1603,12 +1671,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout< 96, 8>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1623,12 +1693,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout<128, 8>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1642,12 +1714,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout<128, 8>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1662,12 +1736,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout<192, 8>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1681,12 +1757,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout<192, 8>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1701,12 +1779,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 8>; using BLayout = GMMA::ABLayout<256, 8>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ElementDVal = float; using ElementAVal = tfloat32_t; @@ -1720,12 +1800,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x8; using BLayout = GMMA::ABLayout<256, 8>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1740,12 +1822,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1760,12 +1844,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1780,12 +1866,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1800,12 +1888,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1820,12 +1910,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1840,12 +1932,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1860,12 +1954,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1880,12 +1976,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1899,12 +1997,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1918,12 +2018,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1937,12 +2039,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1956,12 +2060,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1975,12 +2081,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -1994,12 +2102,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2013,12 +2123,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2032,12 +2144,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2052,12 +2166,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2072,12 +2188,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2092,12 +2210,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2112,12 +2232,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2132,12 +2254,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2152,12 +2276,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2172,12 +2298,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2192,12 +2320,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2211,12 +2341,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2230,12 +2362,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2249,12 +2383,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2268,12 +2404,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2287,12 +2425,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2306,12 +2446,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2325,12 +2467,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = int8_t; @@ -2344,12 +2488,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2364,12 +2510,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2384,12 +2532,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2404,12 +2554,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2424,12 +2576,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2444,12 +2598,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2464,12 +2620,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2484,12 +2642,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2504,12 +2664,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2523,12 +2685,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2542,12 +2706,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2561,12 +2727,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2580,12 +2748,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2599,12 +2769,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2618,12 +2790,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2637,12 +2811,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2656,12 +2832,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2676,12 +2854,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2696,12 +2876,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2716,12 +2898,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2736,12 +2920,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2756,12 +2942,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2776,12 +2964,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2796,12 +2986,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2816,12 +3008,14 @@ struct MMA_Traits> using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2835,12 +3029,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 8, 32>; using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2854,12 +3050,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 16, 32>; using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2873,12 +3071,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 32, 32>; using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2892,12 +3092,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2911,12 +3113,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2930,12 +3134,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<128, 32>; using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2949,12 +3155,14 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<192, 32>; using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template <> +struct MMA_Traits { using ElementDVal = int32_t; using ElementAVal = uint8_t; @@ -2968,6 +3176,8 @@ struct MMA_Traits> using ALayout = GMMA::ALayout_64x32; using BLayout = GMMA::ABLayout<256, 32>; using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/config.hpp b/include/cute/config.hpp index b2f4de83..c6533510 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -40,9 +40,12 @@ # define CUTE_HOST inline #endif // CUTE_HOST_DEVICE, CUTE_DEVICE -#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +#if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) # define CUTE_UNROLL #pragma unroll # define CUTE_NO_UNROLL #pragma unroll 1 +#elif defined(__CUDACC_RTC__) +# define CUTE_UNROLL _Pragma("unroll") +# define CUTE_NO_UNROLL _Pragma("unroll 1") #else # define CUTE_UNROLL # define CUTE_NO_UNROLL @@ -54,6 +57,24 @@ # define CUTE_INLINE_CONSTANT static constexpr #endif +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + // Some versions of GCC < 11 have trouble deducing that a // function with "auto" return type and all of its returns in an "if // constexpr ... else" statement must actually return. Thus, GCC @@ -72,17 +93,33 @@ # endif #endif +#ifdef _MSC_VER +// Provides support for alternative operators 'and', 'or', and 'not' +#include +#endif // _MSC_VER + +#if defined(__CUDACC_RTC__) +#define CUTE_STL_NAMESPACE cuda::std +#define CUTE_STL_NAMESPACE_IS_CUDA_STD +#else +#define CUTE_STL_NAMESPACE std +#endif + // // Assertion helpers // +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #define CUTE_STATIC_ASSERT static_assert #define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) #if defined(__CUDA_ARCH__) -# define CUTE_RUNTIME_ASSERT(x) asm volatile ("brkpt;\n" ::: "memory") +# define CUTE_RUNTIME_ASSERT(x) __brkpt() #else # define CUTE_RUNTIME_ASSERT(x) assert(0 && x) #endif @@ -91,9 +128,11 @@ // IO // +#if !defined(__CUDACC_RTC__) #include #include #include +#endif // // Support diff --git a/include/cute/container/alignment.hpp b/include/cute/container/alignment.hpp index 49101fa7..36dfe765 100644 --- a/include/cute/container/alignment.hpp +++ b/include/cute/container/alignment.hpp @@ -42,7 +42,7 @@ namespace cute template CUTE_HOST_DEVICE constexpr bool -is_byte_aligned(void const* const ptr) +is_byte_aligned(void const* const ptr) { static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2 in alignment check"); return (reinterpret_cast(ptr) & (N-1)) == 0; @@ -54,7 +54,7 @@ is_byte_aligned(void const* const ptr) # define CUTE_ALIGNAS(n) alignas(n) #endif -template +template struct aligned_struct {}; template <> struct CUTE_ALIGNAS( 1) aligned_struct< 1> {}; diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp index 571ac089..9e70e87f 100644 --- a/include/cute/container/array.hpp +++ b/include/cute/container/array.hpp @@ -30,20 +30,20 @@ **************************************************************************************************/ #pragma once -#include -#include - #include +#include +#include + namespace cute { -template +template struct array { using value_type = T; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; + using size_type = size_t; + using difference_type = ptrdiff_t; using reference = value_type&; using const_reference = const value_type&; using pointer = value_type*; @@ -184,7 +184,7 @@ struct array CUTE_HOST_DEVICE constexpr void swap(array& other) { - using std::swap; + using CUTE_STL_NAMESPACE::swap; for (size_type i = 0; i < size(); ++i) { swap((*this)[i], other[i]); } @@ -194,11 +194,11 @@ struct array }; -template +template CUTE_HOST_DEVICE constexpr bool operator==(array const& lhs, array const& rhs) { - for (std::size_t i = 0; i < N; ++i) { + for (size_t i = 0; i < N; ++i) { if (lhs[i] != rhs[i]) { return false; } @@ -206,21 +206,21 @@ bool operator==(array const& lhs, array const& rhs) return true; } -template +template CUTE_HOST_DEVICE constexpr void clear(array& a) { a.fill(T(0)); } -template +template CUTE_HOST_DEVICE constexpr void fill(array& a, T const& value) { a.fill(value); } -template +template CUTE_HOST_DEVICE constexpr void swap(array& a, array& b) { @@ -234,12 +234,16 @@ void swap(array& a, array& b) // Specialize tuple-related functionality for cute::array // +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif namespace cute { -template +template CUTE_HOST_DEVICE constexpr T& get(array& a) { @@ -247,15 +251,15 @@ T& get(array& a) return a[I]; } -template +template CUTE_HOST_DEVICE constexpr T const& get(array const& a) { - static_assert(I < N, "Index out of range"); + static_assert(I < N, "Index out of range"); return a[I]; } -template +template CUTE_HOST_DEVICE constexpr T&& get(array&& a) { @@ -265,18 +269,66 @@ T&& get(array&& a) } // end namespace cute +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { -template +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template struct tuple_size> - : std::integral_constant + : cute::integral_constant {}; -template +template struct tuple_element> { using type = T; }; -} // end std +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namepsace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/array_aligned.hpp b/include/cute/container/array_aligned.hpp index b1b35727..6bf9da39 100644 --- a/include/cute/container/array_aligned.hpp +++ b/include/cute/container/array_aligned.hpp @@ -30,247 +30,13 @@ **************************************************************************************************/ #pragma once -#include - +#include #include -#include -#include namespace cute { -template -struct array_aligned - : public aligned_struct -{ - /// Make sure the Alignment makes sense wrt the size of elements. - static_assert(Alignment == 16 || Alignment >= sizeof(T), "Alignment is too small"); - /// Alignment must be a power of two - static_assert(has_single_bit(Alignment), "Alignment must be a power of two"); - - using value_type = T; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; - using reference = value_type&; - using const_reference = const value_type&; - using pointer = value_type*; - using const_pointer = const value_type*; - using iterator = pointer; - using const_iterator = const_pointer; - - CUTE_HOST_DEVICE constexpr - reference operator[](size_type pos) - { - return begin()[pos]; - } - - CUTE_HOST_DEVICE constexpr - const_reference operator[](size_type pos) const - { - return begin()[pos]; - } - - CUTE_HOST_DEVICE constexpr - reference front() - { - return *begin(); - } - - CUTE_HOST_DEVICE constexpr - const_reference front() const - { - return *begin(); - } - - CUTE_HOST_DEVICE constexpr - reference back() - { - // return *rbegin(); - return operator[](N-1); - } - - CUTE_HOST_DEVICE constexpr - const_reference back() const - { - // return *rbegin(); - return operator[](N-1); - } - - CUTE_HOST_DEVICE constexpr - T* data() - { - return reinterpret_cast(storage); - } - - CUTE_HOST_DEVICE constexpr - T const* data() const - { - return reinterpret_cast(storage); - } - - CUTE_HOST_DEVICE constexpr - iterator begin() - { - return data(); - } - - CUTE_HOST_DEVICE constexpr - const_iterator begin() const - { - return data(); - } - - CUTE_HOST_DEVICE constexpr - const_iterator cbegin() - { - return begin(); - } - - CUTE_HOST_DEVICE constexpr - const_iterator cbegin() const - { - return begin(); - } - - CUTE_HOST_DEVICE constexpr - iterator end() - { - return data() + size(); - } - - CUTE_HOST_DEVICE constexpr - const_iterator end() const - { - return data() + size(); - } - - CUTE_HOST_DEVICE constexpr - const_iterator cend() - { - return end(); - } - - CUTE_HOST_DEVICE constexpr - const_iterator cend() const - { - return end(); - } - - CUTE_HOST_DEVICE constexpr - bool empty() const - { - return size() == 0; - } - - CUTE_HOST_DEVICE constexpr - size_type size() const - { - return N; - } - - CUTE_HOST_DEVICE constexpr - size_type max_size() const - { - return size(); - } - - CUTE_HOST_DEVICE constexpr - void fill(T const& value) - { - for (auto& e : *this) { - e = value; - } - } - - CUTE_HOST_DEVICE constexpr - void clear() - { - fill(T(0)); - } - - // Not private, we want trivial type - //private: - - /// Storage type to use for Elements - using StorageType = typename uint_byte(Alignment)>::type; - - /// Ensure that there's enough storage for all elements - static_assert(sizeof(StorageType) <= Alignment, "StorageType is too big for given alignment"); - - /// Number of elements in the storage - static constexpr std::size_t storageN = (sizeof(T)*N + sizeof(StorageType) - 1) / sizeof(StorageType); - - /// The storage. - StorageType storage[storageN > 0 ? storageN : 1]; -}; - -// -// Operators -// - -template -CUTE_HOST_DEVICE constexpr -void clear(array_aligned& a) -{ - a.clear(); -} - -template -CUTE_HOST_DEVICE constexpr -void fill(array_aligned& a, T const& value) -{ - a.fill(value); -} - -} // end namespace cute - -// -// Specialize tuple-related functionality for cute::array -// - -#include - -namespace cute -{ - -template -CUTE_HOST_DEVICE constexpr -T& get(array_aligned& a) -{ - static_assert(I < N, "Index out of range"); - return a[I]; -} - -template -CUTE_HOST_DEVICE constexpr -T const& get(array_aligned const& a) -{ - static_assert(I < N, "Index out of range"); - return a[I]; -} - -template -CUTE_HOST_DEVICE constexpr -T&& get(array_aligned&& a) -{ - static_assert(I < N, "Index out of range"); - return std::move(a[I]); -} +template +struct CUTE_ALIGNAS(Alignment) array_aligned : cute::array {}; } // end namespace cute - -namespace std -{ - -template -struct tuple_size> - : std::integral_constant -{}; - -template -struct tuple_element> -{ - using type = T; -}; - -} // end std diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index a217a671..e3fd8ee4 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -38,6 +38,7 @@ #include #include // sizeof_bits +#include namespace cute { @@ -45,7 +46,7 @@ namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// /// Statically sized array for any data type -template +template class array_subbyte { public: @@ -54,22 +55,15 @@ class array_subbyte static constexpr int kSizeBits = sizeof_bits::value * N; /// Storage type - using Storage = typename std::conditional< - (kSizeBits % 32) == 0, - uint32_t, - typename std::conditional< - (kSizeBits % 16) == 0, - uint16_t, - uint8_t - >::type - >::type; - + using Storage = conditional_t<(kSizeBits % 32) == 0, uint32_t, + conditional_t<(kSizeBits % 16) == 0, uint16_t, + uint8_t>>; /// Number of logical elements per stored object static constexpr int kElementsPerStoredItem = sizeof_bits::value / sizeof_bits::value; /// Number of storage elements - static constexpr std::size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; + static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; /// Bitmask for covering one item static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits::value) - 1); @@ -82,8 +76,8 @@ class array_subbyte using pointer = value_type*; using const_pointer = value_type const*; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; + using size_type = size_t; + using difference_type = ptrdiff_t; // // References @@ -110,7 +104,7 @@ class array_subbyte /// Assignment CUTE_HOST_DEVICE constexpr reference& operator=(T x) { - Storage item = (reinterpret_cast(x) & bit_mask_); + Storage item = (x & bit_mask_); Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits::value))); *ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits::value))); return *this; @@ -118,34 +112,21 @@ class array_subbyte CUTE_HOST_DEVICE constexpr T get() const { - Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); - return reinterpret_cast(item); + if constexpr (is_same::value) { + // Extract to bool -- potentially faster impl + return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); + } else { + // Extract to T + Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); + return reinterpret_cast(item); + } } - /// Extract to type T -- disable if T == bool - template ::value)> + /// Extract to type T CUTE_HOST_DEVICE constexpr operator T() const { return get(); } - - // Extract to bool -- potentially faster impl - CUTE_HOST_DEVICE constexpr - operator bool() const { - return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); - } - - /// Explicit cast to int - CUTE_HOST_DEVICE constexpr - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to float - CUTE_HOST_DEVICE constexpr - explicit operator float() const { - return float(get()); - } }; /// Reference object extracts sub-byte items @@ -169,34 +150,21 @@ class array_subbyte CUTE_HOST_DEVICE constexpr const T get() const { - Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); - return reinterpret_cast(item); + if constexpr (is_same::value) { + // Extract to bool -- potentially faster impl + return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); + } else { + // Extract to T + Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); + return reinterpret_cast(item); + } } - /// Extract to type T -- disable if T == bool - template ::value)> + /// Extract to type T CUTE_HOST_DEVICE constexpr operator T() const { return get(); } - - // Extract to bool -- potentially faster impl - CUTE_HOST_DEVICE constexpr - operator bool() const { - return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); - } - - /// Explicit cast to int - CUTE_HOST_DEVICE constexpr - explicit operator int() const { - return int(get()); - } - - /// Explicit cast to float - CUTE_HOST_DEVICE constexpr - explicit operator float() const { - return float(get()); - } }; // @@ -543,14 +511,14 @@ class array_subbyte // Operators // -template +template CUTE_HOST_DEVICE constexpr void clear(array_subbyte& a) { a.clear(); } -template +template CUTE_HOST_DEVICE constexpr void fill(array_subbyte& a, T const& value) { @@ -565,12 +533,16 @@ void fill(array_subbyte& a, T const& value) // Specialize tuple-related functionality for cute::array_subbyte // +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif namespace cute { -template +template CUTE_HOST_DEVICE constexpr T& get(array_subbyte& a) { @@ -578,7 +550,7 @@ T& get(array_subbyte& a) return a[I]; } -template +template CUTE_HOST_DEVICE constexpr T const& get(array_subbyte const& a) { @@ -586,7 +558,7 @@ T const& get(array_subbyte const& a) return a[I]; } -template +template CUTE_HOST_DEVICE constexpr T&& get(array_subbyte&& a) { @@ -596,18 +568,66 @@ T&& get(array_subbyte&& a) } // end namespace cute +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { -template +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template struct tuple_size> - : std::integral_constant + : cute::integral_constant {}; -template +template struct tuple_element> { using type = T; }; +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + } // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/array_view.hpp b/include/cute/container/array_view.hpp deleted file mode 100644 index 51b3ccc0..00000000 --- a/include/cute/container/array_view.hpp +++ /dev/null @@ -1,274 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include -#include - -#include - -namespace cute -{ - -template -struct array_view -{ - using value_type = T; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; - using reference = value_type&; - using const_reference = const value_type&; - using pointer = value_type*; - using const_pointer = const value_type*; - using iterator = pointer; - using const_iterator = const_pointer; - - array_view(array& a) - : __elems_(a.data()) {} - - CUTE_HOST_DEVICE - reference operator[](size_type pos) - { - return begin()[pos]; - } - - CUTE_HOST_DEVICE - const_reference operator[](size_type pos) const - { - return begin()[pos]; - } - - CUTE_HOST_DEVICE - reference front() - { - return *begin(); - } - - CUTE_HOST_DEVICE - const_reference front() const - { - return *begin(); - } - - CUTE_HOST_DEVICE - reference back() - { - // return *rbegin(); - return operator[](N-1); - } - - CUTE_HOST_DEVICE - const_reference back() const - { - // return *rbegin(); - return operator[](N-1); - } - - CUTE_HOST_DEVICE - T* data() - { - return __elems_; - } - - CUTE_HOST_DEVICE - const T* data() const - { - return __elems_; - } - - CUTE_HOST_DEVICE - iterator begin() - { - return data(); - } - - CUTE_HOST_DEVICE - const_iterator begin() const - { - return data(); - } - - CUTE_HOST_DEVICE - const_iterator cbegin() - { - return begin(); - } - - CUTE_HOST_DEVICE - const_iterator cbegin() const - { - return begin(); - } - - CUTE_HOST_DEVICE - iterator end() - { - return data() + size(); - } - - CUTE_HOST_DEVICE - const_iterator end() const - { - return data() + size(); - } - - CUTE_HOST_DEVICE - const_iterator cend() - { - return end(); - } - - CUTE_HOST_DEVICE - const_iterator cend() const - { - return end(); - } - - CUTE_HOST_DEVICE constexpr - bool empty() const - { - return size() == 0; - } - - CUTE_HOST_DEVICE constexpr - size_type size() const - { - return N; - } - - CUTE_HOST_DEVICE constexpr - size_type max_size() const - { - return size(); - } - - CUTE_HOST_DEVICE - void fill(const T& value) - { - for(auto& e : *this) - { - e = value; - } - } - - CUTE_HOST_DEVICE - void swap(array_view& other) - { - using std::swap; - swap(__elems_, other.__elems_); - } - - value_type* __elems_; -}; - - -template -CUTE_HOST_DEVICE -bool operator==(const array_view& lhs, const array_view& rhs) -{ - for(std::size_t i = 0; i < N; ++i) - { - if(lhs[i] != rhs[i]) return false; - } - - return true; -} - -template -CUTE_HOST_DEVICE -void clear(array_view& a) -{ - a.fill(T(0)); -} - -template -CUTE_HOST_DEVICE -void swap(array_view& a, array_view& b) -{ - a.swap(b); -} - -} // end cute - - -// -// Specialize tuple-related functionality for cute::array_view -// - -#include - -namespace cute -{ - -template -CUTE_HOST_DEVICE constexpr -T& -get(array_view& a) -{ - static_assert(I < N, "Index out of range"); - return a[I]; -} - -template -CUTE_HOST_DEVICE constexpr -const T& -get(const array_view& a) -{ - static_assert(I < N, "Index out of range"); - return a[I]; -} - -template -CUTE_HOST_DEVICE constexpr -T&& -get(array_view&& a) -{ - static_assert(I < N, "Index out of range"); - return std::move(a[I]); -} - -} // end namespace cute - -namespace std -{ - -template -struct tuple_size> - : std::integral_constant -{}; - -template -struct tuple_element> -{ - using type = T; -}; - -} // end std diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp index 06b08754..5398e327 100644 --- a/include/cute/container/bit_field.hpp +++ b/include/cute/container/bit_field.hpp @@ -60,7 +60,7 @@ struct bit_field (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; using storage_type = cute::uint_bit_t; - static_assert(sizeof(OtherValueType) == sizeof(value_type) || std::is_same::value, + static_assert(sizeof(OtherValueType) == sizeof(value_type) || is_same::value, "sizeof(OtherValueType) must be same as sizeof(value_type)."); // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) diff --git a/include/cute/container/cuda_types.hpp b/include/cute/container/cuda_types.hpp new file mode 100644 index 00000000..057d7464 --- /dev/null +++ b/include/cute/container/cuda_types.hpp @@ -0,0 +1,175 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +namespace cute +{ + +// +// dim3 +// + +using dim3 = ::dim3; + +template +CUTE_HOST_DEVICE constexpr +uint32_t& get(dim3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t const& get(dim3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t&& get(dim3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return std::move(a.x); + } else if constexpr (I == 1) { + return std::move(a.y); + } else if constexpr (I == 2) { + return std::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +// +// uint3 +// + +using uint3 = ::uint3; + +template +CUTE_HOST_DEVICE constexpr +uint32_t& get(uint3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t const& get(uint3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t&& get(uint3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return std::move(a.x); + } else if constexpr (I == 1) { + return std::move(a.y); + } else if constexpr (I == 2) { + return std::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +} // end namespace cute diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index 1b3ffa42..ab6d37dc 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -30,34 +30,16 @@ **************************************************************************************************/ #pragma once -#include -#include - #include #include - #include // cute::true_type, cute::false_type -//#include // Advanced optimizations - -#if 0 -// -// Use of agency::tuple is functional, but is over-engineered for our purposes... -// This tends to result in slow compilation times and unintentionally propagated cvref types -// +#include -#include +#include -namespace cute -{ - -using agency::tuple; - -using agency::make_tuple; -using agency::tuple_cat; - -} // end namespace cute -#endif +//#include // Advanced optimizations +// // cute::tuple is like std::tuple, with two differences. // // 1. It works on both host and device. @@ -68,12 +50,12 @@ using agency::tuple_cat; // but do _not_ include references like int& or float&. // (See std::tie for an example of a tuple of references.) // -// This is simplified over the implementation in std:: and agency:: by ignoring much of +// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of // the conversion SFINAE, special overloading, and avoiding cvref template types. // Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding // construction calls, and ignoring any need for unique element addresses. // -// Over the agency::tuple implementation, this appears to accelerate compilation times by over 3x. +// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. namespace cute { @@ -91,7 +73,7 @@ namespace detail // EBO always "holds" a single value of type T. // N is like an array index that TupleBase uses // to access the desired tuple element. -template ::value> +template ::value> struct EBO; // Specialization for types T that have no data; @@ -99,7 +81,7 @@ struct EBO; // integral_constant, Int, // and any other semiregular type // for which std::is_empty_v is true. -template +template struct EBO { CUTE_HOST_DEVICE constexpr @@ -109,7 +91,7 @@ struct EBO EBO(T const&) {} }; -template +template CUTE_HOST_DEVICE constexpr T getv(EBO const&) { return {}; } @@ -117,7 +99,7 @@ CUTE_HOST_DEVICE constexpr T getv(EBO const&) // the "dynamic tuple leaf." Valid T here include int, // any other integral or floating-point type, // or any semiregular type for which std::is_empty_v is false. -template +template struct EBO { CUTE_HOST_DEVICE constexpr @@ -130,15 +112,15 @@ struct EBO T t_; }; -template +template CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) { return x.t_; } -template +template CUTE_HOST_DEVICE constexpr T& getv(EBO& x) { return x.t_; } -template +template CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) { return static_cast(x.t_); } @@ -152,8 +134,8 @@ struct TupleBase; // compile-time integer values in a single type. // We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice, // as the type alias TupleBase below indicates. -template -struct TupleBase, T...> +template +struct TupleBase, T...> : EBO... { CUTE_HOST_DEVICE constexpr @@ -166,39 +148,50 @@ struct TupleBase, T...> template CUTE_HOST_DEVICE constexpr - TupleBase(TupleBase, U...> const& u) + TupleBase(TupleBase, U...> const& u) : EBO(getv(static_cast const&>(u)))... {} }; } // end namespace detail -// make_index_sequence returns index_sequence<0, 1, ..., K-1>. -template -using TupleBase = detail::TupleBase, T...>; +// Attempting to use the following commented-out alias +// in the declaration of `struct tuple` causes MSVC 2022 build errors. +// +//template +//using TupleBase = detail::TupleBase, T...>; // This is the actual cute::tuple class. // The storage (if any) lives in TupleBase's EBO base classes. +// +// Inheriting from the above alias TupleBase +// causes MSVC 2022 build errors when assigning one tuple to another: +// +// illegal member initialization: +// 'TupleBase< /* template arguments */ >' is not a base or member +// +// Not using the alias or any kind of alias fixed the errors. +// In summary: this is verbose as a work-around for MSVC build errors. template -struct tuple : TupleBase +struct tuple : detail::TupleBase, T...> { CUTE_HOST_DEVICE constexpr tuple() {} template CUTE_HOST_DEVICE constexpr - tuple(U const&... u) : TupleBase(u...) {} + tuple(U const&... u) : detail::TupleBase, T...>(u...) {} template CUTE_HOST_DEVICE constexpr tuple(tuple const& u) - : TupleBase(static_cast const&>(u)) {} + : detail::TupleBase, T...>(static_cast, U...> const&>(u)) {} }; // // get for cute::tuple (just like std::get for std::tuple) // -template +template CUTE_HOST_DEVICE constexpr decltype(auto) get(tuple const& t) noexcept @@ -207,7 +200,7 @@ get(tuple const& t) noexcept return detail::getv(t); } -template +template CUTE_HOST_DEVICE constexpr decltype(auto) get(tuple& t) noexcept @@ -216,7 +209,7 @@ get(tuple& t) noexcept return detail::getv(t); } -template +template CUTE_HOST_DEVICE constexpr decltype(auto) get(tuple&& t) noexcept @@ -226,21 +219,19 @@ get(tuple&& t) noexcept } // -// Custom is_tuple trait simply checks the existence of std::tuple_size +// Custom is_tuple trait simply checks the existence of tuple_size // and assumes std::get(.), std::tuple_element // namespace detail { template -std::integral_constant::value >= 0> has_tuple_size(int); - -template -std::false_type has_tuple_size(...); +auto has_tuple_size( T*) -> integral_constant::value>; +auto has_tuple_size(...) -> false_type; } // end namespace detail template -struct is_tuple : decltype(detail::has_tuple_size(0)) {}; +struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; // // make_tuple (value-based implementation) @@ -265,11 +256,11 @@ make_tuple(T const&... t) namespace detail { template + size_t... I0, size_t... I1> CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, - std::index_sequence, std::index_sequence) + index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)...); } @@ -298,8 +289,8 @@ auto tuple_cat(T0 const& t0, T1 const& t1) { return detail::tuple_cat(t0, t1, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}); + make_index_sequence::value>{}, + make_index_sequence::value>{}); } template @@ -317,41 +308,41 @@ tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) namespace detail { template + size_t... I0, size_t... I1> CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, - std::index_sequence, std::index_sequence) + index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)...); } template + size_t... I0, size_t... I1, size_t... I2> CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, - std::index_sequence, std::index_sequence, std::index_sequence) + index_sequence, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); } template + size_t... I0, size_t... I1, size_t... I2, size_t... I3> CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, - std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) + index_sequence, index_sequence, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); } template + size_t... I0, size_t... I1, size_t... I2, size_t... I3, size_t... I4> CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, - std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) + index_sequence, index_sequence, index_sequence, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); } @@ -380,8 +371,8 @@ auto tuple_cat(T0 const& t0, T1 const& t1) { return detail::tuple_cat(t0, t1, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}); + make_index_sequence::value>{}, + make_index_sequence::value>{}); } template @@ -390,9 +381,9 @@ auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) { return detail::tuple_cat(t0, t1, t2, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}); + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); } template @@ -401,10 +392,10 @@ auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) { return detail::tuple_cat(t0, t1, t2, t3, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}); + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); } template @@ -413,11 +404,11 @@ auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) { return detail::tuple_cat(t0, t1, t2, t3, t4, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}); + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); } template @@ -434,24 +425,24 @@ tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, namespace detail { -template +template struct tuple_cat_helper { - static constexpr cute::array ns = {Ns...}; + static constexpr cute::array ns = {Ns...}; - static constexpr std::size_t total_size() { - std::size_t sum = 0; - for (std::size_t n : ns) sum += n; + static constexpr size_t total_size() { + size_t sum = 0; + for (size_t n : ns) sum += n; return sum; } - static constexpr std::size_t total_size_ = total_size(); + static constexpr size_t total_size_ = total_size(); static constexpr auto values() { - cute::array outer_inner = {}; + cute::array outer_inner = {}; - std::size_t idx = 0; - for (std::size_t i = 0; i < ns.size(); ++i) { - for (std::size_t j = 0; j < ns[i]; ++j, ++idx) { + size_t idx = 0; + for (size_t i = 0; i < ns.size(); ++i) { + for (size_t j = 0; j < ns[i]; ++j, ++idx) { outer_inner[idx][0] = i; outer_inner[idx][1] = j; } @@ -460,23 +451,23 @@ struct tuple_cat_helper } static constexpr auto outer_inner_ = values(); - using total_sequence = std::make_index_sequence; + using total_sequence = make_index_sequence; }; -template +template CUTE_HOST_DEVICE constexpr auto -tuple_cat(Tuple const& t, std::index_sequence) +tuple_cat(Tuple const& t, index_sequence) { return cute::make_tuple(get(get(t))...); } template + size_t... I0, size_t... I1> CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, - std::index_sequence, std::index_sequence) + index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)...); } @@ -505,8 +496,8 @@ auto tuple_cat(T0 const& t0, T1 const& t1) { return detail::tuple_cat(t0, t1, - std::make_index_sequence::value>{}, - std::make_index_sequence::value>{}); + make_index_sequence::value>{}, + make_index_sequence::value>{}); } template @@ -514,8 +505,8 @@ CUTE_HOST_DEVICE constexpr auto tuple_cat(Tuples const&... ts) { - using Helper = detail::tuple_cat_helper::value...>; - return detail::tuple_cat(make_tuple(ts...), typename Helper::total_sequence{}); + using Helper = detail::tuple_cat_helper::value...>; + return detail::tuple_cat(cute::make_tuple(ts...), typename Helper::total_sequence{}); } #endif @@ -525,14 +516,14 @@ tuple_cat(Tuples const&... ts) namespace detail { -template +template CUTE_HOST_DEVICE constexpr auto equal_impl(TupleA const& a, TupleB const& b) { - if constexpr (I == std::tuple_size::value) { + if constexpr (I == tuple_size::value) { return cute::true_type{}; // Terminal: TupleA is exhausted - } else if constexpr (I == std::tuple_size::value) { + } else if constexpr (I == tuple_size::value) { return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted } else { return (get(a) == get(b)) && equal_impl(a,b); @@ -596,24 +587,15 @@ operator!=(TupleT const& t, TupleU const& u) // That said, see int_tuple for more explicitly named common comparison ops. // -// -// Shortcuts -// - -//using std::get; -using std::tuple_size; -using std::tuple_element; -using std::tuple_element_t; - // // Display utilities // namespace detail { -template +template CUTE_HOST_DEVICE void print_tuple(Tuple const& t, - std::index_sequence, char s = '(', char e = ')') + index_sequence, char s = '(', char e = ')') { using eat = int[]; using cute::print; @@ -622,9 +604,10 @@ CUTE_HOST_DEVICE void print_tuple(Tuple const& t, (print(e), 0)}; } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, - std::index_sequence, char s = '(', char e = ')') + index_sequence, char s = '(', char e = ')') { using eat = int[]; (void) eat {(void(os << s), 0), @@ -632,6 +615,7 @@ CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, (void(os << e), 0)}; return os; } +#endif // !defined(__CUDACC_RTC__) } // end namespace detail @@ -639,33 +623,80 @@ template ::value)> CUTE_HOST_DEVICE void print(Tuple const& t) { - return detail::print_tuple(t, std::make_index_sequence::value>{}); + return detail::print_tuple(t, make_index_sequence::value>{}); } +#if !defined(__CUDACC_RTC__) template ::value)> CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) { - return detail::print_tuple_os(os, t, std::make_index_sequence::value>{}); + return detail::print_tuple_os(os, t, make_index_sequence::value>{}); } +#endif // !defined(__CUDACC_RTC__) } // end namespace cute +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + // -// std:: compatability +// std compatibility // +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + template struct tuple_size> - : std::integral_constant + : cute::integral_constant {}; -template +template struct tuple_element> - : std::tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> {}; -} // end std +} // end namepsace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index c082a6da..4c6ddc09 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -30,6 +30,8 @@ **************************************************************************************************/ #pragma once +#include + namespace cute { @@ -47,7 +49,12 @@ struct type_list {}; // Specialize tuple-related functionality for cute::type_list // +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif + #include namespace cute @@ -55,30 +62,75 @@ namespace cute template CUTE_HOST_DEVICE constexpr -std::tuple_element_t> +CUTE_STL_NAMESPACE::tuple_element_t> get(type_list&) noexcept { return {}; } template CUTE_HOST_DEVICE constexpr -std::tuple_element_t> +CUTE_STL_NAMESPACE::tuple_element_t> get(type_list const& t) noexcept { return {}; } } // end namespace cute +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> +{}; + +} // end namespace std + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + template struct tuple_size> - : std::integral_constant + : cute::integral_constant {}; -template +template struct tuple_element> - : cute::type_c>::type> + : cute::type_c>::type> +{}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : cute::type_c>::type> {}; } // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 492d08cc..b73e2ec7 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -54,7 +54,7 @@ make_int_tuple(Ts const&... t) /** if rank(int) == 1, then get<0>(int) should work too */ -template >::value)> +template >::value)> CUTE_HOST_DEVICE constexpr decltype(auto) get(T&& t) noexcept @@ -65,7 +65,7 @@ get(T&& t) noexcept /** Custom recursive get for anything that implements get(.) */ -template +template CUTE_HOST_DEVICE constexpr decltype(auto) get(Tuple&& t) noexcept @@ -96,7 +96,7 @@ rank(IntTuple const& t) } template -using rank_t = decltype(rank(std::declval())); +using rank_t = decltype(rank(declval())); template static constexpr int rank_v = rank_t::value; @@ -196,7 +196,7 @@ depth(IntTuple const& t) } template -using depth_t = decltype(depth(std::declval())); +using depth_t = decltype(depth(declval())); template static constexpr int depth_v = depth_t::value; @@ -219,72 +219,6 @@ product(IntTuple const& a) CUTE_GCC_UNREACHABLE; } -// Work-around for some compiler versions (e.g., GCC 8.x) -// incorrectly not being able to compile certain -// legal C++ fold expressions inside generic lambdas. -// Issue is known to exist in GCC 8.4 and GCC 8.5. -// Work-around should be valid portable CUDA C++. -#if ! defined(CUTE_FOLD_GENERIC_LAMBDA_WORKAROUND) -# if defined(__GNUC__) && __GNUC__ == 8 -# define CUTE_FOLD_GENERIC_LAMBDA_WORKAROUND 1 -# endif -#endif - -#if defined(CUTE_FOLD_GENERIC_LAMBDA_WORKAROUND) -namespace impl { - -template -struct SubrangeProductImpl { - // GCC 8.4 accepts the fold expression here. If that doesn't work, - // the other branch (recursive operator()) is known to build - // with GCC 8.4 as well. The code does not enable recursion by default, - // as fold expressions might be easier for compilers to optimize. -#if 1 - template - CUTE_HOST_DEVICE constexpr auto - operator()(Args const&... args) const - { - return (Int<1>{} * ... * product(args)); - } -#else - CUTE_HOST_DEVICE constexpr Int<1> - operator()() const - { - return Int<1>{}; - } - - template - CUTE_HOST_DEVICE constexpr auto - operator()(Head const& head, Tail const&... tail) const - { - return (*this)(tail...) * product(head); - } -#endif // 1 -}; - -} // namespace impl - -#endif // defined(CUTE_FOLD_GENERIC_LAMBDA_WORKAROUND) - -// Product of a subrange -template -CUTE_HOST_DEVICE constexpr -auto -product(Tuple const& a) -{ - // Work around some compiler versions that do not accept - // the generic lambda in the else branch, by replacing - // the lambda with a function object. The work-around - // is legal C++17, but the original code might be easier - // for non-broken compilers to optimize, so it remains. -#if defined(CUTE_FOLD_GENERIC_LAMBDA_WORKAROUND) - impl::SubrangeProductImpl function_object; - return detail::apply(a, function_object, make_range{}); -#else - return detail::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }, make_range{}); -#endif // defined(CUTE_FOLD_GENERIC_LAMBDA_WORKAROUND) -} - template CUTE_HOST_DEVICE constexpr auto @@ -309,7 +243,7 @@ size(IntTuple const& a) } template -static constexpr int size_v = decltype(size(std::declval()))::value; +static constexpr int size_v = decltype(size(declval()))::value; // // sum @@ -380,10 +314,10 @@ shape_div(IntTupleA const& a, IntTupleB const& b) if constexpr (is_tuple::value) { // tuple tuple static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); - } else { // tuple int - auto const [result, rest] = fold(a, make_tuple(make_tuple(), b), + } else { // tuple int + auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), [] (auto const& init, auto const& ai) { - return make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); + return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); }); return result; } @@ -436,12 +370,40 @@ CUTE_HOST_DEVICE constexpr auto congruent(IntTupleA const& a, IntTupleB const& b) { - return bool_constant::value>{}; } template -using is_congruent = decltype(congruent(std::declval(), std::declval())); +using is_congruent = decltype(congruent(declval(), declval())); + +/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_congruent(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return true_type{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_congruent(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_congruent = decltype(weakly_congruent(declval(), declval())); /** Test if Shape B is compatible with Shape A: * Any coordinate into A can also be used as a coordinate into B @@ -471,7 +433,36 @@ compatible(IntTupleA const& a, IntTupleB const& b) } template -using is_compatible = decltype(compatible(std::declval(), std::declval())); +using is_compatible = decltype(compatible(declval(), declval())); + +/** Test if Shape B is weakly compatible with Shape A: + * Shape B divides Shape A at some level of refinement + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a % size(b) == Int<0>{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_compatible = decltype(weakly_compatible(declval(), declval())); /** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> */ @@ -565,13 +556,13 @@ Tuple make_int_tuple_from(Ts const&... ts) { Tuple result = Tuple{}; - fill_int_tuple_from(result, make_tuple(ts...)); + fill_int_tuple_from(result, cute::make_tuple(ts...)); return result; } /** Convert a tuple to a flat homogeneous array of type T * \code - * auto tup = make_tuple(Int<1>{}, make_tuple(6,3,Int<3>{}),4,Int<2>{}); + * auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{}); * cute::array result = to_array(tup); // [1,6,3,3,4,2] * \endcode */ @@ -625,7 +616,7 @@ elem_less(IntTupleA const& a, IntTupleB const& b); namespace detail { -template +template CUTE_HOST_DEVICE constexpr auto lex_less_impl(TupleA const& a, TupleB const& b) @@ -641,7 +632,7 @@ lex_less_impl(TupleA const& a, TupleB const& b) CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr auto colex_less_impl(TupleA const& a, TupleB const& b) @@ -651,15 +642,15 @@ colex_less_impl(TupleA const& a, TupleB const& b) } else if constexpr (I == tuple_size::value) { return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted } else { - constexpr std::size_t A = tuple_size::value - 1 - I; - constexpr std::size_t B = tuple_size::value - 1 - I; + constexpr size_t A = tuple_size::value - 1 - I; + constexpr size_t B = tuple_size::value - 1 - I; return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); } CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr auto elem_less_impl(TupleA const& a, TupleB const& b) diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index fe937ee7..cdbbb5ac 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -80,33 +80,24 @@ make_coord(Ts const&... t) { } -template > +template > struct Layout - : private cute::tuple // EBO for static layouts + : private cute::tuple // EBO for static layouts { - // Avoid bad CTAD: - // Layout smem = GMMA::Layout_MN_SW128_Atom; - // Should fail because smem is a ComposedLayout (SwizzleLayout) and not a Layout - static_assert(is_integral::value || is_tuple::value); - // Expensive in compilation time... - //static_assert(is_congruent::value, - // "Shape and Stride must have the same hierarchical structure"); - //static_assert(is_integral::value || is_tuple::value); + //static_assert(is_congruent::value, "Shape and Stride must be congruent"); // NOTE: This defaults static Shapes/Strides correctly, but not dynamic CUTE_HOST_DEVICE constexpr - Layout(LogicalShape const& logical_shape = {}, - LogicalStride const& logical_stride = {}) - : cute::tuple(logical_shape, logical_stride) + Layout(Shape const& shape = {}, Stride const& stride = {}) + : cute::tuple(shape, stride) {} // // Accessors // - static constexpr int rank = rank_v ; + static constexpr int rank = rank_v; CUTE_HOST_DEVICE constexpr decltype(auto) @@ -124,28 +115,28 @@ struct Layout CUTE_HOST_DEVICE constexpr decltype(auto) shape() { - return get<0,I...>(static_cast&>(*this)); + return get<0,I...>(static_cast&>(*this)); } template CUTE_HOST_DEVICE constexpr decltype(auto) shape() const { - return get<0,I...>(static_cast const&>(*this)); + return get<0,I...>(static_cast const&>(*this)); } template CUTE_HOST_DEVICE constexpr decltype(auto) stride() { - return get<1,I...>(static_cast&>(*this)); + return get<1,I...>(static_cast&>(*this)); } template CUTE_HOST_DEVICE constexpr decltype(auto) stride() const { - return get<1,I...>(static_cast const&>(*this)); + return get<1,I...>(static_cast const&>(*this)); } // @@ -314,7 +305,6 @@ struct Layout #endif }; - template struct is_layout : false_type {}; template @@ -395,7 +385,15 @@ CUTE_HOST_DEVICE constexpr auto make_layout_like(Layout const& layout) { + auto any_zero = any_of(layout.stride(), [](auto d) { return is_constant<0, decltype(d)>{}; }); + if constexpr (any_zero) { + // If there are static-0 strides, then make a col-major layout that keeps those 0s + return make_layout(layout.shape(), + compact_col_major(filter_zeros(layout.stride(), layout.shape()))); + } else if constexpr (is_static::value && is_static::value) { + // If the layout is fully static, then make a layout that follows the same order as the strides + // Assumes the strides are unique return make_ordered_layout(layout.shape(), layout.stride()); } else { return make_layout(layout.shape()); @@ -404,17 +402,18 @@ make_layout_like(Layout const& layout) CUTE_GCC_UNREACHABLE; } +// // Make a layout of the same shape, -// with mode-0 being colmajor then following the the mode order in layout +// with mode-0 being colmajor then following the mode order in layout +// template CUTE_HOST_DEVICE constexpr auto make_fragment_like(Layout const& layout) { - auto shape = replace<0>(layout.shape(), size<0>(layout)); - auto order = replace<0>(layout.stride(), Int<0>{}); - if constexpr (is_static::value && is_static::value) { - return make_ordered_layout(shape, order); + constexpr int R = Layout::rank; + if constexpr (R > 1 && is_static::value && is_static::value) { + return tiled_product(make_layout(shape<0>(layout)), make_ordered_layout(take<1,R>(layout))); } else { return make_layout(layout.shape()); } @@ -422,6 +421,19 @@ make_fragment_like(Layout const& layout) CUTE_GCC_UNREACHABLE; } +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Shape const& shape) +{ + return make_layout(shape); +} + +// +// Make an identity layout that maps a coordinate to itself +// + template CUTE_HOST_DEVICE constexpr auto @@ -434,7 +446,7 @@ make_identity_layout(Shape const& shape) // Operations to manipulate Layouts like a tuple of pairs // -template +template CUTE_HOST_DEVICE constexpr auto get(Layout const& layout) @@ -555,7 +567,7 @@ cosize(Layout const& layout) } template -using cosize_t = decltype(cosize(std::declval())); +using cosize_t = decltype(cosize(declval())); template static constexpr int cosize_v = cosize_t::value; @@ -840,7 +852,7 @@ composition(Layout const& lhs, // NOTE: Should only flatten once for efficiency auto flat_shape = flatten(lhs.shape()); - auto flat_stride = flatten(lhs.stride()); + [[maybe_unused]] auto flat_stride = flatten(lhs.stride()); [[maybe_unused]] constexpr int R = rank(flat_shape); if constexpr (is_constant<0, RStride>::value) { @@ -857,9 +869,9 @@ composition(Layout const& lhs, auto result_shape_0 = take<0,R-1>(flat_shape); // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_1, rest_shape] = fold(result_shape_0, make_tuple(make_tuple(), rhs_shape), + auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), [] (auto const& init, auto const& si) { - return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + return cute::make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); }); // Jump into coalesce and append (rest_shape, get(lhs.stride()) @@ -871,18 +883,18 @@ composition(Layout const& lhs, auto result_stride_0 = take<0,R-1>(flat_stride); // Divide out the rhs_stride from the lhs.shape() - auto const [result_shape_1, rest_stride] = fold(result_shape_0, make_tuple(make_tuple(), rhs_stride), + auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), [] (auto const& init, auto const& di) { - return make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); + return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); }); // Apply any lhs.shape() changes to the stride auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_2, rest_shape] = fold(result_shape_1, make_tuple(make_tuple(), rhs_shape), + auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), [] (auto const& init, auto const& si) { - return make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + return cute::make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); }); // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) @@ -936,37 +948,37 @@ composition(Layout const& lhs, // @a result(i) != @a layout(j) // +namespace detail { + +// @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). template CUTE_HOST_DEVICE constexpr auto -complement(Layout const& layout, CoSizeHi const& cosize_hi) +complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) { - // Remove the stride-0 modes, the size-1 modes, and flatten the layout - auto flat_layout = filter(layout); - - if constexpr (is_constant<0, decltype(flat_layout.stride())>::value) { - // Special case for stride-0 layout + if constexpr (is_constant<0, Stride>::value) { + // Special case for irreducible rank-1 stride-0 layout return make_layout(cosize_hi); } else { // General case - constexpr int R = decltype(rank(flat_layout))::value; - static_assert(R == 1 || is_static::value, + constexpr int R = rank_v; + static_assert(R == 1 || is_static::value, "Dynamic-stride complement only for rank-1 layouts"); // Should just be a sort and a fold... // Then we could even handle dynamic strides (but they would destroy all static strides) auto result = fold(make_seq{}, - make_tuple(flat_layout.shape(), - flat_layout.stride(), - make_tuple(), - make_tuple(Int<1>{})), + cute::make_tuple(shape, + stride, + cute::make_tuple(), + cute::make_tuple(Int<1>{})), [](auto const& init, auto i) { auto curr_stride = cute::min(get<1>(init)); auto curr_idx = find(get<1>(init), curr_stride); auto curr_shape = get(get<0>(init)); - return make_tuple(remove(get<0>(init)), // Remove the curr shape + return cute::make_tuple(remove(get<0>(init)), // Remove the curr shape remove(get<1>(init)), // Remove the curr stride append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride @@ -986,12 +998,25 @@ complement(Layout const& layout, CoSizeHi const& cosize_hi) CUTE_GCC_UNREACHABLE; } +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout, CoSizeHi const& cosize_hi) +{ + static_assert(cute::is_integral::value, "Expected integral codomain size in complement."); + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize_hi); +} + template CUTE_HOST_DEVICE constexpr auto complement(Layout const& layout) { - return complement(layout, cosize(layout)); + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize(filter_layout)); } // @@ -1067,7 +1092,7 @@ right_inverse(Underscore const& _) // // Build the left-inverse of a layout // @pre is_static -// @pre not has_int0 // @a layout has no 0-strides (is injective) +// @pre @a layout is an injective function // @result A layout @a result such that // @a result(@a layout(i)) == i for all i < size(@a layout) // @result A layout @a result such that @@ -1090,40 +1115,62 @@ left_inverse(Underscore const& _) } // -// Max Common Vector +// Max Common Layout // -/* Return Int such that N is the maximum number of continguous elements +/* Return a layout that points to the maximum number of contiguous elements * that logically correspond in the layouts of @a a and @a b. This is, - * the number of elements that could reasonably be "vectorized" in the layouts. + * the elements that could reasonably be "vectorized" in the layouts. * - * @returns Int with N >= 1 - * @post For all 0 <= n < N, a(b[n]) == n (NOTE: Problems with negative strides/coords in this post-condition) + * @returns Layout R + * @post For all 0 <= i < size(R), a(R(i)) == i && b(R(i)) == i */ template CUTE_HOST_DEVICE constexpr auto -max_common_vector(Layout const& a, Layout const& b) +max_common_layout(Layout const& a, + Layout const& b) { - if constexpr (is_static>::value && - is_static>::value) + if constexpr (is_static::value && is_static::value && + is_static::value && is_static::value) { - auto result = coalesce(composition(a, right_inverse(b))); + Layout inv_b = right_inverse(b); + Layout common = coalesce(composition(a, inv_b)); - if constexpr (is_constant<1, decltype(stride<0>(result))>::value) { - return shape<0>(result); + if constexpr (is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, layout<0>(common)); } else { - return Int<1>{}; + return Layout<_1,_0>{}; } } else { - // Dynamic case NOTE: could weaken if we assume dynamic strides are large and multiples of the vector - return Int<1>{}; + // CASE: One of the layouts is dynamic, can't prove alignment+vectorization is valid + // NOTE: Could weaken if we assume dynamic shapes/strides obey alignment requirements + // (i.e. are large and multiples of the vector) + return Layout<_1,_0>{}; } CUTE_GCC_UNREACHABLE; } +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. This is, + * the number of elements that could reasonably be "vectorized" in the layouts. + * + * @returns Int with N >= 1 + * @post For all 0 <= n < N, a(b[n]) == n (NOTE: Problems with negative strides/coords in this post-condition) + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + Layout const& b) +{ + return size(max_common_layout(a, b)); +} + // // Zip // @@ -1450,11 +1497,13 @@ CUTE_HOST_DEVICE void print(Layout const& layout) print(layout.shape()); print(":"); print(layout.stride()); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) { return os << shape(layout) << ":" << stride(layout); } +#endif // Generic 2D Layout to console table template diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 33471e4f..a7ce47f1 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -36,6 +36,7 @@ #include #include #include +#include namespace cute { @@ -361,28 +362,75 @@ operator+(ScaledBasis const& t, constant) { template CUTE_HOST_DEVICE void print(ScaledBasis const& e) { - printf("%d:", N); print(e.value()); + print(e.value()); printf("@%d", N); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { - return os << N << ":" << e.value(); + return os << e.value() << "@" << N; } +#endif } // end namespace cute +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + template struct tuple_size> - : std::integral_constant + : cute::integral_constant {}; -template +template struct tuple_element> - : std::tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> {}; } // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/numeric/bfloat.hpp b/include/cute/numeric/bfloat.hpp index 94f64ab5..07bd7c66 100644 --- a/include/cute/numeric/bfloat.hpp +++ b/include/cute/numeric/bfloat.hpp @@ -43,9 +43,11 @@ using cutlass::bfloat16_t; // Display utilities // +#if !defined(__CUDACC_RTC__) CUTE_HOST std::ostream& operator<<(std::ostream& os, bfloat16_t const& v) { return os << float(v); } +#endif } // end namespace cute diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 3790ebd3..43e4dd63 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -30,7 +30,7 @@ **************************************************************************************************/ #pragma once -#include +#include //#if defined(__CUDA_ARCH__) //# include @@ -38,13 +38,37 @@ //# include //#endif -// With CUDA 11.4, builds show spurious "-Wconversion" warnings -// on line 656 of thrust/detail/type_traits.h. -// These pragmas suppress the warnings. +// Suppress warnings for code in Thrust headers. + +#if defined(_MSC_VER) + // We check for MSVC first, because MSVC also defines __GNUC__. + // It's common for non-GCC compilers that emulate GCC's behavior + // to define __GNUC__. + // + // thrust/complex.h triggers MSVC's warning on conversion + // from double to float (or const float) ("possible loss of data"). + // MSVC treats this as an error by default (at least with + // CUTLASS's default CMake configuration). +#pragma warning( push ) +#pragma warning( disable : 4244 ) +#elif defined(__GNUC__) + // With GCC + CUDA 11.4, builds show spurious "-Wconversion" + // warnings on line 656 of thrust/detail/type_traits.h. #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wconversion" +#endif + +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif + +#if defined(_MSC_VER) +#pragma warning( pop ) +#elif defined(__GNUC__) #pragma GCC diagnostic pop +#endif #include @@ -62,7 +86,11 @@ namespace cute //template //using complex = thrust::complex; +#if defined(__CUDACC_RTC__) +using cuda::std::complex; +#else using thrust::complex; +#endif template CUTE_HOST_DEVICE @@ -147,6 +175,7 @@ struct is_complex> { ////////////////////////////////////////////////////////////////////////////////////////////////// // Display utilities +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, complex const& z) { @@ -159,5 +188,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, complex const& z) return os << _r; } } +#endif } // end namespace cute diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index a08297f2..e2b29884 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -46,10 +46,12 @@ namespace cute // Signed integers // -using int8_t = std::int8_t; -using int16_t = std::int16_t; -using int32_t = std::int32_t; -using int64_t = std::int64_t; +using int2_t = cute::int2b_t; +using int4_t = cute::int4b_t; +using int8_t = CUTE_STL_NAMESPACE::int8_t; +using int16_t = CUTE_STL_NAMESPACE::int16_t; +using int32_t = CUTE_STL_NAMESPACE::int32_t; +using int64_t = CUTE_STL_NAMESPACE::int64_t; template struct int_bit; template <> struct int_bit< 2> { using type = cute::int2b_t; }; @@ -72,10 +74,14 @@ using int_byte_t = typename int_byte::type; // Unsigned integers // -using uint8_t = std::uint8_t; -using uint16_t = std::uint16_t; -using uint32_t = std::uint32_t; -using uint64_t = std::uint64_t; +using uint1_t = cute::uint1b_t; +using uint2_t = cute::uint2b_t; +using uint4_t = cute::uint4b_t; +using uint8_t = CUTE_STL_NAMESPACE::uint8_t; +using uint16_t = CUTE_STL_NAMESPACE::uint16_t; +using uint32_t = CUTE_STL_NAMESPACE::uint32_t; +using uint64_t = CUTE_STL_NAMESPACE::uint64_t; +using uint128_t = cute::uint128_t; template struct uint_bit; template <> struct uint_bit< 1> { using type = cute::uint1b_t; }; @@ -102,7 +108,7 @@ using uint_byte_t = typename uint_byte::type; template struct sizeof_bytes { - static constexpr std::size_t value = sizeof(T); + static constexpr size_t value = sizeof(T); }; template static constexpr int sizeof_bytes_v = sizeof_bytes::value; @@ -113,15 +119,15 @@ static constexpr int sizeof_bytes_v = sizeof_bytes::value; template struct sizeof_bits { - static constexpr std::size_t value = sizeof(T) * 8; + static constexpr size_t value = sizeof(T) * 8; }; template <> struct sizeof_bits { - static constexpr std::size_t value = 1; + static constexpr size_t value = 1; }; template struct sizeof_bits> { - static constexpr std::size_t value = Bits; + static constexpr size_t value = Bits; }; template static constexpr int sizeof_bits_v = sizeof_bits::value; diff --git a/include/cute/numeric/integer_sequence.hpp b/include/cute/numeric/integer_sequence.hpp index 73a83f76..1e2d6596 100644 --- a/include/cute/numeric/integer_sequence.hpp +++ b/include/cute/numeric/integer_sequence.hpp @@ -30,34 +30,46 @@ **************************************************************************************************/ #pragma once -#include // std::integer_sequence - #include +#include +#include namespace cute { -using std::integer_sequence; -using std::make_integer_sequence; +using CUTE_STL_NAMESPACE::integer_sequence; +using CUTE_STL_NAMESPACE::make_integer_sequence; namespace detail { template -struct make_integer_range_impl; +struct range_impl; template -struct make_integer_range_impl, Begin> { +struct range_impl, Begin> { using type = integer_sequence; }; +template +struct reverse_impl; + +template +struct reverse_impl> { + using type = integer_sequence; +}; + } // end namespace detail template -using make_integer_range = typename detail::make_integer_range_impl< +using make_integer_range = typename detail::range_impl< T, make_integer_sequence 0) ? (End-Begin) : 0>, Begin>::type; +template +using make_integer_sequence_reverse = typename detail::reverse_impl< + make_integer_sequence>::type; + // // Common aliases // @@ -70,19 +82,25 @@ using int_sequence = integer_sequence; template using make_int_sequence = make_integer_sequence; +template +using make_int_rsequence = make_integer_sequence_reverse; + template using make_int_range = make_integer_range; // index_sequence -template -using index_sequence = integer_sequence; +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; -template -using make_index_sequence = make_integer_sequence; +template +using make_index_rsequence = make_integer_sequence_reverse; -template -using make_index_range = make_integer_range; +template +using make_index_range = make_integer_range; // // Shortcuts @@ -94,46 +112,40 @@ using seq = int_sequence; template using make_seq = make_int_sequence; +template +using make_rseq = make_int_rsequence; + template using make_range = make_int_range; template -using tuple_seq = make_seq>::value>; - -} // end namespace cute +using tuple_seq = make_seq>::value>; +template +using tuple_rseq = make_rseq>::value>; // -// Specialize tuple-related functionality for cute::integer_sequence +// Specialize cute::tuple-traits for std::integer_sequence // -#include -#include +template +struct tuple_size> + : cute::integral_constant +{}; -namespace cute +template +struct tuple_element> { + constexpr static T idx[sizeof...(Is)] = {Is...}; + using type = cute::integral_constant; +}; -template +template CUTE_HOST_DEVICE constexpr -std::tuple_element_t> +tuple_element_t> get(integer_sequence) { static_assert(I < sizeof...(Ints), "Index out of range"); return {}; } } // end namespace cute - -namespace std -{ - -template -struct tuple_size> - : std::integral_constant -{}; - -template -struct tuple_element> - : std::tuple_element...>> -{}; - -} // end namespace std diff --git a/include/cute/numeric/integer_subbyte.hpp b/include/cute/numeric/integer_subbyte.hpp index 3d24a952..c20f485b 100644 --- a/include/cute/numeric/integer_subbyte.hpp +++ b/include/cute/numeric/integer_subbyte.hpp @@ -53,7 +53,7 @@ struct integer_subbyte static_assert(Bits <= 8*sizeof(Storage), "Require a subbyte of bits in integer_subbyte"); /// External type - using xint_t = typename std::conditional::type; + using xint_t = typename conditional::type; /// Bitmask for truncation from larger integers static constexpr Storage bits_mask_ = Storage((1 << Bits) - 1); @@ -166,7 +166,7 @@ using bin1_t = bool; #include -namespace std { +namespace CUTE_STL_NAMESPACE { template <> struct numeric_limits { @@ -230,4 +230,4 @@ struct numeric_limits { } // namespace std -#endif +#endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index 106763df..a0b5b075 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -39,7 +39,7 @@ namespace cute { template -struct constant : std::integral_constant { +struct constant : CUTE_STL_NAMESPACE::integral_constant { static constexpr T value = v; using value_type = T; using type = constant; @@ -56,7 +56,7 @@ using bool_constant = constant; using true_type = bool_constant; using false_type = bool_constant; -// +// // Traits // @@ -64,14 +64,14 @@ using false_type = bool_constant; // Use cute::is_integral to match both built-in integral types AND constant template -struct is_integral : bool_constant::value> {}; +struct is_integral : bool_constant::value> {}; template struct is_integral> : true_type {}; // is_static detects if an (abstract) value is defined completely by it's type (no members) template -struct is_static : bool_constant::value> {}; +struct is_static : bool_constant::value> {}; // is_constant detects if a type is a constant and if v is equal to a value @@ -95,45 +95,51 @@ struct is_constant &&> : bool_constant {}; template using Int = constant; -using _m32 = Int<-32>; -using _m24 = Int<-24>; -using _m16 = Int<-16>; -using _m12 = Int<-12>; -using _m10 = Int<-10>; -using _m9 = Int<-9>; -using _m8 = Int<-8>; -using _m7 = Int<-7>; -using _m6 = Int<-6>; -using _m5 = Int<-5>; -using _m4 = Int<-4>; -using _m3 = Int<-3>; -using _m2 = Int<-2>; -using _m1 = Int<-1>; -using _0 = Int<0>; -using _1 = Int<1>; -using _2 = Int<2>; -using _3 = Int<3>; -using _4 = Int<4>; -using _5 = Int<5>; -using _6 = Int<6>; -using _7 = Int<7>; -using _8 = Int<8>; -using _9 = Int<9>; -using _10 = Int<10>; -using _12 = Int<12>; -using _16 = Int<16>; -using _24 = Int<24>; -using _32 = Int<32>; -using _64 = Int<64>; -using _96 = Int<96>; -using _128 = Int<128>; -using _192 = Int<192>; -using _256 = Int<256>; -using _512 = Int<512>; -using _1024 = Int<1024>; -using _2048 = Int<2048>; -using _4096 = Int<4096>; -using _8192 = Int<8192>; +using _m32 = Int<-32>; +using _m24 = Int<-24>; +using _m16 = Int<-16>; +using _m12 = Int<-12>; +using _m10 = Int<-10>; +using _m9 = Int<-9>; +using _m8 = Int<-8>; +using _m7 = Int<-7>; +using _m6 = Int<-6>; +using _m5 = Int<-5>; +using _m4 = Int<-4>; +using _m3 = Int<-3>; +using _m2 = Int<-2>; +using _m1 = Int<-1>; +using _0 = Int<0>; +using _1 = Int<1>; +using _2 = Int<2>; +using _3 = Int<3>; +using _4 = Int<4>; +using _5 = Int<5>; +using _6 = Int<6>; +using _7 = Int<7>; +using _8 = Int<8>; +using _9 = Int<9>; +using _10 = Int<10>; +using _12 = Int<12>; +using _16 = Int<16>; +using _24 = Int<24>; +using _32 = Int<32>; +using _64 = Int<64>; +using _96 = Int<96>; +using _128 = Int<128>; +using _192 = Int<192>; +using _256 = Int<256>; +using _512 = Int<512>; +using _1024 = Int<1024>; +using _2048 = Int<2048>; +using _4096 = Int<4096>; +using _8192 = Int<8192>; +using _16384 = Int<16384>; +using _32768 = Int<32768>; +using _65536 = Int<65536>; +using _131072 = Int<131072>; +using _262144 = Int<262144>; +using _524288 = Int<524288>; /***************/ /** Operators **/ @@ -198,7 +204,7 @@ CUTE_BINARY_OP(<=); // template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator*(constant, U) { @@ -206,7 +212,7 @@ operator*(constant, U) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator*(U, constant) { @@ -214,7 +220,7 @@ operator*(U, constant) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator/(constant, U) { @@ -222,7 +228,7 @@ operator/(constant, U) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator%(U, constant) { @@ -230,7 +236,7 @@ operator%(U, constant) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator%(U, constant) { @@ -238,7 +244,7 @@ operator%(U, constant) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator%(constant, U) { @@ -246,7 +252,7 @@ operator%(constant, U) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator&(constant, U) { @@ -254,7 +260,7 @@ operator&(constant, U) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr constant operator&(U, constant) { @@ -262,7 +268,7 @@ operator&(U, constant) { } template ::value && !bool(t))> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && !bool(t))> CUTE_HOST_DEVICE constexpr constant operator&&(constant, U) { @@ -270,7 +276,7 @@ operator&&(constant, U) { } template ::value && !bool(t))> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && !bool(t))> CUTE_HOST_DEVICE constexpr constant operator&&(U, constant) { @@ -278,7 +284,7 @@ operator&&(U, constant) { } template ::value && bool(t))> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && bool(t))> CUTE_HOST_DEVICE constexpr constant operator||(constant, U) { @@ -286,7 +292,7 @@ operator||(constant, U) { } template ::value && bool(t))> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && bool(t))> CUTE_HOST_DEVICE constexpr constant operator||(U, constant) { @@ -314,7 +320,7 @@ operator||(U, constant) { } \ \ template ::value)> \ + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> \ CUTE_HOST_DEVICE constexpr \ auto \ OP (constant, U u) { \ @@ -322,7 +328,7 @@ operator||(U, constant) { } \ \ template ::value)> \ + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> \ CUTE_HOST_DEVICE constexpr \ auto \ OP (T t, constant) { \ @@ -356,7 +362,7 @@ safe_div(constant, constant) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr auto safe_div(constant, U u) { @@ -364,7 +370,7 @@ safe_div(constant, U u) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr auto safe_div(T t, constant) { @@ -376,7 +382,7 @@ safe_div(T t, constant) { template CUTE_HOST_DEVICE constexpr decltype(auto) -conditional_return(std::true_type, TrueType&& t, FalseType&&) { +conditional_return(true_type, TrueType&& t, FalseType&&) { return static_cast(t); } @@ -385,7 +391,7 @@ conditional_return(std::true_type, TrueType&& t, FalseType&&) { template CUTE_HOST_DEVICE constexpr decltype(auto) -conditional_return(std::false_type, TrueType&&, FalseType&& f) { +conditional_return(false_type, TrueType&&, FalseType&& f) { return static_cast(f); } @@ -397,6 +403,18 @@ conditional_return(bool b, TrueType const& t, FalseType const& f) { return b ? t : f; } +// TrueType and FalseType don't require a common type +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(TrueType const& t, FalseType const& f) { + if constexpr (b) { + return t; + } else { + return f; + } +} + // // Display utilities // @@ -406,9 +424,11 @@ CUTE_HOST_DEVICE void print(integral_constant const&) { printf("_%d", N); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant const&) { return os << "_" << N; } +#endif } // end namespace cute diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 03e83799..a90716a6 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -30,16 +30,10 @@ **************************************************************************************************/ #pragma once -#include - -#if defined(__CUDACC_RTC__) -#include -#else -#include -#endif - #include +#include + namespace cute { @@ -48,8 +42,8 @@ namespace cute // template ::value && - std::is_arithmetic::value)> + __CUTE_REQUIRES(is_arithmetic::value && + is_arithmetic::value)> CUTE_HOST_DEVICE constexpr auto max(T const& t, U const& u) { @@ -57,8 +51,8 @@ max(T const& t, U const& u) { } template ::value && - std::is_arithmetic::value)> + __CUTE_REQUIRES(is_arithmetic::value && + is_arithmetic::value)> CUTE_HOST_DEVICE constexpr auto min(T const& t, U const& u) { @@ -66,11 +60,11 @@ min(T const& t, U const& u) { } template ::value)> + __CUTE_REQUIRES(is_arithmetic::value)> CUTE_HOST_DEVICE constexpr auto abs(T const& t) { - if constexpr (std::is_signed::value) { + if constexpr (is_signed::value) { return t < T(0) ? -t : t; } else { return t; @@ -85,8 +79,8 @@ abs(T const& t) { // Greatest common divisor of two integers template ::value && - std::is_integral::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && + CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr auto gcd(T t, U u) { @@ -100,8 +94,8 @@ gcd(T t, U u) { // Least common multiple of two integers template ::value && - std::is_integral::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && + CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr auto lcm(T const& t, U const& u) { @@ -133,11 +127,11 @@ template CUTE_HOST_DEVICE constexpr T bit_width(T x) { - static_assert(std::is_unsigned::value, "Only to be used for unsigned types."); - constexpr int N = (std::numeric_limits::digits == 64 ? 6 : - (std::numeric_limits::digits == 32 ? 5 : - (std::numeric_limits::digits == 16 ? 4 : - (std::numeric_limits::digits == 8 ? 3 : (assert(false),0))))); + static_assert(is_unsigned::value, "Only to be used for unsigned types."); + constexpr int N = (numeric_limits::digits == 64 ? 6 : + (numeric_limits::digits == 32 ? 5 : + (numeric_limits::digits == 16 ? 4 : + (numeric_limits::digits == 8 ? 3 : (assert(false),0))))); T r = 0; for (int i = N - 1; i >= 0; --i) { T shift = (x > ((T(1) << (T(1) << i))-1)) << i; @@ -193,7 +187,7 @@ template CUTE_HOST_DEVICE constexpr T rotl(T x, int s) { - constexpr int N = std::numeric_limits::digits; + constexpr int N = numeric_limits::digits; return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s); } @@ -202,7 +196,7 @@ template CUTE_HOST_DEVICE constexpr T rotr(T x, int s) { - constexpr int N = std::numeric_limits::digits; + constexpr int N = numeric_limits::digits; return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s); } @@ -214,7 +208,7 @@ template CUTE_HOST_DEVICE constexpr T countl_zero(T x) { - return std::numeric_limits::digits - bit_width(x); + return numeric_limits::digits - bit_width(x); } // Counts the number of consecutive 1 bits, starting from the most significant bit @@ -236,7 +230,7 @@ template CUTE_HOST_DEVICE constexpr T countr_zero(T x) { - return x == 0 ? std::numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB + return x == 0 ? numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB } // Counts the number of consecutive 1 bits, starting from the least significant bit @@ -288,7 +282,7 @@ shiftr(T x, int s) { // Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. template ::value)> + __CUTE_REQUIRES(is_unsigned::value)> CUTE_HOST_DEVICE constexpr int signum(T const& x) { @@ -296,7 +290,7 @@ signum(T const& x) { } template ::value)> + __CUTE_REQUIRES(not is_unsigned::value)> CUTE_HOST_DEVICE constexpr int signum(T const& x) { @@ -307,8 +301,8 @@ signum(T const& x) { // @pre t % u == 0 // @result t / u template ::value && - std::is_integral::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value && + CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE constexpr auto safe_div(T const& t, U const& u) { diff --git a/include/cute/numeric/tfloat.hpp b/include/cute/numeric/tfloat.hpp index bb68b703..481a63e5 100644 --- a/include/cute/numeric/tfloat.hpp +++ b/include/cute/numeric/tfloat.hpp @@ -43,9 +43,11 @@ using cutlass::tfloat32_t; // Display utilities // +#if !defined(__CUDACC_RTC__) CUTE_HOST std::ostream& operator<<(std::ostream& os, tfloat32_t const& v) { return os << float(v); } +#endif } // end namespace cute diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 40ce5d1a..da32784f 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -44,11 +44,11 @@ namespace cute // template -struct has_dereference : std::false_type { +struct has_dereference : false_type { }; template -struct has_dereference())>> : std::true_type { +struct has_dereference())>> : true_type { }; // @@ -91,7 +91,7 @@ struct device_ptr DerivedType operator+(Index const& i) const { return {ptr_ + i}; } CUTE_HOST_DEVICE constexpr friend - std::ptrdiff_t operator-(device_ptr const& a, + ptrdiff_t operator-(device_ptr const& a, device_ptr const& b) { return a.ptr_ - b.ptr_; } @@ -301,6 +301,7 @@ CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) printf("rmem_ptr_%db(%p)", int(8*sizeof(T)), ptr.get()); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) { @@ -319,4 +320,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) return os << "rmem_ptr_" << int(8*sizeof(T)) << "b"; } +#endif // !defined(__CUDACC_RTC__) + } // end namespace cute diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index 5fb0da8a..515bb7b3 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -192,7 +192,12 @@ idx2crd(Index const& idx, return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); } } else { // "int" "int" "int" - return (idx / stride) % shape; + if constexpr (is_constant<1, Shape>::value) { + // Skip potential stride-0 division + return Int<0>{}; + } else { + return (idx / stride) % shape; + } } } @@ -259,66 +264,88 @@ crd2crd(Coord const& coord, // Compact Major // -// General tag for common layouts and dispatching -struct GenColMajor {}; -struct GenRowMajor {}; +// Tags for common layouts and dispatching +struct LayoutLeft; // Col-major layout mapping; leftmost extent has stride 1 +using GenColMajor = LayoutLeft; // Alias -template , class Major = GenColMajor> -CUTE_HOST_DEVICE constexpr -auto -compact_major(Shape const& shape, - Current const& current = {}, - Major const& major = {}); +struct LayoutRight; // Row-major layout mapping; rightmost extent has stride 1 +using GenRowMajor = LayoutRight; // Alias namespace detail { -template +// GGC8.5 WAR -- Use of lambdas in unevaluated contexts. Instead use function objects. +template +struct CompactLambda; + +// @pre is_integral +// Return (result, current * product(shape)) to enable recurrence +template CUTE_HOST_DEVICE constexpr auto -compact_major_ti(Shape const& shape, - Current const& current, - GenColMajor const& major, seq) +compact(Shape const& shape, + Current const& current) { - return cute::make_tuple(compact_major(get(shape), current * product<0,Is>(shape), major)...); + if constexpr (is_tuple::value) { // Shape::tuple Current::int + using Lambda = CompactLambda; // Append or Prepend + using Seq = typename Lambda::template seq; // Seq or RSeq + return cute::detail::fold(shape, cute::make_tuple(cute::make_tuple(), current), Lambda{}, Seq{}); + } else { // Shape::int Current::int + if constexpr (is_constant<1, Shape>::value) { + return cute::make_tuple(Int<0>{}, current); // If current is dynamic, this could save a reg + } else { + return cute::make_tuple(current, current * shape); + } + } + + CUTE_GCC_UNREACHABLE; } -template -CUTE_HOST_DEVICE constexpr -auto -compact_major_ti(Shape const& shape, - Current const& current, - GenRowMajor const& major, seq) +// GCC8.5 WAR -- Specialization LayoutLeft +template <> +struct CompactLambda { - constexpr int E = tuple_size::value; - return cute::make_tuple(compact_major(get(shape), current * product(shape), major)...); -} + template + CUTE_HOST_DEVICE constexpr auto + operator()(Init const& init, Shape const& si) { + auto result = detail::compact(si, get<1>(init)); + return cute::make_tuple(append(get<0>(init), get<0>(result)), get<1>(result)); // Append + } + + template + using seq = tuple_seq; // Seq +}; + +// GCC8.5 WAR -- Specialization LayoutRight +template <> +struct CompactLambda +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Init const& init, Shape const& si) { + auto result = detail::compact(si, get<1>(init)); + return cute::make_tuple(prepend(get<0>(init), get<0>(result)), get<1>(result)); // Prepend + } + + template + using seq = tuple_rseq; // RSeq +}; } // end namespace detail -template +template , + __CUTE_REQUIRES(is_tuple::value || is_integral::value)> CUTE_HOST_DEVICE constexpr auto compact_major(Shape const& shape, - Current const& current, - Major const& major) + Current const& current = {}) { - if constexpr (is_tuple::value) { - if constexpr (is_tuple::value) { // tuple tuple - static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c,major); }); - } else { // tuple int - return detail::compact_major_ti(shape, current, major, tuple_seq{}); - } + if constexpr (is_tuple::value) { // Shape::tuple Current::tuple + static_assert(is_tuple::value, "Invalid parameters"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + // Recurse to apply to the terminals of current + return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c); }); } else { - if constexpr (is_tuple::value) { // int tuple - static_assert(sizeof(Shape) == 0, "Invalid parameters"); - } else { // int int - if constexpr (is_constant<1, Shape>::value) { - return Int<0>{}; // If current is dynamic, this could save a reg - } else { - return current; - } - } + return get<0>(detail::compact(shape, current)); } CUTE_GCC_UNREACHABLE; @@ -328,34 +355,38 @@ compact_major(Shape const& shape, // Compact Col Major // +struct LayoutLeft { + template + using Apply = decltype(compact_major(declval())); +}; + template > CUTE_HOST_DEVICE constexpr auto compact_col_major(Shape const& shape, Current const& current = {}) { - return compact_major(shape, current, GenColMajor{}); + return compact_major(shape, current); } -template -using ColMajor = decltype(compact_col_major(std::declval())); - // // Compact Row Major // +struct LayoutRight { + template + using Apply = decltype(compact_major(declval())); +}; + template > CUTE_HOST_DEVICE constexpr auto compact_row_major(Shape const& shape, Current const& current = {}) { - return compact_major(shape, current, GenRowMajor{}); + return compact_major(shape, current); } -template -using RowMajor = decltype(compact_row_major(std::declval())); - // // Compact Order -- compute a compact stride based on an ordering of the modes // @@ -397,7 +428,7 @@ CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, GenColMajor const& major) { - return compact_major(shape, Int<1>{}, major); + return compact_major(shape); } template @@ -405,7 +436,7 @@ CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, GenRowMajor const& major) { - return compact_major(shape, Int<1>{}, major); + return compact_major(shape); } } // end namespace cute diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index 0a13e551..ec5ee818 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -476,6 +476,7 @@ CUTE_HOST_DEVICE void print(MixedBits const& m) printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m)); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) { @@ -493,5 +494,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) { return os << "S<" << B << "," << M << "," << S << ">"; } +#endif // !defined(__CUDACC_RTC__) } // end namespace cute diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 1376a47d..8303731e 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -308,7 +308,7 @@ transfer_swizzle(Layout const& old_layout, auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) - auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); + auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); // Use this new swizzle identifier to construct the new swizzle for new_layout @@ -394,7 +394,7 @@ cosize(ComposedLayout const& layout) // Operations to manipulate Layouts like a tuple of pairs // -template +template CUTE_HOST_DEVICE constexpr auto get(ComposedLayout const& a) @@ -450,7 +450,7 @@ make_swizzle_strides(true_type, int_sequence) { // Below is an optimized/compressed version of: - //return make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); + //return cute::make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); // with knowledge of Swizzle, I... ranges for each B bits, // and the layout won't slice along z-bits that are already set @@ -471,7 +471,7 @@ make_swizzle_strides(false_type, int_sequence) { // Below is an optimized/compressed version of: - //return make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); + //return cute::make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); // with knowledge of Swizzle, I... ranges for each B bits, // and the layout won't slice along y-bits that are already set @@ -1001,10 +1001,12 @@ CUTE_HOST_DEVICE void print(ComposedLayout const& layout) print(layout.swizzle_fn()); print(" o "); print(layout.offset_fn()); print(" o "); print(layout.layout_fn()); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) { return os << layout.swizzle_fn() << " o " << layout.offset_fn() << " o " << layout.layout_fn(); } +#endif } // end namespace cute diff --git a/include/cute/swizzle_ptr.hpp b/include/cute/swizzle_ptr.hpp index ed77acba..17ff3bcd 100644 --- a/include/cute/swizzle_ptr.hpp +++ b/include/cute/swizzle_ptr.hpp @@ -69,7 +69,7 @@ namespace cute template struct smem_ptr_swizzle { - static_assert(std::is_empty::value, "Swizzle can't have state."); + static_assert(is_empty::value, "Swizzle can't have state."); CUTE_HOST_DEVICE constexpr T* get() const @@ -86,7 +86,7 @@ struct smem_ptr_swizzle CUTE_HOST_DEVICE constexpr static T* apply_swizzle(T* ptr) { - return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); + return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); } CUTE_HOST_DEVICE constexpr @@ -200,7 +200,7 @@ recast(smem_ptr_swizzle const& ptr) } // -// Conversion with swizzle_layout +// Conversion with swizzle_layout // template @@ -273,10 +273,12 @@ CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(8*sizeof(T)), ptr.get()); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) { return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(8*sizeof(T)) << "b"; } +#endif } // end namespace cute diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index e88c22bc..1a8141bf 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -57,14 +57,14 @@ namespace cute // }; template -using ArrayEngine = typename std::conditional<(sizeof_bits::value % 8 == 0), +using ArrayEngine = typename conditional<(sizeof_bits::value % 8 == 0), array_aligned, array_subbyte>::type; template struct ViewEngine { - using value_type = typename cute::remove_cvref())>::type; + using value_type = typename cute::remove_cvref())>::type; using iterator = Iterator; iterator storage_; @@ -91,7 +91,7 @@ struct is_gmem> : is_gmem {}; template struct ConstViewEngine { - using value_type = typename cute::remove_cvref())>::type; + using value_type = typename cute::remove_cvref())>::type; using iterator = Iterator; iterator storage_; @@ -335,80 +335,134 @@ template struct is_smem> : is_smem {}; template struct is_gmem> : is_gmem {}; +// Customization point for creation of owning and non-owning Tensors +template +struct MakeTensor +{ + template ::value && + is_layout::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(Layout const& layout) const + { + static_assert(is_static::value, "Dynamic owning tensors not supported"); + using Engine = ArrayEngine>; + return Tensor(); + } + + template ::value && + is_layout::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(T const& iter, Layout const& layout) + { + using Engine = ViewEngine; + return Tensor(iter, layout); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(LayoutArg const& arg, LayoutArgs const&... args) const + { + return operator()(make_layout(arg, args...)); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr auto + operator()(T const& iter, LayoutArg const& arg, LayoutArgs const&... args) + { + return operator()(iter, make_layout(arg, args...)); + } +}; + // -// Make an owning Tensor that will allocate a static array +// make_tensor // -template ::value)> +// Make an owning Tensor that will allocate a static array +// e.g. make_tensor(Int<12>{}) +template CUTE_HOST_DEVICE constexpr auto -make_tensor(Layout const& layout) +make_tensor(Args const&... args) { - static_assert(is_static::value, "Dynamic owning tensors not supported"); - using Engine = ArrayEngine>; - return Tensor(); + return MakeTensor{}(args...); } -// e.g. make_tensor(12) -template ::value)> +// Make a non-owning Tensor that will use a pointer (view) +// e.g. make_tensor(vec.data(), 12) +template CUTE_HOST_DEVICE constexpr auto -make_tensor(LayoutArg const& arg, LayoutArgs const&... args) +make_tensor(Iterator const& iter, Args const&... args) { - return make_tensor(make_layout(arg, args...)); + return MakeTensor{}(iter, args...); } // -// Make a non-owning Tensor that will use a pointer (view) +// make_tensor_like +// Make a register tensor the same type and shape and (if possible) order as another tensor // -template ::value && - is_layout::value)> +template CUTE_HOST_DEVICE constexpr auto -make_tensor(Iterator const& iter, Layout const& layout) +make_tensor_like(Layout const& layout) { - using Engine = ViewEngine; - return Tensor(iter, layout); + if constexpr (is_static::value) { + return make_tensor(make_ordered_layout(layout)); + } else { + return make_tensor(make_layout(layout.shape())); + } } -// e.g. make_tensor(vec.data(), 12) -template ::value)> +template CUTE_HOST_DEVICE constexpr auto -make_tensor(Iterator const& iter, LayoutArg const& arg, LayoutArgs const&... args) +make_tensor_like(Tensor const& tensor) { - return make_tensor(iter, make_layout(arg, args...)); + return make_tensor_like(tensor.layout()); } -// -// make_tensor_like -- make a register tensor the same type and shape as another -// - template CUTE_HOST_DEVICE constexpr auto make_tensor_like(Tensor const& tensor) { - using value_type = typename Tensor::value_type; - return make_tensor(tensor.shape()); + return make_tensor_like(tensor.layout()); } // -// make_fragment_like -- make a register tensor the same type, shape, and (if possible) order as another tensor +// make_fragment_like -- +// Make a tensor the same shape and (if possible) order as another tensor, with special +// consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms +// so this allocates the 0th mode with LayoutLeft regardless of the reference layout. // +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + return make_tensor(make_fragment_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + template CUTE_HOST_DEVICE constexpr auto make_fragment_like(Tensor const& tensor) { - using value_type = typename Tensor::value_type; - return make_tensor(make_layout_like(tensor.layout())); + return make_fragment_like(tensor.layout()); } // @@ -452,7 +506,7 @@ template >::value)> CUTE_HOST_DEVICE constexpr decltype(auto) -take(Tensor&& tensor) +take(Tensor&& tensor) { return make_tensor(std::forward(tensor).data(), take(tensor.layout())); } @@ -627,11 +681,11 @@ max_common_vector(Tensor const& a, if constexpr (// Should be the same value_types, else the copy is also performing a cast sizeof(SrcType) == sizeof(DstType) && // The types should be trivially copyable so that vectorization is valid - std::is_trivially_copyable::value && - std::is_trivially_copyable::value && + is_trivially_copyable::value && + is_trivially_copyable::value && // Should be load/storing real data, rather than implicit iterators or such - std::is_reference::value && - std::is_reference::value) + is_reference::value && + is_reference::value) { return max_common_vector(a.layout(), b.layout()); } else { @@ -833,6 +887,7 @@ CUTE_HOST_DEVICE void print(Tensor const& tensor) print_tensor(tensor); } +#if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) { @@ -879,6 +934,7 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const os << tensor.layout() << std::endl; return print_tensor_os(os, tensor); } +#endif // !defined(__CUDACC_RTC__) } // end namespace cute diff --git a/include/cute/underscore.hpp b/include/cute/underscore.hpp index d79b4ee8..155f5eb1 100644 --- a/include/cute/underscore.hpp +++ b/include/cute/underscore.hpp @@ -60,11 +60,11 @@ struct has_elem : false_type {}; template struct has_elem : true_type {}; template -struct has_elem::value> > +struct has_elem::value> > : has_elem > {}; template struct has_elem> - : disjunction, Elem>...> {}; + : disjunction, Elem>...> {}; // Tuple trait for detecting static member element template @@ -72,11 +72,11 @@ struct all_elem : false_type {}; template struct all_elem : true_type {}; template -struct all_elem::value> > +struct all_elem::value> > : all_elem > {}; template struct all_elem> - : conjunction, Elem>...> {}; + : conjunction, Elem>...> {}; // Tuple trait for detecting Underscore member template @@ -141,8 +141,10 @@ CUTE_HOST_DEVICE void print(Underscore const&) { printf("_"); } +#if !defined(__CUDACC_RTC__) CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { return os << "_"; } +#endif } // end namespace cute diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index 9a62143c..83e84294 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -99,8 +99,12 @@ namespace cute // A dummy function that uses compilation failure to print a type template -CUTE_HOST_DEVICE -void +CUTE_HOST_DEVICE void +print_type() { + static_assert(sizeof(T) < 0, "Printing type T."); +} +template +CUTE_HOST_DEVICE void print_type(T&&) { static_assert(sizeof(T) < 0, "Printing type T."); } @@ -113,13 +117,23 @@ print_type(T&&) { // if (block0()) print(...); // if (thread(42)) print(...); +CUTE_HOST_DEVICE +bool +block(int bid) +{ +#if defined(__CUDA_ARCH__) + return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; +#else + return true; +#endif +} + CUTE_HOST_DEVICE bool thread(int tid, int bid) { #if defined(__CUDA_ARCH__) - return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) - && ( blockIdx.x + blockIdx.y* gridDim.x + blockIdx.z* gridDim.x* gridDim.y == bid); + return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); #else return true; #endif @@ -129,7 +143,7 @@ CUTE_HOST_DEVICE bool thread(int tid) { - return thread(tid, 0); + return thread(tid,0); } CUTE_HOST_DEVICE @@ -143,11 +157,7 @@ CUTE_HOST_DEVICE bool block0() { -#if defined(__CUDA_ARCH__) - return !(blockIdx.x | blockIdx.y | blockIdx.z); -#else - return true; -#endif + return block(0); } } // end namespace cute diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index ec774b00..320b4f5b 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -30,10 +30,11 @@ **************************************************************************************************/ #pragma once -#include - #include +#include +#include + // // CUDA compatible print and printf // @@ -123,7 +124,7 @@ print(char const& c) { } template ::value)> + __CUTE_REQUIRES(CUTE_STL_NAMESPACE::is_integral::value)> CUTE_HOST_DEVICE void print(T const& a) { diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index 4d37eb9e..28e53597 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -30,32 +30,97 @@ **************************************************************************************************/ #pragma once +#if defined(__CUDACC_RTC__) +#include +#include +#include +#include +#include +#else #include +#include // tuple_size, tuple_element +#include // ptrdiff_t +#include // uintptr_t +#include // numeric_limits +#endif #include -#define __CUTE_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr -#define __CUTE_REQUIRES_V(...) typename std::enable_if::type* = nullptr +namespace cute +{ + using CUTE_STL_NAMESPACE::enable_if; + using CUTE_STL_NAMESPACE::enable_if_t; +} + +#define __CUTE_REQUIRES(...) typename cute::enable_if<(__VA_ARGS__)>::type* = nullptr +#define __CUTE_REQUIRES_V(...) typename cute::enable_if::type* = nullptr namespace cute { -using std::conjunction; -using std::conjunction_v; +// +using CUTE_STL_NAMESPACE::conjunction; +using CUTE_STL_NAMESPACE::conjunction_v; + +using CUTE_STL_NAMESPACE::disjunction; +using CUTE_STL_NAMESPACE::disjunction_v; + +using CUTE_STL_NAMESPACE::negation; +using CUTE_STL_NAMESPACE::negation_v; + +using CUTE_STL_NAMESPACE::void_t; +using CUTE_STL_NAMESPACE::is_void_v; + +using CUTE_STL_NAMESPACE::is_base_of; +using CUTE_STL_NAMESPACE::is_base_of_v; + +// using CUTE_STL_NAMESPACE::true_type; +// using CUTE_STL_NAMESPACE::false_type; + +using CUTE_STL_NAMESPACE::conditional; +using CUTE_STL_NAMESPACE::conditional_t; + +using CUTE_STL_NAMESPACE::remove_cv_t; +using CUTE_STL_NAMESPACE::remove_reference_t; + +using CUTE_STL_NAMESPACE::extent; +using CUTE_STL_NAMESPACE::remove_extent; + +using CUTE_STL_NAMESPACE::decay; +using CUTE_STL_NAMESPACE::decay_t; + +using CUTE_STL_NAMESPACE::is_reference; +using CUTE_STL_NAMESPACE::is_trivially_copyable; + +using CUTE_STL_NAMESPACE::is_same; +using CUTE_STL_NAMESPACE::is_same_v; + +using CUTE_STL_NAMESPACE::is_arithmetic; +using CUTE_STL_NAMESPACE::is_unsigned; +using CUTE_STL_NAMESPACE::is_signed; +// using CUTE_STL_NAMESPACE::is_integral; + +using CUTE_STL_NAMESPACE::is_empty; -using std::disjunction; -using std::disjunction_v; +using CUTE_STL_NAMESPACE::invoke_result_t; -using std::negation; -using std::negation_v; +// +using CUTE_STL_NAMESPACE::declval; -using std::void_t; +// +using CUTE_STL_NAMESPACE::numeric_limits; + +// +using CUTE_STL_NAMESPACE::ptrdiff_t; + +// +using CUTE_STL_NAMESPACE::uintptr_t; // C++20 // using std::remove_cvref; template struct remove_cvref { - using type = std::remove_cv_t>; + using type = remove_cv_t>; }; // C++20 @@ -63,38 +128,79 @@ struct remove_cvref { template using remove_cvref_t = typename remove_cvref::type; +// +// dependent_false +// +// @brief An always-false value that depends on one or more template parameters. +// See +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +// https://github.com/cplusplus/papers/issues/572 +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +template +inline constexpr bool dependent_false = false; + +// +// tuple_size, tuple_element +// +// @brief CuTe-local tuple-traits to prevent conflicts with other libraries. +// For cute:: types, we specialize std::tuple-traits, which is explicitly allowed. +// cute::tuple, cute::array, cute::array_subbyte, etc +// But CuTe wants to treat some external types as tuples as well. For those, +// we specialize cute::tuple-traits to avoid polluting external traits. +// dim3, uint3, etc + +template +struct tuple_size; + +template +struct tuple_size::type>> : CUTE_STL_NAMESPACE::integral_constant::value> {}; + +// S = : std::integral_constant::value> {}; + +template +constexpr size_t tuple_size_v = tuple_size::value; + +template +struct tuple_element; + +template +struct tuple_element::type>> : CUTE_STL_NAMESPACE::tuple_element {}; + +template +using tuple_element_t = typename tuple_element::type; + // // is_valid // namespace detail { -template ()(std::declval()...))> -CUTE_HOST_DEVICE constexpr auto -is_valid_impl(int) { return std::true_type{}; } +template ()(declval()...))> +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(int) { return CUTE_STL_NAMESPACE::true_type{}; } template -CUTE_HOST_DEVICE constexpr auto -is_valid_impl(...) { return std::false_type{}; } +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(...) { return CUTE_STL_NAMESPACE::false_type{}; } template struct is_valid_fn { template - CUTE_HOST_DEVICE constexpr auto + CUTE_HOST_DEVICE constexpr auto operator()(Args&&...) const { return is_valid_impl(int{}); } }; } // end namespace detail template -CUTE_HOST_DEVICE constexpr auto -is_valid(F&&) { +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&) { return detail::is_valid_fn{}; } template -CUTE_HOST_DEVICE constexpr auto -is_valid(F&&, Args&&...) { +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&, Args&&...) { return detail::is_valid_impl(int{}); } diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 043bfac9..2ad48951 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -86,7 +86,6 @@ struct Sm80 { struct Sm86 { static int const kMinComputeCapability = 86; }; - struct Sm90 { static int const kMinComputeCapability = 90; }; diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index 34f0b4ee..1d52141c 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -1,16 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are not permit- - * ted. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -22,7 +36,6 @@ #include #include - namespace cutlass { /// @brief namespace arch { @@ -72,7 +85,7 @@ class NamedBarrier { static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -81,7 +94,7 @@ class NamedBarrier { static void arrive(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -94,15 +107,15 @@ class NamedBarrier { //////////////////////////////////////////////////////////////////////////////////////////////////// -// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide AW behaviour. -// This is an extension to the Ampere AW barriers -// Note : Ampere AW Barriers have a larger max-arrive count (2^30) than Hopper AW Barriers (2^20). +// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide arrive-wait behaviour. +// This is an extension to the Ampere arrive-wait barriers +// Note : Ampere arrive-wait Barriers have a larger max-arrive count (2^30) than Hopper arrive-wait Barriers (2^20). struct ClusterBarrier { using ValueType = uint64_t; protected: - // Can never be initializated - can only be aliased to smem + // Can never be initialized - can only be aliased to smem ValueType barrier_; public: @@ -120,6 +133,11 @@ struct ClusterBarrier { return ClusterBarrier::test_wait(&this->barrier_, phase, pred); } + CUTLASS_DEVICE + uint32_t try_wait(uint32_t phase) const { + return ClusterBarrier::try_wait(&this->barrier_, phase); + } + CUTLASS_DEVICE void wait(uint32_t phase) const { ClusterBarrier::wait(&this->barrier_, phase); @@ -150,7 +168,7 @@ struct ClusterBarrier { "}" : : "r"(arrive_count), "r"(smem_addr)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -174,7 +192,7 @@ struct ClusterBarrier { : : "r"(smem_addr), "r"(phase), "r"(ticks)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -197,7 +215,29 @@ struct ClusterBarrier { : "r"(smem_addr), "r"(phase), "r"(pred)); return waitComplete; -#else +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + return 0; + } + + CUTLASS_DEVICE + static uint32_t try_wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase)); + + return waitComplete; +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif return 0; @@ -218,7 +258,7 @@ struct ClusterBarrier { "}" : : "r"(smem_addr), "r"(cta_id), "r"(pred)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -235,7 +275,7 @@ struct ClusterBarrier { "}" : : "r"(smem_addr), "l"(state)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -250,7 +290,7 @@ struct ClusterBarrier { "}" : : "r"(smem_addr)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -303,7 +343,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -324,7 +364,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -340,7 +380,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -360,7 +400,7 @@ struct ClusterTransactionBarrier : public ClusterBarrier { "}" : : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); -#else +#elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); #endif } @@ -378,8 +418,8 @@ void fence_barrier_init() { "fence.mbarrier_init.release.cluster; \n" "}" ::); -#else - asm volatile ("brkpt;\n" ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); #endif } @@ -392,8 +432,8 @@ void fence_view_async_shared() { "fence.proxy.async.shared::cta; \n" "}" ::); -#else - asm volatile ("brkpt;\n" ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); #endif } diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index b2a9468f..f4902b31 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -468,7 +468,7 @@ void shared_store<16>(uint32_t ptr, void const *src) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "memory_sm75.h" -#include "memory_sm80.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/memory_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 537f215a..f35cdb34 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -76,6 +76,12 @@ struct OpMultiplyAddFastF32 {}; // Perform 3xTF32 or 4xTF32 for every complex output element struct OpMultiplyAddComplexFastF32 {}; +/// Helper for determining whether staged accumulation should be used for a given operator +template +struct UseStagedAccumulation { + static bool const value = platform::is_same::value || + platform::is_same::value; +}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the complex multiply-add operation diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 85a178ba..78777147 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -81,7 +81,7 @@ struct Barrier CUTLASS_DEVICE static void red_release(int *ptr, int val) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) #if (__CUDA_ARCH__ >= 700) /// SM70 and newer use memory consistency qualifiers @@ -104,7 +104,7 @@ struct Barrier CUTLASS_DEVICE static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; if (thread_idx == 0) @@ -122,7 +122,7 @@ struct Barrier CUTLASS_DEVICE static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; if (thread_idx == 0) @@ -138,7 +138,7 @@ struct Barrier /// Uses thread[0] to wait for the specified count of signals on the given flag counter CUTLASS_DEVICE static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; if (thread_idx == 0) @@ -156,7 +156,7 @@ struct Barrier CUTLASS_DEVICE static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; __syncthreads(); @@ -173,7 +173,7 @@ struct Barrier CUTLASS_DEVICE static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) int flag_idx = first_flag_idx + thread_idx; T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 48435407..b405e2e2 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -35,11 +35,16 @@ #pragma once -#include #include #include "cutlass/cutlass.h" #include "cutlass/trace.h" +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + #if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) # define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED #endif @@ -72,7 +77,7 @@ struct ClusterLauncher { // Check for hardware compatibility static inline __host__ - Status check_cluster_dims(dim3 const& grid, dim3 const& cluster) { + Status check_cluster_dims(dim3 grid, dim3 cluster) { if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { return Status::kSuccess; @@ -105,11 +110,11 @@ struct ClusterLauncher { // This is the method we expect to use going forward static inline __host__ Status launch( - dim3 const& grid_dims, - dim3 const& cluster_dims, - dim3 const& block_dims, - size_t const& smem_size, - cudaStream_t& cuda_stream, + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, void const* kernel, void** kernel_params) { #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) @@ -153,4 +158,78 @@ struct ClusterLauncher { } }; +namespace detail { + +template +void* checked_addressof(Arg&& arg) { + static_assert(! std::is_rvalue_reference_v || ! std::is_const_v, "You cannot take the address of a const rvalue reference (const T&&)."); + // We use std::addressof to ensure we get the address, + // in case the type has an overloaded operator&. + // Note that this precludes `const T&&` references. + return const_cast(reinterpret_cast(std::addressof(arg))); +} + +} // namespace detail + +//! Parameters for launch_on_cluster (see below). +struct ClusterLaunchParams { + //! Grid dimensions + dim3 grid_dims{1, 1, 1}; + + //! Block dimensions + dim3 block_dims{1, 1, 1}; + + //! Cluster dimensions + dim3 cluster_dims{1, 1, 1}; + + //! Number of bytes required for the kernel's shared memory. + int smem_size_in_bytes = 0; + + //! CUDA stream on which to launch the kernel. + cudaStream_t cuda_stream = nullptr; +}; + +/// @brief Launch the kernel on the stream using cluster launch. +/// +/// @param params Cluster launch parameters (see above). +/// @param kernel_ptr Pointer to the kernel function (see example). +/// @param args Zero or more arguments to pass to the kernel. +/// +/// @tparam Args Types of the arguments passed to the kernel. +/// Don't specify this/these template argument(s) explicitly. +/// +/// @return Status::Success on success, else an error code. +/// +/// @code +/// template +/// __global__ void kernel(A a, B b, C c); +/// +/// X x = get_x(); +/// Y y = get_y(); +/// Z z = get_z(); +/// +/// void const* kernel_ptr = +/// const_cast(reinterpret_cast( +/// &kernel)); +/// auto status = launch_on_cluster( +/// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, +/// kernel_ptr, x, y, z); +/// @endcode +template +__host__ cutlass::Status +launch_kernel_on_cluster(const ClusterLaunchParams& params, + void const* kernel_ptr, + Args&& ... args) +{ + // Unfortunately, we find ourselves needing to pass in + // the parameters as an array of raw pointers. + void* kernel_params[] = { + detail::checked_addressof(std::forward(args))... + }; + return cutlass::ClusterLauncher::launch( + params.grid_dims, params.cluster_dims, params.block_dims, + params.smem_size_in_bytes, params.cuda_stream, + kernel_ptr, kernel_params); +} + } // namespace cutlass diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 089f474d..a3f56e4b 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -179,16 +179,6 @@ class complex complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} #endif - /// Assignment - template - CUTLASS_HOST_DEVICE - complex& operator=(complex const &z) - { - _real = static_cast(z.real()); - _imag = static_cast(z.imag()); - return *this; - } - /// Equality operator CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { return this->real() == rhs.real() && this->imag() == rhs.imag(); diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index 2bc4eb07..e7d8360f 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -47,13 +47,6 @@ #pragma once - -#if defined(__CUDACC_RTC__) -#include -#else -#include -#endif - #include "cutlass/cutlass.h" #include "cutlass/tensor_coord.h" #include "cutlass/fast_math.h" diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 0647edfb..7f800e4c 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -29,18 +29,18 @@ * **************************************************************************************************/ /*! \file - \brief + \brief -This file contains definitions and utility functions for describing convolution problem sizes in terms of -activation (NHWC), filter (KRSC), output (NPQK), pading (pad_h, pad_w), stride (stride_h, stride_w), -dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map cutlass' implicit gemm -tensor extents, sizes, data types to that of convolutions extents, sizes, and data types. +This file contains definitions and utility functions for describing convolution problem sizes in terms of +activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and +dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm +tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types. * Mapping convolutions to Gemm computation * -Cutlass employs ImplicitGemm algorithm to implement convolutions. ImplicitGemm algorithm runs gemm operation -on convolution tensors Activation, Filter, and Output . The underlying gemm operation follows the standard -gemm definition: +Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm +(general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output. +The underlying gemm operation follows the standard gemm definition: C = A * B + C @@ -48,22 +48,23 @@ gemm definition: C is source and output matrix -For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped on -to convolution tensors Activation, Filter and Output as per the below table: +For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped +to convolution tensors Activation, Filter and Output as described in the table below. ___________________________________________________________________________ - ConvolutionalOperator | A | B | C + ConvolutionalOperator | A | B | C ___________________________________________________________________________ | | | | | - | Fprop | Activation | Filter | Output | - | Dgrad | Output | Filter | Activation | - | Wgrad | Output | Activation | Filter | + | Fprop | Activation | Filter | Output | + | Dgrad | Output | Filter | Activation | + | Wgrad | Output | Activation | Filter | ___________________________________________________________________________ -In convolution codebase, DO NOT mix using (A, B, C) with (Acvitation, Filter, Output). +In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output). -For example, a convolution class/function with A, B, Output is confusing and error-prone. Instead use below -mapping functions and adhere to using either A, B, C or Acvitation, Filter, Output. +For example, it's confusing and error prone to document a convolution class or function +as operating on "A, B, Output." Instead, use the mapping functions below, +and adhere to using either A, B, C or Activation, Filter, Output. Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap @@ -83,20 +84,20 @@ namespace conv { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Convolutional operator -enum class Operator { - kFprop, - kDgrad, - kWgrad +enum class Operator { + kFprop, + kDgrad, + kWgrad }; -/// Distinguishes convolution from cross correlation -enum class Mode { - kCrossCorrelation, - kConvolution +/// Distinguishes convolution from cross correlation +enum class Mode { + kCrossCorrelation, + kConvolution }; /// Selects among several implementation variants trading off performance with simplicity -enum class IteratorAlgorithm { +enum class IteratorAlgorithm { kAnalytic, ///< functionally correct in all cases but lower performance kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) @@ -113,9 +114,9 @@ enum class StrideSupport { }; /// Identifies split-K mode -enum class SplitKMode { - kNone, - kSerial, +enum class SplitKMode { + kNone, + kSerial, kParallel }; diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index 9ffb05e7..82226dbd 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -99,7 +99,7 @@ struct DefaultConv2dFpropWithBroadcast { AlignmentB >::Kernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< ArchTag, typename ImplicitGemmBase::Epilogue::Shape, diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index 00b8c909..d31e2efe 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -100,7 +100,7 @@ struct DefaultConv2dFpropWithReduction { AlignmentB >::Kernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< ArchTag, typename ImplicitGemmBase::Epilogue::Shape, diff --git a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h index cdd89e00..885ac638 100644 --- a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -222,6 +222,138 @@ struct DefaultConv2dGroupFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and +/// 2 stage pipeline that supports all GroupMode. + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage /// pipeline that supports GroupMode::kSingleGroup. template < diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index e667ddd6..f9cc0f97 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -303,7 +303,7 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index fb1fcfc3..a9f80acc 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -388,7 +388,7 @@ class Conv2dFpropActivationTileAccessIteratorOptimized { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index 5c7dbd78..7cc576cd 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -290,7 +290,7 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index a85c6205..4d543d17 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -288,7 +288,7 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index c72356be..7798c35c 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -268,7 +268,7 @@ class Conv2dWgradActivationTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h index 16cd2564..724609ce 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -304,7 +304,7 @@ class Conv2dWgradActivationTileAccessIteratorOptimized { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h index 97fd31ef..34b7085d 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -243,7 +243,7 @@ class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index 6725ed47..6362f6ad 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -293,7 +293,7 @@ class Conv2dWgradOutputGradientTileAccessIteratorOptimized { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index d9fe9ada..a5144cf2 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -270,7 +270,7 @@ class Conv3dWgradActivationTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.C % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 2d56341d..6f2c6796 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -300,7 +300,7 @@ class Conv3dWgradActivationTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.C % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index c21d3f9a..2f39c687 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -248,7 +248,7 @@ class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.K % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index 7a79983d..6970d561 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -291,7 +291,7 @@ class Conv3dWgradOutputGradientTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.K % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index 80dc435c..437ae6c1 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -134,6 +134,12 @@ class ImplicitGemmMultistage : /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; }; private: @@ -387,10 +393,7 @@ class ImplicitGemmMultistage : FragmentC tmp_accum; - if (platform::is_same::value - || platform::is_same::value) { + if (Detail::kStagedAccumulation) { tmp_accum.clear(); } @@ -444,10 +447,7 @@ class ImplicitGemmMultistage : copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - if (platform::is_same::value - || platform::is_same::value) { + if (Detail::kStagedAccumulation) { warp_mma( tmp_accum, warp_transformed_frag_A[warp_mma_k % 2], @@ -518,10 +518,7 @@ class ImplicitGemmMultistage : } - if (platform::is_same::value - || platform::is_same::value) { + if (Detail::kStagedAccumulation) { accum = plus_accum(accum, tmp_accum); } diff --git a/include/cutlass/conv/threadblock/threadblock_swizzle.h b/include/cutlass/conv/threadblock/threadblock_swizzle.h index 3cbcc8b5..4b886049 100644 --- a/include/cutlass/conv/threadblock/threadblock_swizzle.h +++ b/include/cutlass/conv/threadblock/threadblock_swizzle.h @@ -107,7 +107,7 @@ struct StridedDgradHorizontalThreadblockSwizzle : // compute number of tiles in m dimension int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); - // compute number of tiles in n dimension + // compute number of tiles in n dimension int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); return gemm::GemmCoord( @@ -148,7 +148,7 @@ struct StridedDgradIdentityThreadblockSwizzle : // compute number of tiles in m dimension int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); - // compute number of tiles in n dimension + // compute number of tiles in n dimension int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); return gemm::GemmCoord( diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index 4d154320..c0438a17 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -59,7 +59,7 @@ inline std::ostream &operator<<(std::ostream &out, dim3 d) { /// Output operator for CUDA built-in error type inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) return out << cudaGetErrorString(error); #endif } diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 12bc3a37..ab7b6c8d 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -68,23 +68,26 @@ CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) #endif -#if !defined(__CUDACC_RTC__) +#ifdef _MSC_VER +// Provides support for alternative operators 'and', 'or', and 'not' +#include +#endif // _MSC_VER +#if !defined(__CUDACC_RTC__) #include +#endif - #if defined(__CUDA_ARCH__) - #if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } - #else - #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } - #endif - +#if defined(__CUDA_ARCH__) + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } #else - #if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) - #else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) - #endif + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } + #endif +#else + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) + #else + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) #endif #endif @@ -160,7 +163,7 @@ static char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// // CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") diff --git a/include/cutlass/detail/dependent_false.hpp b/include/cutlass/detail/dependent_false.hpp new file mode 100644 index 00000000..aa77bb12 --- /dev/null +++ b/include/cutlass/detail/dependent_false.hpp @@ -0,0 +1,86 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +/// @brief A bool constant that depends on one or more template parameters. +/// +/// For more detailed documentation and use cases, +/// please see `dependent_false` below. +template +inline constexpr bool dependent_bool_value = Value; + +/// @brief An always-false value that depends on one or more template parameters. +/// +/// This exists because `static_assert(false);` always fails, +/// even if it occurs in the `else` branch of an `if constexpr`. +/// The following example shows how to use `dependent_false` in that case. +/// +/// @code +/// template +/// void foo (T t) +/// { +/// if constexpr (std::is_integral_v) { +/// do_integer_stuff(t); +/// } +/// else if constexpr (std::is_floating_point_v) { +/// do_floating_point_stuff(t); +/// } +/// else { +/// static_assert(dependent_false, "T must be " +/// "an integral or floating-point type."); +/// } +/// } +/// @endcode +/// +/// This implements the C++ Standard Library proposal P1830R1. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +/// +/// That proposal is under review as of 2022/12/05. +/// The following link shows P1830's current review status. +/// +/// https://github.com/cplusplus/papers/issues/572 +/// +/// P2593R0 proposes an alternate solution to this problem, +/// that would change the C++ language itself. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +/// +/// For headers in this library, however, we only consider library solutions +/// as work-arounds for future C++ features. +template +inline constexpr bool dependent_false = dependent_bool_value; + +} // end namespace cutlass::detail diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 68042e3f..cde9f1ff 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -99,7 +99,11 @@ void Kernel2(typename Operator::Params params) { /// Generic CUTLASS kernel template. template -__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +__global__ +#ifdef __CUDACC__ +// Enclosing this in __CUDACC__ suppresses MSVC warnings. +__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +#endif // __CUDACC__ void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) { // Dynamic shared memory base pointer diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl new file mode 100644 index 00000000..f0d51bb4 --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -0,0 +1,536 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_traits_sm90.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90.hpp" + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +sm90_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // ColMajor C/D (M-major) + if constexpr (size<0>(GmemStrideType{}) == 1) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (size<1>(GmemStrideType{}) == 1) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. +template +constexpr auto +sm90_compute_tile_shape_or_override() { + if constexpr (cute::is_same_v) { + + constexpr int SmemAlloc = 4096; + if constexpr (detail::sm90_is_cooperative_v) { + constexpr int M = 128; + constexpr int N = SmemAlloc / (M * sizeof(Element)); + + return make_shape(Int{}, Int{}); + } + else if constexpr (detail::sm90_is_warp_specialized_v) { + constexpr int M = 64; + constexpr int N = SmemAlloc / (M * sizeof(Element)); + + return make_shape(Int{}, Int{}); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported schedule."); + } + } + else if constexpr (cute::is_tuple::value) { + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(M == 64 && detail::sm90_is_warp_specialized_v || + M == 128 && detail::sm90_is_cooperative_v, "Unsupported tile shape"); + static_assert(N % 8 == 0, "Unsupported tile shape"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} + +// Selects the largest vectorized smem store atom available +template +constexpr auto +sm90_get_smem_store_op_for_accumulator() { + using namespace cute; + + if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) { + return SM90_U16x8_STSM_T{}; + } + else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) { + return SM90_U32x4_STSM_N{}; + } + else { + // auto-vectorizing store + return DefaultCopy{}; + } +} + +// Selects the largest vectorized smem load atom available +template +constexpr auto +sm90_get_smem_load_op_for_source() { + using namespace cute; + + // Reuse the logic from smem store selector + using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator()); + + if constexpr (cute::is_same_v) { + return SM75_U16x8_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U32x4_LDSM_N{}; + } + else { + // auto-vectorizing load + return DefaultCopy{}; + } +} + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the thread-level epilogue operation performed and the dispatch policy to use. +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class ThreadOp, + class DispatchPolicy +> +struct TmaBuilderImpl { + using GmemStrideTypeC = gemm::TagToStrideC_t; + using GmemStrideTypeD = gemm::TagToStrideC_t; + + using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule>()); + + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, + EpilogueTile_MN, + ElementC, + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + ThreadOp, + SM90_TMA_LOAD, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + SM90_TMA_STORE, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()) + >; +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +// No-smem builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + cute::enable_if_t>> { + + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + using ThreadOp = thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementCompute, + ScaleType, FloatRoundStyle::round_to_nearest, ElementC>; + + using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault> + >; +}; + +// Tma warp-specialized builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + cute::enable_if_t || + cute::is_same_v >> { +public: + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + using ThreadOp = thread::LinearCombination< + ElementD, AlignmentD, ElementAccumulator, ElementCompute, + thread::ScaleType::Default, FloatRoundStyle::round_to_nearest, ElementC>; + +private: + using Impl = detail::TmaBuilderImpl< + TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, + ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, + Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecialized<1,2,true>>; + +public: + using CollectiveOp = typename Impl::CollectiveOp; +}; + +// Auto builder +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + cute::enable_if_t>> { + +private: + static constexpr bool IsTmaAligned = cutlass::gemm::collective::detail::is_aligned< + ElementC, AlignmentC, ElementD, AlignmentD, cutlass::gemm::collective::detail::tma_alignment_bytes>(); + + // Current TMA epilogues require sixteen-bit data types and epilogue tile M to be of size 64. + // Only dispatch to the TMA builder if these requirements are satisfied. + static constexpr bool IsSixteenBit = sizeof_bits::value == 16 && sizeof_bits::value == 16; + static constexpr bool IsEpiTileM64 = size<0>(shape(TileShape_MNK{})) == 64; + + using _CollectiveBuilder = CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + cute::conditional_t + >; + +public: + using ThreadOp = typename _CollectiveBuilder::ThreadOp; + using CollectiveOp = typename _CollectiveBuilder::CollectiveOp; +}; + +// Tma warp-specialized builder for elementwise fusion +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + cute::enable_if_t || + cute::is_base_of_v >> { + +public: + using ThreadOp = thread::LinearCombinationGeneric< + Schedule::ActivationFunctor, + ElementD, AlignmentD, + ElementAccumulator, ElementCompute, Schedule::Scale, + Schedule::Round>; + +private: + using Impl = detail::TmaBuilderImpl< + TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, + ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, + Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecialized<1,2,true>>; + +public: + using CollectiveOp = typename Impl::CollectiveOp; +}; + +// Tma warp-specialized builder for bias + elementwise fusion +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + cute::enable_if_t || + cute::is_base_of_v >> { + +public: + using ThreadOp = thread::LinearCombinationBiasElementwise< + ElementC, ElementAccumulator, ElementCompute, ElementD, typename Schedule::ElementT, AlignmentD, + typename Schedule::ActivationFunctor, typename Schedule::BiasOp, + Schedule::StoreT, typename Schedule::ElementBias>; + +private: + using Impl = detail::TmaBuilderImpl< + TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, + ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, + Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise<1,2>>; + +public: + using CollectiveOp = typename Impl::CollectiveOp; +}; + +// CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels +// since swapping NNN kernels input matrix and transposing its output at the same time then +// we can get TTN kernel. +template < + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + cute::enable_if_t>> { + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + using ThreadOp = thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementCompute, + ScaleType, FloatRoundStyle::round_to_nearest, ElementC>; + + using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueTransposed> + >; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp new file mode 100644 index 00000000..d71b7a30 --- /dev/null +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape +struct EpilogueTileAuto {}; + +// Used to let the builder pick the epilogue schedule automatically. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct EpilogueScheduleAuto {}; + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 5b1b9245..37bb79b0 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -24,6 +24,8 @@ **************************************************************************************************/ #pragma once +#include + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::epilogue::collective { @@ -34,8 +36,8 @@ template < class DispatchPolicy, class... Args > -struct CollectiveEpilogue { - static_assert(std::is_void_v, "Could not find an epilogue specialization."); +class CollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,6 +46,10 @@ struct CollectiveEpilogue { ///////////////////////////////////////////////////////////////////////////////////////////////// +#include "detail.hpp" #include "default_epilogue.hpp" -#include "epilogue.hpp" +#include "epilogue_tensor_broadcast.hpp" +#include "sm70_epilogue_vectorized.hpp" +#include "sm90_epilogue_tma_warpspecialized.hpp" +#include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index 71499b5d..a4e612db 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -35,6 +35,8 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" #include "cute/tensor.hpp" #include "cute/numeric/int.hpp" @@ -52,13 +54,16 @@ namespace collective { template < class StrideC_, class StrideD_, - class ThreadEpilogueOp_ + class ThreadEpilogueOp_, + class EpilogueSchedule_ > class DefaultEpilogue { public: // // Type Aliases // + using EpilogueSchedule = EpilogueSchedule_; + // derived types of output thread level operator using ThreadEpilogueOp = ThreadEpilogueOp_; using ElementOutput = typename ThreadEpilogueOp::ElementOutput; @@ -78,28 +83,40 @@ class DefaultEpilogue { struct SharedStorage { }; - // Params of epilogue::collective contain the epilogue::thread params - struct Params { + // Host side epilgoue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; ElementC const* ptr_C = nullptr; StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; - typename ThreadEpilogueOp::Params thread_params{}; }; + // Device side epilogue params + using Params = Arguments; + // // Methods // - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { - (void) workspace; - return {args.epilogue_params}; + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; } CUTLASS_HOST_DEVICE - DefaultEpilogue(Params const& params_) : params(params_) { } + DefaultEpilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } template< class ProblemShapeMNKL, @@ -118,7 +135,7 @@ class DefaultEpilogue { TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, - char* smem_buf) + [[maybe_unused]] char* smem_buf) { using namespace cute; using X = Underscore; @@ -128,17 +145,17 @@ class DefaultEpilogue { static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - (void) smem_buf; - ThreadEpilogueOp epilogue_op{params.thread_params}; - // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) @@ -184,6 +201,7 @@ class DefaultEpilogue { private: Params params; + ThreadEpilogueOp epilogue_op; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp new file mode 100644 index 00000000..033f5ccc --- /dev/null +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -0,0 +1,211 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/int.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +namespace detail { + +template +static constexpr int elements_per_access_v = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + +template +static constexpr bool sm90_is_cooperative_v = + std::is_base_of_v; + +template +static constexpr bool sm90_is_warp_specialized_v = + std::is_base_of_v; + +template +struct EmptyStorage { + CUTLASS_HOST_DEVICE + T* data() { return nullptr; } +}; + +template +CUTLASS_HOST_DEVICE +auto get_epilogue_stride(Stride stride){ + if constexpr (cute::is_base_of_v) { + return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); + } + else { + return stride; + } +} + +template +struct IsThreadEpilogueOpWithBias { + static constexpr bool value = false; + using type = typename ThreadEpilogueOp::ElementCompute; +}; + +template +struct IsThreadEpilogueOpWithBias > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ElementBias; +}; + +// IF_EPILOGUE_USES_TMA::value will be true only if: +// class T has member CopyOpS2G and T::CopyOpS2G is true +template +struct IF_EPILOGUE_USES_TMA { static constexpr bool value = false; }; + +template +struct IF_EPILOGUE_USES_TMA > +{ static constexpr bool value = true; }; + +// Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels +template +class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { +public: + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; // 0 stage to disable smem alloc + using LoadPipelineState = cutlass::PipelineState<0>; + constexpr static uint32_t TmaTransactionBytes = 0; + + using StorePipeline = cutlass::PipelineTmaStore<1>; // tma store pipe has no smem alloc + using StorePipelineState = cutlass::PipelineState<1>; + + using TensorStorage = typename EpilogueOp::SharedStorage; + using PipelineStorage = typename LoadPipeline::SharedStorage; + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment([[maybe_unused]] TileShapeMNK) { + return 1; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment([[maybe_unused]] TileShapeMNK) { + return 1; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) + { + } + + // ctor inheritance + using EpilogueOp::EpilogueOp; + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE void + load( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] TileShapeMNK tile_shape_MNK, + [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors) + { + // source load is performed in epilogue operator + } + + CUTLASS_DEVICE void + load_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_producer_state) + { + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE void + store( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors) + { + constexpr int BLK_M_RANK = rank<0>(tile_shape_MNK); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); + })); + + constexpr int BLK_N_RANK = rank<1>(tile_shape_MNK); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + tile_shape_MNK, + tile_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + } + +}; + +} // namespace detail +} // namespace collective +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp similarity index 50% rename from include/cutlass/epilogue/collective/default_transposed_epilogue.hpp rename to include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index 7e38acd7..512ce1a8 100644 --- a/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -28,83 +28,117 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + /*! \file - \brief Functor performing elementwise operations used by epilogues. + \brief Functor for performing tensor-tensor broadacasts atop existing epilogues. + + Concretely, the opeartion performed is the following: + UnaryOp( + BinaryOp1( + BinaryOp0( + Activation((alpha * A @ B) + bias), + beta * C0 + ), + beta * C1 + ) + ) + + where: + - C0 and C1 have the same extents as the output + - BinaryOp0 and BinaryOp1 perform elementwise binary operations + - UnaryOp is an elementwise operation */ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" #include "cute/tensor.hpp" -#include "cute/numeric/int.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace epilogue { namespace collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -using namespace cute; - ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Applies an element wise operation to all elements within the fragment -/// and writes them out to destination storage. +/// Collective epilogue that applies elementwise tensor-tensor operations atop other epilogues +/// template < class StrideC_, class StrideD_, - class ThreadEpilogueOp_ + class ThreadEpilogueOp_, + class EpilogueSchedule_ > -class DefaultTransposedEpilogue { - +class EpilogueTensorBroadcast { public: // // Type Aliases // + using EpilogueSchedule = EpilogueSchedule_; + // derived types of output thread level operator using ThreadEpilogueOp = ThreadEpilogueOp_; using ElementOutput = typename ThreadEpilogueOp::ElementOutput; using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; using ElementCompute = typename ThreadEpilogueOp::ElementCompute; using ElementScalar = ElementCompute; + using ElementBias = typename ThreadEpilogueOp::ElementBias; using ElementC = typename ThreadEpilogueOp::ElementC; using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - - static const int kOutputAlignment = ThreadEpilogueOp::kCount; - using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static constexpr bool IsBinaryOp0Enabled = ThreadEpilogueOp::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled; + struct SharedStorage { }; - // Params of epilogue::collective contain the epilogue::thread params - struct Params { - ElementC const* ptr_C = nullptr; + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; - typename ThreadEpilogueOp::Params thread_params{}; + ElementBias* ptr_Bias = nullptr; + ElementC* ptr_C0 = nullptr; + ElementC* ptr_C1 = nullptr; }; + // Device side epilogue params + using Params = Arguments; + // // Methods // - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { - (void) workspace; - return {args.epilogue_params}; + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; } CUTLASS_HOST_DEVICE - DefaultTransposedEpilogue(Params const& params_) : params(params_) { } + EpilogueTensorBroadcast(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source0_needed() || epilogue_op.is_source1_needed(); + } template< class ProblemShapeMNKL, @@ -123,7 +157,7 @@ class DefaultTransposedEpilogue { TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, - char* smem_buf) + [[maybe_unused]] char* smem_buf) { using namespace cute; using X = Underscore; @@ -131,67 +165,75 @@ class DefaultTransposedEpilogue { static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - (void) smem_buf; - ThreadEpilogueOp epilogue_op{params.thread_params}; + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); - // Tranpose stride C/D. - auto stride_c = make_stride(get<1>(params.dC), get<0>(params.dC), get<2>(params.dC)); - auto stride_d = make_stride(get<1>(params.dD), get<0>(params.dD), get<2>(params.dD)); + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + auto stride_bias = detail::get_epilogue_stride(Stride<_1, _0, _0>{}); // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) - Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mC1_mnl = make_tensor(make_gmem_ptr(params.ptr_C1), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), stride_bias); // (m,n,l) + + Tensor gC0_mnl = local_tile(mC0_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gC1_mnl = local_tile(mC1_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - // Slice to get the tile this CTA is responsible for + // Slice to get the tile this thread block is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC0 = gC0_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC1 = gC1_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) - - static_assert(is_static::value, "Accumulator layout must be static"); - CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC0 = thr_mma.partition_C(gC0); // (VEC,THR_M,THR_N) + Tensor tCgC1 = thr_mma.partition_C(gC1); // (VEC,THR_M,THR_N) + Tensor tCgBias = thr_mma.partition_C(gBias); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, + "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC0) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgC1) == size(tCgD), "Source and destination must have the same number of elements."); CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); + CUTE_STATIC_ASSERT_V(size(tCgBias) == size(accumulators), + "Accumulator count must have the same destination element count."); - auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); - // source is needed - if (epilogue_op.is_source_needed()) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); - } - } - } - // source is not needed, avoid load - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators); ++i) { - if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCgD(i) = epilogue_op(accumulators(i)); - } + bool bias_needed = params.ptr_Bias != nullptr; + bool c0_needed = (params.ptr_C0 != nullptr) && epilogue_op.is_source0_needed(); + bool c1_needed = (params.ptr_C1 != nullptr) && epilogue_op.is_source1_needed(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + ElementBias bias = bias_needed ? tCgBias(i) : ElementBias(0); + ElementC c0 = c0_needed ? tCgC0(i) : ElementC(0); + ElementC c1 = c1_needed ? tCgC1(i) : ElementC(0); + + tCgD(i) = epilogue_op(accumulators(i), c0, c1, bias); } } } private: Params params; + ThreadEpilogueOp epilogue_op; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/epilogue.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp similarity index 92% rename from include/cutlass/epilogue/collective/epilogue.hpp rename to include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 565e752e..509f3b94 100644 --- a/include/cutlass/epilogue/collective/epilogue.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -95,28 +95,40 @@ class Epilogue { cute::array_aligned> smem_epilogue; }; - // Params of epilogue::collective contain the epilogue::thread params - struct Params { + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; ElementC const* ptr_C = nullptr; StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; - typename ThreadEpilogueOp::Params thread_params{}; }; + // Device side epilogue params + using Params = Arguments; + // // Methods // - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { - (void) workspace; - return {args.epilogue_params}; + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; } CUTLASS_HOST_DEVICE - Epilogue(Params const& params_) : params(params_) { }; + Epilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } template< class ProblemShapeMNKL, @@ -147,13 +159,11 @@ class Epilogue { // synchronizing function for smem reads/writes #if CUDA_BARRIER_ENABLED - auto synchronize = [] () { NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); }; + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); }; #else auto synchronize = [] () { __syncthreads(); }; #endif - ThreadEpilogueOp epilogue_op{this->params.thread_params}; - // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); @@ -175,7 +185,8 @@ class Epilogue { Tensor sC = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) // Partition sC to match the accumulator partitioning - auto tC = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma).get_thread_slice(thread_idx); + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto tC = tiled_r2s.get_thread_slice(thread_idx); Tensor tCaC = tC.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor tCsC = tC.partition_D(sC); // ((Atom,AtomNum),PIPE_M,PIPE_N) @@ -185,7 +196,8 @@ class Epilogue { Tensor gDt = local_tile(gD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) // Partition sC, gC, and gD for the output - auto tD = TiledCopyS2R{}.get_thread_slice(thread_idx); + auto tiled_s2r = TiledCopyS2R{}; + auto tD = tiled_s2r.get_thread_slice(thread_idx); Tensor tDsC = tD.partition_S(sC); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tDgC = tD.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) Tensor tDgD = tD.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) @@ -239,7 +251,7 @@ class Epilogue { int mma_m = step_m * size<1>(tCsC) + pipe_m; int mma_n = step_n * size<2>(tCsC) + pipe_n; - copy(tC, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); + copy(tiled_r2s, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); } } @@ -247,7 +259,7 @@ class Epilogue { synchronize(); // Step 3. Copy from SMEM into a fragment - copy(tD, tDsC, tDrC); + copy(tiled_s2r, tDsC, tDrC); // Step 4. Wait for SMEM reads to complete synchronize(); @@ -310,6 +322,7 @@ class Epilogue { private: Params params; + ThreadEpilogueOp epilogue_op; }; diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp new file mode 100644 index 00000000..5654597e --- /dev/null +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,582 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + bool DisableSmemReuseC_, + class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) per-collective + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using DispatchPolicy = Sm90TmaWarpSpecialized; + using BlockTileShape = BlockTileShape_; + using EpilogueTile = EpilogueTile_; + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; + constexpr static bool iskThreadEpilogueOpWithBias = detail::IsThreadEpilogueOpWithBias::value; + using AlignmentType = typename uint_bit::value * kOutputAlignment>::type; + + static_assert(sizeof(ElementC) == 2, "Only 16b source supported for now"); + static_assert(sizeof(ElementD) == 2, "Only 16b output supported for now"); + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M,EPI_TILE_N]"); + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + +private: + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default || + ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + // internal optimization to reuse C shared memory for storing D + using SmemLayoutAtomBitsC = decltype(downcast::value>(SmemLayoutAtomC{})); + using SmemLayoutAtomBitsD = decltype(downcast::value>(SmemLayoutAtomD{})); + constexpr static bool ReuseSmemC = not DispatchPolicy::DisableSmemReuseC && + is_source_supported && + sizeof(ElementC) == sizeof(ElementD) && + StrideC{} == StrideD{} && + cute::is_same_v; + + // Find the max contiguous layout usable by TMA (if EpilogueTile is a by-mode tiler) + using SmemLayoutTmaD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{})))), + cute::conditional_t(StrideD{}) == 1, Step<_2,_1>, Step<_1,_2>>{} )); + +public: + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int{}), + cute::conditional_t(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutTmaD{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); + + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(ElementC)); + + // TMA pipeline for storing D + using StorePipeline = cutlass::PipelineTmaStore; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage : aligned_struct<128> { + cute::conditional_t, + array_aligned> smem_C; + alignas(128) cute::conditional_t, + array_aligned> smem_D; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilgoue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(static_cast(nullptr), + repeat_like(StrideC{}, int32_t(0)), StrideC{}), + SmemLayoutC{}(_,_,0))); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(static_cast(nullptr), + repeat_like(StrideD{}, int32_t(0)), StrideD{}), + SmemLayoutTmaD{})); + + typename ThreadEpilogueOp::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) + { + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); + Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); + typename Params::TMA_C tma_load_c = make_tma_copy( + CopyOpG2S{}, + tensor_c, + SmemLayoutC{}(_,_,0)); + typename Params::TMA_D tma_store_d = make_tma_copy( + CopyOpS2G{}, + tensor_d, + SmemLayoutTmaD{}); + return { + args.thread, + tma_load_c, + tma_store_d + }; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of C subtiles (currently always one) + constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutC{}); + constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutC{}); + + return epi_m * epi_n; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + if constexpr (ReuseSmemC) { + return get_load_pipe_increment(tile_shape_MNK); + } + + // Compute number of D subtiles + constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutD{}); + constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutD{}); + + return epi_m * epi_n; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE void + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + TensorStorage& shared_tensors) { + using namespace cute; + using X = Underscore; + + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // Represent the full source tensor + Tensor mC_mnl = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (TILE_M,TILE_N,m,n,l) + // Slice to get the gmem tile of C (gC) this CTA is currently responsible for + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) + // Get the corresponding smem tile of C (sC) + Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), SmemLayoutC{}); // (TILE_M,TILE_N,PIPE) + + // Prepare the thread(b)lock (G)mem to (S)mem TMA copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC); // (TMA,TMA_M,TMA_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC); // (TMA,TMA_M,TMA_N,PIPE) + + auto* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + uint16_t mcast_mask = 0; + + // Execute the TMA load for C + if (warp_idx_in_warp_group == 0 and lane_predicate) { + load_pipeline.producer_acquire(load_pipe_producer_state); + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), bGS_gC, bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_commit(load_pipe_producer_state); + } + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE void + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors) { + using namespace cute; + using X = Underscore; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{}); + auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{}); + auto epi_tile_m = size<0>(shape(EpilogueTile{})); + auto epi_tile_n = size<1>(shape(EpilogueTile{})); + + // Represent the full output tensor + Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (TILE_M,TILE_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) + + // Construct the smem tensors for source (sC) and output (sD) + Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), // (TILE_M,TILE_N) + SmemLayoutC{})(_,_,load_pipe_consumer_state.index()); + Tensor bEsD = make_tensor(make_smem_ptr(shared_tensors.smem_D.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE) + SmemLayoutD{}); + + // Tile thread(b)lock tensors by (E)pilogue output tile shape (bE) + Tensor bEsC = local_tile(sC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor bEgD = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Partition for register to smem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = conditional_return( + thread_r2s.partition_D(recast(bEsC)), // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + thread_r2s.partition_D(bEsD) ); // (R2S,R2S_M,R2S_N,PIPE) + + // Allocate register tensors + auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(bEsD))); // (R2S,R2S_M,R2S_N) + Tensor tRS_rC = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view for thread epilogue op + Tensor tRS_rAcc_frg = recast(tRS_rAcc); + Tensor tRS_rC_frg = recast(tRS_rC); + Tensor tRS_rD_frg = recast(tRS_rD); + + // Partition for smem to register copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_r2s); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // Partition for smem to gmem copy (tSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor tSG_sD = conditional_return( + thrblk_s2g.partition_S(recast(bEsC)), // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE) + Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + CUTE_STATIC_ASSERT(size<0,0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly"); + CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + if (epilogue_op.is_source_needed()) { + // Wait for epilogue load to fill smem buffer with C + load_pipeline.consumer_wait(load_pipe_consumer_state); + } + + // Delay issue of TMA store by 1 iteration to achieve better instruction pipelining + PipelineState store_pipe_producer_state_prev = store_pipe_producer_state; + int epi_m_prev = 0, epi_n_prev = 0; + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(bEgD); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(bEgD); ++epi_m) { + // The current tile in accumulator + int mma_m = epi_m; + int mma_n = (epi_n * epi_tile_n) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Elementwise operation with conversion + int r2s_v = epi_n * size(tRS_rD_frg); + if (epilogue_op.is_source_needed()) { + // Copy source tile to register from smem + copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = epilogue_op(tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i)); + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = epilogue_op(tRS_rAcc_frg_mn(r2s_v + i)); + } + } + + if constexpr (ReuseSmemC) { + // Issue the TMA store of the previous iteration + if (not (epi_m == 0 && epi_n == 0)) { + // Make sure smem writes are visible to TMA + cutlass::arch::fence_view_async_shared(); + synchronize(); // ensure all threads have issued their async fence + + // Write the tile to gmem from smem with TMA + if (issue_tma_store) { + copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); + } + } + + // Copy output tile to smem from register + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n)); + } + else { + // Issue the TMA store of the previous iteration + if (not (epi_m == 0 && epi_n == 0)) { + // Make sure smem writes are visible to TMA + cutlass::arch::fence_view_async_shared(); + synchronize(); // ensure all threads have issued their async fence + + // Write the tile to gmem from smem with TMA + if (issue_tma_store) { + copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); + store_pipeline.producer_commit(store_pipe_producer_state_prev); + } + } + + // Wait for a smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + // Copy tile to smem from register + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + + // Advance pipeline state + store_pipe_producer_state_prev = store_pipe_producer_state; + ++store_pipe_producer_state; + } + + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + } + + if constexpr (ReuseSmemC) { + // Fence and issue the TMA store of the last iteration + cutlass::arch::fence_view_async_shared(); + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); + } + + // Arrive and advance pipeline state + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for a smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + // Let dma warp know smem buffer is consumed and empty + if (epilogue_op.is_source_needed()) { + load_pipeline.consumer_release(store_pipe_producer_state); + } + } + else { + // Fence and issue the TMA store of the last iteration + cutlass::arch::fence_view_async_shared(); + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); + store_pipeline.producer_commit(store_pipe_producer_state_prev); + } + + // Let dma warp know smem buffer is consumed and empty + if (epilogue_op.is_source_needed()) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + } + } + +private: + Params const& params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp new file mode 100644 index 00000000..31e1973c --- /dev/null +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -0,0 +1,558 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) per-collective + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + Sm90TmaWarpSpecializedBiasElementwise, + BlockTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; + using BlockTileShape = BlockTileShape_; + using EpilogueTile = EpilogueTile_; + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using ElementT = typename ThreadEpilogueOp::ElementT; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; + using BinaryOp = typename ThreadEpilogueOp::BinaryOp; + + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + constexpr static bool StoreT = ThreadEpilogueOp::kStoreT; + constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; + static_assert(detail::IsThreadEpilogueOpWithBias::value, + "Epilogue dispatch policy Sm90TmaWarpSpecializedBiasElementwise requires the use of a thread-level epiogue that supports bias calculation"); + constexpr static bool iskThreadEpilogueOpWithBias = true; + using AlignmentType = typename uint_bit::value * kOutputAlignment>::type; + + static_assert(sizeof(ElementC) == 2, "Only 16b source supported for now"); + static_assert(sizeof(ElementD) == 2, "Only 16b output supported for now"); + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M,EPI_TILE_N]"); + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + +private: + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default || + ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + // Find the max contiguous layout usable by TMA (if EpilogueTile is a by-mode tiler) + using SmemLayoutTmaD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{})))), + cute::conditional_t(StrideD{}) == 1, Step<_2,_1>, Step<_1,_2>>{} )); + +public: + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int{}), + cute::conditional_t(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutTmaD{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); + + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(ElementC)); + + // TMA pipeline for storing D and T + using StorePipeline = cutlass::PipelineTmaStore; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage : aligned_struct<128> { + cute::conditional_t, + array_aligned> smem_C; + alignas(128) array_aligned smem_D; + alignas(128) cute::conditional_t, + array_aligned> smem_T; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + ElementT* ptr_T = nullptr; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(static_cast(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}), + SmemLayoutC{}(_,_,0))); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(static_cast(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD_{}), + SmemLayoutTmaD{})); + using TMA_T = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(static_cast(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD{}), + SmemLayoutTmaD{})); + typename ThreadEpilogueOp::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + TMA_T tma_store_t; + ElementBias const* ptr_Bias = nullptr; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); + Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); + typename Params::TMA_C tma_load_c = make_tma_copy( + CopyOpG2S{}, + tensor_c, + SmemLayoutC{}(_,_,0)); + typename Params::TMA_D tma_store_d = make_tma_copy( + CopyOpS2G{}, + tensor_d, + SmemLayoutTmaD{}); + typename Params::TMA_T tma_store_t = [&]() { + if constexpr (StoreT) { + Tensor tensor_t = make_tensor(args.ptr_T, make_layout(make_shape(M,N,L), args.dD)); + return make_tma_copy( + CopyOpS2G{}, + tensor_t, + SmemLayoutTmaD{}); + } + else { + return typename Params::TMA_T{}; + } + }(); + + return { + args.thread, + tma_load_c, + tma_store_d, + tma_store_t, + args.ptr_Bias + }; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of C subtiles (currently always one) + constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutC{}); + constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutC{}); + + return epi_m * epi_n; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of D subtiles + constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutD{}); + constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutD{}); + + return epi_m * epi_n; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + if constexpr (StoreT) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_t.get_tma_descriptor()); + } + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE void + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + TensorStorage& shared_tensors) { + using namespace cute; + using X = Underscore; + + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // Represent the full source tensor + Tensor mC_mnl = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (TILE_M,TILE_N,m,n,l) + // Slice to get the gmem tile of C (gC) this CTA is currently responsible for + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) + // Get the corresponding smem tile of C (sC) + Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), SmemLayoutC{}); // (TILE_M,TILE_N,PIPE) + + // Prepare the thread(b)lock (G)mem to (S)mem TMA copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC); // (TMA,TMA_M,TMA_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC); // (TMA,TMA_M,TMA_N,PIPE) + + auto* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + uint16_t mcast_mask = 0; + + // Execute the TMA load for C + if (warp_idx_in_warp_group == 0 and lane_predicate) { + load_pipeline.producer_acquire(load_pipe_producer_state); + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), bGS_gC, bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_commit(load_pipe_producer_state); + } + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE void + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors) { + using namespace cute; + using X = Underscore; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{}); + auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{}); + auto epi_tile_m = size<0>(shape(EpilogueTile{})); + auto epi_tile_n = size<1>(shape(EpilogueTile{})); + + // Represent the full output tensor + Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1, _1, X>{}); // (TILE_M,TILE_N,m,n,l) + Tensor mT_mnl = params.tma_store_t.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) + Tensor gT_mnl = local_tile(mT_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1, _1, X>{}); // (TILE_M,TILE_N,m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), Stride<_1, _0, _0>{}); // (m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1,X>{}); // (TILE_M,TILE_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) + Tensor gT = gT_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) + + // Construct the smem tensors for source (sC) and output (sD) + Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), // (TILE_M,TILE_N) + SmemLayoutC{})(_,_,load_pipe_consumer_state.index()); + Tensor bEsD = make_tensor(make_smem_ptr(shared_tensors.smem_D.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE) + SmemLayoutD{}); + Tensor bEsT = make_tensor(make_smem_ptr(shared_tensors.smem_T.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE) + SmemLayoutD{}); + + // Tile thread(b)lock tensors by (E)pilogue output tile shape (bE) + Tensor bEsC = local_tile(sC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor bEgD = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor bEgT = local_tile(gT, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor bEgBias = local_tile(gBias, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Partition for register to smem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(bEsD); // (R2S,R2S_M,R2S_N,PIPE) + Tensor tRS_sT = thread_r2s.partition_D(bEsT); // (R2S,R2S_M,R2S_N,PIPE) + + // Allocate register tensors + auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(bEsD))); // (R2S,R2S_M,R2S_N) + Tensor tRS_rC = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + Tensor tRS_rT = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + + Tensor tRS_gBias = thread_r2s.partition_S(bEgBias); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + Tensor tRS_rBias = make_tensor(take<0,3>(shape(tRS_gBias))); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view for thread epilogue op + Tensor tRS_rAcc_frg = recast(tRS_rAcc); + Tensor tRS_rC_frg = recast(tRS_rC); + Tensor tRS_rD_frg = recast(tRS_rD); + Tensor tRS_rT_frg = recast(tRS_rT); + Tensor tRS_rBias_frg = recast(tRS_rBias); + + // Partition for smem to register copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_r2s); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // Partition for smem to gmem copy (tSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor tSG_sD = thrblk_s2g.partition_S(bEsD); // (S2G,S2G_M,S2G_N,PIPE) + Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + ThrCopy thrblk_s2g_t = params.tma_store_t.get_slice(Int<0>{}); + Tensor tSG_sT = thrblk_s2g_t.partition_S(bEsT); // (S2G,S2G_M,S2G_N,PIPE) + Tensor tSG_gT = thrblk_s2g_t.partition_D(bEgT); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + CUTE_STATIC_ASSERT(size<0,0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly"); + CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + if (epilogue_op.is_source_needed()) { + // Wait for epilogue load to fill smem buffer with C + load_pipeline.consumer_wait(load_pipe_consumer_state); + } + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(bEgD); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(bEgD); ++epi_m) { + // The current tile in accumulator + int mma_m = epi_m; + int mma_n = (epi_n * epi_tile_n) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Copy bias to registers from gmem + copy(tRS_gBias(_,_,_,epi_m,epi_n), tRS_rBias); + + // Elementwise operation with conversion + int r2s_v = epi_n * size(tRS_rD_frg); + if (epilogue_op.is_source_needed()) { + // Copy source tile to registers from smem + copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), tRS_rBias_frg(i)); + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rBias_frg(i)); + } + } + + // Wait for a smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + // Copy tile to smem from register + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + + if constexpr (StoreT) { + copy(tiled_r2s, tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index())); + } + + // Make sure smem writes are visible to TMA + cutlass::arch::fence_view_async_shared(); + synchronize(); // ensure all threads have issued their async fence + + // Write the tile to gmem from smem with TMA + if (issue_tma_store) { + copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state.index()), tSG_gD(_,_,_,epi_m,epi_n)); + if constexpr (StoreT) { + copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state.index()), tSG_gT(_,_,_,epi_m,epi_n)); + } + store_pipeline.producer_commit(store_pipe_producer_state); + } + + // Advance pipeline state + ++store_pipe_producer_state; + } + } + + // Let dma warp know smem buffer is consumed and empty + if (epilogue_op.is_source_needed()) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + } + +private: + Params const& params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index de318d53..c3fb61ef 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -24,16 +24,113 @@ **************************************************************************************************/ #pragma once +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" + ////////////////////////////////////////////////////////////////////////////// namespace cutlass::epilogue { ////////////////////////////////////////////////////////////////////////////// +// Epilogue schedule types that can be used for categorical dispatch +struct NoSmemWarpSpecialized {}; +struct TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperative {}; + +struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; + +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +struct TmaWarpSpecializedBiasElementwiseBase : public TmaWarpSpecialized{}; +struct TmaWarpSpecializedCooperativeBiasElementwiseBase : public TmaWarpSpecializedCooperative {}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + // // Collective Epilogue Policies // +template< + int StagesC_, + int StagesD_, + bool DisableSmemReuseC_ +> +struct Sm90TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool DisableSmemReuseC = DisableSmemReuseC_; +}; + +template< + int StagesC_, + int StagesD_ +> +struct Sm90TmaWarpSpecializedBiasElementwise { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; +}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue diff --git a/tools/library/scripts/pycutlass/src/cpp/cute.cpp b/include/cutlass/epilogue/thread/detail.hpp similarity index 72% rename from tools/library/scripts/pycutlass/src/cpp/cute.cpp rename to include/cutlass/epilogue/thread/detail.hpp index 8995159e..5f4aa079 100644 --- a/tools/library/scripts/pycutlass/src/cpp/cute.cpp +++ b/include/cutlass/epilogue/thread/detail.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,27 +28,25 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/* \file - \brief binding CuTe C++ APIs to Python +/*! \file + \brief Utilities for thread-level epilogues */ -#include -#include +#pragma once -#include "cute/arch/mma_sm90_gmma.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// -namespace py = pybind11; +namespace cutlass { +namespace epilogue { +namespace thread { +namespace detail { -PYBIND11_MODULE(cute, m) { +/// Class used to identify cases in which no operation is performed +template +struct NoOp {}; - // module doc - m.doc() = "CuTe C++ bindings"; - - py::enum_(m, "GMMAMajor", - R"pbdoc(classification of CuTe GMMA tensor major specification)pbdoc") - .value("K", cute::GMMA::Major::K, - R"pbdoc(Tensor is contiguous in reduction dimension)pbdoc") - .value("MN", cute::GMMA::Major::MN, - R"pbdoc(Tensor is contiguous in non-reduction dimension)pbdoc"); -} +} // namespace detail +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 0c4b3849..918f3301 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -52,7 +52,7 @@ namespace thread { /// Applies a linear combination operator to an array of elements. /// -/// D = alpha * accumulator + beta * source + uniform +/// D = alpha * accumulator + beta * source /// template < typename ElementOutput_, ///< Data type used to load and store tensors @@ -69,6 +69,7 @@ class LinearCombination { public: using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; using ElementC = ElementSource_; @@ -77,14 +78,15 @@ class LinearCombination { static int const kCount = Count; static const ScaleType::Kind kScale = Scale; using FragmentOutput = Array; + using FragmentSource = Array; using FragmentAccumulator = Array; - using ComputeFragment = Array; + using FragmentCompute = Array; - using ParamsBase = LinearCombinationParams; static FloatRoundStyle const kRound = Round; /// Host-constructable parameters structure - struct Params : ParamsBase{ + struct Params + { ElementCompute alpha; ///< scales accumulators ElementCompute beta; ///< scales source tensor ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory @@ -92,10 +94,6 @@ class LinearCombination { CUTLASS_HOST_DEVICE Params(): - ParamsBase( - ElementCompute(1), - ElementCompute(0) - ), alpha(ElementCompute(1)), beta(ElementCompute(0)), alpha_ptr(nullptr), @@ -106,14 +104,12 @@ class LinearCombination { ElementCompute alpha, ElementCompute beta ): - ParamsBase(alpha, beta), alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha ): - ParamsBase(alpha, ElementCompute(0)), alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } CUTLASS_HOST_DEVICE @@ -121,28 +117,13 @@ class LinearCombination { ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr ): - ParamsBase(*alpha_ptr, *beta_ptr), alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr ): - ParamsBase(*alpha_ptr, ElementCompute(0)), alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { } - - CUTLASS_HOST_DEVICE - Params( - ParamsBase const& base - ): ParamsBase(base), alpha_ptr(nullptr), beta_ptr(nullptr) { - #if defined(__CUDA_ARCH__) - alpha = reinterpret_cast(base.alpha_data); - beta = reinterpret_cast(base.beta_data); - #else - memcpy( alpha, base.alpha_data, sizeof(ElementCompute) ); - memcpy( beta, base.alpha_data, sizeof(ElementCompute) ); - #endif - } }; private: @@ -183,30 +164,73 @@ class LinearCombination { } } - /// Computes linear scaling: D = alpha * accumulator + beta * source + /// Computes intermediate: X = beta * source + CUTLASS_HOST_DEVICE + FragmentCompute compute_intermediate( + FragmentSource const &source) const { + + // Convert source to internal compute numeric type + NumericArrayConverter source_converter; + FragmentCompute converted_source = source_converter(source); + + if (Scale == ScaleType::NoBetaScaling) { + return converted_source; + } + else { + multiplies mul_source; + return mul_source(beta_, converted_source); + } + } + + /// Computes linear scaling with intermediate: D = alpha * accumulator + X + CUTLASS_HOST_DEVICE + FragmentOutput with_intermediate( + FragmentAccumulator const& accumulator, + FragmentCompute const& intermediate) const { + + // Convert accumulator to internal compute numeric type + NumericArrayConverter accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + if (Scale == ScaleType::Nothing) { + return destination_converter(converted_accumulator); + } else { + // Perform binary operations + multiply_add mul_add_accumulator; + FragmentCompute computed_output = mul_add_accumulator(alpha_, converted_accumulator, intermediate); + + return destination_converter(computed_output); + } + } + + /// Computes linear scaling with source: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( - FragmentAccumulator const &accumulator, - FragmentOutput const &source) const { + FragmentAccumulator const &accumulator, + FragmentSource const &source) const { - // Convert source to interal compute numeric type - NumericArrayConverter source_converter; + // Convert source to internal compute numeric type + NumericArrayConverter source_converter; NumericArrayConverter accumulator_converter; // Convert to destination numeric type NumericArrayConverter destination_converter; - ComputeFragment converted_source = source_converter(source); - ComputeFragment converted_accumulator = accumulator_converter(accumulator); + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); if (Scale == ScaleType::Nothing) return destination_converter(converted_accumulator); // Perform binary operations - ComputeFragment intermediate; + FragmentCompute intermediate; - multiplies mul_add_source; - multiply_add mul_add_accumulator; + multiplies mul_add_source; + multiply_add mul_add_accumulator; if (Scale == ScaleType::NoBetaScaling) intermediate = converted_source; @@ -221,7 +245,7 @@ class LinearCombination { /// Computes linear scaling: D = alpha * accumulator CUTLASS_HOST_DEVICE FragmentOutput operator()( - FragmentAccumulator const &accumulator) const { + FragmentAccumulator const &accumulator) const { // Convert source to interal compute numeric type NumericArrayConverter accumulator_converter; @@ -229,14 +253,14 @@ class LinearCombination { // Convert to destination numeric type NumericArrayConverter destination_converter; - ComputeFragment converted_accumulator = accumulator_converter(accumulator); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); if (Scale == ScaleType::Nothing) return destination_converter(converted_accumulator); // Perform binary operations - ComputeFragment intermediate; - multiplies mul_accumulator; + FragmentCompute intermediate; + multiplies mul_accumulator; intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index d145bfa7..b1ea759f 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -42,6 +42,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -90,7 +91,13 @@ class LinearCombinationBiasElementwise { using FragmentZ = Array; using FragmentT = Array; + // Definitions needed for collective epilogue + using FragmentSource = FragmentC; using FragmentOutput = FragmentZ; + using ElementBias = ElementVector; + using FragmentBias = FragmentCompute; + using ActivationFunctor = ElementwiseOp; + static const ScaleType::Kind kScale = ScaleType::Default; static bool const kIsHeavy = ElementwiseOp::kIsHeavy; @@ -196,8 +203,8 @@ class LinearCombinationBiasElementwise { /// Applies the operation when is_source_needed() is true CUTLASS_HOST_DEVICE void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, + FragmentZ &frag_Z, + FragmentT &frag_T, FragmentAccumulator const &AB, FragmentC const &frag_C, FragmentCompute const &V) const { @@ -227,8 +234,8 @@ class LinearCombinationBiasElementwise { /// Applies the operation when is_source_needed() is false CUTLASS_HOST_DEVICE void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, + FragmentZ &frag_Z, + FragmentT &frag_T, FragmentAccumulator const &AB, FragmentCompute const &V) const { diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index fdfe171d..b0d445e0 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -87,6 +87,7 @@ class LinearCombinationClamp { using FragmentOutput = Array; using FragmentAccumulator = Array; using ComputeFragment = Array; + using FragmentSource = Array; static FloatRoundStyle const kRound = Round; diff --git a/include/cutlass/epilogue/thread/linear_combination_dgelu.h b/include/cutlass/epilogue/thread/linear_combination_dgelu.h index d026a8c3..a8254629 100644 --- a/include/cutlass/epilogue/thread/linear_combination_dgelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_dgelu.h @@ -35,7 +35,7 @@ #pragma once -#include +#include "cutlass/half.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" diff --git a/include/cutlass/epilogue/thread/linear_combination_drelu.h b/include/cutlass/epilogue/thread/linear_combination_drelu.h index f05da6d8..44522d2d 100644 --- a/include/cutlass/epilogue/thread/linear_combination_drelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_drelu.h @@ -34,7 +34,7 @@ #pragma once -#include +#include "cutlass/half.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h index 71ada3ff..9a762ae0 100644 --- a/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -78,6 +78,7 @@ class LinearCombinationGeneric { using FragmentOutput = Array; using FragmentAccumulator = Array; + using FragmentSource = Array; using FragmentCompute = Array; static FloatRoundStyle const kRound = Round; diff --git a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h index ebee6b48..9b8044cc 100644 --- a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h @@ -72,6 +72,7 @@ class LinearCombinationLeakyRelu { using FragmentOutput = Array; using FragmentAccumulator = Array; using ComputeFragment = Array; + using FragmentSource = Array; static FloatRoundStyle const kRound = Round; diff --git a/include/cutlass/epilogue/thread/linear_combination_params.h b/include/cutlass/epilogue/thread/linear_combination_params.h index a3f825e0..4a64f64f 100644 --- a/include/cutlass/epilogue/thread/linear_combination_params.h +++ b/include/cutlass/epilogue/thread/linear_combination_params.h @@ -56,13 +56,13 @@ struct LinearCombinationParams { LinearCombinationParams(ElementCompute alpha, ElementCompute beta) : alpha_data {0lu, 0lu}, beta_data {0lu, 0lu} { - #if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) reinterpret_cast(alpha_data) = alpha; reinterpret_cast(beta_data) = beta; - #else +#else memcpy( alpha_data, &alpha, sizeof(ElementCompute) ); memcpy( beta_data, &beta, sizeof(ElementCompute) ); - #endif +#endif } }; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index eb1b436c..8c9b7f4b 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -34,7 +34,7 @@ #pragma once -#include +#include "cutlass/half.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" @@ -90,6 +90,7 @@ class LinearCombinationRelu { using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentScaleBias = Array; + using FragmentSource = Array; static FloatRoundStyle const kRound = Round; @@ -321,6 +322,7 @@ class LinearCombinationRelu { using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentScaleBias = Array; + using FragmentSource = Array; static FloatRoundStyle const kRound = Round; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu0.h b/include/cutlass/epilogue/thread/linear_combination_relu0.h index 3cffd93c..31a281d5 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu0.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu0.h @@ -37,7 +37,7 @@ #pragma once -#include +#include "cutlass/half.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" @@ -93,6 +93,7 @@ class LinearCombinationRelu0 { using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentScaleBias = Array; + using FragmentSource = Array; static FloatRoundStyle const kRound = Round; @@ -308,6 +309,7 @@ class LinearCombinationRelu0 { using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentScaleBias = Array; + using FragmentSource = Array; static FloatRoundStyle const kRound = Round; diff --git a/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/include/cutlass/epilogue/thread/linear_combination_residual_block.h index 42d14662..ddb564b3 100644 --- a/include/cutlass/epilogue/thread/linear_combination_residual_block.h +++ b/include/cutlass/epilogue/thread/linear_combination_residual_block.h @@ -38,6 +38,7 @@ #include "cutlass/array.h" #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/detail.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -45,14 +46,6 @@ namespace cutlass { namespace epilogue { namespace thread { -namespace detail { - -/// Dummy class used to designate that the second binary operator in the epilogue is unsued -template -class NoOp {}; - -} - /// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) template +CUTLASS_HOST_DEVICE +bool is_binary_op_source_needed(ElementCompute scale) { + if constexpr (cute::is_same_v>) { + return false; + } + else if constexpr (cute::is_same_v> || cute::is_same_v>) { + // Cases for binary operators for which 0 is an identity element + if constexpr (Scale == ScaleType::NoBetaScaling) return true; + + if constexpr (Scale == ScaleType::OnlyAlphaScaling) return false; + + if constexpr (Scale == ScaleType::Nothing) return false; + + return scale != ElementCompute(0); + } + + return true; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Compute a tensor-tensor broadcast epilogue. + * + * @param ElementOutput_ Data type used to load and store tensors + * @param ElementAccumulator_ Accumulator data type + * @param ElementCompute_ Data type used to compute linear combination + * @param ElementBias_ Data type of Bias elements + * @param ActivationFunctor_ Fused Activation + * @param BinaryOp0_ Binary operation to perform on O0 and C0. detail::NoOp means no operation + * @param BinaryOp1_ Binary operation to perform on O1 and C1. detail::NoOp means no operation + * @param UnaryOp_ Unary operation to perform on final result + * @param Scale Controls the type of Alpha and Beta scaling to perform + * @param Round How values should be rounded in conversions + * @param ElementSource_ Data type used for source operands + * + * Computes the following: + * O0 = alpha * accumulator + bias + * O1 = BinaryOp0(O0, beta * C0) + * O2 = BinaryOp1(O1, beta * C1) + * D = UnaryOp(O2) + */ +template < + class ElementOutput_, + class ElementAccumulator_ = ElementOutput_, + class ElementCompute_ = ElementOutput_, + class ElementBias_ = ElementCompute_, + template class ActivationFunctor_ = Identity, + template class BinaryOp0_ = plus, + template class BinaryOp1_ = detail::NoOp, + template class UnaryOp_ = Identity, + ScaleType::Kind Scale = ScaleType::Default, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, + class ElementSource_ = ElementOutput_ +> +class LinearCombinationTensorBroadcast { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementBias = ElementBias_; + using ElementC = ElementSource_; + using ElementD = ElementOutput_; + using ElementScalingFactor = ElementAccumulator_; + + using UnaryOp = UnaryOp_; + using BinaryOp0 = BinaryOp0_; + using BinaryOp1 = BinaryOp1_; + using ActivationFunctor = ActivationFunctor_; + + static constexpr int kCount = 1; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentBias = Array; + + static constexpr FloatRoundStyle kRound = Round; + using NoOpType = detail::NoOp; + static constexpr bool IsBinaryOp0Enabled = !cute::is_same_v; + static constexpr bool IsBinaryOp1Enabled = !cute::is_same_v; + static constexpr bool IsUnaryOpEnabled = !cute::is_same_v && !cute::is_same_v>; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha{}; ///< scales accumulators + ElementCompute beta{}; ///< scales source tensor + ElementCompute const* alpha_ptr = nullptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const* beta_ptr = nullptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + Params() = default; + + CUTLASS_HOST_DEVICE + Params(ElementCompute const* alpha_ptr, ElementCompute const* beta_ptr) + : alpha_ptr(alpha_ptr), + beta_ptr(beta_ptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const* alpha_ptr) + : alpha_ptr(alpha_ptr) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, + ElementCompute beta) + : alpha(alpha), + beta(beta) {} + }; + +private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationTensorBroadcast(Params const& params) + : alpha_(params.alpha_ptr ? *params.alpha_ptr : params.alpha), + beta_(params.beta_ptr ? *params.beta_ptr : params.beta) {} + + /// Returns true if source 0 is needed + CUTLASS_HOST_DEVICE + bool is_source0_needed() const { + return detail::is_binary_op_source_needed(beta_); + } + + /// Returns true if source 1 is needed + CUTLASS_HOST_DEVICE + bool is_source1_needed() const { + return detail::is_binary_op_source_needed(beta_); + } + + // + // Specialization for scalar + // + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator, ElementC const source0, ElementC source1, ElementBias const bias) { + // Convert everything to Compute type, do compute, and then store to output type + NumericConverter accumulator_converter; + NumericConverter bias_converter; + NumericConverter source_converter; + NumericConverter destination_converter; + + ActivationFunctor act; + multiplies mul; + multiply_add madd; + + ElementCompute intermediate = accumulator_converter(accumulator); + intermediate = madd(alpha_, intermediate, bias_converter(bias)); + intermediate = act(intermediate); + + // Apply BinaryOp0, if needed + if constexpr (IsBinaryOp0Enabled) { + BinaryOp0 bin0; + ElementCompute converted_source = source_converter(source0); + intermediate = bin0(intermediate, mul(beta_, converted_source)); + } + + // Apply BinaryOp1, if needed + if constexpr (IsBinaryOp1Enabled) { + BinaryOp1 bin1; + ElementCompute converted_source = source_converter(source1); + intermediate = bin1(intermediate, mul(beta_, converted_source)); + } + + // Apply UnaryOp, if needed + if constexpr (IsUnaryOpEnabled) { + UnaryOp unary; + intermediate = unary(intermediate); + } + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h index aac19b00..4df811e0 100644 --- a/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_with_elementwise.h @@ -35,7 +35,7 @@ #pragma once -#include +#include "cutlass/half.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 685b6bb4..c8b3c3bf 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -71,7 +71,7 @@ template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename Element_, ///< Element data type bool ScatterD = false, ///< Scatter D operand or not - typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not + typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not bool UseCUDAStore = false > class PredicatedTileIterator { @@ -93,6 +93,8 @@ class PredicatedTileIterator { static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; + static bool constexpr PermuteD = !layout::is_trivial_permute; + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); @@ -202,11 +204,9 @@ class PredicatedTileIterator { /// Scatter indices int const *indices_; - /// Whether to perform Permute Op - bool PermuteD; /// PermuteDLayout - mutable PermuteDLayout permute_layout_; - + PermuteDLayout permute_layout_; + // // Static asserts about internal strides // @@ -237,7 +237,8 @@ class PredicatedTileIterator { TensorCoord threadblock_offset = TensorCoord(), int const *indices = nullptr ): - params_(params), indices_(indices) + params_(params), indices_(indices), + permute_layout_(PitchLinearCoord(extent.column(), extent.row()), params_.stride * kElementsPerAccess / sizeof(AccessType)) { TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; @@ -276,17 +277,7 @@ class PredicatedTileIterator { } // store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used. - store_byte_pointer_ = byte_pointer_; - - // Initialize PermuteD. If PermuteD is true, store_byte_pointer_ is initialized accordingly. - if (platform::is_same::value) { - PermuteD = false; - }else{ - PermuteD = true; - store_byte_pointer_ = reinterpret_cast(pointer); - permute_layout_ = PermuteDLayout(extent, - params_.stride * kElementsPerAccess / sizeof(AccessType)); - } + store_byte_pointer_ = PermuteD ? reinterpret_cast(pointer) : byte_pointer_; // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; @@ -411,18 +402,17 @@ class PredicatedTileIterator { for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; - - int col_offset = column * ThreadMap::Delta::kColumn; if (PermuteD) { + + int col_offset = column * ThreadMap::Delta::kColumn; + int col = col_offset + thread_start_column_; int row = row_offset + thread_start_row_; - TensorCoord init_coord(row, col); - // Locate memory_pointer memory_pointer = reinterpret_cast(byte_pointer + byte_offset - + permute_layout_(init_coord) * sizeof(AccessType) / kElementsPerAccess); + + permute_layout_(PitchLinearCoord(col, row)) * sizeof(AccessType) / kElementsPerAccess); } if (UseCUDAStore) { diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h index ff7f659c..360d5f27 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -249,17 +249,21 @@ class SharedLoadIteratorMixed { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for int32_t x 16 => int8_t/int4b_t x 16 +/// Partial specialization for +/// int32_t x 16 => int8_t/int4b_t x 16 and +/// float x 16 => float_e4m3_t/float_e5m2_t x 16 template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) + typename Element_, int OutputSizeBits_ ///< Size of output element in bits > -class SharedLoadIteratorMixed { +class SharedLoadIteratorMixed { public: using ThreadMap = ThreadMap_; using Shape = typename ThreadMap::Shape; - using Element = int32_t; + using Element = Element_; + static_assert(sizeof_bits::value == 32, "Element size in bits must be 32."); using Layout = layout::RowMajor; using TensorRef = TensorRef; @@ -414,17 +418,21 @@ class SharedLoadIteratorMixed int8_t/int4b_t x 8 +/// Partial specialization for: +/// int32_t x 8 => int8_t/int4b_t x 8 and +/// float x 8 => float_e4m3_t/float_e5m2_t x 8 template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) + typename Element_, int OutputSizeBits_ > -class SharedLoadIteratorMixed { +class SharedLoadIteratorMixed { public: using ThreadMap = ThreadMap_; using Shape = typename ThreadMap::Shape; - using Element = int32_t; + using Element = Element_; + static_assert(sizeof_bits::value == 32, "Element size in bits must be 32."); using Layout = layout::RowMajor; using TensorRef = TensorRef; diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index 74911b81..bda94740 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -716,7 +716,6 @@ class TileIteratorTensorOpMixed +// FP8 types are available starting CUDA 11.8+ +#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) +#define CUDA_FP8_ENABLED 1 +#endif -#include "cutlass/cutlass.h" +#if defined(__CUDA_ARCH__) +# if (__CUDA_ARCH__ >= 900) +# if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) +# define CUDA_PTX_FP8_CVT_ENABLED 1 +# endif // (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) +# endif // (__CUDA_ARCH__ >= 900) +#endif // defined(__CUDA_ARCH__) + +#ifdef __GNUC__ +// Ignore checks on reinterpret-casts that are being used for bitcasts. +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(__CUDACC_RTC__) @@ -53,20 +69,14 @@ #include #endif -/////////////////////////////////////////////////////////////////////////////////////////////////// - -#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -#ifndef CUDA_PTX_FP8_CVT_ENABLED -#define CUDA_PTX_FP8_CVT_ENABLED 1 -#endif -#endif +#ifdef CUDA_FP8_ENABLED +#include #endif +#include -#ifdef __GNUC__ -// Ignore checks on reinterpret-casts that are being used for bitcasts. -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif +#include "cutlass/cutlass.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -435,20 +445,16 @@ struct alignas(1) float_e4m3_t : float8_base { // Methods // - /// Default constructor - CUTLASS_HOST_DEVICE - float_e4m3_t() : Base() { } + /// Constructor inheritance + using Base::Base; - /// Reinterpret cast from CUDA's FP8 type +#ifdef CUDA_FP8_ENABLED + /// Conversion from CUDA's FP8 type CUTLASS_HOST_DEVICE - float_e4m3_t(float_e4m3_t const& x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - uint8_t raw = x.storage; - std::memcpy(&storage, &raw, sizeof(storage)); - #endif + explicit float_e4m3_t(__nv_fp8_e4m3 x) { + storage = x.__x; } +#endif /// Floating point conversion CUTLASS_HOST_DEVICE @@ -475,17 +481,14 @@ struct alignas(1) float_e4m3_t : float8_base { CUTLASS_HOST_DEVICE explicit float_e4m3_t(float_e5m2_t x); - /// Assignment +#ifdef CUDA_FP8_ENABLED + /// Assignment from CUDA's FP8 type CUTLASS_HOST_DEVICE - float_e4m3_t & operator=(float_e4m3_t const &x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - uint8_t raw = x.storage; - std::memcpy(&storage, &raw, sizeof(storage)); - #endif + float_e4m3_t & operator=(__nv_fp8_e4m3 x) { + storage = x.__x; return *this; } +#endif /// Converts to float CUTLASS_HOST_DEVICE @@ -561,7 +564,6 @@ struct alignas(1) float_e4m3_t : float8_base { return int(storage & Base::FP8_MANTISSA_MASK); } }; - /////////////////////////////////////////////////////////////// /// /// floating-point 8 type : E5M2 @@ -645,20 +647,16 @@ struct alignas(1) float_e5m2_t : float8_base { // Methods // - /// Default constructor - CUTLASS_HOST_DEVICE - float_e5m2_t() : Base() { } + /// Constructor inheritance + using Base::Base; - /// Reinterpret cast from CUDA's FP8 type +#ifdef CUDA_FP8_ENABLED + /// Conversion from CUDA's FP8 type CUTLASS_HOST_DEVICE - float_e5m2_t(float_e5m2_t const& x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - uint8_t raw = x.storage; - std::memcpy(&storage, &raw, sizeof(storage)); - #endif + explicit float_e5m2_t(__nv_fp8_e5m2 x) { + storage = x.__x; } +#endif /// Floating point conversion CUTLASS_HOST_DEVICE @@ -685,17 +683,14 @@ struct alignas(1) float_e5m2_t : float8_base { CUTLASS_HOST_DEVICE explicit float_e5m2_t(float_e4m3_t x); - /// Assignment +#ifdef CUDA_FP8_ENABLED + /// Assignment from CUDA's FP8 type CUTLASS_HOST_DEVICE - float_e5m2_t & operator=(float_e5m2_t const &x) { - #if defined(__CUDA_ARCH__) - storage = reinterpret_cast(x); - #else - uint8_t raw = x.storage; - std::memcpy(&storage, &raw, sizeof(storage)); - #endif + float_e5m2_t & operator=(__nv_fp8_e5m2 x) { + storage = x.__x; return *this; } +#endif /// Converts to float CUTLASS_HOST_DEVICE @@ -771,7 +766,6 @@ struct alignas(1) float_e5m2_t : float8_base { return int(storage & Base::FP8_MANTISSA_MASK); } }; - /////////////////////////////////////////////////////////////////////////////////////////////////// // // Arithmetic operators diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index f12515f2..b8b79e31 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -44,6 +44,11 @@ #include #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) +#ifdef _MSC_VER +// Provides support for alternate operators such as 'and', 'or', ... +#include +#endif // _MSC_VER + namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -148,7 +153,11 @@ template struct maximum_with_nan_propogation { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { +#if defined(__CUDA_ARCH__) + return lhs > rhs or isnan(lhs) ? lhs : rhs; +#else return lhs > rhs or std::isnan(lhs) ? lhs : rhs; +#endif } }; @@ -159,6 +168,8 @@ struct maximum_with_nan_propogation { float res; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); +#elif defined(__CUDA_ARCH__) + res = lhs > rhs or isnan(lhs) ? lhs : rhs; #else res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; #endif diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 580d5ca2..9f922e3c 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -59,12 +59,11 @@ constexpr int sm90_smem_capacity_bytes = 232448; // Maps 2.x A matrix layout tag to respective GMMA major mode enum template constexpr cute::GMMA::Major -tag_to_gmma_major_A() { - // MN major mode is only valid for non-TF32 and non-int MMAs +gmma_ss_tag_to_major_A() { + // MN major mode is only valid for non-TF32, non-int if constexpr (cutlass::gemm::detail::is_mn_major_A() && - not std::is_same_v && - not std::is_same_v && - not std::is_same_v) { + not cute::is_same_v && + sizeof(ElementA) != 1) { return cute::GMMA::Major::MN; } else { @@ -75,12 +74,23 @@ tag_to_gmma_major_A() { // Maps 2.x B matrix layout tag to respective GMMA major mode enum template constexpr cute::GMMA::Major -tag_to_gmma_major_B() { - // MN major mode is only valid for non-TF32 and non-int MMAs +gmma_ss_tag_to_major_B() { + // MN major mode is only valid for non-TF32, non-int if constexpr (cutlass::gemm::detail::is_mn_major_B() && - not std::is_same_v && - not std::is_same_v && - not std::is_same_v) { + not cute::is_same_v && + sizeof(ElementB) != 1) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +template +constexpr cute::GMMA::Major +gmma_rs_tag_to_major_A() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_A()) { return cute::GMMA::Major::MN; } else { @@ -88,10 +98,21 @@ tag_to_gmma_major_B() { } } +template +constexpr cute::GMMA::Major +gmma_rs_tag_to_major_B() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_B()) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} // Maps a rank-1 cute::Shape<> representing the cluster shape on to the TMA atom that should be used with it template constexpr auto -cluster_shape_to_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { +sm90_cluster_shape_to_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { static_assert(cute::rank(unimodal_cluster_shape) == 1, "Use this function to figure out TMA for each mode individually."); @@ -113,6 +134,7 @@ make_cp_async_gmem_tiled_copy() { // Maximize the number of threads along the gmem major mode to promote coalesced reads // While making sure our thread layout tiles the threadblock tile evenly + if constexpr (cutlass::gemm::detail::is_k_major()) { // K major thread layout for K major gmem constexpr int threads_major = TileSizeK / Alignment; @@ -140,89 +162,333 @@ make_cp_async_gmem_tiled_copy() { Layout,_1>>{}); } else { - static_assert(std::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); + static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); } } + // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template +template constexpr int -compute_stage_count_or_override(int KernelSmemCarveout = 0) { - if constexpr (std::is_same_v) { - // 32 bytes to account for barriers etc. - constexpr int stage_barrier_bytes = 32; - constexpr int a_bytes = static_cast(sizeof(ElementA)); - constexpr int b_bytes = static_cast(sizeof(ElementB)); - constexpr int stage_bytes = - (a_bytes * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + - (b_bytes * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + - stage_barrier_bytes; - - return (CapacityBytes - KernelSmemCarveout) / stage_bytes; +compute_stage_count_or_override(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(cute::integral_constant stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCountAutoCarveout stage_count) { + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bytes = static_cast(sizeof(ElementA)); + constexpr int b_bytes = static_cast(sizeof(ElementB)); + constexpr int stage_bytes = + (a_bytes * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + (b_bytes * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + stage_barrier_bytes; + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the optimal GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +constexpr auto +rs_smem_selector() { + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + if constexpr (major == GMMA::Major::MN) { + if constexpr (sizeof(ElementType) == 4){ + if constexpr (is_ws_transposed_B) { + // only optimized transpositionB(SW32 and SW128 for tf32) can be used, but prefer SW32 due to free bank conflict + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_SW32_Atom{})"); + } + } + else { + // Fall into SW32 due to free bank conflict + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + } + // Used for int8, fp16 and bf16 I/O kernels + else if constexpr (sizeof(ElementType) == 1 || sizeof(ElementType) == 2) { + if constexpr (sizeof(ElementType) == 1 && is_ws_transposed_B) { + // Only optimized transpositionB (SW32 for int8) can be used + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_128_Atom{})"); + } + } + else { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Smem selector does not support this element type"); + } } - else { - return StageCountType::value; + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +ss_smem_selector() +{ + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + + if constexpr (major == GMMA::Major::MN) { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +template +constexpr bool +is_input_size_two_bytes() { + return (sizeof(ElementA) == 2 && sizeof(ElementB) == 2); } -// Kernel policy selection logic: auto dispatches to KernelTmaWarpSpecialized for now. Subject to change. +template +constexpr bool +is_use_rmem_A() { + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A() && + cutlass::gemm::detail::is_k_major_B(); + constexpr bool IsUseRmemA = !IsInputSizeTwoBytes && !IsLayoutAkBk; + return IsUseRmemA; +} + +template +constexpr bool +is_swapAB(){ + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAkBmn = cutlass::gemm::detail::is_k_major_A() && + cutlass::gemm::detail::is_mn_major_B(); + constexpr bool SwapAB = !IsInputSizeTwoBytes && IsLayoutAkBmn; + return SwapAB; +} + +template +constexpr bool +is_aligned() { + return ((sizeof(ElementA) * AlignmentA) % RequiredAlignment == 0) && + ((sizeof(ElementB) * AlignmentB) % RequiredAlignment == 0); +} + +template +constexpr bool +is_warpspecialized_transpose_B(){ + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAmnBmn = cutlass::gemm::detail::is_mn_major_A() && + cutlass::gemm::detail::is_mn_major_B(); + constexpr bool IsWarpSpecialized = cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v; + constexpr bool IsWarpSpecializedTransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn && IsWarpSpecialized; + return IsWarpSpecializedTransposeB; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS template < class ElementA, + class GmemLayoutA, + int AlignmentA, class ElementB, + class GmemLayoutB, + int AlignmentB, class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType, class KernelScheduleType > -constexpr auto -generate_gmma_dispatch_policy() { - if constexpr (std::is_base_of_v or - std::is_same_v) { - constexpr int PipelineStages = compute_stage_count_or_override< - sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, StageCountType>(); - - if constexpr (std::is_same_v or - std::is_same_v) { - return MainloopSm90TmaGmmaWarpSpecialized{}; - } - else { - static_assert(sizeof(ElementA) == 0, "Invalid kernel schedule type."); - } - } +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); - else if constexpr (std::is_base_of_v) { - // For the persistent kernel, assume that the epilogue uses 1 MN tile worth of smem - constexpr int EpilogueTileCarveout = sizeof(ElementAccumulator) * - (size<0>(TileShape_MNK{}) * size<1>(TileShape_MNK{})); - constexpr int PipelineStages = compute_stage_count_or_override< - sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, StageCountType>(EpilogueTileCarveout); + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; - if constexpr (std::is_same_v) { - return MainloopSm90TmaGmmaWarpSpecialized{}; - } - else { - static_assert(sizeof(ElementA) == 0, "Invalid kernel schedule type."); - } - } + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - else if constexpr (std::is_base_of_v) { - constexpr int PipelineStages = compute_stage_count_or_override< - sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK, StageCountType>(); + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; - return MainloopSm90TmaGmma{}; - } + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); - else { - static_assert(sizeof(ElementA) == 0, "Invalid kernel schedule type."); - } -} + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); -} // namespace detail + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; ///////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA_TMA_SS +// GMMA_TMA_WS_RS template < class ElementA, class GmemLayoutA, @@ -250,45 +516,138 @@ struct CollectiveBuilder< ClusterShape_MNK, StageCountType, KernelScheduleType, - std::enable_if_t< - // TMA requires alignment be 16 bytes - ((sizeof(ElementA) * AlignmentA) % detail::tma_alignment_bytes == 0) && - ((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes == 0) && - not std::is_same_v && - // dispatch TN tf32 and int8 kernels only to TMA builder - ((sizeof(ElementA) == 2 && sizeof(ElementB) == 2) || - (cutlass::gemm::detail::is_k_major_A() && - cutlass::gemm::detail::is_k_major_B()))> + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + detail::is_use_rmem_A()> > { static_assert(is_static::value); static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool SwapAB = detail::is_swapAB(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B< + ElementA, GmemLayoutA, ElementB, GmemLayoutB, KernelScheduleType>(); + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; + + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); - #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(sizeof(ElementA) == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); - #endif + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_SS +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t && + not detail::is_use_rmem_A()> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif // For fp32 types, map to tf32 MMA value type - using MmaElementA = std::conditional_t, tfloat32_t, ElementA>; - using MmaElementB = std::conditional_t, tfloat32_t, ElementB>; + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; - static constexpr cute::GMMA::Major GmmaMajorA = detail::tag_to_gmma_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::tag_to_gmma_major_B(); + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>())); - using GmemTiledCopyA = decltype(detail::cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< - GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) - >()); - using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< - GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) - >()); + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmma; - using DispatchPolicy = decltype(detail::generate_gmma_dispatch_policy< - MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>()); + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; using CollectiveOp = CollectiveMma< DispatchPolicy, @@ -300,18 +659,20 @@ struct CollectiveBuilder< TiledMma, GmemTiledCopyA, SmemLayoutAtomA, - void, // GMMA_SS does not need an SmemCopyAtom + SmemCopyAtomA, cute::identity, GmemTiledCopyB, SmemLayoutAtomB, - void, // GMMA_SS does not need an SmemCopyAtom + SmemCopyAtomB, cute::identity >; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA_CpAsync_SS +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_CpAsync template < class ElementA, class GmemLayoutA, @@ -339,34 +700,24 @@ struct CollectiveBuilder< ClusterShape_MNK, StageCountType, KernelScheduleType, - std::enable_if_t< - // Even if we could build a TMA kernel, let the user override and use cp_async instead - std::is_same_v || - // But always guard against invalid TMA alignments and dispatch to cp_async - ((sizeof(ElementA) * AlignmentA) % detail::tma_alignment_bytes != 0) || - ((sizeof(ElementB) * AlignmentB) % detail::tma_alignment_bytes != 0) || - // dispatch non-TN tf32 and int8 kernels only to cp_async builder - ((sizeof(ElementA) != 2 || sizeof(ElementB) != 2) && - (not cutlass::gemm::detail::is_k_major_A() || - not cutlass::gemm::detail::is_k_major_B()))> + cute::enable_if_t< + cute::is_same_v> > { static_assert(is_static::value); static_assert(is_static::value); - - #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(sizeof(ElementA) == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); - #endif +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif // For fp32 types, map to tf32 MMA value type - using MmaElementA = std::conditional_t, tfloat32_t, ElementA>; - using MmaElementB = std::conditional_t, tfloat32_t, ElementB>; + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; - static_assert((sizeof(ElementA) * AlignmentA) % detail::cp_async_min_alignment_bytes == 0 and - (sizeof(ElementB) * AlignmentB) % detail::cp_async_min_alignment_bytes == 0, + static_assert(detail::is_aligned(), "Minimum alignment required for cp.async is 4B."); - static constexpr cute::GMMA::Major GmmaMajorA = detail::tag_to_gmma_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::tag_to_gmma_major_B(); + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>())); @@ -374,21 +725,17 @@ struct CollectiveBuilder< using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< 128, ElementA, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< 128, ElementB, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector< - GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) - >()); - - using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector< - GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})) - >()); + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); static constexpr int PipelineStages = detail::compute_stage_count_or_override< - detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB, TileShape_MNK, StageCountType>(); + detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB, TileShape_MNK>(StageCountType{}); using CollectiveOp = CollectiveMma< MainloopSm90CpAsyncGmma, @@ -400,17 +747,86 @@ struct CollectiveBuilder< TiledMma, GmemTiledCopyA, SmemLayoutAtomA, - void, // GMMA_SS does not need an SmemCopyAtom + void, cute::identity, GmemTiledCopyB, SmemLayoutAtomB, - void, // GMMA_SS does not need an SmemCopyAtom + void, cute::identity >; }; ///////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA auto kernel schedule +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + +static constexpr bool IsTmaWarpSpecialized = detail::is_aligned< + ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(); + +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) + // Cooperative schedule performs best for CUDA Toolkits with version >= 12.1 + + // For TileShape_M == 64, choosing KernelTmaWarpSpecialized as the KernelSchedule + // Since KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 + using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelTmaWarpSpecialized, KernelTmaWarpSpecializedCooperative>; +#else + using KernelWarpSpecializedSchedule = KernelTmaWarpSpecialized; +#endif + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + cute::conditional_t + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 3cd68a41..3c0aa15d 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -40,7 +40,11 @@ namespace cutlass::gemm::collective { // Used to specify stage counts or dispatch to automatic computation of stage count template struct StageCount { static constexpr int value = num_stages; }; -struct StageCountAuto {}; + +template +struct StageCountAutoCarveout { static constexpr int bytes = carveout_bytes; }; + +using StageCountAuto = StageCountAutoCarveout<0>; // Used to automatically let the builder pick the kernel schedule. // Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index a2a90675..2a0ba6da 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -30,6 +30,8 @@ **************************************************************************************************/ #pragma once +#include "cutlass/detail/dependent_false.hpp" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { @@ -54,7 +56,7 @@ template < class TransformB > struct CollectiveMma { - static_assert(sizeof(ElementA) == 0, "Could not find a mainloop specialization."); + static_assert(cutlass::detail::dependent_false == 0, "Could not find a mainloop specialization."); }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -67,5 +69,6 @@ struct CollectiveMma { #include "sm80_mma_multistage.hpp" #include "sm90_mma_multistage_gmma_ss.hpp" #include "sm90_mma_tma_gmma_ss.hpp" +#include "sm90_mma_tma_gmma_rs_warpspecialized.hpp" #include "sm90_mma_tma_gmma_ss_warpspecialized.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp index 11e5515a..ffe1ea6d 100644 --- a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp +++ b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -121,24 +121,28 @@ struct CollectiveMma< cute::array_aligned> smem_b; }; - struct Params { + // Host side kernel arguments + struct Arguments { ElementA const* ptr_A; StrideA dA; ElementB const* ptr_B; StrideB dB; }; + // Device side kernel params + using Params = Arguments; + // // Methods // CollectiveMma() = default; - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { (void) workspace; - return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + return args; } /// Perform a threadblock-scoped matrix multiply-accumulate @@ -360,24 +364,28 @@ struct CollectiveMma< cute::array_aligned> smem_b; }; - struct Params { + // Host side kernel arguments + struct Arguments { ElementA const* ptr_A; StrideA dA; ElementB const* ptr_B; StrideB dB; }; + // Device side kernel params + using Params = Arguments; + // // Methods // CollectiveMma() = default; - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { (void) workspace; - return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + return args; } /// Perform a threadblock-scoped matrix multiply-accumulate diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index 6ba6ccc0..dc98823c 100644 --- a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -124,24 +124,28 @@ struct CollectiveMma< cute::array_aligned> smem_b; }; - struct Params { + // Host side kernel arguments + struct Arguments { ElementA const* ptr_A; StrideA dA; ElementB const* ptr_B; StrideB dB; }; + // Device side kernel params + using Params = Arguments; + // // Methods // CollectiveMma() = default; - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { (void) workspace; - return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + return args; } /// Perform a collective-scoped matrix multiply-accumulate @@ -409,24 +413,28 @@ struct CollectiveMma< cute::array_aligned> smem_b; }; - struct Params { + // Host side kernel arguments + struct Arguments { ElementA const* ptr_A; StrideA dA; ElementB const* ptr_B; StrideB dB; }; + // Device side kernel params + using Params = Arguments; + // // Methods // CollectiveMma() = default; - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { (void) workspace; - return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + return args; } /// Perform a collective-scoped matrix multiply-accumulate diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp index 3b1921b9..57c07995 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp @@ -32,7 +32,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/reg_reconfig.h" @@ -120,8 +120,8 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(std::is_base_of::value && - std::is_base_of::value, + static_assert(cute::is_base_of::value && + cute::is_base_of::value, "MMA atom must source both A and B operand from smem_desc for this mainloop."); struct SharedStorage @@ -130,24 +130,26 @@ struct CollectiveMma< cute::array_aligned> smem_b; }; - struct Params { + struct Arguments { ElementA const* ptr_A; StrideA dA; ElementB const* ptr_B; StrideB dB; }; + using Params = Arguments; + // // Methods // CollectiveMma() = default; - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { (void) workspace; - return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + return args; } /// Perform a collective-scoped matrix multiply-accumulate @@ -180,13 +182,13 @@ struct CollectiveMma< static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); SharedStorage& storage = *reinterpret_cast(smem_buf); @@ -353,8 +355,8 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(std::is_base_of::value && - std::is_base_of::value, + static_assert(cute::is_base_of::value && + cute::is_base_of::value, "MMA atom must source both A and B operand from smem_desc for this mainloop."); struct SharedStorage @@ -363,24 +365,26 @@ struct CollectiveMma< cute::array_aligned> smem_b; }; - struct Params { + struct Arguments { ElementA const* ptr_A; StrideA dA; ElementB const* ptr_B; StrideB dB; }; + using Params = Arguments; + // // Methods // CollectiveMma() = default; - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { (void) workspace; - return {args.ptr_A, args.dA, args.ptr_B, args.dB}; + return args; } /// Perform a collective-scoped matrix multiply-accumulate @@ -413,13 +417,13 @@ struct CollectiveMma< static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(std::is_same::value, + static_assert(cute::is_same::value, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); SharedStorage& storage = *reinterpret_cast(smem_buf); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp new file mode 100644 index 00000000..a80a6dbd --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -0,0 +1,608 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template < + int Stages, + class ClusterShape, + class KernelSchedule, + int PipelineAsyncMmaStages, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaRmemAWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8 WGMMA) + static constexpr bool IsLayoutAkBmn = + cute::is_same_v, layout::RowMajor> && + cute::is_same_v, layout::RowMajor>; + + static constexpr bool IsInputSizeTwoBytes = sizeof(ElementA) == 2 && sizeof(ElementB) == 2; + static constexpr bool SwapAB = !IsInputSizeTwoBytes && IsLayoutAkBmn; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync< + DispatchPolicy::Stages, + typename DispatchPolicy::ClusterShape>; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(tile_to_shape( + InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(tile_to_shape( + InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, layout::ColumnMajor> && + cute::is_same_v, layout::RowMajor>; + static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using GmmaSmemLayoutAtomB = decltype(transform::collective::detail::gmma_smem_transpose_or_passthrough< + TransposeB, InternalSmemLayoutAtomB, InternalElementB>()); + + // SmemLayoutB for GMMA is different from SmemLayoutB for TMA if TransposeB + using GmmaSmemLayoutB = decltype(tile_to_shape( + GmmaSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + Step<_2,_1,_3>{})); + + static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); + static_assert(TransposeB || (cute::is_same_v), + "Should be same layout if not TransposeB."); + static_assert(!TransposeB || size<1>(SmemLayoutB{}) * sizeof(InternalElementB) == 128, + "SmemLayoutB K must be 128bytes to be transposed."); + static_assert(!transform::collective::detail::use_universal_transposition(), + "Warp specialized ARF kernels have not supported universal B transposition yet."); + static_assert(!TransposeB || shape<0>(TileShape{}) == 64, "Optimized transpose RS kernel requires TileShape M = 64."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + InternalElementA const* ptr_A; + InternalStrideA dA; + InternalElementB const* ptr_B; + InternalStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } + else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b + }; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + static_assert(K_PIPE_MMAS >= 1, "At least one MMA stage should be asynchronous for this mainloop."); + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(InternalElementA)))+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(InternalElementB))); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TMA_LOAD_A, + class TensorB, class TMA_LOAD_B, + class KTileIterator + > + CUTLASS_DEVICE void + load( + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA, TMA_LOAD_A& tma_load_a, + TensorB const& gB, TMA_LOAD_B& tma_load_b, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors) + { + + using namespace cute; + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + dim3 cluster_local_block_id = cute::block_id_in_cluster(); + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Issue the prologue loads + int k_tile_prologue = min(k_tile_count, K_PIPE_MAX); + CUTLASS_PRAGMA_UNROLL + for (int count = 0; count < k_tile_prologue; ++count) { + pipeline.producer_acquire(smem_pipe_write); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + ++smem_pipe_write; + } + k_tile_count -= k_tile_prologue; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (warp_idx_in_warp_group == 0 and lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) + { + using namespace cute; + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx(); + int warp_idx_in_warp_group = warp_idx % 4; + int warp_group_thread_idx = thread_idx % 128; + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // If TransposeB, GMMA will read from transposed B layout SMEM + Tensor gmma_sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCsA = thread_mma.partition_A(sA); + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(gmma_sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read); + + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,_,_,smem_pipe_read.index()), tCrA_copy_view); + // transpose B operand in SMEM + if constexpr (TransposeB) { + transform::collective::detail::transpose_b_operand( + sB, gmma_sB, smem_pipe_read, warp_idx_in_warp_group, warp_group_thread_idx, + tiled_mma, SmemLayoutB{}, InternalSmemLayoutAtomB{}, InternalElementB{}); + } // if TransposeB + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_read); + + // + // Compute on k_tile + // + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,_,_,smem_pipe_read.index()), tCrA_copy_view); + // transpose B operand in SMEM + if constexpr (TransposeB) { + transform::collective::detail::transpose_b_operand( + sB, gmma_sB, smem_pipe_read, warp_idx_in_warp_group, warp_group_thread_idx, + tiled_mma, SmemLayoutB{}, InternalSmemLayoutAtomB{}, InternalElementB{}); + } // if TransposeB + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 25eaffb7..cf0a050d 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -40,7 +40,7 @@ #include "cute/algorithm/gemm.hpp" #include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,20 +132,20 @@ struct CollectiveMma< Step<_2,_1,_3>{})); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(std::is_base_of::value && - std::is_base_of::value, + static_assert(cute::is_base_of::value && + cute::is_base_of::value, "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(std::is_same_v || std::is_same_v, + static_assert(cute::is_same_v || cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(std::is_same_v || std::is_same_v, + static_assert(cute::is_same_v || cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); // TMA converts f32 input to tf32 when copying from GMEM to SMEM // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. - static constexpr bool ConvertF32toTF32A = std::is_same_v; - static constexpr bool ConvertF32toTF32B = std::is_same_v; - using InternalElementA = std::conditional_t>>; - using InternalElementB = std::conditional_t>>; + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; struct SharedStorage { @@ -156,22 +156,27 @@ struct CollectiveMma< alignas(16) PipelineStorage pipeline_storage; }; - struct Params { - InternalElementA const* ptr_A; + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; StrideA dA; - InternalElementB const* ptr_B; + ElementB const* ptr_B; StrideB dB; + }; + + // Device side kernel params + struct Params { // Assumption: StrideA is congruent with Problem_MK using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, - make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), SmemLayoutA{}(_,_,0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, - make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), SmemLayoutB{}(_,_,0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any @@ -183,22 +188,23 @@ struct CollectiveMma< // Methods // - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; + // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); auto K = get<2>(problem_shape_MNKL); auto L = get<3>(problem_shape_MNKL); - auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); - auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); - Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); typename Params::TMA_A tma_load_a = make_tma_copy( GmemTiledCopyA{}, tensor_a, @@ -212,10 +218,6 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any return { - reinterpreted_ptr_A, - args.dA, - reinterpreted_ptr_B, - args.dB, tma_load_a, tma_load_b }; @@ -253,9 +255,9 @@ struct CollectiveMma< static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(std::is_void_v, + static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(std::is_void_v, + static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); SharedStorage& storage = *reinterpret_cast(shared_memory); @@ -342,14 +344,14 @@ struct CollectiveMma< // Issue TmaLoads (Prologue fetches) if (warp_idx == 0 && lane_predicate == 1) { // Maps the tile -> block, value - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } } - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); @@ -361,8 +363,8 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < prologue_tma_count; ++stage) { pipeline.producer_acquire(smem_pipe_write); - using BarrierType = typename MainloopPipeline::ValueType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(stage); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,stage)); copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,stage)); @@ -397,6 +399,8 @@ struct CollectiveMma< __syncthreads(); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(accum); // Prologue MMAs CUTLASS_PRAGMA_UNROLL @@ -406,7 +410,14 @@ struct CollectiveMma< // WAIT on smem_pipe_read until it's data is available pipeline.consumer_wait(smem_pipe_read); warpgroup_arrive(); - cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,smem_pipe_read.index()), tCrB(_,_,k_block,smem_pipe_read.index()), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); ++smem_pipe_read; --k_tile_count; @@ -429,7 +440,13 @@ struct CollectiveMma< warpgroup_fence_operand(accum); warpgroup_arrive(); - cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read.index()), tCrB(_,_,_,smem_pipe_read.index()), accum); // (V,M,K) x (V,N,K) => (V,M,N) + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,smem_pipe_read.index()), tCrB(_,_,k_block,smem_pipe_read.index()), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } warpgroup_commit_batch(); /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed @@ -446,8 +463,8 @@ struct CollectiveMma< if (warp_idx == 0 && lane_predicate == 1 && (k_tile_count_tma > 0) ) { pipeline.producer_acquire(smem_pipe_write); // LOCK wr stage, for _writing_ - using BarrierType = typename MainloopPipeline::ValueType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write.index()); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write.index())); copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write.index())); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 41b0f13b..01638c52 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -40,7 +40,7 @@ #include "cute/algorithm/gemm.hpp" #include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,47 +132,56 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), Step<_2,_1,_3>{})); - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(std::is_base_of::value && - std::is_base_of::value, + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(std::is_same_v || std::is_same_v, + static_assert(cute::is_same_v || cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(std::is_same_v || std::is_same_v, + static_assert(cute::is_same_v || cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); // TMA converts f32 input to tf32 when copying from GMEM to SMEM // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. - static constexpr bool ConvertF32toTF32A = std::is_same_v; - static constexpr bool ConvertF32toTF32B = std::is_same_v; - using InternalElementA = std::conditional_t>>; - using InternalElementB = std::conditional_t>>; + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; struct SharedStorage { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; - alignas(16) PipelineStorage pipeline_storage; + PipelineStorage pipeline; }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; - struct Params { - InternalElementA const* ptr_A; + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; StrideA dA; - InternalElementB const* ptr_B; + ElementB const* ptr_B; StrideB dB; + }; + + // Device side kernel params + struct Params { // Assumption: StrideA is congruent with Problem_MK using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, - make_tensor(ptr_A, repeat_like(StrideA{}, int32_t(0)), dA), + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), SmemLayoutA{}(_,_,0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, - make_tensor(ptr_B, repeat_like(StrideB{}, int32_t(0)), dB), + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), SmemLayoutB{}(_,_,0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any @@ -184,22 +193,23 @@ struct CollectiveMma< // Methods // - template + template static constexpr Params - to_underlying_arguments(Args const& args, void* workspace) { + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; + // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); auto K = get<2>(problem_shape_MNKL); auto L = get<3>(problem_shape_MNKL); - auto reinterpreted_ptr_A = reinterpret_cast(args.ptr_A); - auto reinterpreted_ptr_B = reinterpret_cast(args.ptr_B); + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); - Tensor tensor_a = make_tensor(reinterpreted_ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(reinterpreted_ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); typename Params::TMA_A tma_load_a = make_tma_copy( GmemTiledCopyA{}, tensor_a, @@ -213,10 +223,6 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any return { - reinterpreted_ptr_A, - args.dA, - reinterpreted_ptr_B, - args.dB, tma_load_a, tma_load_b }; @@ -228,12 +234,6 @@ struct CollectiveMma< (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); - CUTLASS_DEVICE - static MainloopPipeline make_pipeline(char* shared_memory, PipelineParams params){ - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - return {shared_storage.pipeline_storage, params}; - } - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) @@ -250,13 +250,14 @@ struct CollectiveMma< class KTileIterator > CUTLASS_DEVICE void - dma(MainloopPipeline pipeline, + load( + MainloopPipeline pipeline, PipelineState smem_pipe_write, TensorA const& gA, TMA_LOAD_A& tma_load_a, TensorB const& gB, TMA_LOAD_B& tma_load_b, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, - char* shared_memory) + TensorStorage& shared_tensors) { using namespace cute; @@ -265,9 +266,8 @@ struct CollectiveMma< int lane_predicate = cute::elect_one_sync(); if (warp_idx_in_warp_group == 0 and lane_predicate) { - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // // Prepare the TMA loads for A and B @@ -289,14 +289,14 @@ struct CollectiveMma< // Issue TmaLoads // Maps the tile -> block, value - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } } - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); @@ -308,10 +308,10 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int count = 0; count < k_tile_prologue; ++count) { pipeline.producer_acquire(smem_pipe_write); - int write_stage = smem_pipe_write.index(); - using BarrierType = typename MainloopPipeline::ValueType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + int write_stage = smem_pipe_write.index(); copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; @@ -330,10 +330,10 @@ struct CollectiveMma< // Copy gmem to smem for *k_tile_iter // - int write_stage = smem_pipe_write.index(); - using BarrierType = typename MainloopPipeline::ValueType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(write_stage); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + int write_stage = smem_pipe_write.index(); copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; @@ -346,7 +346,8 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void - dma_epilogue(MainloopPipeline pipeline, + load_tail( + MainloopPipeline pipeline, PipelineState smem_pipe_write) { int warp_idx = canonical_warp_idx(); @@ -361,10 +362,7 @@ struct CollectiveMma< * then would just be acquired since the phase was * still inverted from make_producer_start_state */ - for (int count = 0; count < K_PIPE_MAX; ++count) { - pipeline.producer_acquire(smem_pipe_write); - ++smem_pipe_write; - } + pipeline.producer_tail(smem_pipe_write); } } @@ -379,23 +377,21 @@ struct CollectiveMma< FrgTensorC& accum, int k_tile_count, int thread_idx, - char* shared_memory, - Params const& mainloop_params - ) + TensorStorage& shared_tensors, + Params const& mainloop_params) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(std::is_void_v, + static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(std::is_void_v, + static_assert(cute::is_void_v, "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // // Define C accumulators and A/B partitioning @@ -424,12 +420,14 @@ struct CollectiveMma< static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); - // We release buffers to producer warps(dma) with some mmas in flight + // We release buffers to producer warps(dma load) with some mmas in flight PipelineState smem_pipe_release = smem_pipe_read; // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + warpgroup_fence_operand(accum); CUTLASS_PRAGMA_UNROLL for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) @@ -439,7 +437,14 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); warpgroup_arrive(); - cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); ++smem_pipe_read; @@ -462,23 +467,41 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); warpgroup_fence_operand(accum); warpgroup_arrive(); - cute::gemm(tiled_mma, tCrA(_,_,_,read_stage), tCrB(_,_,_,read_stage), accum); // (V,M,K) x (V,N,K) => (V,M,N) + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } warpgroup_commit_batch(); /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed warpgroup_wait(); warpgroup_fence_operand(accum); - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); // Advance smem_pipe_read and smem_pipe_release ++smem_pipe_read; ++smem_pipe_release; } + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + // Wait on all GMMAs to complete warpgroup_wait<0>(); - warpgroup_fence_operand(accum); for (int count = 0; count < prologue_mma_count; ++count) { pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it diff --git a/include/cutlass/gemm/device/base_grouped.h b/include/cutlass/gemm/device/base_grouped.h index e5b50f36..207266c7 100644 --- a/include/cutlass/gemm/device/base_grouped.h +++ b/include/cutlass/gemm/device/base_grouped.h @@ -157,7 +157,7 @@ class BaseGrouped { static void reorder_array(T* data, const std::vector& indices) { // For now, simply create a copy of the data and then copy over to the original. std::vector copy(indices.size()); - for (unsigned i = 0; i < indices.size(); ++i) { + for (size_t i = 0; i < indices.size(); ++i) { copy.at(i) = data[indices[i]]; } diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index 46ef274e..8d193c5b 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -763,9 +763,6 @@ struct DefaultGemmConfiguration< }; //////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - template struct DefaultGemmConfiguration class GemmUniversal : public GemmUniversalBase< @@ -161,7 +165,9 @@ class GemmUniversal : GatherA, GatherB, ScatterD, - PermuteDLayout + PermuteDLayout_, + PermuteALayout_, + PermuteBLayout_ >::GemmKernel > { @@ -176,6 +182,9 @@ class GemmUniversal : using EpilogueOutputOp = EpilogueOutputOp_; using ThreadblockSwizzle = ThreadblockSwizzle_; using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; static int const kStages = Stages; static int const kAlignmentA = AlignmentA; static int const kAlignmentB = AlignmentB; @@ -209,7 +218,9 @@ class GemmUniversal : GatherA, GatherB, ScatterD, - PermuteDLayout + PermuteDLayout_, + PermuteALayout_, + PermuteBLayout_ >::GemmKernel >; @@ -268,14 +279,19 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout_, + /// Permute operand A + typename PermuteALayout_, + /// Permute operand B + typename PermuteBLayout_ > class GemmUniversal { + Operator_, TransformA, TransformB, GatherA, GatherB, ScatterD, + PermuteDLayout_, PermuteALayout_, PermuteBLayout_> { public: using ElementA = ElementA_; @@ -297,6 +313,9 @@ class GemmUniversal::Base; using GemmKernel = typename UnderlyingOperator::GemmKernel; diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 64adac33..819491a3 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -80,7 +80,7 @@ class GemmUniversalAdapter; template class GemmUniversalAdapter< GemmKernel_, - std::enable_if_t::value>> + cute::enable_if_t::value>> { public: using GemmKernel = GemmKernel_; @@ -88,6 +88,7 @@ class GemmUniversalAdapter< using ElementA = typename GemmKernel::ElementA; using ElementB = typename GemmKernel::ElementB; using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC; using DispatchPolicy = typename GemmKernel::DispatchPolicy; using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; @@ -107,14 +108,14 @@ class GemmUniversalAdapter< using MathOperator = cutlass::arch::OpMultiplyAdd; // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! - using OperatorClass = std::conditional_t< + using OperatorClass = cute::conditional_t< (cute::size(typename GemmKernel::TiledMma::AtomThrID{}) > 1), cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; using ArchTag = typename GemmKernel::ArchTag; // NOTE: Assume identity swizzle for now - static_assert(std::is_void_v, + static_assert(cute::is_void_v, "CUTLASS 3.x kernel types do not support grid swizzle functors yet."); using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; @@ -142,7 +143,7 @@ class GemmUniversalAdapter< // But we can best approximate it by inspecting the TiledMma::TiledShape_MNK // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads - static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); + static constexpr int WarpsInMma = cute::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); static constexpr int WarpsInMmaM = 4; static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); using WarpCount = cutlass::gemm::GemmShape; @@ -166,7 +167,7 @@ class GemmUniversalAdapter< using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; // Split-K preserves splits that are 128b aligned - static int constexpr kSplitKAlignment = std::max( + static int constexpr kSplitKAlignment = cute::max( 128 / sizeof_bits::value, 128 / sizeof_bits::value); /// Argument structure: User API @@ -208,8 +209,8 @@ class GemmUniversalAdapter< /// Computes the grid shape static dim3 - get_grid_shape(Arguments const& args) { - auto tmp_params = GemmKernel::to_underlying_arguments(args); + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); return GemmKernel::get_grid_shape(tmp_params); } @@ -397,14 +398,14 @@ class GemmUniversalAdapter< template class GemmUniversalAdapter< GemmKernel_, - std::enable_if_t::value>> + cute::enable_if_t::value>> { public: using GemmKernel = GemmKernel_; static bool const kInternalTranspose = - platform::is_same::value; + cute::is_same::value; using ThreadblockShape = typename GemmKernel::Mma::Shape; using WarpShape = typename GemmKernel::WarpShape; @@ -447,11 +448,15 @@ class GemmUniversalAdapter< using ElementC = typename GemmKernel::ElementC; using LayoutC = typename MapArguments::LayoutC; static int const kAlignmentC = GemmKernel::kAlignmentC; + + // C and D same type for 2.x kernel + using ElementD = ElementC; + using LayoutD = LayoutC; using TensorRefA = TensorRef; using TensorRefB = TensorRef; using TensorRefC = TensorRef; - using TensorRefD = TensorRef; + using TensorRefD = TensorRef; static int const kStages = GemmKernel::Mma::kStages; diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index a09afb4c..cdedd3fe 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -36,7 +36,11 @@ #pragma once +#if defined(__CUDACC_RTC__) +#include +#else #include +#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index a2cd9a11..8de19d46 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -48,7 +48,12 @@ using namespace cute; struct KernelMultistage { }; struct KernelTma { }; struct KernelTmaWarpSpecialized { }; -struct KernelTmaWarpSpecializedPersistent { }; +struct KernelTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperative { }; + +// Policies for dispatch of epilogue +struct EpilogueDefault { }; +struct EpilogueTransposed { }; // // Collective Mainloop Policies @@ -130,7 +135,7 @@ struct MainloopSm90TmaGmma { template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecialized + class KernelSchedule = KernelTmaWarpSpecializedCooperative > struct MainloopSm90TmaGmmaWarpSpecialized { constexpr static int Stages = Stages_; @@ -139,6 +144,27 @@ struct MainloopSm90TmaGmmaWarpSpecialized { using Schedule = KernelSchedule; }; +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule +// With GMMA's A data from registers. +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecialized, + int PipelineAsyncMmaStages_ = 1 +> +struct MainloopSm90TmaGmmaRmemAWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); +}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 0b1fd25b..cd9a310e 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -523,8 +523,8 @@ template constexpr int get_alignment_count_from_gmem_tiled_copy() { // For TMA tiled copies, we know the alignment has to be 128 bits - if constexpr ( std::is_base_of_v - || std::is_base_of_v + if constexpr ( cute::is_base_of_v + || cute::is_base_of_v ) { return 128 / sizeof_bits::value; } @@ -595,11 +595,11 @@ is_k_major_B() { // The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal` // is implementing the CUTLASS 3.x API or not, by checking if the problem shape type is aliased within or not. template -struct IsCutlass3GemmKernel : std::false_type { }; +struct IsCutlass3GemmKernel : cute::false_type { }; template -struct IsCutlass3GemmKernel> - : std::true_type { }; +struct IsCutlass3GemmKernel> + : cute::true_type { }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm.h b/include/cutlass/gemm/kernel/default_gemm.h index 4432008e..ab400a82 100644 --- a/include/cutlass/gemm/kernel/default_gemm.h +++ b/include/cutlass/gemm/kernel/default_gemm.h @@ -129,6 +129,10 @@ template < bool ScatterD = false, /// Permute result D typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, /// typename Enable = void > @@ -180,19 +184,25 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemm { + Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, + PermuteDLayout, PermuteALayout, PermuteBLayout> { /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; + Operator, false, SharedMemoryClear, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; @@ -254,13 +264,18 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemm { + Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, + PermuteDLayout, PermuteALayout, PermuteBLayout> { static_assert((platform::is_same::value || platform::is_same>::value), @@ -271,7 +286,8 @@ struct DefaultGemm::ThreadblockMma; + Operator, false, SharedMemoryClear, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; @@ -337,7 +353,11 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemm< ElementA, LayoutA, kAlignmentA, @@ -358,7 +378,9 @@ struct DefaultGemm< GatherA, GatherB, ScatterD, - PermuteDLayout + PermuteDLayout, + PermuteALayout, + PermuteBLayout > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -381,7 +403,9 @@ struct DefaultGemm< false, SharedMemoryClear, GatherA, - GatherB + GatherB, + PermuteALayout, + PermuteBLayout >::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; @@ -576,7 +600,11 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemm< ElementA, LayoutA, kAlignmentA, @@ -597,7 +625,9 @@ struct DefaultGemm< GatherA, GatherB, ScatterD, - PermuteDLayout + PermuteDLayout, + PermuteALayout, + PermuteBLayout > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -620,7 +650,9 @@ struct DefaultGemm< false, SharedMemoryClear, GatherA, - GatherB + GatherB, + PermuteALayout, + PermuteBLayout >::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; @@ -685,7 +717,11 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemm< ElementA, @@ -712,6 +748,8 @@ struct DefaultGemm< GatherB, ScatterD, PermuteDLayout, + PermuteALayout, + PermuteBLayout, typename platform::enable_if< ! platform::is_same::value >::type > { static_assert((platform::is_same::value @@ -738,7 +776,9 @@ struct DefaultGemm< false, SharedMemoryClear, GatherA, - GatherB>::ThreadblockMma; + GatherB, + PermuteALayout, + PermuteBLayout>::ThreadblockMma; static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); @@ -814,7 +854,11 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemm { + PermuteDLayout, + PermuteALayout, + PermuteBLayout> { static_assert((platform::is_same::value || platform::is_same>::value), @@ -850,7 +896,8 @@ struct DefaultGemm, Stages, - Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; + Operator, false, SharedMemoryClear, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); @@ -921,14 +968,16 @@ struct DefaultGemm, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, - Operator, SharedMemoryClear, false, false, false> { + Operator, SharedMemoryClear, false, false, false, + layout::NoPermute, layout::NoPermute> { using InstructionShape = GemmShape<1, 1, 4>; using ElementA = int8_t; using ElementB = int8_t; using OperatorClass = arch::OpClassSimt; /// Define the threadblock-scoped matrix multiply-accumulate - using Mma = typename cutlass::gemm::threadblock::DefaultMma { /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaultMma< diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h index 45a825d4..afccf875 100644 --- a/include/cutlass/gemm/kernel/default_gemm_universal.h +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -114,6 +114,10 @@ template < bool ScatterD = false, /// Permute result D typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout_ = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout_ = layout::NoPermute, /// typename Enable = void > @@ -170,7 +174,11 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultGemmUniversal< ElementA, @@ -198,6 +206,8 @@ struct DefaultGemmUniversal< GatherB, ScatterD, PermuteDLayout, + PermuteALayout, + PermuteBLayout, typename platform::enable_if< ! cutlass::is_complex::value>::type > { @@ -225,7 +235,9 @@ struct DefaultGemmUniversal< GatherA, GatherB, ScatterD, - PermuteDLayout + PermuteDLayout, + PermuteALayout, + PermuteBLayout >::GemmKernel; /// Universal kernel without StreamkFeature member type @@ -326,6 +338,8 @@ struct DefaultGemmUniversal< false, false, layout::NoPermute, + layout::NoPermute, + layout::NoPermute, typename platform::enable_if::value>::type > { diff --git a/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h b/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h index dfe62d35..d83dcfd7 100644 --- a/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h +++ b/include/cutlass/gemm/kernel/default_gemm_with_broadcast.h @@ -114,7 +114,7 @@ struct DefaultGemmWithBroadcast { Operator >::GemmKernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< typename GemmBase::Epilogue::Shape, typename GemmBase::Epilogue::WarpMmaOperator, @@ -214,7 +214,7 @@ struct DefaultGemmWithBroadcast< Operator >::GemmKernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< typename GemmBase::Epilogue::Shape, typename GemmBase::Epilogue::WarpMmaOperator, diff --git a/include/cutlass/gemm/kernel/default_gemm_with_reduction.h b/include/cutlass/gemm/kernel/default_gemm_with_reduction.h index 789b4bde..6d19ee32 100644 --- a/include/cutlass/gemm/kernel/default_gemm_with_reduction.h +++ b/include/cutlass/gemm/kernel/default_gemm_with_reduction.h @@ -117,7 +117,7 @@ struct DefaultGemmWithReduction { SharedMemoryClearOption::kClearLastStage >::GemmKernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< typename GemmBase::Epilogue::Shape, typename GemmBase::Epilogue::WarpMmaOperator, @@ -218,7 +218,7 @@ struct DefaultGemmWithReduction< Operator >::GemmKernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< typename GemmBase::Epilogue::Shape, typename GemmBase::Epilogue::WarpMmaOperator, diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index 3dbd422d..116d83db 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -70,7 +70,7 @@ class GemmUniversal< void, // 3.x kernels use the first template argument to define the ProblemShape tuple // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API - std::enable_if_t::value> + cute::enable_if_t::value> > { public: @@ -364,24 +364,24 @@ class GemmUniversal< { CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); - static int const kAlignmentA = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) ? 64 : Epilogue::OutputTileIterator::kElementsPerAccess; @@ -390,30 +390,30 @@ class GemmUniversal< bool isBMisaligned = false; bool isCMisaligned = false; - if (platform::is_same::value) { + if (cute::is_same::value) { isAMisaligned = problem_size.k() % kAlignmentA; - } else if (platform::is_same::value) { + } else if (cute::is_same::value) { isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value - || platform::is_same>::value) { + } else if (cute::is_same>::value + || cute::is_same>::value) { isAMisaligned = problem_size.k() % kAlignmentA; } - if (platform::is_same::value) { + if (cute::is_same::value) { isBMisaligned = problem_size.n() % kAlignmentB; - } else if (platform::is_same::value) { + } else if (cute::is_same::value) { isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value - || platform::is_same>::value) { + } else if (cute::is_same>::value + || cute::is_same>::value) { isBMisaligned = problem_size.k() % kAlignmentB; } - if (platform::is_same::value) { + if (cute::is_same::value) { isCMisaligned = problem_size.n() % kAlignmentC; - } else if (platform::is_same::value) { + } else if (cute::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value - || platform::is_same>::value) { + } else if (cute::is_same>::value + || cute::is_same>::value) { isCMisaligned = problem_size.n() % kAlignmentC; } diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index cdac6ca4..7bee6bbd 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -68,5 +68,6 @@ class GemmUniversal; #include "cutlass/gemm/kernel/sm70_gemm.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp" -#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index f41e8130..5ef25d78 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -198,7 +198,7 @@ struct GemmWithFusedEpilogue { lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt) { CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); CUTLASS_TRACE_HOST(" ldt: " << this->ldt); @@ -304,7 +304,7 @@ struct GemmWithFusedEpilogue { batch_stride_Tensor(args.batch_stride_Tensor) { CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); CUTLASS_TRACE_HOST(" ldt: " << args.ldt); @@ -335,7 +335,7 @@ struct GemmWithFusedEpilogue { output_op = args.epilogue; CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); } @@ -1055,7 +1055,7 @@ struct GemmWithFusedEpilogue { output_op = args.epilogue; CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Vector: " << (void *)this->ptr_Vector); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); } diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h index aee9c71c..92cc2a73 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h @@ -138,7 +138,7 @@ i = i_macro j = j_macro - Handling cases with grid dimensions that aren't multiples of each other + Handling cases with grid dimensions that aren't multiples of eachother ---------------------------------------------------------------------- Even though threadblock shapes M and N are typically multiples of one another, the grid for a given problem may not have dimensions of the same ratio as that of the threadblock. diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index efe51e23..830a77df 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -52,7 +52,7 @@ class GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, GridSwizzle_, - std::enable_if_t>> + cute::enable_if_t>> { public: // @@ -74,6 +74,7 @@ class GemmUniversal< using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; // Epilogue derived types @@ -82,8 +83,9 @@ class GemmUniversal< using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(std::is_same_v, + static_assert(cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); static constexpr int SharedStorageSize = cute::max( @@ -97,12 +99,9 @@ class GemmUniversal< struct Arguments { GemmUniversalMode mode{}; ProblemShape problem_shape{}; - ElementA const* ptr_A = nullptr; - StrideA dA{}; - ElementB const* ptr_B = nullptr; - StrideB dB{}; - EpilogueParams epilogue_params{}; - KernelHardwareInfo hw_info; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; }; // Kernel entry point API @@ -125,8 +124,8 @@ class GemmUniversal< return { args.mode, args.problem_shape, - CollectiveMainloop::to_underlying_arguments(args, workspace), - CollectiveEpilogue::to_underlying_arguments(args, workspace) + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) }; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 305654d8..768c780c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -35,6 +35,7 @@ #include "cutlass/kernel_hardware_info.hpp" #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" @@ -71,7 +72,7 @@ class GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, GridSwizzle_, - std::enable_if_t>> + cute::enable_if_t>> { public: // @@ -94,6 +95,7 @@ class GemmUniversal< using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -103,8 +105,9 @@ class GemmUniversal< using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Params; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(std::is_same_v, + static_assert(cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); static constexpr int SharedStorageSize = cute::max( @@ -118,12 +121,9 @@ class GemmUniversal< struct Arguments { GemmUniversalMode mode{}; ProblemShape problem_shape{}; - ElementA const* ptr_A = nullptr; - StrideA dA{}; - ElementB const* ptr_B = nullptr; - StrideB dB{}; - EpilogueParams epilogue_params{}; - KernelHardwareInfo hw_info; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; }; // Kernel entry point API @@ -152,16 +152,38 @@ class GemmUniversal< return { args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args, workspace), - CollectiveEpilogue::to_underlying_arguments(args, workspace) + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) }; } CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { - return args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); + return implementable; + } + static constexpr int tma_alignment_bits = 128; + static constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; + auto M = get<0>(args.problem_shape); + auto N = get<1>(args.problem_shape); + auto K = get<2>(args.problem_shape); + // Contiguous dimension for the TMA tensor should be 128b aligned + implementable = std::is_same_v, layout::RowMajor> ? + K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; + implementable = implementable && (std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); + implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || + (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && + std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; } static @@ -251,8 +273,6 @@ class GemmUniversal< TiledMma tiled_mma; Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - clear(accumulators); - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); auto k_tile_count = size<2>(gA); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index f3a4a55c..d6619feb 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -36,10 +36,11 @@ #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/reg_reconfig.h" #include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" /////////////////////////////////////////////////////////////////////////////// @@ -59,7 +60,7 @@ class GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, GridSwizzle_, - std::enable_if_t>> + cute::enable_if_t>> { public: // @@ -82,6 +83,7 @@ class GemmUniversal< using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -91,29 +93,44 @@ class GemmUniversal< using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(std::is_same_v, + static_assert(cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); - static constexpr int SharedStorageSize = cute::max( - sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage)); + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - static constexpr uint32_t NumDmaWarpGroups = 1; + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 1; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumDmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments struct Arguments { GemmUniversalMode mode{}; ProblemShape problem_shape{}; - ElementA const* ptr_A = nullptr; - StrideA dA{}; - ElementB const* ptr_B = nullptr; - StrideB dB{}; - EpilogueParams epilogue_params{}; - KernelHardwareInfo hw_info; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; }; // Kernel entry point API @@ -142,16 +159,38 @@ class GemmUniversal< return { args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args, workspace), - CollectiveEpilogue::to_underlying_arguments(args, workspace) + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) }; } CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { - return args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); + return implementable; + } + static constexpr int tma_alignment_bits = 128; + static constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; + auto M = get<0>(args.problem_shape); + auto N = get<1>(args.problem_shape); + auto K = get<2>(args.problem_shape); + // Contiguous dimension for the TMA tensor should be 128b aligned + implementable = std::is_same_v, layout::RowMajor> ? + K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; + implementable = implementable && (std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); + implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || + (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && + std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; } static @@ -196,6 +235,9 @@ class GemmUniversal< Consumer = 1, }; + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + int thread_idx = int(threadIdx.x); int warp_idx = canonical_warp_idx(); int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; @@ -205,24 +247,54 @@ class GemmUniversal< // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } - using Pipeline = typename CollectiveMainloop::MainloopPipeline; - - using PipelineParams = typename CollectiveMainloop::PipelineParams; - PipelineParams params_pipeline; - params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; if (warp_group_role == WarpGroupRole::Producer) { - params_pipeline.role = Pipeline::ThreadCategory::Producer; + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } - else { - params_pipeline.role = Pipeline::ThreadCategory::Consumer; + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; } - params_pipeline.is_leader = warp_group_thread_idx == 0; - params_pipeline.num_consumers = NumThreadsPerWarpGroup; - - // Initialize pipeline and setup starting pipeline state for the collectives - Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); auto cluster_wait_fn = [&] () { // We need this to guarantee that the Pipeline init is visible @@ -258,89 +330,99 @@ class GemmUniversal< // Get the appropriate blocks for this thread block -- potential for thread block locality auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice + TiledMma tiled_mma; - // Make tiled views - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) // Compute m_coord, n_coord, and l_coord with their post-tiled shapes auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); - auto output_tile_coord = make_coord(m_coord, n_coord, _, l_coord); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); // Slice with m_coord and n_coord Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + // Get pipeline iterators and increments from tensor shapes auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); auto k_tile_count = size<2>(gA); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); // Wait for all thread blocks in the Cluster cluster_wait_fn(); - // In a warp specialized kernel, CollectiveMainloop exposes data movement and compute operations separately + // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue}; if (warp_group_role == WarpGroupRole::Producer) { - // For the DMA (prologue) - we start with an opposite phase - since we skip all waits - // i.e., we know that the buffer is indeed empty - typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); - collective_mainloop.dma( - pipeline, - smem_pipe_write, + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, gA, params.mainloop.tma_load_a, gB, params.mainloop.tma_load_b, k_tile_iter, k_tile_count, thread_idx, - smem_buf + shared_storage.tensors.mainloop ); - // Update starting pipeline state for the next tile - smem_pipe_write.advance(k_tile_count); - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_source_needed()) { + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting load pipeline state for the pipeline drain + epi_load_pipe_producer_state.advance(c_tile_count); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } } else if (warp_group_role == WarpGroupRole::Consumer) { - typename CollectiveMainloop::PipelineState smem_pipe_read; - TiledMma tiled_mma; Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - clear(accumulators); collective_mainloop.mma( - pipeline, - smem_pipe_read, + mainloop_pipeline, + mainloop_pipe_consumer_state, accumulators, k_tile_count, thread_idx, - smem_buf, + shared_storage.tensors.mainloop, params.mainloop ); - constexpr int BLK_M_RANK = rank<0>(blk_shape); - bool m_oob = int(blockIdx.x) >= size<2>(gA_mkl); - auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); - })); - - constexpr int BLK_N_RANK = rank<1>(blk_shape); - bool n_oob = int(blockIdx.y) >= size<2>(gB_nkl); - auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); - })); - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); // Epilogue and write to gD - CollectiveEpilogue epilogue{params.epilogue}; - epilogue( + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, - output_tile_coord, + blk_coord, accumulators, tiled_mma, - residue_mnk, warp_group_thread_idx, - smem_buf + shared_storage.tensors.epilogue ); } } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp new file mode 100644 index 00000000..3b8e61e0 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,496 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GridSwizzle_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + GridSwizzle_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using GridSwizzle = GridSwizzle_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + using PersistentTileSchedulerParams = typename detail::PersistentTileSchedulerSm90::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + KernelHardwareInfo hw_info; + PersistentTileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + {args.hw_info.device_id, sm_count}, + detail::PersistentTileSchedulerSm90::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}) + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); + return implementable; + } + static constexpr int tma_alignment_bits = 128; + static constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; + auto M = get<0>(args.problem_shape); + auto N = get<1>(args.problem_shape); + auto K = get<2>(args.problem_shape); + // Contiguous dimension for the TMA tensor should be 128b aligned + implementable = std::is_same_v, layout::RowMajor> ? + K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; + implementable = implementable && (std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); + implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || + (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && + std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + // Computes the kernel launch grid shape based on runtime parameters + static constexpr + dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + // Preconditions + static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_idx = canonical_warp_idx(); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Separate out problem shape for convenience + // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + detail::PersistentTileSchedulerSm90 scheduler; + auto work_tile_info = scheduler.get_current_work(params.scheduler); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue}; + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + if (collective_epilogue.is_source_needed()) { + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting pipeline state for the next tile + epi_load_pipe_producer_state.advance(c_tile_count); + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(params.scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + if (collective_epilogue.is_source_needed()) { + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc(); + + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the the accumulators for the (M,N) blk_shape + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting load/store pipeline states for the next tile + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(params.scheduler); + } // Scheduler work fetch loop + } // Consumer Warp Groups End + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp similarity index 55% rename from include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp rename to include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 6fa93945..af619b96 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -36,11 +36,12 @@ #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/reg_reconfig.h" #include "cutlass/arch/mma_sm90.h" -#include "cutlass/pipeline.hpp" -#include "cutlass/trace.h" +#include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" #include "cute/tensor.hpp" @@ -61,7 +62,7 @@ class GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, GridSwizzle_, - std::enable_if_t>> + cute::enable_if_t>> { public: // @@ -84,7 +85,9 @@ class GemmUniversal< using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; + using PersistentTileSchedulerParams = typename detail::PersistentTileSchedulerSm90::Params; static_assert(ArchTag::kMinComputeCapability >= 90); // Epilogue derived types @@ -93,33 +96,44 @@ class GemmUniversal< using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(std::is_same_v, + static_assert(cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); - static constexpr uint32_t NumDmaWarpGroups = 1; + static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 2; static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - /// Register requirement for DMA and MATH WGs - static constexpr uint32_t DmaRegisterRequirement = 40; + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; - /* Order Sequence barrier with two stages: one for Mainloop and one for Epilogue */ + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue static constexpr uint32_t StagesPerMathWarpGroup = 2; using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< StagesPerMathWarpGroup, NumMmaWarpGroups>; // Kernel level shared memory storage struct SharedStorage { - using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; - using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; - using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; - - MainloopSharedStorage mainloop; - EpilogueSharedStorage epilogue; - alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order_barrier_storage; + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + } pipelines; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -128,12 +142,9 @@ class GemmUniversal< struct Arguments { GemmUniversalMode mode{}; ProblemShape problem_shape{}; - ElementA const* ptr_A = nullptr; - StrideA dA{}; - ElementB const* ptr_B = nullptr; - StrideB dB{}; - EpilogueParams epilogue_params{}; - KernelHardwareInfo hw_info; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; }; // Kernel entry point API @@ -143,6 +154,7 @@ class GemmUniversal< MainloopParams mainloop; EpilogueParams epilogue; KernelHardwareInfo hw_info; + PersistentTileSchedulerParams scheduler; }; // @@ -162,6 +174,7 @@ class GemmUniversal< get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); } + auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; @@ -175,25 +188,39 @@ class GemmUniversal< return { args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args, workspace), - CollectiveEpilogue::to_underlying_arguments(args, workspace), - {args.hw_info.device_id, sm_count} + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + {args.hw_info.device_id, sm_count}, + detail::PersistentTileSchedulerSm90::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}) }; } CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { - bool implementable = args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); - - // Number of blocks per problem (without batch) must not exceed 2^31 for the persistent scheduler to calculate using FastDivmod - auto problem_shape_MNKL = append<4>(args.problem_shape, Int<1>{}); - auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = - detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); - uint64_t problem_blocks = problem_blocks_m * problem_blocks_n * problem_blocks_l; - implementable = implementable && (problem_blocks < (uint64_t(1) << 31)); - + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); + return implementable; + } + static constexpr int tma_alignment_bits = 128; + static constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; + auto M = get<0>(args.problem_shape); + auto N = get<1>(args.problem_shape); + auto K = get<2>(args.problem_shape); + // Contiguous dimension for the TMA tensor should be 128b aligned + implementable = std::is_same_v, layout::RowMajor> ? + K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; + implementable = implementable && (std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); + implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || + (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && + std::is_same_v, layout::RowMajor> ? + N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } return implementable; } @@ -207,40 +234,8 @@ class GemmUniversal< static constexpr dim3 get_grid_shape(Params const& params) { - int sm_count = params.hw_info.sm_count; - CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count); - - // Compute the total number of output tiles our problem has - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = - detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(problem_shape_MNKL, TileShape{}, ClusterShape{}); - int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l; - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - dim3 launch_grid(1, cute::size<1>(ClusterShape{}), 1); - - // The else path is generic, however, we can avoid some divs if we know Cluster size is 1 - if constexpr (size(ClusterShape{}) == 1) { - launch_grid.x = std::min(sm_count, problem_blocks_total); - } - else { - /* - * Optimal grid size calculation is based on - * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU - * Hence, maximum SMs per GPC = 18 - */ - constexpr int max_sm_per_gpc = 18; - // Provided SM count could possibly be less than the assumed maximum SMs per GPC - int min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(ClusterShape{})); - int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc; - - launch_grid.x = std::min( - blk_per_device / size<1>(ClusterShape{}), - problem_blocks_total / size<1>(ClusterShape{})); - } - - return launch_grid; + return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info); } static constexpr @@ -287,30 +282,60 @@ class GemmUniversal< // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } - using Pipeline = typename CollectiveMainloop::MainloopPipeline; - using PipelineParams = typename CollectiveMainloop::PipelineParams; - PipelineParams params_pipeline; - params_pipeline.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; if (warp_group_role == WarpGroupRole::Producer) { - params_pipeline.role = Pipeline::ThreadCategory::Producer; + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } - else { - params_pipeline.role = Pipeline::ThreadCategory::Consumer; + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; } - params_pipeline.is_leader = warp_group_thread_idx == 0; - params_pipeline.num_consumers = NumThreadsPerWarpGroup; - - // Initialize pipeline and setup starting pipeline state for the collectives - Pipeline pipeline = CollectiveMainloop::make_pipeline(smem_buf, params_pipeline); - typename CollectiveMainloop::PipelineState collective_start_state_pipe; + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; - // DMA WG will not participate in these Ordered Barrier syncs + // DMA Load WG will not participate in these Ordered Barrier syncs params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group - MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.math_wg_order_barrier_storage, params_math_wg_order_barrier); + MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); auto cluster_wait_fn = [&] () { // We need this to guarantee that the Pipeline init is visible @@ -339,38 +364,40 @@ class GemmUniversal< Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - auto blk_coord = make_coord(_,_,_); // (m,n,k) -- defer the slice - // Slice to get the tiles this thread block is responsible for - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - // Get iterations along k-dimension + // Get pipeline stage increments from tensor shapes auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - detail::PersistentTileSchedulerSm90 scheduler(problem_shape_MNKL, blk_shape, ClusterShape{}); + detail::PersistentTileSchedulerSm90 scheduler; if (warp_group_role == WarpGroupRole::Consumer1) { - /* Advance 2nd Math WG to the next work tile for the startup */ + // Advance 2nd Math WG to the next work tile for the startup scheduler.advance_to_next_work(); - /* Advance 2nd Math WG pipeline state to the end of 1st Math WG */ - collective_start_state_pipe.advance(k_tile_count); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); } - auto work_tile_info = scheduler.get_current_work(); + auto work_tile_info = scheduler.get_current_work(params.scheduler); - // Perform the collective scoped MMA + // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue}; // Wait for all thread blocks in the Cluster cluster_wait_fn(); if (warp_group_role == WarpGroupRole::Producer) { - cutlass::arch::warpgroup_reg_dealloc(); + cutlass::arch::warpgroup_reg_dealloc(); - // For the DMA (prologue) - we start with an opposite phase - since we skip all waits - // i.e., we know that the buffer is indeed empty - typename CollectiveMainloop::PipelineState smem_pipe_write = cutlass::make_producer_start_state(); while (work_tile_info.is_valid_tile) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); @@ -384,27 +411,46 @@ class GemmUniversal< auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - collective_mainloop.dma( - pipeline, - smem_pipe_write, + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, gA, params.mainloop.tma_load_a, gB, params.mainloop.tma_load_b, k_tile_iter, k_tile_count, thread_idx, - reinterpret_cast(&shared_storage.mainloop) + shared_storage.tensors.mainloop ); // Update starting pipeline state for the next tile - smem_pipe_write.advance(k_tile_count); + mainloop_pipe_producer_state.advance(k_tile_count); + + if (collective_epilogue.is_source_needed()) { + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting pipeline state for the next tile + epi_load_pipe_producer_state.advance(c_tile_count); + } + + // Get next work tile scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); + work_tile_info = scheduler.get_current_work(params.scheduler); } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.dma_epilogue(pipeline, smem_pipe_write); + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + if (collective_epilogue.is_source_needed()) { + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { - // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape cutlass::arch::warpgroup_reg_alloc(); while (work_tile_info.is_valid_tile) { @@ -414,69 +460,64 @@ class GemmUniversal< auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - - TiledMma tiled_mma; + // Allocate the the accumulators for the (M,N) blk_shape Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - clear(accumulators); - /* Order two Math WG's MMA one after the other, helps hide Epilogue */ + // Order two Math WG's MMA one after the other, helps hide Epilogue math_wg_order_barrier.wait(); collective_mainloop.mma( - pipeline, - collective_start_state_pipe, + mainloop_pipeline, + mainloop_pipe_consumer_state, accumulators, k_tile_count, thread_idx, - reinterpret_cast(&shared_storage.mainloop), + shared_storage.tensors.mainloop, params.mainloop ); - /* Cue for next Math WG's MMA to start */ + // Cue for next Math WG's MMA to start math_wg_order_barrier.arrive(); - /* Order two Math WG's Epilogue one after the other */ - math_wg_order_barrier.wait(); - - constexpr int BLK_M_RANK = rank<0>(blk_shape); - bool m_oob = int(work_tile_info.M_idx) >= size<2>(gA_mkl); - auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return m_oob ? 0 : get(M) - get<0,i>(blk_shape) * get(m_coord); - })); + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - constexpr int BLK_N_RANK = rank<1>(blk_shape); - bool n_oob = int(work_tile_info.N_idx) >= size<2>(gB_nkl); - auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { - return n_oob ? 0 : get(N) - get<1,i>(blk_shape) * get(n_coord); - })); - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); // Epilogue and write to gD - CollectiveEpilogue epilogue{params.epilogue}; - epilogue( + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators, tiled_mma, - residue_mnk, warp_group_thread_idx, - reinterpret_cast(&shared_storage.epilogue) + shared_storage.tensors.epilogue ); + // Update starting load/store pipeline states for the next tile + epi_load_pipe_consumer_state.advance(c_tile_count * NumMmaWarpGroups); + epi_store_pipe_producer_state.advance(d_tile_count * NumMmaWarpGroups); - /* Cue for next Math WG's Epilogue to start */ - math_wg_order_barrier.arrive(); + // Wait for all TMA stores to complete + epi_store_pipeline.producer_tail(epi_store_pipe_producer_state); - // Update starting pipeline state for the next tile - collective_start_state_pipe.advance(k_tile_count * NumMmaWarpGroups); + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + // Get next work tile scheduler.advance_to_next_work(NumMmaWarpGroups); - work_tile_info = scheduler.get_current_work(); + work_tile_info = scheduler.get_current_work(params.scheduler); } // Scheduler work fetch loop } // Consumer Warp Groups End } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index 496d5e07..c1c47020 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -31,6 +31,7 @@ #pragma once #include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" #include "cute/layout.hpp" namespace cutlass::gemm::kernel::detail { @@ -44,13 +45,8 @@ class PersistentTileSchedulerSm90 { // private: - uint32_t blocks_per_problem_; - uint32_t current_work_linear_idx_; - uint32_t grid_blocks_total_; - - FastDivmod divmod_batch_; - FastDivmod divmod_grid_y_; - FastDivmod divmod_blk_m_; + uint64_t current_work_linear_idx_{static_cast((int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y))}; + uint64_t grid_blocks_total_{static_cast(int(gridDim.x) * int(gridDim.y))}; struct WorkTileInfo { int32_t M_idx = 0; @@ -65,9 +61,17 @@ class PersistentTileSchedulerSm90 { public: - template - CUTLASS_DEVICE - PersistentTileSchedulerSm90(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) { + struct Params { + FastDivmodU64 divmod_batch_{}; + FastDivmodU64 divmod_grid_y_{}; + FastDivmodU64 divmod_blk_m_{}; + + uint64_t blocks_per_problem_ = 0; + }; + + template + static Params + to_underlying_arguments(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) { // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic static_assert(is_static::value); static_assert(is_static::value); @@ -76,32 +80,32 @@ class PersistentTileSchedulerSm90 { auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl( problem_shape_mnkl, tile_shape, cluster_shape); - blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks_l; - current_work_linear_idx_ = (int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y); - grid_blocks_total_ = int(gridDim.x) * int(gridDim.y); - - // Pre-compute our fast div/mods for rasterization so we don't have to pay for DIVs - divmod_batch_ = FastDivmod(problem_blocks_m * problem_blocks_n); - divmod_grid_y_ = FastDivmod(size<1>(cluster_shape)); - divmod_blk_m_ = FastDivmod(problem_blocks_m); + return { + FastDivmodU64(problem_blocks_m * problem_blocks_n), + FastDivmodU64(size<1>(cluster_shape)), + FastDivmodU64(problem_blocks_m), + problem_blocks_m * problem_blocks_n * problem_blocks_l + }; } + PersistentTileSchedulerSm90() = default; + CUTLASS_DEVICE WorkTileInfo - get_current_work() const { + get_current_work(Params const& scheduler_params) const { // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices - int work_idx_l, remainder; - divmod_batch_(work_idx_l, remainder, current_work_linear_idx_); + uint64_t work_idx_l, remainder; + scheduler_params.divmod_batch_(work_idx_l, remainder, current_work_linear_idx_); - int blk_per_grid_dim, dontcare; - divmod_grid_y_(blk_per_grid_dim, dontcare, remainder); + uint64_t blk_per_grid_dim, dontcare; + scheduler_params.divmod_grid_y_(blk_per_grid_dim, dontcare, remainder); - int block_idx_m, block_idx_n; - divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim); - int work_idx_m = block_idx_m; - int work_idx_n = (block_idx_n * gridDim.y) + blockIdx.y; + uint64_t block_idx_m, block_idx_n; + scheduler_params.divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim); + int32_t work_idx_m = static_cast(block_idx_m); + int32_t work_idx_n = static_cast((block_idx_n * gridDim.y) + blockIdx.y); - return {work_idx_m, work_idx_n, work_idx_l, current_work_linear_idx_ < blocks_per_problem_}; + return {work_idx_m, work_idx_n, static_cast(work_idx_l), current_work_linear_idx_ < scheduler_params.blocks_per_problem_}; } CUTLASS_DEVICE @@ -128,6 +132,45 @@ class PersistentTileSchedulerSm90 { int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl)); return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)}; } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE constexpr static + dim3 + get_grid_shape(ProblemShapeMNKL problem_shape_mnk, BlockShape blk_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info) { + int const sm_count = hw_info.sm_count; + CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count); + // Compute the total number of output tiles our problem has + auto problem_shape_MNKL = append<4>(problem_shape_mnk, Int<1>{}); + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = + get_tiled_blk_shape_mnl(problem_shape_MNKL, blk_shape, cluster_shape); + int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l; + + dim3 launch_grid(1, cute::size<1>(cluster_shape), 1); + + // The else path is generic, however, we can avoid some divs if we know Cluster size is 1 + if constexpr (size(cluster_shape) == 1) { + launch_grid.x = std::min(sm_count, problem_blocks_total); + } + else { + /* + * Optimal grid size calculation is based on + * GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU + * Hence, maximum SMs per GPC = 18 + */ + constexpr int max_sm_per_gpc = 18; + // Provided SM count could possibly be less than the assumed maximum SMs per GPC + int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; + int const max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(cluster_shape)); + int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc; + blk_per_device = sm_count < blk_per_device ? sm_count : blk_per_device; + + launch_grid.x = std::min( + blk_per_device / size<1>(cluster_shape), + problem_blocks_total / size<1>(cluster_shape)); + } + return launch_grid; + } }; } // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 7e0b206c..69e3f0d2 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -40,6 +40,7 @@ #include "cutlass/arch/wmma.h" #include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" @@ -100,7 +101,11 @@ template < /// Gather operand A by using an index array bool GatherA = false, /// Gather operand B by using an index array - bool GatherB = false + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute > struct DefaultMma; @@ -137,13 +142,17 @@ template < /// Gather operand A by using an index array bool GatherA, /// Gather operand B by using an index array - bool GatherB + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultMma { + GatherA, GatherB, PermuteALayout, PermuteBLayout> { static_assert(platform::is_same::value || platform::is_same>::value, @@ -159,13 +168,15 @@ struct DefaultMma, - ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, + GatherA, PermuteALayout>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< cutlass::MatrixShape, - ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, + GatherB, PermuteBLayout>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< @@ -207,13 +218,17 @@ template < /// Gather operand A by using an index array bool GatherA, /// Gather operand B by using an index array - bool GatherB + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultMma { + GatherA, GatherB, PermuteALayout, PermuteBLayout> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, @@ -225,14 +240,14 @@ struct DefaultMma, ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, - GatherA>; + GatherA, PermuteALayout>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< cutlass::MatrixShape, ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, - GatherB>; + GatherB, PermuteBLayout>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< @@ -265,13 +280,17 @@ template < /// Gather operand A by using an index array bool GatherA, /// Gather operand B by using an index array - bool GatherB + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultMma { + GatherA, GatherB, PermuteALayout, PermuteBLayout> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, @@ -282,13 +301,15 @@ struct DefaultMma, - float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA, + GatherA, PermuteALayout>; // Define iterators over tiles from the B operand using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< cutlass::MatrixShape, - float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB, + GatherB, PermuteBLayout>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined< @@ -333,7 +354,8 @@ struct DefaultMma, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, - Operator, true, SharedMemoryClearOption::kNone, false, false> { + Operator, true, SharedMemoryClearOption::kNone, false, false, + layout::NoPermute, layout::NoPermute> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, @@ -400,13 +422,17 @@ template < /// Gather operand A by using an index array bool GatherA, /// Gather operand B by using an index array - bool GatherB + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultMma { + GatherA, GatherB, PermuteALayout, PermuteBLayout> { static_assert(platform::is_same::value || platform::is_same>::value, @@ -424,7 +450,7 @@ struct DefaultMma, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA>; + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; @@ -432,7 +458,7 @@ struct DefaultMma, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB>; + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< @@ -479,13 +505,17 @@ template < /// Gather operand A by using an index array bool GatherA, /// Gather operand B by using an index array - bool GatherB + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout > struct DefaultMma { + GatherA, GatherB, PermuteALayout, PermuteBLayout> { static_assert(platform::is_same::value || platform::is_same>::value, @@ -513,7 +543,7 @@ struct DefaultMma, - ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA>; + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; // Define iterators over tiles from the B operand using ThreadMapB = typename MmaCore::IteratorThreadMapB; @@ -521,7 +551,7 @@ struct DefaultMma, - ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB>; + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; // Define the threadblock-scoped multistage matrix multiply using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< @@ -569,7 +599,8 @@ struct DefaultMma, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, - Stages, Operator, true, SharedMemoryClearOption::kNone, false, false> { + Stages, Operator, true, SharedMemoryClearOption::kNone, + false, false, layout::NoPermute, layout::NoPermute> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, @@ -626,7 +657,8 @@ template < struct DefaultMma, 2, - Operator, false, SharedMemoryClearOption::kNone, false, false> { + Operator, false, SharedMemoryClearOption::kNone, + false, false, layout::NoPermute, layout::NoPermute> { using InstructionShape = GemmShape<1, 1, 4>; using ElementA = int8_t; using ElementB = int8_t; @@ -695,7 +727,7 @@ struct DefaultMma { + false, false, layout::NoPermute, layout::NoPermute> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, @@ -755,7 +787,7 @@ struct DefaultMma { + false, false, layout::NoPermute, layout::NoPermute> { // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, diff --git a/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h b/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h index 0345084f..91fa4495 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h @@ -186,7 +186,7 @@ template < int kAccessSizeInBits_, /// Operation performed by GEMM typename Operator_> -struct DefaultMmaCoreWithAccessSize>::type, ElementA_, +struct DefaultMmaCoreWithAccessSize>::type, ElementA_, layout::ColumnMajor, ElementB_, layout::RowMajor, ElementC_, LayoutC_, arch::OpClassSimt, kAccessSizeInBits_, 2, Operator_ > { diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index 5f6f8521..34d97487 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -157,10 +157,7 @@ class MmaMultistage : // accuracy, where each mainloop iteration first accumulates into a temporary // set of freshly-cleared accumulators, which are subsequently added to the // final accumulator set. - static bool const kStagedAccumulation = - platform::is_same::value || - platform::is_same::value; - + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; }; private: diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index 3f8d7a8d..a1d522bc 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -50,7 +50,6 @@ #endif - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -632,7 +631,7 @@ struct ThreadblockSwizzleStreamK { // Guards needed for PyCUTLASS library generation -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) // // Device-side interface @@ -795,7 +794,7 @@ struct ThreadblockSwizzleStreamK { return get_sk_block_idx(iter); } -#endif // defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#endif // !defined(CUTLASS_PYTHON_HOST_CC) }; diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op.h b/include/cutlass/gemm/warp/default_mma_tensor_op.h index 3421de95..3bb65a43 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op.h @@ -118,6 +118,6 @@ struct DefaultMmaTensorOp { ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "default_mma_tensor_op_sm80.h" +#include "cutlass/gemm/warp/default_mma_tensor_op_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h index 7bcf7fe0..58d0c01a 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -819,8 +819,13 @@ class MmaComplexTensorOp< // Define conversions from source type to instruction operands' type // + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + FloatRoundStyle const kRoundA = FloatRoundStyle::round_to_nearest; + FloatRoundStyle const kRoundB = FloatRoundStyle::round_to_nearest; + #else FloatRoundStyle const kRoundA = FloatRoundStyle::round_half_ulp_trunc_dntz; FloatRoundStyle const kRoundB = FloatRoundStyle::round_half_ulp_trunc_dntz; + #endif detail::UnpackComplexConvertAndPackForMma < RealElementA, diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index 693425b6..73c81707 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -35,7 +35,7 @@ data to describe strides between elements. Permute layout functions must implement all members in the interface of NoPermute<> defined in this file. Address offset - computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_permute_} as new addresses after permute op. + computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_} as new addresses after permute op. */ #pragma once #if defined(__CUDACC_RTC__) @@ -53,25 +53,89 @@ namespace cutlass { namespace layout { -class NoPermute { +// template +// struct PermuteSelect { +// // Try to give a reasonable error message to the user +// static_assert(!platform::is_same::value, // aka always_false +// "You've tried to use a layout permutation for which the implementation is not availble. " +// "In order to provide an implementation for a particular combination of matrix layout " +// "and direction (direct/inverse), please specialize PermuteSelect trait."); +// }; + +// Base template for defining specializations of permutation inverses +template +struct InversePermute +{ + // Try to give a reasonable error message to the user + static_assert(!platform::is_same::value, // aka always_false + "To apply permutation to a GEMM input operand (A or B), an inverse permutation for the desired " + "permute class must be defined and enabled by specializing cutlass::layout::InversePermute trait."); +}; + +class PermuteBase { public: /// Index type used for coordinates using Index = int32_t; /// Long index type used for offsets using LongIndex = int64_t; +}; -private: +class NoPermute : public PermuteBase { +public: // - // Data members + // Methods // - MatrixCoord extent_; + /// Constructor from matrix extent + CUTLASS_HOST_DEVICE + NoPermute(MatrixCoord extent, Index stride) { }; + + /// Constructor from pitch-linear extent + CUTLASS_HOST_DEVICE + NoPermute(PitchLinearCoord extent, Index stride) { }; + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { return 0; } // not correct but should never be called + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { return 0; } // not correct but should never be called +}; + +template<> +struct InversePermute { + using type = NoPermute; +}; + +/// Helper trait to detect if permute operation is a noop +template +bool constexpr is_trivial_permute = platform::is_same::value; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Defines permute layouts of various tensor formats. +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Tensor4DPermute0213 +///////////////////////////////////////////////////////////////////////////////////////////////// - Index stride_unit_; // sizeof(AccessType) / kElementsPerAccess in epilogue's predicated_tile_iterator +/// Permute layout function for 4-D permuted tensors with matrix (dimensions [M, N]) reshaped +/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. +template +class Tensor4DPermute0213RowMajor : public PermuteBase { +private: + // + // Data members + // - Index stride_permute_; + Index D3_; + Index stride_; + public: // // Methods @@ -79,42 +143,73 @@ class NoPermute { /// Constructor CUTLASS_HOST_DEVICE - NoPermute() { } + Tensor4DPermute0213RowMajor(MatrixCoord extent, Index stride) { + + assert(extent.row() % D1 == 0); + assert(extent.column() % D2 == 0); + + D3_ = extent.column() / D2; + + stride_ = stride * D1 / D2; + } /// Constructor CUTLASS_HOST_DEVICE - NoPermute(MatrixCoord extent, Index stride_init): extent_(extent) { } + Tensor4DPermute0213RowMajor(PitchLinearCoord extent, Index stride) + : Tensor4DPermute0213RowMajor(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + + // [i,j,k,l] -> [i,k,j,l] + Index l = coord.column() % D3_; + Index k = coord.column() / D3_; + Index j = coord.row() % D1; + Index i = coord.row() / D1; - /// Computes the address offset after Permute Op in Bytes + MatrixCoord permuted{k + i * D2, l + j * D3_}; + + return LongIndex(permuted.row()) * LongIndex(stride_) + LongIndex(permuted.column()); + } + + /// Computes the offset after Permute Op in logical elements CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord offset_init) { return 0; } + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.strided(), coord.contiguous())); + } }; -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Defines permute layouts of various tensor formats. -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Permute layout function for 4-D permuted tensors with output matrix (dimension as [M, N]) reshaped -/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding output tensor. +// Inverse for Tensor4DPermute0213 can be implemented by simply swapping D1 and D2 template -class Tensor4DPermute0213 { +class Tensor4DPermute0213RowMajorInverse : public Tensor4DPermute0213RowMajor { public: - /// Index type used for coordinates - using Index = int32_t; + using Base = Tensor4DPermute0213RowMajor; + using Base::Base; +}; - /// Long index type used for offsets - using LongIndex = int64_t; +template +struct InversePermute> { + using type = Tensor4DPermute0213RowMajorInverse; +}; + +template +struct InversePermute> { + using type = Tensor4DPermute0213RowMajor; +}; +/// Permute layout function for 4-D permuted tensors with matrix (dimensions [M, N]) reshaped +/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. +template +class Tensor4DPermute0213ColumnMajor : public PermuteBase { private: // // Data members // - MatrixCoord extent_; + Index D0_; - Index stride_permute_; + Index stride_; public: // @@ -123,74 +218,216 @@ class Tensor4DPermute0213 { /// Constructor CUTLASS_HOST_DEVICE - Tensor4DPermute0213() { } + Tensor4DPermute0213ColumnMajor(MatrixCoord extent, Index stride) { + + assert(extent.row() % D1 == 0); + assert(extent.column() % D2 == 0); + + D0_ = extent.row() / D1; + + stride_ = stride * D2 / D1; + } /// Constructor CUTLASS_HOST_DEVICE - Tensor4DPermute0213(MatrixCoord extent, Index stride_init): extent_(extent) { + Tensor4DPermute0213ColumnMajor(PitchLinearCoord extent, Index stride) + : Tensor4DPermute0213ColumnMajor(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { - /// Update stride_permute with stride_init - stride_permute_ = stride_init / D2 * D1; // stride in Elements + // [i,j,k,l] -> [i,k,j,l] + Index l = coord.column() / D2; + Index k = coord.column() % D2; + Index j = coord.row() / D0_; + Index i = coord.row() % D0_; + MatrixCoord permuted{i + k * D0_, j + l * D1}; + + return LongIndex(permuted.row()) + LongIndex(permuted.column()) * LongIndex(stride_); } - - /// Computes the address offset after Permute Op in Bytes + + /// Computes the offset after Permute Op in logical elements CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord offset_init) { - // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X - // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. - assert(extent_.row() % D1 == 0); - assert(extent_.column() % D2 == 0); + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.contiguous(), coord.strided())); + } +}; + +// Inverse for Tensor4DPermute0213 can be implemented by simply swapping D1 and D2 +template +class Tensor4DPermute0213ColumnMajorInverse : public Tensor4DPermute0213ColumnMajor { +public: + using Base = Tensor4DPermute0213ColumnMajor; + using Base::Base; +}; + +template +struct InversePermute> { + using type = Tensor4DPermute0213ColumnMajorInverse; +}; + +template +struct InversePermute> { + using type = Tensor4DPermute0213ColumnMajor; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Tensor4DPermuteBMM0213 +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimensions [B, M, N]) reshaped +/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor. +template +class Tensor4DPermuteBMM0213RowMajor : public PermuteBase { +private: + // + // Data members + // - int D3 = extent_.column() / D2; + Index D3_; - Index col_init = offset_init.column(); - Index row_init = offset_init.row(); + Index stride_; + + Index batch_stride_; + +public: + // + // Methods + // - int l = col_init % D3; - int k = col_init / D3; - int j = row_init % D1; - int i = row_init / D1; + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermuteBMM0213RowMajor(MatrixCoord extent, Index stride) { - // After the Permute Op - Index col_permute = l + j * D3; - Index row_permute = k + i * D2; + Index D2 = extent.row(); + D3_ = extent.column(); - return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); + stride_ = stride * D1; + batch_stride_ = D2 * stride_; } - /// Return D1 + /// Constructor CUTLASS_HOST_DEVICE - Index d1() const { - return D1; + Tensor4DPermuteBMM0213RowMajor(PitchLinearCoord extent, Index stride) + : Tensor4DPermuteBMM0213RowMajor(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + + // The batch index for BMM + Index BMM_batch_idx = blockIdx.z; + + // [i,j,k,l] -> [i,k,j,l] + Index l = coord.column(); + Index k = coord.row(); + Index j = BMM_batch_idx % D1; + Index i = BMM_batch_idx / D1; + + Index pbatch = i; + MatrixCoord pcoord{k, l + j * D3_}; + + return pbatch * LongIndex(batch_stride_) + pcoord.row() * LongIndex(stride_) + pcoord.column(); } - /// Return D2 + /// Computes the offset after Permute Op in logical elements CUTLASS_HOST_DEVICE - Index d2() const { - return D2; + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.strided(), coord.contiguous())); } }; -/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped -/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. template -class Tensor4DPermuteBMM0213 { +class Tensor4DPermuteBMM0213RowMajorInverse : public PermuteBase { +private: + // + // Data members + // + + Index D3_; + + Index stride_; + + Index batch_stride_; + public: - /// Index type used for coordinates - using Index = int32_t; + // + // Methods + // - /// Long index type used for offsets - using LongIndex = int64_t; + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermuteBMM0213RowMajorInverse(MatrixCoord extent, Index stride) { + + assert(extent.column() % D1 == 0); + + Index D2 = extent.row(); + D3_ = extent.column() / D1; + + stride_ = stride / D1; + + batch_stride_ = D2 * stride_; + } + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermuteBMM0213RowMajorInverse(PitchLinearCoord extent, Index stride) + : Tensor4DPermuteBMM0213RowMajorInverse(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + + // The batch index for BMM + Index BMM_batch_idx = blockIdx.z; + + // TODO: semantics of the original Tensor4DPermuteBMM0213 are unclear. + // The following assumes grouping [(D0)->batch, (D2)->row, (D1,D3)->col] + Index l = coord.column() % D3_; + Index j = coord.column() / D3_; + Index k = coord.row(); + Index i = BMM_batch_idx; + + // compute original [batch, row, col] index + Index pbatch = j + i * D1; + MatrixCoord pcoord{k, l}; + + return pbatch * LongIndex(batch_stride_) + pcoord.row() * LongIndex(stride_) + pcoord.column(); + } + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.strided(), coord.contiguous())); + } +}; + +template +struct InversePermute> { + using type = Tensor4DPermuteBMM0213RowMajorInverse; +}; + +template +struct InversePermute> { + using type = Tensor4DPermuteBMM0213RowMajor; +}; + +/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimensions [B, M, N]) reshaped +/// as [B/D1, D1, M, N]. Then perform permute([0, 3, 2, 1]) on the corresponding whole BMM tensor. +template +class Tensor4DPermuteBMM0321ColumnMajor : public PermuteBase { private: // // Data members // - MatrixCoord extent_; + Index D2_; + + Index stride_; - Index stride_permute_; + Index batch_stride_; public: // @@ -199,70 +436,199 @@ class Tensor4DPermuteBMM0213 { /// Constructor CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0213() { } + Tensor4DPermuteBMM0321ColumnMajor(MatrixCoord extent, Index stride) { + + D2_ = extent.row(); + Index D3 = extent.column(); + + stride_ = stride * D1; + batch_stride_ = stride_ * D3; + } /// Constructor CUTLASS_HOST_DEVICE - Tensor4DPermuteBMM0213(MatrixCoord extent, Index stride_init): extent_(extent) { + Tensor4DPermuteBMM0321ColumnMajor(PitchLinearCoord extent, Index stride) + : Tensor4DPermuteBMM0321ColumnMajor(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + + Index BMM_batch_idx = blockIdx.z; + + // [i,j,k,l] -> [i,k,j,l] + Index l = coord.column(); + Index k = coord.row(); + Index j = BMM_batch_idx % D1; + Index i = BMM_batch_idx / D1; - /// Update stride_permute with stride_init - stride_permute_ = stride_init * D1; // stride in Elements + Index pbatch = i; + MatrixCoord pcoord{k + j * D2_, l}; + return pbatch * LongIndex(batch_stride_) + pcoord.row() + pcoord.column() * LongIndex(stride_); } + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.contiguous(), coord.strided())); + } +}; + +template +class Tensor4DPermuteBMM0321ColumnMajorInverse : public PermuteBase { +private: + // + // Data members + // + + Index D2_; + + Index stride_; + + Index batch_stride_; - /// Computes the address offset after Permute Op in Bytes +public: + // + // Methods + // + + /// Constructor CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord offset_init) { + Tensor4DPermuteBMM0321ColumnMajorInverse(MatrixCoord extent, Index stride) { - // The batch index for BMM - Index BMM_batch_idx = blockIdx.z; - - // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X - // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. - int D2 = extent_.row(); - int D3 = extent_.column(); + assert(extent.row() % D1 == 0); + + D2_ = extent.row() / D1; + Index D3 = extent.column(); + + stride_ = stride / D1; + batch_stride_ = stride_ * D3; + } - Index col_init = offset_init.column(); - Index row_init = offset_init.row(); + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermuteBMM0321ColumnMajorInverse(PitchLinearCoord extent, Index stride) + : Tensor4DPermuteBMM0321ColumnMajorInverse(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { - int l = col_init; - int k = row_init; - int j = BMM_batch_idx % D1; - int i = BMM_batch_idx / D1; + Index BMM_batch_idx = blockIdx.z; + + // The following assumes grouping [(D0)->batch, (D1,D2)->row, (D3)->col] + Index l = coord.column(); + Index k = coord.row() % D2_; + Index j = coord.row() / D2_; + Index i = BMM_batch_idx; - // After the Permute Op - Index col_permute = l + j * D3; - Index row_permute = k + i * D2; + Index pbatch = i * D1 + j; + MatrixCoord pcoord{k, l}; - return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); + return pbatch * LongIndex(batch_stride_) + pcoord.row() + pcoord.column() * LongIndex(stride_); } - /// Return D1 + /// Computes the offset after Permute Op in logical elements CUTLASS_HOST_DEVICE - Index d1() const { - return D1; + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.contiguous(), coord.strided())); } }; +template +struct InversePermute> { + using type = Tensor4DPermuteBMM0321ColumnMajorInverse; +}; + +template +struct InversePermute> { + using type = Tensor4DPermuteBMM0321ColumnMajor; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Tensor5DPermute20314 +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped /// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. template -class Tensor5DPermute20314 { +class Tensor5DPermute20314RowMajor : public PermuteBase { +private: + // + // Data members + // + + Index T0_; + + Index T4_; + + Index stride_; + public: - /// Index type used for coordinates - using Index = int32_t; + // + // Methods + // - /// Long index type used for offsets - using LongIndex = int64_t; + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute20314RowMajor(MatrixCoord extent, Index stride) { + + assert(extent.row() % T1 == 0); + assert(extent.column() % (T2 * T3) == 0); + + T0_ = extent.row() / T1; + T4_ = extent.column() / (T2 * T3); + + /// Update stride_permute with stride + stride_ = stride / T2 * T1; // stride in Elements + } + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute20314RowMajor(PitchLinearCoord extent, Index stride) + : Tensor5DPermute20314RowMajor(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} + + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + + // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X + // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T2, T0, T3, T1, T4]. + + Index m = coord.column() % T4_; + Index l = (coord.column() / T4_) % T3; + Index k = (coord.column() / T4_) / T3; + Index j = coord.row() % T1; + Index i = coord.row() / T1; + + MatrixCoord permuted{i + k * T0_, m + j * T4_ + l * T1 * T4_}; + + return LongIndex(permuted.row()) * LongIndex(stride_) + LongIndex(permuted.column()); + } + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.strided(), coord.contiguous())); + } +}; + +/// Inverse for Tensor5DPermute20314 (could also be given a proper name, e.g. Tensor5DPermute13024). +template +class Tensor5DPermute20314RowMajorInverse : public PermuteBase { private: // // Data members // - MatrixCoord extent_; + Index T0_; - Index stride_permute_; + Index T4_; + + // Permuted stride in units of elements + Index stride_; public: // @@ -271,41 +637,190 @@ class Tensor5DPermute20314 { /// Constructor CUTLASS_HOST_DEVICE - Tensor5DPermute20314() { } + Tensor5DPermute20314RowMajorInverse(MatrixCoord extent, Index stride) { + + assert(extent.row() % T2 == 0); + assert(extent.column() % (T1 * T3) == 0); + + T0_ = extent.row() / T2; + T4_ = extent.column() / (T1 * T3); + + stride_ = stride / T1 * T2; + } /// Constructor CUTLASS_HOST_DEVICE - Tensor5DPermute20314(MatrixCoord extent, Index stride_init): extent_(extent) { + Tensor5DPermute20314RowMajorInverse(PitchLinearCoord extent, Index stride) + : Tensor5DPermute20314RowMajorInverse(MatrixCoord(extent.strided(), extent.contiguous()), stride) {} - /// Update stride_permute with stride_init - stride_permute_ = stride_init / T2 * T1; // stride in Elements + /// Computes the offset after the inverse of permute operation in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + Index m = coord.column() % T4_; + Index j = (coord.column() / T4_) % T1; + Index l = (coord.column() / T4_) / T1; + Index i = coord.row() % T0_; + Index k = coord.row() / T0_; + + MatrixCoord permuted{j + i * T1, m + l * T4_ + k * T3 * T4_}; + + return LongIndex(permuted.row()) * LongIndex(stride_) + LongIndex(permuted.column()); + } + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.strided(), coord.contiguous())); } +}; + +template +struct InversePermute> { + using type = Tensor5DPermute20314RowMajorInverse; +}; + +template +struct InversePermute> { + using type = Tensor5DPermute20314RowMajor; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Tensor5DPermute02413 +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permute layout function for 5-D permuted tensors with matrix (dimensions [M, N]) reshaped +/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([0, 2, 4, 1, 3]) on the corresponding tensor. +template +class Tensor5DPermute02413ColumnMajor : public PermuteBase { +private: + // + // Data members + // + + Index T0_; + + Index T4_; + + Index stride_; - /// Computes the address offset after Permute Op in Bytes +public: + // + // Methods + // + + /// Constructor CUTLASS_HOST_DEVICE - LongIndex operator()(MatrixCoord offset_init) { + Tensor5DPermute02413ColumnMajor(MatrixCoord extent, Index stride) { + + assert(extent.row() % T1 == 0); + assert(extent.column() % (T2 * T3) == 0); + + T0_ = extent.row() / T1; + T4_ = extent.column() / (T2 * T3); + + /// Update stride_permute with stride + stride_ = stride / T1 * T2; // stride in Elements + } + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute02413ColumnMajor(PitchLinearCoord extent, Index stride) + : Tensor5DPermute02413ColumnMajor(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X - // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T2, T0, T3, T1, T4]. - int T0 = extent_.row() / T1; - int T4 = extent_.column() / T2 / T3; + // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T0, T2, T4, T1, T3]. + + Index m = (coord.column() / T2) / T3; + Index l = (coord.column() / T2) % T3; + Index k = coord.column() % T2; + Index j = coord.row() / T0_; + Index i = coord.row() % T0_; - Index col_init = offset_init.column(); - Index row_init = offset_init.row(); + MatrixCoord permuted{i + k * T0_, m + j * T4_ + l * T4_ * T1}; - int m = col_init % T4; - int l = int(col_init / T4) % T3; - int k = int(col_init / T4) / T3; - int j = row_init % T1; - int i = row_init / T1; + return LongIndex(permuted.row()) + LongIndex(permuted.column()) * LongIndex(stride_); + } + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.contiguous(), coord.strided())); + } +}; + +/// Inverse for Tensor5DPermute02413ColumnMajor +template +class Tensor5DPermute02413ColumnMajorInverse : public PermuteBase { +private: + // + // Data members + // + + Index T0_; + + Index T4_; - // After the Permute Op - Index col_permute = m + j * T4 + l * T1 * T4; - Index row_permute = i + k * T0; + // Permuted stride in units of elements + Index stride_; + +public: + // + // Methods + // - return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute02413ColumnMajorInverse(MatrixCoord extent, Index stride) { + + assert(extent.row() % T2 == 0); + assert(extent.column() % (T1 * T3) == 0); + + T0_ = extent.row() / T2; + T4_ = extent.column() / (T1 * T3); + + stride_ = stride / T2 * T1; } + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute02413ColumnMajorInverse(PitchLinearCoord extent, Index stride) + : Tensor5DPermute02413ColumnMajorInverse(MatrixCoord(extent.contiguous(), extent.strided()), stride) {} + + /// Computes the offset after the inverse of permute operation in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord coord) const { + + Index m = coord.column() % T4_; + Index j = (coord.column() / T4_) % T1; + Index l = (coord.column() / T4_) / T1; + Index i = coord.row() % T0_; + Index k = coord.row() / T0_; + + MatrixCoord permuted{i + j * T0_, k + l * T2 + m * T2 * T3}; + + return LongIndex(permuted.row()) + LongIndex(permuted.column()) * LongIndex(stride_); + } + + /// Computes the offset after Permute Op in logical elements + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return operator()(MatrixCoord(coord.contiguous(), coord.strided())); + } +}; + +template +struct InversePermute> { + using type = Tensor5DPermute02413ColumnMajorInverse; +}; + +template +struct InversePermute> { + using type = Tensor5DPermute02413ColumnMajor; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 3095cec8..68c259cc 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -545,6 +545,9 @@ struct NumericConverter { unsigned storage = reinterpret_cast(s); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("cvt.rn.tf32.f32 %0, %1;" : "=r"(storage) : "r"(storage)); +#else if ((storage & 0x7f800000) != 0x7f800000) { bool mantissa_bit = ((storage & (1 << 13)) != 0); @@ -570,6 +573,7 @@ struct NumericConverter { else if (storage & ~0xff800000) { storage = 0x7fffffff; } +#endif return tfloat32_t::bitcast(storage); } @@ -716,6 +720,24 @@ struct NumericConverterClamp { } }; +// This converter is needed to enable half_t output types when using int32_t accumulators. +// Since floating-point types do not require a clamp, this converter simply casts from +// the source type to half_t. +template < + typename S +> +struct NumericConverterClamp { + + using result_type = cutlass::half_t; + using source_type = S; + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return static_cast(s); + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // // Conversion operator for Array @@ -1989,7 +2011,7 @@ struct NumericArrayConverter { ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Partial specialziations for: +// Partial specializations for: // Array <=> Array // Array <=> Array // using packed converter under the hood @@ -2414,11 +2436,13 @@ struct PreferredRoundingMode { static FloatRoundStyle const kRound = FloatRoundStyle::round_to_nearest; }; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 900 /// Defines preferred rounding mode for a pair of types template <> struct PreferredRoundingMode { static FloatRoundStyle const kRound = FloatRoundStyle::round_half_ulp_truncate; }; +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline.hpp b/include/cutlass/pipeline.hpp deleted file mode 100644 index 67538aea..00000000 --- a/include/cutlass/pipeline.hpp +++ /dev/null @@ -1,529 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are not permit- - * ted. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/numeric/integral_constant.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/barrier.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -using namespace arch; -using namespace cute; - -// Circular Buffer Index + Associated Phase -// Assumes only one operation possible - i.e., ++ -template -struct PipelineState { - - static constexpr uint32_t Stages = Stages_; - -private: - int index_ = 0; - uint32_t phase_ = 0; - -public: - CUTLASS_DEVICE - PipelineState(): index_{}, phase_{} {} - - CUTLASS_DEVICE - PipelineState(int index, uint32_t phase) - : index_(index) - , phase_(phase){} - - CUTLASS_DEVICE - int index() const { - return index_; - } - - CUTLASS_DEVICE - uint32_t phase() const { - return phase_; - } - - CUTLASS_DEVICE - void operator++() { - ++index_; - if (index_ == Stages) { - index_ = 0; - phase_ ^= 1; - } - } - - CUTLASS_DEVICE - PipelineState& operator=(const PipelineState& other) { - index_ = other.index(); - phase_ = other.phase(); - return *this; - } - - CUTLASS_DEVICE - PipelineState advance(uint32_t num_iterations) { - // Number of iterations cross over the stage boundary => flipped phase - if ((num_iterations < Stages) && (index_ + num_iterations) >= Stages ) { - phase_ ^= 1; - } - // How many times number of iterations cross over the stage boundary and - // end up on a odd number => flipped phase - if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { - phase_ ^= 1; - } - index_ = (index_ + num_iterations) % Stages; - return *this; - } - - CUTLASS_DEVICE - static PipelineState make_pipeline_state(PipelineState start_state, uint32_t num_iterations) { - return start_state.advance(num_iterations); - } -}; - -template -CUTLASS_DEVICE -PipelineState make_producer_start_state() -{ - // Producer starts with an opposite phase as the buffer are initially empty - constexpr int InitialProducerStage = 0; - constexpr uint32_t InitialProducerPhase = 1; - return {InitialProducerStage, InitialProducerPhase}; -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// TMA (producer) Async Pipeline class -// -/////////////////////////////////////////////////////////////////////////////////////////////////// -// Assumptions : Constructor is Visible Cluster-wide (as it needs a Cluster-Sync) -// We have exactly one thread elected in the Producer as the "leader" -// Currently, it is optional to elect a leader for the Consumers -template -class PipelineTmaAsync { -public : - using ClusterShape = ClusterShape_; - using FullBarrier = ClusterTransactionBarrier; - using EmptyBarrier = ClusterBarrier; - using ValueType = FullBarrier::ValueType; - static constexpr uint32_t Stages = Stages_; - - struct SharedStorage { - FullBarrier full_barrier_[Stages]; - EmptyBarrier empty_barrier_[Stages]; - }; - - enum class ThreadCategory { - NonParticipant, - Producer, - Consumer, - ProducerConsumer - }; - - struct Params { - uint32_t transaction_bytes = 0; - ThreadCategory role = ThreadCategory::NonParticipant; - uint32_t is_leader = 0; - uint32_t num_consumers = 0; - }; - -private : - // - // Data Members - // - uint32_t dst_blockid_ = 0; - uint32_t is_signalling_thread_ = 0; - FullBarrier *full_barrier_ptr_ = nullptr; - EmptyBarrier *empty_barrier_ptr_ = nullptr; - Params params_; - - // - // Methods - // - -public: - // Constructor - CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params) - : params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - int warp_idx = canonical_warp_idx(); - int lane_predicate = cute::elect_one_sync(); - auto cluster_shape = ClusterShape{}; - - if (warp_idx == 0 && lane_predicate == 1) { - // Barrier FULL init - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(1); - } - - // Barrier EMPTY init - uint32_t const num_consumers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; - for (int i = 0; i < Stages; ++i) { - empty_barrier_ptr_[i].init(num_consumers); - } - } - - // Logic to optimally schedule Empty Arrives - // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) - dim3 block_id = block_id_in_cluster(); - auto cluster_size = cute::size(cluster_shape); - static constexpr int MaxClusterSize = 16; - static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); - - // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) - if (params_.num_consumers == 128) { - int thread_idx = threadIdx.x % 128; - is_signalling_thread_ = (thread_idx % (128 / MaxClusterSize)) == 0; - auto layout = cute::composition(Swizzle<2,0,-2>{}, - Layout,Stride<_4, _1>>{}); - uint32_t thread_row = warp_idx % 4; - uint32_t thread_col = (thread_idx / 8) % 4; - dst_blockid_ = layout(thread_row, thread_col); - } - else if (params_.num_consumers == 32){ - int thread_idx = threadIdx.x % 32; - is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; - auto layout = Layout,Stride<_4, _1>>{}; - uint32_t thread_row = thread_idx / 8; - uint32_t thread_col = (thread_idx % 8) / 2; - dst_blockid_ = layout(thread_row, thread_col); - } - else { - is_signalling_thread_ = 0; - } - - // STEP 2: Find if this dst block-id needs an arrival for this problem - is_signalling_thread_ &= dst_blockid_ < cluster_size; - is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); - - cutlass::arch::fence_barrier_init(); - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { - // 1. Wait for empty barrier to be ready - // 2. Set the transaction bytes set to occur on the Full barrier - uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); - if ((!done) && (!skip_wait)){ - empty_barrier_ptr_[stage].wait(phase); - } - - if (params_.is_leader) { - full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes); - } - - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state) { - producer_acquire(state.index(), state.phase()); - } - - // NOP for TMA based mainloop - CUTLASS_DEVICE - void producer_commit(uint32_t stage, uint32_t bytes) { - // Below code is used only for unit-testing (in the absennce of TMA commit) - #if CUTLASS_UNIT_TEST_PIPELINE - if (params_.is_leader) { - // STEP 1 : Commit to self - full_barrier_ptr_[stage].commit(bytes); - - // STEP 2 : Commit to other blocks in our cluster - auto cluster_shape = ClusterShape{}; - Layout block_layout_in_cluster = make_layout(cluster_shape); - dim3 local_block_id = cute::block_id_in_cluster(); - - CUTLASS_PRAGMA_UNROLL - for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { - uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); - full_barrier_ptr_[stage].commit(dst_block_id, bytes, n!=local_block_id.y); - } - - CUTLASS_PRAGMA_UNROLL - for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { - uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); - full_barrier_ptr_[stage].commit(dst_block_id, bytes, m!=local_block_id.x); - } - } - #endif - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state, uint32_t bytes) { - producer_commit(state.index(), bytes); - } - - - // Wait for producer to commit transactions (done by TMA) - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase) { - uint32_t done = full_barrier_ptr_[stage].test_wait(phase); - if (!done){ - full_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state) { - consumer_wait(state.index(), state.phase()); - } - - // Consumer signalling Producer of completion - // Ensures all blocks in the Same Row and Column get notifed. - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip = false) { - empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - - CUTLASS_DEVICE - ValueType* producer_get_barrier(uint32_t stage) { - return reinterpret_cast(&full_barrier_ptr_[stage]); - } - - CUTLASS_DEVICE - bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { - return ((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x || - (dst_block_id / cute::size<0>(cluster_shape)) == block_id.y); - } -}; - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Simple producer-consumer async Pipeline class -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -// *Count Signifies the number of producers / consumers who will announce their completion - -template -class PipelineAsync { -public : - using FullBarrier = ClusterBarrier; - using EmptyBarrier = ClusterBarrier; - using ProducerBarrierType = FullBarrier::ValueType; - static constexpr uint32_t Stages = Stages_; - - struct SharedStorage { - FullBarrier full_barrier_[Stages]; - EmptyBarrier empty_barrier_[Stages]; - }; - - enum class ThreadCategory { - NonParticipant, - Producer, - Consumer, - ProducerConsumer - }; - - struct Params { - ThreadCategory role = ThreadCategory::NonParticipant; - uint32_t producer_arv_count = 1; - uint32_t consumer_arv_count = 1; - uint32_t dst_blockid = cute::block_rank_in_cluster(); - }; - -private: - // - // Data Members - // - Params params_; - FullBarrier *full_barrier_ptr_; - EmptyBarrier *empty_barrier_ptr_; - -public: - - // Default assumption when only storage is passed is : - // => single producer, single consumer & they are in the same block (within the Cluster) - CUTLASS_DEVICE - PipelineAsync(SharedStorage& storage) - : PipelineAsync(storage, {}) {} - - CUTLASS_DEVICE - PipelineAsync( - SharedStorage& storage, - Params const& params) : - params_(params), - full_barrier_ptr_(&storage.full_barrier_[0]), - empty_barrier_ptr_(&storage.empty_barrier_[0]) { - - int warp_idx = canonical_warp_idx(); - int lane_predicate = cute::elect_one_sync(); - - // Barrier FULL, EMPTY init - // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate == 1) { - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(params.producer_arv_count); - empty_barrier_ptr_[i].init(params.consumer_arv_count); - } - } - - cutlass::arch::fence_barrier_init(); - } - - CUTLASS_DEVICE - void producer_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait = false) { - uint32_t done = empty_barrier_ptr_[stage].test_wait(phase, (!skip_wait)); - if ((!done) && (!skip_wait)){ - empty_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void producer_acquire(PipelineState state) { - producer_acquire(state.index(), state.phase()); - } - - CUTLASS_DEVICE - void producer_commit(uint32_t stage) { - full_barrier_ptr_[stage].arrive(); - } - - CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index()); - } - - CUTLASS_DEVICE - void consumer_wait(uint32_t stage, uint32_t phase) { - uint32_t done = full_barrier_ptr_[stage].test_wait(phase); - if (!done){ - full_barrier_ptr_[stage].wait(phase); - } - } - - CUTLASS_DEVICE - void consumer_wait(PipelineState state) { - consumer_wait(state.index(), state.phase()); - } - - CUTLASS_DEVICE - void consumer_release(uint32_t stage, uint32_t skip = false) { - empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); - } - - CUTLASS_DEVICE - void consumer_release(PipelineState state) { - consumer_release(state.index()); - } - - CUTLASS_DEVICE - ProducerBarrierType* get_producer_barrier(uint32_t stage) { - return reinterpret_cast(&full_barrier_ptr_[stage]); - } -}; - - - -/////////////////////////////////////////////////////////////////////////////////////////////////// -// -// Barrier to ensure an Ordered Sequence between -// SequenceLength number of groups (each with group_size participants) executing SequenceDepth Stages -// i.e., for all i < j - only after id "i" arrives at a particular stage "m" -// will the wait() for id "j" succeed for the same stage -// -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class OrderedSequenceBarrier { -public : - using Barrier = ClusterBarrier; - - struct SharedStorage { - Barrier barrier_[SequenceDepth][SequenceLength]; - }; - - struct Params { - uint32_t group_id; - uint32_t group_size; - }; - -private : - // - // Data Members - // - - // In future this Params object can be replaced easily with a CG object - Params params_; - Barrier *barrier_ptr_; - PipelineState stage_; - - static constexpr int Depth = SequenceDepth; - static constexpr int Length = SequenceLength; - -public: - OrderedSequenceBarrier() = delete; - OrderedSequenceBarrier(const OrderedSequenceBarrier&) = delete; - OrderedSequenceBarrier(OrderedSequenceBarrier&&) = delete; - OrderedSequenceBarrier& operator=(const OrderedSequenceBarrier&) = delete; - OrderedSequenceBarrier& operator=(OrderedSequenceBarrier&&) = delete; - ~OrderedSequenceBarrier() = default; - - CUTLASS_DEVICE - OrderedSequenceBarrier(SharedStorage& storage, Params const& params) : - params_(params), - barrier_ptr_(&storage.barrier_[0][0]), - // Group 0 - starts with an opposite phase - stage_({0, params.group_id == 0}) { - - int warp_idx = canonical_warp_idx(); - int lane_predicate = cute::elect_one_sync(); - - // Barrier FULL, EMPTY init - // Init is done only by the one elected thread of the block - if (warp_idx == 0 && lane_predicate == 1) { - for (int d = 0; d < Depth; ++d) { - for (int l = 0; l < Length; ++l) { - barrier_ptr_[d * Length + l].init(params.group_size); - } - } - } - - cutlass::arch::fence_barrier_init(); - } - - // Wait on a stage to be unlocked - CUTLASS_DEVICE - void wait() { - get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); - } - - // Signal completion of Stage and move to the next stage - // (group_id) signals to (group_id+1) - CUTLASS_DEVICE - void arrive() { - int signalling_id = (params_.group_id + 1) % Length; - get_barrier_for_current_stage(signalling_id).arrive(); - ++stage_; - } - -private: - - CUTLASS_DEVICE - Barrier& get_barrier_for_current_stage(int group_id) { - return barrier_ptr_[stage_.index() * Length + group_id]; - } -}; - -} // end namespace cutlass diff --git a/include/cutlass/pipeline/pipeline.hpp b/include/cutlass/pipeline/pipeline.hpp new file mode 100644 index 00000000..246e6fa4 --- /dev/null +++ b/include/cutlass/pipeline/pipeline.hpp @@ -0,0 +1,36 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/pipeline/sm90_pipeline.hpp" +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp new file mode 100644 index 00000000..d90a7f14 --- /dev/null +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -0,0 +1,989 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/numeric/integral_constant.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +enum class BarrierStatus : uint32_t { + WaitAgain = 0u, + WaitDone = 1u +}; + +class ArrivalToken { +public: + CUTLASS_HOST_DEVICE + ArrivalToken(BarrierStatus barrier_status) : barrier_status_(barrier_status) {} + + CUTLASS_HOST_DEVICE + ArrivalToken() = delete; + + CUTLASS_HOST_DEVICE + BarrierStatus get() const { + return barrier_status_;; + } + + CUTLASS_HOST_DEVICE + bool operator==(ArrivalToken const& other) const { + return barrier_status_ == other.get(); + } + +private: + BarrierStatus barrier_status_; + + CUTLASS_HOST_DEVICE + friend bool operator==(const ArrivalToken& left, const BarrierStatus& right) { + return left.get() == right; + } + + CUTLASS_HOST_DEVICE + friend bool operator==(const BarrierStatus& left, const ArrivalToken& right) { + return left == right.get(); + } +}; + +class ProducerToken : public ArrivalToken { + using ArrivalToken::ArrivalToken; +}; + +class ConsumerToken : public ArrivalToken { + using ArrivalToken::ArrivalToken; +}; + +// Circular Buffer Index + Associated Phase +// Assumes only one operation possible - i.e., ++ +template +struct PipelineState { + + static constexpr uint32_t Stages = Stages_; + +private: + int index_ = 0; + uint32_t phase_ = 0; + uint32_t phase_count_ = 0; + +public: + CUTLASS_DEVICE + PipelineState(): index_{}, phase_{}, phase_count_{} {} + + CUTLASS_DEVICE + PipelineState(int index, uint32_t phase, uint32_t phase_count) + : index_(index) + , phase_(phase) + , phase_count_(phase_count) {} + + CUTLASS_DEVICE + int index() const { + return index_; + } + + CUTLASS_DEVICE + uint32_t phase() const { + return phase_; + } + + CUTLASS_DEVICE + uint32_t phase_count() const { + return phase_count_; + } + + CUTLASS_DEVICE + void operator++() { + if constexpr (Stages > 0) { + ++index_; + if (index_ == Stages) { + index_ = 0; + phase_ ^= 1; + ++phase_count_; + } + } + } + + CUTLASS_DEVICE + PipelineState& operator=(const PipelineState& other) { + index_ = other.index(); + phase_ = other.phase(); + phase_count_ = other.phase_count(); + return *this; + } + + CUTLASS_DEVICE + PipelineState advance(uint32_t num_iterations) { + if constexpr (Stages > 0) { + // Number of iterations cross over the stage boundary => flipped phase + if ((num_iterations < Stages) && (index_ + num_iterations) >= Stages ) { + phase_ ^= 1; + } + // How many times number of iterations cross over the stage boundary and + // end up on a odd number => flipped phase + if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { + phase_ ^= 1; + } + phase_count_ += (index_ + num_iterations) / Stages; + index_ = (index_ + num_iterations) % Stages; + } + return *this; + } + + CUTLASS_DEVICE + static PipelineState make_pipeline_state(PipelineState start_state, uint32_t num_iterations) { + return start_state.advance(num_iterations); + } +}; + +template +CUTLASS_DEVICE +PipelineState make_producer_start_state() { + // Producer starts with an opposite phase as the buffers are initially empty + constexpr int InitialProducerStage = 0; + constexpr uint32_t InitialProducerPhase = 1; + constexpr uint32_t InitialProducerPhaseCount = 0; + return {InitialProducerStage, InitialProducerPhase, InitialProducerPhaseCount}; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA load (producer) Async Pipeline class +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Assumptions : Constructor is visible Cluster-wide (as it needs a Cluster-Sync) +// We have exactly one thread elected in the Producer as the "leader" +// Currently, it is optional to elect a leader for the Consumers +template < + int Stages_, + class ClusterShape_ +> +class PipelineTmaAsync { +public : + using ClusterShape = ClusterShape_; + using FullBarrier = cutlass::arch::ClusterTransactionBarrier; + using EmptyBarrier = cutlass::arch::ClusterBarrier; + using ProducerBarrierType = FullBarrier::ValueType; + using ConsumerBarrierType = EmptyBarrier::ValueType; + static constexpr uint32_t Stages = Stages_; + + struct SharedStorage { + FullBarrier full_barrier_[Stages]; + EmptyBarrier empty_barrier_[Stages]; + }; + + enum class ThreadCategory { + NonParticipant, + Producer, + Consumer, + ProducerConsumer + }; + + struct Params { + uint32_t transaction_bytes = 0; + ThreadCategory role = ThreadCategory::NonParticipant; + uint32_t is_leader = 0; + uint32_t num_consumers = 0; + }; + + // Constructor + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params) + : params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + auto cluster_shape = ClusterShape{}; + + if (warp_idx == 0 && lane_predicate == 1) { + // Barrier FULL init + for (int i = 0; i < Stages; ++i) { + full_barrier_ptr_[i].init(1); + } + // Barrier EMPTY init + uint32_t const num_consumer_warpgroups_per_cluster = params_.num_consumers / NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * + num_consumer_warpgroups_per_cluster; + for (int i = 0; i < Stages; ++i) { + empty_barrier_ptr_[i].init(multicast_consumer_arrival_count); + } + } + + // Logic to optimally schedule Empty Arrives + // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) + dim3 block_id = cute::block_id_in_cluster(); + auto cluster_size = cute::size(cluster_shape); + static constexpr int MaxClusterSize = 16; + static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); + + // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) + if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { + int thread_idx = threadIdx.x % NumThreadsPerWarpGroup; + is_signalling_thread_ = (thread_idx % (NumThreadsPerWarpGroup / MaxClusterSize)) == 0; + auto layout = cute::composition(Swizzle<2,0,-2>{}, + Layout,Stride<_4,_1>>{}); + uint32_t thread_row = warp_idx % 4; + uint32_t thread_col = (thread_idx / 8) % 4; + dst_blockid_ = layout(thread_row, thread_col); + } + else if (params_.num_consumers == 32) { + int thread_idx = threadIdx.x % 32; + is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; + auto layout = Layout,Stride<_4, _1>>{}; + uint32_t thread_row = thread_idx / 8; + uint32_t thread_col = (thread_idx % 8) / 2; + dst_blockid_ = layout(thread_row, thread_col); + } + else { + is_signalling_thread_ = 0; + #ifndef NDEBUG + asm volatile ("brkpt;\n" ::); + #endif + } + + // STEP 2: Find if this dst block-id needs an arrival for this problem + is_signalling_thread_ &= dst_blockid_ < cluster_size; + is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); + + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { + return ((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x || + (dst_block_id / cute::size<0>(cluster_shape)) == block_id.y); + } + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return producer_try_acquire(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire(state.index(), state.phase(), barrier_token); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state, uint32_t bytes) { + producer_commit(state.index(), bytes); + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + for (int count = 0; count < Stages; ++count) { + producer_acquire(state); + ++state; + } + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return producer_get_barrier(state.index()); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return consumer_try_wait(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state) { + consumer_wait(state.index(), state.phase()); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token) { + consumer_wait(state.index(), state.phase(), barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +private : + uint32_t dst_blockid_ = 0; + uint32_t is_signalling_thread_ = 0; + FullBarrier *full_barrier_ptr_ = nullptr; + EmptyBarrier *empty_barrier_ptr_ = nullptr; + Params params_; + + CUTLASS_DEVICE + ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + uint32_t barrier_status = empty_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_status)}; + } + + CUTLASS_DEVICE + void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + if (barrier_token == BarrierStatus::WaitAgain) { + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes); + } + #ifndef NDEBUG + if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + + // Most likely you have elected more than one leader + if (params_.is_leader && (threadIdx.x % 32 != 0)) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + // NOP for TMA based mainloop + CUTLASS_DEVICE + void producer_commit(uint32_t stage, uint32_t bytes) { + // Below code is used only for unit-testing (in the absence of TMA commit) + #if CUTLASS_UNIT_TEST_PIPELINE + if (params_.is_leader) { + // STEP 1 : Commit to self + full_barrier_ptr_[stage].commit(bytes); + + // STEP 2 : Commit to other blocks in our cluster + auto cluster_shape = ClusterShape{}; + Layout block_layout_in_cluster = make_layout(cluster_shape); + dim3 local_block_id = cute::block_id_in_cluster(); + + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { + uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); + full_barrier_ptr_[stage].commit(dst_block_id, bytes, n!=local_block_id.y); + } + + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { + uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); + full_barrier_ptr_[stage].commit(dst_block_id, bytes, m!=local_block_id.x); + } + } + #endif + } + + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + uint32_t barrier_status = full_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_status)}; + } + + // Wait for producer to commit transactions (done by TMA) + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase) { + uint32_t done = full_barrier_ptr_[stage].test_wait(phase); + if (not done) { + full_barrier_ptr_[stage].wait(phase); + } + } + + // Wait for producer to commit transactions (done by TMA) + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + if (barrier_token == BarrierStatus::WaitAgain) { + consumer_wait(stage, phase); + } + } + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); + #ifndef NDEBUG + if (params_.role == ThreadCategory::Producer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(uint32_t stage) { + return reinterpret_cast(&full_barrier_ptr_[stage]); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA store (consumer) pipeline class +// producer-only class, no async barriers between threads because consumer is TMA unit +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template < + int Stages_ +> +class PipelineTmaStore { +public: + static constexpr uint32_t Stages = Stages_; + + struct Params { + bool always_wait = false; + }; + + CUTLASS_DEVICE + PipelineTmaStore(Params params = {}) : params_(params) {} + + //////////////////// + // Producer APIs + //////////////////// + // Wait for the least recently committed batch of TMA stores to complete + CUTLASS_DEVICE + void producer_acquire(PipelineState state) { + producer_acquire(state.index(), state.phase_count()); + } + + // Commit the most recently issued batch of TMA stores + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index(), state.phase_count()); + } + + // Wait for all TMA stores to complete + CUTLASS_DEVICE + void producer_tail([[maybe_unused]] PipelineState state) { + tma_store_wait<0>(); + } + +private: + Params params_; + + // Wait for the least recently committed batch of TMA stores to complete + CUTLASS_DEVICE + void producer_acquire([[maybe_unused]] uint32_t stage, uint32_t phase_count) { + if (params_.always_wait || phase_count > 0) { + tma_store_wait(); + } + } + + // Commit the most recently issued batch of TMA stores + CUTLASS_DEVICE + void producer_commit([[maybe_unused]] uint32_t stage, [[maybe_unused]] uint32_t phase_count) { + tma_store_arrive(); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Simple producer-consumer async Pipeline class using producer transaction barriers +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class PipelineTransactionAsync { +public : + using FullBarrier = cutlass::arch::ClusterTransactionBarrier; + using EmptyBarrier = cutlass::arch::ClusterBarrier; + using ProducerBarrierType = FullBarrier::ValueType; + using ConsumerBarrierType = EmptyBarrier::ValueType; + static constexpr uint32_t Stages = Stages_; + + struct SharedStorage { + FullBarrier full_barrier_[Stages]; + EmptyBarrier empty_barrier_[Stages]; + }; + + enum class ThreadCategory { + NonParticipant, + Producer, + Consumer, + ProducerConsumer + }; + + struct Params { + ThreadCategory role = ThreadCategory::NonParticipant; + uint32_t transaction_bytes = 0; + uint32_t producer_arv_count = 1; + uint32_t consumer_arv_count = 1; + uint32_t dst_blockid = cute::block_rank_in_cluster(); + }; + + // Constructor + CUTLASS_DEVICE + PipelineTransactionAsync(SharedStorage& storage, Params const& params) + : params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + + // Barrier FULL, EMPTY init + // Init is done only by thread 0 of the block + if (warp_idx == 0 && lane_predicate == 1) { + for (int i = 0; i < Stages; ++i) { + full_barrier_ptr_[i].init(params.producer_arv_count); + empty_barrier_ptr_[i].init(params.consumer_arv_count); + } + } + + cutlass::arch::fence_barrier_init(); + } + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return producer_try_acquire(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire(state.index(), state.phase(), barrier_token); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index()); + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + for (int count = 0; count < Stages; ++count) { + producer_acquire(state); + ++state; + } + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return producer_get_barrier(state.index()); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return consumer_try_wait(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + consumer_wait(state.index(), state.phase(), barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +protected: + FullBarrier *full_barrier_ptr_ = nullptr; + EmptyBarrier *empty_barrier_ptr_ = nullptr; + Params params_; + + CUTLASS_DEVICE + ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + uint32_t barrier_status = empty_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_status)}; + } + + CUTLASS_DEVICE + void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + if (barrier_token == BarrierStatus::WaitAgain) { + empty_barrier_ptr_[stage].wait(phase); + } + + full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes, params_.dst_blockid); + } + + CUTLASS_DEVICE + void producer_commit([[maybe_unused]] uint32_t stage) { + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(uint32_t stage) { + return reinterpret_cast(&full_barrier_ptr_[stage]); + } + + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + uint32_t barrier_status = full_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_status)}; + } + + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + if (barrier_token == BarrierStatus::WaitAgain) { + full_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip = false) { + empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Simple producer-consumer async Pipeline class +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class PipelineAsync { +public : + using FullBarrier = cutlass::arch::ClusterBarrier; + using EmptyBarrier = cutlass::arch::ClusterBarrier; + using ProducerBarrierType = FullBarrier::ValueType; + using ConsumerBarrierType = EmptyBarrier::ValueType; + static constexpr uint32_t Stages = Stages_; + + struct SharedStorage { + FullBarrier full_barrier_[Stages]; + EmptyBarrier empty_barrier_[Stages]; + }; + + enum class ThreadCategory { + NonParticipant, + Producer, + Consumer, + ProducerConsumer + }; + + struct Params { + ThreadCategory role = ThreadCategory::NonParticipant; + uint32_t producer_arv_count = 1; + uint32_t consumer_arv_count = 1; + uint32_t dst_blockid = cute::block_rank_in_cluster(); + }; + + // Default assumption when only storage is passed is : + // => single producer, single consumer & they are in the same block (within the Cluster) + CUTLASS_DEVICE + PipelineAsync(SharedStorage& storage) + : PipelineAsync(storage, {}) {} + + CUTLASS_DEVICE + PipelineAsync( + SharedStorage& storage, + Params const& params) : + params_(params), + full_barrier_ptr_(&storage.full_barrier_[0]), + empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + + // Barrier FULL, EMPTY init + // Init is done only by thread 0 of the block + if (warp_idx == 0 && lane_predicate == 1) { + for (int i = 0; i < Stages; ++i) { + full_barrier_ptr_[i].init(params.producer_arv_count); + empty_barrier_ptr_[i].init(params.consumer_arv_count); + } + } + + cutlass::arch::fence_barrier_init(); + } + + //////////////////// + // Producer APIs + //////////////////// + // Four member functions are always used in pairs: + // + // * producer_try_acquire and producer_acquire, and + // * consumer_try_wait and consumer_wait. + // + // The two functions with "try" in their names are called "try" functions, + // and the other two are conceptually "finalize" functions. + // The "try" function in each pair starts the process of waiting on the barrier to flip. + // It opportunistically waits for an implementation-dependent timeout. + // Whether or not the barrier has flipped yet, the try function will return a token. + // If the token indicates that the barrier has not flipped, + // then the token must be passed into the corresponding "finalize" function. + // The finalize function will then block until the barrier has flipped. + // If the token indicates that the barrier _has_ flipped, + // then it is still correct to pass it into the finalize function. + // The finalize function will return immediately in that case. + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return producer_try_acquire(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire(state.index(), state.phase(), barrier_token); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index()); + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + for (int count = 0; count < Stages; ++count) { + producer_acquire(state); + ++state; + } + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return producer_get_barrier(state.index()); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return consumer_try_wait(state.index(), state.phase(), skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + consumer_wait(state.index(), state.phase(), barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index()); + } + +private: + Params params_; + FullBarrier *full_barrier_ptr_; + EmptyBarrier *empty_barrier_ptr_; + + CUTLASS_DEVICE + ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + uint32_t barrier_status = empty_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_status)}; + } + + CUTLASS_DEVICE + void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + if (barrier_token == BarrierStatus::WaitAgain) { + empty_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void producer_commit(uint32_t stage) { + full_barrier_ptr_[stage].arrive(); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(uint32_t stage) { + return reinterpret_cast(&full_barrier_ptr_[stage]); + } + + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + if (skip_wait) { + return {BarrierStatus::WaitDone}; + } + uint32_t barrier_status = full_barrier_ptr_[stage].try_wait(phase); + return {static_cast(barrier_status)}; + } + + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase) { + uint32_t done = full_barrier_ptr_[stage].test_wait(phase); + if (!done) { + full_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + if (barrier_token == BarrierStatus::WaitAgain) { + full_barrier_ptr_[stage].wait(phase); + } + } + + CUTLASS_DEVICE + void consumer_release(uint32_t stage) { + empty_barrier_ptr_[stage].arrive(params_.dst_blockid); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Barrier to ensure an Ordered Sequence between +// SequenceLength number of groups (each with group_size participants) executing SequenceDepth Stages +// i.e., for all i < j - only after id "i" arrives at a particular stage "m" +// will the wait() for id "j" succeed for the same stage +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class OrderedSequenceBarrier { +public : + using Barrier = cutlass::arch::ClusterBarrier; + + struct SharedStorage { + Barrier barrier_[SequenceDepth][SequenceLength]; + }; + + struct Params { + uint32_t group_id; + uint32_t group_size; + }; + +private : + // In future this Params object can be replaced easily with a CG object + Params params_; + Barrier *barrier_ptr_; + PipelineState stage_; + + static constexpr int Depth = SequenceDepth; + static constexpr int Length = SequenceLength; + +public: + OrderedSequenceBarrier() = delete; + OrderedSequenceBarrier(const OrderedSequenceBarrier&) = delete; + OrderedSequenceBarrier(OrderedSequenceBarrier&&) = delete; + OrderedSequenceBarrier& operator=(const OrderedSequenceBarrier&) = delete; + OrderedSequenceBarrier& operator=(OrderedSequenceBarrier&&) = delete; + ~OrderedSequenceBarrier() = default; + + CUTLASS_DEVICE + OrderedSequenceBarrier(SharedStorage& storage, Params const& params) : + params_(params), + barrier_ptr_(&storage.barrier_[0][0]), + // Group 0 - starts with an opposite phase + stage_({0, params.group_id == 0, 0}) { + + int warp_idx = canonical_warp_idx(); + int lane_predicate = cute::elect_one_sync(); + + // Barrier FULL, EMPTY init + // Init is done only by the one elected thread of the block + if (warp_idx == 0 && lane_predicate == 1) { + for (int d = 0; d < Depth; ++d) { + for (int l = 0; l < Length; ++l) { + barrier_ptr_[d * Length + l].init(params.group_size); + } + } + } + + cutlass::arch::fence_barrier_init(); + } + + // Wait on a stage to be unlocked + CUTLASS_DEVICE + void wait() { + get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); + } + + // Signal completion of Stage and move to the next stage + // (group_id) signals to (group_id+1) + CUTLASS_DEVICE + void arrive() { + int signalling_id = (params_.group_id + 1) % Length; + get_barrier_for_current_stage(signalling_id).arrive(); + ++stage_; + } + +private: + + CUTLASS_DEVICE + Barrier& get_barrier_for_current_stage(int group_id) { + return barrier_ptr_[stage_.index() * Length + group_id]; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cutlass diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 96bb8f64..f0582f04 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -879,6 +879,7 @@ struct numeric_limits { static constexpr bool is_integer = true; }; +#if !defined(__CUDACC_RTC__) template <> struct numeric_limits { CUTLASS_HOST_DEVICE @@ -886,6 +887,7 @@ struct numeric_limits { static constexpr bool is_integer = false; static constexpr bool has_infinity = true; }; +#endif } // namespace platform } // namespace cutlass diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 4736e28c..00e73792 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -56,7 +56,11 @@ template CUTLASS_HOST_DEVICE bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else using std::abs; +#endif T abs_A = abs(a); T abs_B = abs(b); @@ -157,6 +161,18 @@ bool relatively_equal(uint64_t a, uint64_t b, uint64_t, uint64_t) { ///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e4m3_t a, float_e4m3_t b, float_e4m3_t epsilon, float_e4m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e5m2_t a, float_e5m2_t b, float_e5m2_t epsilon, float_e5m2_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + template <> CUTLASS_HOST_DEVICE bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) { diff --git a/include/cutlass/semaphore.h b/include/cutlass/semaphore.h index ed8a179e..27343f9e 100644 --- a/include/cutlass/semaphore.h +++ b/include/cutlass/semaphore.h @@ -89,7 +89,7 @@ class Semaphore { /// Waits until the semaphore is equal to the given value CUTLASS_DEVICE void wait(int status = 0) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) while( __syncthreads_and(state != status) ) { fetch(); } @@ -101,7 +101,7 @@ class Semaphore { /// Updates the lock with the given result CUTLASS_DEVICE void release(int status = 0) { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if !defined(CUTLASS_PYTHON_HOST_CC) __syncthreads(); if (wait_thread) { diff --git a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp new file mode 100644 index 00000000..63de6726 --- /dev/null +++ b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing how threads are mapped to a given tile. +*/ + +#pragma once + +#include "cute/arch/mma_sm90_gmma.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +using namespace cute; + +template +constexpr auto +gmma_smem_transpose_or_passthrough() { + if constexpr (Transpose) { + if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported Layout_SW_Atom for B SMEM transposition"); + } + } + else { + return SmemLayoutAtom{}; + } +} + +template +constexpr auto +use_universal_transposition() { + if constexpr (sizeof(ElementType) == 1) { + return !cute::is_same_v, SmemCopyAtom>; + } + else if constexpr (sizeof(ElementType) == 4){ + // Only universal transposition can handle SW64 and Non swizzle SMEM layout + if constexpr (cute::is_same_v, SmemCopyAtom> || + cute::is_same_v, SmemCopyAtom>) { + return true; + } + else { + return false; + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ElementType for B SMEM transposition"); + } +} + +/// Transpose B operand in SMEM +template < + class TensorSmemB, + class TensorTransposedSmemB, + class PipelineState, + class TiledMma, + class SmemLayoutB, + class SmemLayoutAtomB, + class ElementB> +CUTLASS_DEVICE void +transpose_b_operand ( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + PipelineState const& smem_pipe_read, + int warp_idx, int warp_group_thread_idx, + TiledMma, SmemLayoutB, SmemLayoutAtomB, ElementB) +{ + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Important terms: + /// WarpgroupTileSize : The warp_group_tile size (WarpgroupTileSize x WarpgroupTileSize) a warp group would transpose + /// WarpTileSize : The warp_tile size (WarpTile x WarpTile) a warp would transpose + /// Step : The number of steps a warp group takes to complete the entire warp_group_tile transposition. + /// WarpTileNCoordLUT : The look up table to store the n-dim coords used by the warps + /// WarpTileKCoordLUT : The look up table to store the k-dim coords used by the warps + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + static_assert(size(TiledMma{}) == NumThreadsPerWarpGroup, "Wrong thread number for TransposeB"); + constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + constexpr int BytesPerSmemSwizzleUnit = 16; + constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Optimized transposition, less regs per thread than universal approach, need warp sync between load and store + /// TF32/FP32 would use the 2-steps approach. Fp8/Int8 would use 8-steps approach. + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + if constexpr (!detail::use_universal_transposition()) { + constexpr int Steps = sizeof(ElementB) == 1 ? 8 : 2; + constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); + constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// A warp group uses 2 or 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. + /// In each step, one warp would hold two warp_tiles. + /// Step 0: Step 1: + /// W0 W1 W2 W3 -- -- -- -- + /// W1 W0 -- -- -- -- W3 W2 + /// W2 -- -- -- -- W3 W0 W1 + /// W3 -- -- -- -- W2 W1 W1 + /// OR: + /// Divide a warp_group_tile into 8x8 warp_tiles to futher reduce the reg usage. + /// Step 0: Step 1: Step 2: Step 3: + /// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W2 -- -- -- -- -- -- -- -- W3 W0 W1 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W3 -- -- -- -- -- -- -- -- W2 W1 W0 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W1 W0 -- -- -- -- -- -- -- -- W3 W2 + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W3 W0 W1 + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W2 W1 W0 + /// + /// Step 4: Step 5: Step 6: Step 7: + /// -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 + /// W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- + /// W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- + /// W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- + /// W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- + /// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// + /// Fully static coord LUT to avoid extra register use. + /// [warp_id][step][warp_tile][n / k] + /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 + /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 + /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 + /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 + /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 + /// + /// Encoding the coord of warp tile0 into two int64_t values. + /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. + /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. + /// The 2-step transposition and the 8-step transposition share the same encoding. + /// + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + constexpr int NumBitsPerStep = 3; + constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + constexpr int NumBitsPerWarp = 12; + + const int current_warp_tile_n_coord_LUT = (WarpTileNCoordLUT >> (warp_idx * NumBitsPerWarp)) & MaskPerWarp; + const int current_warp_tile_k_coord_LUT = (WarpTileKCoordLUT >> (warp_idx * NumBitsPerWarp)) & MaskPerWarp; + + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + // Divide entire SMEM to multiple warp_tiles + constexpr auto WarpTileShape = make_shape(Int(), Int()); + Tensor s_tile = zipped_divide( sB(_,_,smem_pipe_read.index()), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,smem_pipe_read.index()), WarpTileShape); + + // Get copy tile + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}), "Wrong thread number in TiledCopy."); + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx + + // Construct fragments for transposition + Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); + decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { + make_fragment_like(tmp_tCsB), + make_fragment_like(tmp_tCsB) + }; + + CUTLASS_PRAGMA_NO_UNROLL + for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT; + int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT; + + CUTLASS_PRAGMA_NO_UNROLL + for (int step = 0; step < Steps; ++step) { + // decoding the warp tile coord. + int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx; + int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; + int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; + int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; + + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep; + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep; + + // [warp_tile][n/k] + const int warp_tile_coord[TilesPerWarp][2] = { + // n k + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 + }; + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB = sB_thr_copy.partition_S( + flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); + } + + // Make sure elements in two 8x8 warp tiles are all consumed + __syncwarp(); + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB_transposed = sB_thr_copy.partition_D( + flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); + } + + } // lock step + } // loop warp_group_tile + } // if not use universal transposition + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Universal transposition, need warp_group sync between load and store. + /// The number of reg used depends on the input elementB. + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + else { + /* + In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location. + In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements: + K + ------------ + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step 0 + | W0 W1 W2 W3 --- + .... + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step n + | W0 W1 W2 W3 --- + */ + static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout."); + constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int{}, Int{})); + + // Get copy tile and partition to each thread + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpgroupThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy."); + + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx); + Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,smem_pipe_read.index())); // (CPY, CPY_N, CPY_K) + Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,smem_pipe_read.index())); // (CPY, CPY_N, CPY_K) + + // Divide partitioned tile to limit register usage + constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB)); + static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM."); + + Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape); + Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape); + auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{})); + + CUTLASS_PRAGMA_NO_UNROLL + for (int step = 0; step < CopySteps; ++step) { + copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment); + + // Make sure all elements are read before being overwritten + __syncthreads(); + + copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step)); + } + } // if use universal transposition + + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); +} + +}; // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index 1026bad2..ee3aa58e 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -48,6 +48,7 @@ #include "cutlass/coord.h" #include "cutlass/cutlass.h" #include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/matrix_shape.h" #include "cutlass/predicate_vector.h" @@ -314,7 +315,8 @@ class PredicatedTileAccessIteratorPredicates { /// PredicatedTileAccessIterator /// template + typename ThreadMap, typename AccessType, bool Gather = false, + typename PermuteLayout = layout::NoPermute> class PredicatedTileAccessIterator; //////////////////////////////////////////////////////////////////////////////// @@ -322,9 +324,11 @@ class PredicatedTileAccessIterator; /// Specialization of PredicatedTileAccessIterator for pitch-linear data. /// template + typename ThreadMap_, typename AccessType_, bool Gather, + typename PermuteLayout> class PredicatedTileAccessIterator { + AdvanceRank, ThreadMap_, AccessType_, Gather, + PermuteLayout> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -356,6 +360,9 @@ class PredicatedTileAccessIterator::value + && !platform::is_same>::value; + using Mask = typename UnderlyingPredicates::Mask; /// Uses a non-template class @@ -402,12 +409,18 @@ class PredicatedTileAccessIterator( - const_cast(pointer))), - the_predicates(extent), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), is_residue_tile_(true), - indices_(indices) { + indices_(indices), + permute_layout_(TensorCoord(extent.contiguous(), extent.strided()), params.stride_) { the_predicates.set_predicates(thread_id, threadblock_offset); + if (Gather) { + assert(indices_); + } + // update internal pointers Layout layout(params_.stride_); - if (!Gather) { + if (!Gather && !Permute) { add_pointer_offset(layout(the_predicates.thread_offset_)); } else { - gather_offset_strided = the_predicates.thread_offset_.strided(); - add_pointer_offset(layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + coord_offset_ = the_predicates.thread_offset_; + if (!Permute) { + add_pointer_offset(layout(make_Coord(coord_offset_.contiguous(), 0))); + } } } @@ -499,7 +519,7 @@ class PredicatedTileAccessIterator::value / 8) + the_predicates.iteration_vector_; - int strided_index = gather_offset_strided + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; - - LongIndex strided_offset = indices_[strided_index] * LongIndex(params_.stride_) * sizeof_bits::value / 8; + Index coord_contig = (Permute ? coord_offset_.contiguous() : 0) + the_predicates.iteration_contiguous_ * ThreadMap::Delta::kContiguous + the_predicates.iteration_vector_ * AccessType::kElements; + Index coord_strided = coord_offset_.strided() + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + if (Gather) { + coord_strided = indices_[coord_strided]; + } - return reinterpret_cast(pointer_ + contiguous_offset + strided_offset); + LongIndex offset = Permute ? permute_layout_(TensorCoord(coord_contig, coord_strided)) : (coord_strided * LongIndex(params_.stride_) + coord_contig); + return reinterpret_cast(pointer_ + OffsetBytes(offset)); } return reinterpret_cast( @@ -580,13 +603,12 @@ class PredicatedTileAccessIterator + typename ThreadMap_, typename AccessType_, bool Gather, + typename PermuteLayout> class PredicatedTileAccessIterator { + AdvanceRank, ThreadMap_, AccessType_, Gather, + PermuteLayout> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -687,7 +711,8 @@ class PredicatedTileAccessIterator, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, Gather>; + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, + Gather, PermuteLayout>; /// Predicate vector stores mask to guard accesses using Mask = typename UnderlyingIterator::Mask; @@ -846,9 +871,11 @@ class PredicatedTileAccessIterator + typename ThreadMap_, typename AccessType_, bool Gather, + typename PermuteLayout> class PredicatedTileAccessIterator { + AdvanceRank, ThreadMap_, AccessType_, Gather, + PermuteLayout> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -874,7 +901,8 @@ class PredicatedTileAccessIterator, Element, - layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, Gather>; + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, + Gather, PermuteLayout>; static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; @@ -1035,7 +1063,8 @@ class PredicatedTileAccessIterator class PredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_, false> { + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -1342,7 +1371,8 @@ class PredicatedTileAccessIterator, template class PredicatedTileAccessIterator { + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -1524,7 +1554,8 @@ class PredicatedTileAccessIterator class PredicatedTileAccessIterator { + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -1709,7 +1740,8 @@ template class PredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_, false> { + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -1899,7 +1931,8 @@ template class PredicatedTileAccessIterator, - AdvanceRank, ThreadMap_, AccessType_, false> { + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, diff --git a/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_iterator.h index 839d8f50..ada679fd 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_iterator.h @@ -137,7 +137,8 @@ template < int AdvanceRank, typename ThreadMap, int AccessSize = ThreadMap::kElementsPerAccess, - bool Gather = false + bool Gather = false, + typename PermuteLayout = layout::NoPermute > class PredicatedTileIterator; @@ -151,9 +152,9 @@ class PredicatedTileIterator; /// MaskedTileIteratorConcept /// template + typename ThreadMap_, int AccessSize, bool Gather, typename PermuteLayout> class PredicatedTileIterator { + ThreadMap_, AccessSize, Gather, PermuteLayout> { public: static_assert( AdvanceRank == 0 || AdvanceRank == 1, @@ -182,7 +183,7 @@ class PredicatedTileIterator; + ThreadMap, AccessType, Gather, PermuteLayout>; static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; @@ -395,7 +396,7 @@ class PredicatedTileIterator -class PredicatedTileIterator { +class PredicatedTileIterator { public: static_assert(AdvanceRank == 0 || AdvanceRank == 1, @@ -440,7 +443,8 @@ class PredicatedTileIterator; using AccessType = typename UnderlyingIterator::AccessType; @@ -610,7 +614,7 @@ class PredicatedTileIterator -class PredicatedTileIterator { +class PredicatedTileIterator { public: static_assert(AdvanceRank == 0 || AdvanceRank == 1, @@ -655,7 +661,8 @@ class PredicatedTileIterator; using AccessType = typename UnderlyingIterator::AccessType; diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h index f761cddd..27ce2cb4 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -277,7 +277,7 @@ class RegularTileIterator diff --git a/media/docs/cute/02_layout_operations.md b/media/docs/cute/02_layout_operations.md index f9c9734a..b14735b5 100644 --- a/media/docs/cute/02_layout_operations.md +++ b/media/docs/cute/02_layout_operations.md @@ -52,12 +52,6 @@ For example, `print_layout` can display a rank-2 layout in a table It has an overload taking a rank-2 matrix layout and a thread layout, that displays a table with the mapping between threads and values. -Some CuTe types might not have overloads for `print`, -but there are other ways to print their contents. -For example, copy atoms and mma atoms -(see elsewhere in this tutorial) -have a `print_all()` member function. - ### Printing LaTeX output The `cute::print_latex` function works like `cute::print`, @@ -261,7 +255,7 @@ The complement B of a layout A with respect to an integer M satisfies the follow 1. $A$ and $B$ are *disjoint*: $A(x) \neq B(x)$ for all $x \neq 0$ in the domain of $A$. -2. B is *ordered*: $`B(x-1) < B(x)`$ for all $x$ in $\{0, 1, \dots, size(B) - 1\}$. +2. B is *ordered*: $B(x-1) \lt B(x)$ for all $x$ in $\{0, 1, \dots, size(B) - 1\}$. 3. B is *bounded* by M: $size(B) \geq M / size(A)$, and $cosize(B) \leq floor(M / cosize(A)) * cosize(A)$. diff --git a/media/docs/cute/0t_mma_atom.md b/media/docs/cute/0t_mma_atom.md index 7bdc4074..f1880464 100644 --- a/media/docs/cute/0t_mma_atom.md +++ b/media/docs/cute/0t_mma_atom.md @@ -24,8 +24,8 @@ and an `MMA_Traits` struct templated on the Operation struct type. An "Operation" struct exposes the PTX instruction for that specific operation. It defines the arguments and interface it expects. -Operation structs have minimal software dependencies -- -it does not use layouts, tensors, or non-standard numeric data types. +Operation structs have minimal software dependencies -- +they do not use layouts, tensors, or non-standard numeric data types. Different structs have different names that describe what the MMA instruction does. We will explain the naming scheme below. diff --git a/media/docs/efficient_gemm.md b/media/docs/efficient_gemm.md index ba91dcf2..ddb9043c 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/efficient_gemm.md @@ -226,13 +226,22 @@ as part of the kernel design. A thread block is partitioned into two sets of war [*Producer* warp group (DMA)](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) waits for the shared memory buffers to be signaled as [empty](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) by the *consumer* warp group using the newly added **Async Pipeline class** ([refer](/media/docs/pipeline.md)). Once the data is written into the shared memory, TMA is also updates the barrier associated with that stage to notify affected threads that the buffer has been [filled](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp). The [*Consumer* warp group (MMA)](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) on the other hand waits for the *producer* warp group to signal that the buffer is [filled](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) and then launches tensor core MMA operations. Finally, the *consumer* warp group [releases](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) the buffers for the next set of TMA loads to happens. -**Warp-Specialized Persistent kernel design** +**Warp-Specialized Persistent Cooperative kernel design** -Another flavor of Warp Specialized kernel design being introduced starting with Hopper is the [*Warp-Specialized Persistent*](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp) kernel. Like Warp Specialized kernel the concepts of warp groups and barrier synchronization between warp groups remain the same in the persistent design. The distinctive feature of the Warp-Specialized Persistent kernel are the following : +Another flavor of Warp-Specialized kernel design being introduced starting with Hopper is the [*Warp-Specialized Persistent Cooperative*](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel. Like the Warp-Specialized kernel, the concepts of warp groups and barrier synchronization between warp groups remain the same in the cooperative design. +The distinctive feature of the Warp-Specialized Persistent Cooperative kernel are the following : * Persistent thread blocks launched to occupy as many SMs as mentioned in the [KernelHardwareInfo](/include/cutlass/kernel_hardware_info.hpp) struct. These persistent thread blocks are used to tile the output and thus (potentially) compute multiple output tiles through their lifetime. The main benefit this adds is amortization of the thread-block launch and kernel prologue overheads which are typical of all kernels. -* Presence of one two *consumer* warp groups which allows for *epilogue* of one *consumer* warp group to be overlapped with the math operations of the other *consumer* warp group - thus maximizing tensor core utilization. +* Presence of two *consumer* warp groups cooperating on the same output tile by splitting the tile in half across the M dimension. This allows for larger tile sizes to be enabled - since the register pressure per *consumer* warp group is reduced - and hence improving performance. -Each *consumer* warp group is assigned a different output tile. The *producer* warp group synchronizes using the [Ordered Sequence Barrier](/include/cutlass/pipeline.hpp) to fill buffers of the two *consumer* warp groups one after the other in order. Since each thread block now computes multiple output tiles, the shape of the grid launch and the scheduling of tiles to the thread blocks is managed using the new [*Tile Scheduler*](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). The *Tile Scheduler* considers the shape of the *clusters* as well as the available number of available SMs to compute a valid scheduling of the output tiles to launched thread blocks. +Since each thread block now computes multiple output tiles, the shape of the grid launch and the scheduling of tiles to the thread blocks is managed using the new [*Tile Scheduler*](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). The *Tile Scheduler* considers the shape of the *clusters* as well as the available number of available SMs to compute a valid scheduling of the output tiles to launched thread blocks. + +**Warp-Specialized Persistent Ping-Pong kernel design** + +The third kernel design is the [*Warp-Specialized Persistent Ping-Pong*](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. +Like the Warp-Specialized Persistent Cooperative, kernel the concepts of warp groups, barrier synchronization between warp groups, and the shape of the grid launch remain the same in the persistent ping-pong design. +The distinctive feature of the Warp-Specialized Persistent Ping-Pong kernel is the following : +* The two *consumer* warp groups are assigned a different output tile using the Tile Scheduler. This allows for *epilogue* of one *consumer* warp group to be overlapped with the math operations of the other *consumer* warp group - thus maximizing tensor core utilization. +* The *producer* warp group synchronizes using the [Ordered Sequence Barrier](/include/cutlass/pipeline.hpp) to fill buffers of the two *consumer* warp groups one after the other in order. # Resources diff --git a/media/docs/gemm_api_3x.md b/media/docs/gemm_api_3x.md index c4a45489..ecaa3615 100644 --- a/media/docs/gemm_api_3x.md +++ b/media/docs/gemm_api_3x.md @@ -277,7 +277,7 @@ warp-specialized mainloop implementation: template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecialized + class KernelSchedule = KernelTmaWarpSpecializedCooperative > struct MainloopSm90TmaGmmaWarpSpecialized { constexpr static int Stages = Stages_; @@ -299,7 +299,8 @@ it needs to be run, or exposes a template API that lets the user pick a subset o struct KernelMultistage { }; struct KernelTma { }; struct KernelTmaWarpSpecialized { }; -struct KernelTmaWarpSpecializedPersistent { }; +struct KernelTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperative { }; ``` - A single kernel schedule can support multiple mainloop implementations. For example, @@ -308,7 +309,7 @@ architectures such as `MainloopSm70TwoStage`, `MainloopSm80CpAsyncUnpredicated`, - A single mainloop can be composed with multiple possible kernel schedules. For example, the `MainloopSm90TmaGmmaWarpSpecialized` can be -composed with either the `KernelTmaWarpSpecialized` or `KernelTmaWarpSpecializedPersistent` +composed with any of the `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative` kernel schedules. As [discussed in the CUTLASS 3.0 design documentation](cutlass_3x_design.md), adopting tag @@ -487,7 +488,7 @@ any of various `include/cutlass/gemm/kernel/{arch_tag}*.hpp` files in the direct Which specialization to dispatch to is decided through the dispatch policy's `Schedule` type. For example, the header file -[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_persistent.hpp) +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) has a specialization of `kernel::GemmUniversal` for Hopper that uses a warp-specialized mainloop with a persistent scheduling algorithm, while the header file diff --git a/media/docs/implicit_gemm_convolution.md b/media/docs/implicit_gemm_convolution.md index ed2c84e8..4418b95a 100644 --- a/media/docs/implicit_gemm_convolution.md +++ b/media/docs/implicit_gemm_convolution.md @@ -19,16 +19,20 @@ This release of CUTLASS contains several artifacts related to convolution. # Implicit GEMM Algorithm -2-D convolution may be mapped to matrix multiply by forming a _convolution matrix_ containing -elements of the activations tensor then multiplying this by a matrix formed from the filters tensor. -The earliest form of this algorithm construct the convolution matrix explicitly via an operation +2-D convolution may be mapped to matrix multiply +by first forming a _convolution matrix_ containing elements of the activations tensor, +then multiplying this by a matrix formed from the filters tensor. +The earliest form of this algorithm constructs the convolution matrix explicitly via an operation conventionally referred to as `im2col`. The resulting matrix replicates each activation element by a factor equal to the filter size, consuming additional storage capacity and memory bandwidth. -The _implicit GEMM_ algorithm is a variation on the blocked, hierarchical GEMM computation in CUDA -that instead forms tiles of the convolution matrix on the fly as data is loaded from global memory -into Shared Memory by carefully updating pointers and predicates. Once the convolution matrix is -formed in Shared Memory, the existing components computing warp-level GEMM accumulate the result of +The _implicit GEMM_ algorithm is a variation on the blocked, hierarchical GEMM computation in CUDA. +Instead of constructing the convolution matrix explicitly, +it forms tiles of the convolution matrix on the fly +as data are loaded from global memory into Shared Memory +by carefully updating pointers and predicates. +Once the convolution matrix is formed in Shared Memory, +the existing warp-level GEMM components accumulate the result of convolution and update the output tensor. This section describes the structure of an efficient Implicit GEMM Convolution CUDA kernel @@ -158,7 +162,7 @@ To get the best performance, the following parameters are recommended. - Channel count (C) is a multiple of 32 elements - Filter count (K) is a multiple of 32 elements -This enables 128-bit vector memory acceses which lead to efficient CUDA kernels. Smaller alignment is supported even on tensor cores by setting AlignmentA and AlignmentB in conv::kernel::DefaultConv2dFprop, but the performance is lower than 128-bit aligned tesnors. +This enables 128-bit vector memory acceses which lead to efficient CUDA kernels. Smaller alignment is supported even on tensor cores by setting AlignmentA and AlignmentB in `conv::kernel::DefaultConv2dFprop`, but the performance is lower than 128-bit aligned tensors. # CUTLASS Device-level Convolution Operator @@ -187,12 +191,12 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< SwizzleThreadBlock, // optional function to reorder threadblocks for locality NumStages, // number of pipeline stages in threadblock-scoped GEMM cutlass::arch::OpMultiplyAddSaturate, // math operation on data of element a and b - cutlass::conv::IteratorAlgorithm::kOptimized // globabl memory iterator algorithm + cutlass::conv::IteratorAlgorithm::kOptimized // global memory iterator algorithm >::Kernel ``` This template is intended to be generic and cover all feasible configurations. The example specifies -the following concrete data types, layouts, and tile sizes. +the following concrete data types, layouts, and tile shapes. ```c++ /// Define an Implicit GEMM convolution forward propagation (fprop) kernel @@ -219,7 +223,7 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< SwizzleThreadBlock, // optional function to reorder threadblocks for locality 2, // number of pipeline stages in threadblock-scoped GEMM cutlass::arch::OpMultiplyAddSaturate, // math operation on data of element a and b - cutlass::conv::IteratorAlgorithm::kOptimized // globabl memory iterator algorithm + cutlass::conv::IteratorAlgorithm::kOptimized // global memory iterator algorithm >::Kernel ``` @@ -227,7 +231,7 @@ That is, this computes 2D convolutional forward propagation with 4-bit integer i Internal accumulation is performed using 32-bit integers (`int32_t`), and an elementwise linear combination operation is performed on the output in single-precision floating point (`float`). -The threadblock and warp-level tile sizes refer to the hierarhically blocked GEMM computation +The threadblock and warp-level tile shapes refer to the hierarchically blocked GEMM computation [described here](/media/docs/gemm_api.md). Larger tiles achieve greater reuse of data loaded through shared memory but launch fewer CTAs and may not fully occupy the GPU for small problem sizes. Smaller tile configurations achieve lower peak utilizations but may better match the number of SMs within the GPU for real-world workloads. @@ -344,13 +348,13 @@ creating GEMM-A tile in shared memory. - [conv2d_fprop_filter_tile_access_iterator_optimized.h](/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h) optimizes iterating over global memory and creating GEMM-B tile in shared memory. -The improvements covered by optimized iterators are: -- (a) Precomputing kernel-invariant pointer deltas on the host -- (b) Computing cta-invariant mask predicates on device-side iterator ctors -- (c) Use of [fast divmod](/include/cutlass/fast_math.h) to map GEMM dimensions to convolution tensors. -For example, _optimized_ activation iterator uses fast divmod to map GEMM _M_ to NPQ -for activation iterator +The improvements covered by optimized iterators are: +a. Precomputing kernel-invariant pointer deltas on the host +b. Computing cta-invariant mask predicates on device-side iterator ctors +c. Use of [fast divmod](/include/cutlass/fast_math.h) to map GEMM dimensions to convolution tensors. + +For example, an _optimized_ activation iterator uses fast divmod to map GEMM _M_ to NPQ. **Pipelined mainloop** loads threadblock-scoped tiles from global memory into shared memory and then applies CUTLASS warp-level GEMM operations to load from Shared Memory and issue instructions to Turing Tensor Cores. @@ -483,7 +487,7 @@ inc_next[2] = ( } ``` -This allows only a simple lookup from the _delta table_ performed in device code in `Conv2dFpropActivationTileAccessIteratorOptimized::advance()` +This allows only a simple lookup from the _delta table_ performed in device code in `Conv2dFpropActivationTileAccessIteratorOptimized::advance()`. ```c++ // cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -516,17 +520,17 @@ void advance() { ``` -### Utilizing Tensor Cores +### Making use of Tensor Cores Turing Tensor Cores compute matrix multiply-accumulate operations efficiently by sharing data among all threads within a warp. The following operations are supported. -|**Shape**|**A**|**B**|**C**| -|---------|-----|-----|-----| -| 8x8x32 | int4b_t | int4b_t | int32_t | -| 8x8x16 | int8b_t | int8b_t | int32_t | -| 16x8x8 | half | half | half | -| 16x8x8 | half | half | float | +| **Shape** | **A** | **B** | **C** | +|-----------|---------|---------|---------| +| 8x8x32 | int4b_t | int4b_t | int32_t | +| 8x8x16 | int8b_t | int8b_t | int32_t | +| 16x8x8 | half | half | half | +| 16x8x8 | half | half | float | Functionally, the Turing 8x8x32 matrix multiply operation distributes the _A_, _B_, and _C_ matrix across 32 threads within a warp according to the following illustration. @@ -551,7 +555,7 @@ asm volatile( : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); ``` -To efficiently load data from Shared Memory into registers with the distribution among +To load data efficiently from Shared Memory into registers with the distribution among warps matching the above, the Turing GPU architecture introduces [`ldmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix). `ldmatrix` is the ultimate warp-cooperative instruction, as all threads contribute addresses to up to 32 row vectors of @@ -652,8 +656,11 @@ CUTLASS structures this as several components: ## Unit Tests Unit tests verify the functional behavior of each of the above components in a standalone CUDA kernel. This provides a -convenient environment to (a.) inspect the template definition, (b.) showcase instantiation of use of these templates -in device code, and (c.) assert functional correctness. +convenient environment to + +a. inspect the template definition, +b. showcase instantiation of use of these templates in device code, and +c. assert functional correctness. **Convolution unit tests** - Device-wide convolution operator: [conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu](/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm75.cu) diff --git a/media/docs/pipeline.md b/media/docs/pipeline.md index ccf83859..1107b820 100644 --- a/media/docs/pipeline.md +++ b/media/docs/pipeline.md @@ -149,7 +149,7 @@ if (thread_idx == 0 or thread_idx == 1) { // If any memory operations are involved, then we also need // to guarantee that writes are completed and visible to consumer(s). - pipeline.producer_commit(smem_pipe_write.index()); + pipeline.producer_commit(smem_pipe_write); ++smem_pipe_write; } } diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 8f41a730..cd4a1f0f 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -181,8 +181,7 @@ $ ./tools/profiler/cutlass_profiler --operation=gemm --help GEMM - [enum] --gemm_kind Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array) - [enum] --split_k_mode Variant of split K mode(serial, parallel) + [enum] --gemm_kind Variant of GEMM (e.g. universal, gemm, planar_complex, planar_complex_array) [int] --m,--problem-size::m M dimension of the GEMM problem space [int] --n,--problem-size::n N dimension of the GEMM problem space [int] --k,--problem-size::k K dimension of the GEMM problem space @@ -191,9 +190,10 @@ GEMM [tensor] --C Tensor storing the C operand [scalar] --alpha,--epilogue::alpha Epilogue scalar alpha [scalar] --beta,--epilogue::beta Epilogue scalar beta + [enum] --split_k_mode,--split-k-mode Variant of split K mode(serial, parallel) [int] --split_k_slices,--split-k-slices Number of partitions of K dimension [int] --batch_count,--batch-count Number of GEMMs computed in one batch - [enum] --op_class,--opcode-class Class of math instruction (simt, tensorop, wmmatensorop, wmma) + [enum] --op_class,--opcode-class Class of math instruction (simt, tensorop, wmmatensorop, wmma). [enum] --accum,--accumulator-type Math instruction accumulator data type [int] --cta_m,--threadblock-shape::m Threadblock shape in the M dimension [int] --cta_n,--threadblock-shape::n Threadblock shape in the N dimension @@ -225,9 +225,6 @@ Schmoo over accumulator types: Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t): $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row -Profile a particular problem size with split K and parallel reduction: - $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128 - Using various input value distribution: $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3 $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3 diff --git a/media/docs/programming_guidelines.md b/media/docs/programming_guidelines.md index 8e454fa4..aba270aa 100644 --- a/media/docs/programming_guidelines.md +++ b/media/docs/programming_guidelines.md @@ -39,33 +39,33 @@ and function inlining. ### Constant Memory -Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel -launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several -offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements -of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists +Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel +launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several +offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements +of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists only of the internal global memory pointer. -CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing -a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant +CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing +a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant memory, and (2.) there is no overhead to compute the initial state by each thread. -The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class -which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` -structure should also include a data member corresponding to each data member in the parent class, so these too can -be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as +The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class +which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` +structure should also include a data member corresponding to each data member in the parent class, so these too can +be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument. ### Composable Shared Memory -Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm -introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held -in shared memory. Any object requiring shared memory storage for itself or its data members should define a child -structure called `SharedStorage`. This holds data needed by the class and also instantiates `SharedStorage` +Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm +introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held +in shared memory. Any object requiring shared memory storage for itself or its data members should define a child +structure called `SharedStorage`. This holds data needed by the class and also instantiates `SharedStorage` objects for each data member. -To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. -Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes +To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. +Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes of child objects are known to be non-overlapping, `union`s may be used to alias multiple SharedStorage objects to the same shared memory region and reduce overall shared memory capacity. Developers should carefully note that C++ `union` rules require that they only access the most recently written ("active") member of the `union`; this differs from C rules. @@ -80,7 +80,7 @@ Consequently, most loops within the CUTLASS GEMM implementation are specified by is able to unroll the loop bodies, map array elements to registers, and construct an efficient instruction schedule. All loops expected to be unrolled should be annotated with `CUTLASS_PRAGMA_UNROLL` to explicitly direct the compiler -to unroll them. +to unroll them. ```c++ int const kN = 8; @@ -89,7 +89,7 @@ Array x; // Array we would like to store in reg CUTLASS_PRAGMA_UNROLL // Directs the CUDA compiler to unroll this loop. for (int idx = 0; idx < kN; ++idx) { // Loop has constant number of iterations. - x[i] = float(idx); // Indirect access by induction variable results in + x[i] = float(idx); // Indirect access by induction variable results in // direct register access. } ``` @@ -159,16 +159,13 @@ void possibly_an_unusually_long_function_name( std::uint32_t const* bar, TypeA a, TypeB b, - TypeC c) -{ + TypeC c) { // ... the function's body ... } ``` -For function definitions only, -break the line between the parenthesis -that closes the function's parameters, -and the curly bracket +A newline should not be inserted between the parenthesis +that closes the function's parameters and the curly bracket that opens the function's body. #### If-else brackets and spacing @@ -302,9 +299,9 @@ struct Bar { #ifdef BAD_CUTLASS_SWAP namespace cutlass { +// don't do this template -void swap(T& a, T& b) // don't do this -{ +void swap(T& a, T& b) { T tmp = a; a = b; b = tmp; @@ -324,8 +321,7 @@ using cutlass::swap; // and that T is constrained via // std::enable_if or a requires clause. template -void foo(T& a, T& b) -{ +void foo(T& a, T& b) { // The usual idiom for using std::swap is the "swap two-step": // // 1. import std::swap into the current scope, then @@ -340,8 +336,7 @@ void foo(T& a, T& b) } // namespace other -int main() -{ +int main() { int x = 42; int y = 43; other::foo(x, y); @@ -415,8 +410,7 @@ struct my_computation_result { my_computation_result my_computation(float tolerance); -void foo(float tolerance) -{ +void foo(float tolerance) { // Approach 1: Use structured binding. The names // you choose on the left-hand side have nothing // to do with the struct, so it's up to you @@ -523,8 +517,7 @@ struct foo_result { bool success = false; }; -foo_result foo(std::span input) -{ +foo_result foo(std::span input) { // ... code ... // Prefer this. We know what type the function returns. @@ -539,8 +532,7 @@ However, note that this won't work if the function returns `auto`. The general rule is to avoid code duplication. ```c++ -auto foo(std::span input) -{ +auto foo(std::span input) { // ... code ... if constexpr (some_condition) { @@ -619,7 +611,7 @@ Members within classes and structures should be organized as follows: This convention follows the [CUB library](https://nvlabs.github.io/cub/) -and is also described by +and is also described by [Howard Hinnant](https://howardhinnant.github.io/classdecl.html). It also approximates the usual ordering of chapters in a typical Systems and Controls textbook. @@ -772,7 +764,7 @@ Use `#pragma once` to guard all headers. #### CUDA Built-in Variables Avoid direct access to CUDA built-in variables `threadIdx`, `blockIdx`, `blockDim`, and `gridDim` within -CUTLASS components except in special circumstances. +CUTLASS components except in special circumstances. Using built-in global variables directly within resuable components necessitates that all components use them consistently which may not be possible if CUTLASS components are used in other contexts. diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 1b8e827f..1f92a91a 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -587,9 +587,8 @@ To instantiate all operations supporting all tile sizes, data types, and alignme ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='70;75;80' -DCUTLASS_LIBRARY_KERNELS=all ``` - The above command line generates about twenty thousand kernels targeting NVIDIA Ampere, Turing, and Volta architectures. -Compiling thousands of kernels for three different architectures is time consuming. Additionaly, this would also result +Compiling thousands of kernels for three different architectures is time-consuming. Additionally, this would also result in a large binary size and on some platforms linker to fail on building the library. Enabling the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size diff --git a/python/README.md b/python/README.md new file mode 100644 index 00000000..0926ff80 --- /dev/null +++ b/python/README.md @@ -0,0 +1,180 @@ +![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") + +# CUTLASS Python Interface +The CUTLASS Python interface enables one to compile and run CUTLASS operations from within Python. + +```python +import cutlass +import numpy as np + +plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor) +A, B, C, D = [np.ones((4096, 4096), dtype=np.float16) for i in range(4)] +plan.run(A, B, C, D) +``` + +**NOTE** The CUTLASS Python interface is currently an experimental release. The API may change in the future. +We welcome feedback from the community. + +## Overview +The CUTLASS Python interface aims to provide an ease-of-use interface for using CUTLASS via Python. Toward this goal, +the CUTLASS Python interface attempts to: + +* Present high-level interfaces for operators that require only few parameters +* Select sensible default configurations for an operator given the parameters that have been specified +* Enumerate configurations for users that are known to work in a given setting +* Reduce the occurrence of C++ compile-time errors in favor of descriptive Python exceptions +* Make it easy to export CUTLASS kernels to framework extensions (e.g., PyTorch CUDA extensions) + +### Non-goals +The CUTLASS Python interface does not intended to: + +**Select optimal kernel configurations.** +As an ease-of-use interface, the default selections for operator parameters made by the CUTLASS Python interface may +not achieve the highest possible performance in all scenarios. Users wishing to achieve the highest performance possible +should consider profile different combinations of configuration parameters, or use a library such as [cuBLAS](https://developer.nvidia.com/cublas) +that contains heuristics for selecting kernels. + +**Act as a fast container for CUTLASS kernels.** +The CUTLASS Python interface does not strive to minimize overhead in its Python functions surrounding the running of a kernel. +Those wishing to deploy a CUTLASS kernel should consider either using the C++ emitted by the Python interface directly, or using +one of the CUTLASS emitters for automatically creating a framework extension for the kernel (e.g., a PyTorch CUDA extension). + +**Act as a Python-to-CUDA-kernel JIT compilation engine.** +The CUTLASS Python interface intends to enable one to use CUTLASS via Python. It can be used by frameworks for JIT compiling +Python to CUDA kernels, but does not set out to be such a framework. + +### Comparison to PyCUTLASS +The CUTLASS Python interface builds atop CUTLASS's [PyCUTLASS](https://github.com/NVIDIA/cutlass/tree/v3.0.0/tools/library/scripts/pycutlass) library. PyCUTLASS enables +one to declare, compile, and run GEMMs, convolutions, and grouped GEMM operators with nearly the same configuration +space as CUTLASS's C++ interface. While this flexibility enables one to achieve the similar levels of functionality +as available in CUTLASS's C++ interface, it comes with the burden of needing to specify many configuration parameters +to operators -- similar to what one must do in specifying template parameters to operations in CUTLASS's C++ interface. + +In contrast, the CUTLASS Python interface aims to provide a higher-level API for declaring, emitting, and compiling +kernels that does not require exhaustively defining template parameters. + +#### Transitioning from PyCUTLASS +At present, existing PyCUTLASS functionality remains available via the CUTLASS Python interface. One can +continue to use PyCUTLASS by replacing references to the PyCUTLASS `cutlass` module with `cutlass_bindings` +and the PyCUTLASS `pycutlass` module with `cutlass.backend`. + +For example, the following code using PyCUTLASS: +```python +import pycutlass +import cutlass + +math_inst = pycutlass.MathInstruction( + [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, + cutlass.OpClass.Simt, pycutlass.MathOperation.multiply_add +) +``` + +can work with the Python interface via: +```python +import cutlass.backend as pycutlass +import cutlass_bindings + +math_inst = pycutlass.MathInstruction( + [1, 1, 1], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, + cutlass_bindings.OpClass.Simt, pycutlass.MathOperation.multiply_add +) +``` + +**NOTE:** backwards compatibility of `cutlass.backend` with `pycutlass` will not be maintained moving forward. + +## Current functionality +The CUTLASS Python interface currently supports the following operations: +* GEMMs +* GEMMs with fused elementwise epilogues (e.g., ReLU) (for pre-SM90 kernels) +* Stream K swizzling (for pre-SM90 kernels) +* Grouped GEMM (for pre-SM90 kernels) + +## Getting started +We recommend using the CUTLASS Python interface via one of the Docker images located in the [docker](/python/docker) directory. + +```bash +docker build -t cutlass-cuda12.0:latest -f docker/Dockerfile-cuda12.0-pytorch . +docker run --gpus all -it --rm cutlass-cuda12.0:latest +``` + +The CUTLASS Python interface has been tested with CUDA 11.8 and CUDA 12.0 on Python 3.8.10 and 3.9.7. + +### Optional environment variables +Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables: +* `CUTLASS_PATH`: the path to the cloned CUTLASS repository +* `CUDA_INSTALL_PATH`: the path to the installation of CUDA + +If these environment variables are not set, the installation process will infer them to be the following: +* `CUTLASS_PATH`: one directory level above the current directory (i.e., `$(pwd)/..`) +* `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`) + +**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`. + +### Installation +The CUTLASS Python interface can currently be installed via: +```bash +python setup.py develop --user +``` +This will allow changes to the Python interface source to be reflected when using the Python interface. + +We plan to add support for installing via `python setup.py install` in a future release. + +## Examples +Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python). + +To launch these notebooks from this directory, run: +```bash +jupyter-lab ../examples/python +``` + +## Building documentation +The CUTLASS Python interface uses [Sphinx](https://www.sphinx-doc.org/en/master/) for documentation. + +Building the documentation requires additional packages. These can be installed via: +```bash +sudo apt-get install pandoc +pip install --upgrade Sphinx furo pandoc myst-parser sphinx-copybutton nbsphinx nbsphinx-link sphinx-inline-tabs +``` + +To build documentation, you must first have installed the CUTLASS Python interface via the +[installation instructions](#installation). + +Documentation can then be built via the following commands: +```bash +sphinx-apidoc -o docs_src/source/ cutlass/ cutlass/backend* +cd docs_src +make html +mv _build/* ../docs +``` + +# Copyright + +Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py new file mode 100644 index 00000000..03f96437 --- /dev/null +++ b/python/cutlass/__init__.py @@ -0,0 +1,117 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import logging +import os +import sys + + +def _cutlass_path_from_dir() -> str: + cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../') + if not os.path.isdir(cutlass_path): + raise Exception(f'Environment variable "CUTLASS_PATH" is not defined, ' + f'and default path of {cutlass_path} does not exist.') + return cutlass_path + + +def _cuda_install_path_from_nvcc() -> str: + import subprocess + # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC + result = subprocess.run(['which', 'nvcc'], capture_output=True) + if result.returncode != 0: + raise Exception(f'Unable to find nvcc via `which` utility.') + + cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0] + if not os.path.isdir(cuda_install_path): + raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, ' + f'and default path of {cuda_install_path} does not exist.') + + return cuda_install_path + + +CUTLASS_PATH = os.getenv("CUTLASS_PATH", _cutlass_path_from_dir()) +CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc()) +CACHE_FILE = "compiled_cache.db" + +# Add the path to the CUTLASS profiler generation/manifest scripts to PYTHONPATH +sys.path.insert(0, os.path.join(CUTLASS_PATH, "tools/library/scripts/")) + +# Import types/methods from the CUTLASS utility libraries for profiler generation/emission under +from library import ( + ArchitectureNames, + DataType, + DataTypeSize, + EpilogueFunctor, + GemmKind, + LayoutTag, + LayoutType, + KernelScheduleSuffixes, + KernelScheduleType, + KernelScheduleTag, + MathInstruction, + MathOperation, + OpcodeClass, + OperationKind, + SharedMemPerCC, + SwizzlingFunctor, + TensorDescription, + TileDescription, +) + +this = sys.modules[__name__] +this.logger = logging.getLogger(__name__) + +def set_log_level(level: int): + """ + Sets the log level + + :param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options + :type log_level: int + """ + this.logger.setLevel(level) + +set_log_level(logging.ERROR) + +from cutlass.library_defaults import OptionRegistry +from cutlass.backend.utils.device import device_cc + +this.option_registry = OptionRegistry(device_cc()) + +this.__version__ = '3.1.0' + +from cutlass.backend import get_memory_pool +from cutlass.emit.pytorch import pytorch +from cutlass.op.gemm import Gemm +from cutlass.op.gemm_grouped import GroupedGemm +from cutlass.op.op import OperationBase + +get_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32) diff --git a/python/cutlass/backend/__init__.py b/python/cutlass/backend/__init__.py new file mode 100644 index 00000000..92db1479 --- /dev/null +++ b/python/cutlass/backend/__init__.py @@ -0,0 +1,27 @@ +# module-wide variables +import os + +from cutlass.backend.arguments import * +from cutlass.backend.c_types import * +from cutlass.backend.compiler import ArtifactManager +from cutlass.backend.conv2d_operation import * +from cutlass.backend.epilogue import * +from cutlass.backend.frontend import * +from cutlass.backend.gemm_operation import * +from cutlass.backend.library import * +from cutlass.backend.memory_manager import PoolMemoryManager +from cutlass.backend.operation import * +from cutlass.backend.parser import * +from cutlass.backend.reduction_operation import * +from cutlass.backend.tensor_ref import * +from cutlass.backend.type_hint import * +from cutlass.backend.utils import * +from cutlass.backend.utils.device import device_cc +from cutlass.backend.utils.software import ( + CheckPackages, + SubstituteTemplate, + device_sm_count, + get_memory_pool, +) + +compiler = ArtifactManager() diff --git a/tools/library/scripts/pycutlass/src/pycutlass/arguments.py b/python/cutlass/backend/arguments.py similarity index 79% rename from tools/library/scripts/pycutlass/src/pycutlass/arguments.py rename to python/cutlass/backend/arguments.py index 8329c40a..68c8638d 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +++ b/python/cutlass/backend/arguments.py @@ -29,38 +29,37 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# -from .frontend import CupyFrontend -from typeguard import typechecked -from pycutlass.frontend import * + from typing import Union + +from cuda import cuda, cudart import numpy as np -from cuda import cuda -try: + +from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend +from cutlass.backend.utils.software import CheckPackages + +torch_available = CheckPackages().check_torch() +if torch_available: import torch - torch_available = True -except ImportError: - torch_available = False -from cuda import cudart -try: + +cupy_available = CheckPackages().check_cupy() +if cupy_available: import cupy as cp - cupy_available = True -except ImportError: - cupy_available = False -# @typechecked class ArgumentBase: """ Base class for operation arguments """ - def __init__(self, - A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', - B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', - C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', - D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', - **kwargs) -> None: - + def __init__( + self, + A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + **kwargs, + ) -> None: # tensor_C can be interpreted as the bias with bias=True in keyword args if "bias" in kwargs.keys(): self.bias = kwargs["bias"] @@ -93,7 +92,7 @@ def __init__(self, self.ptr_B = B self.ptr_C = C self.ptr_D = D - + elif cupy_available and isinstance(A, cp.ndarray): self.ptr_A = CupyFrontend.argument(A) self.ptr_B = CupyFrontend.argument(B) @@ -102,17 +101,19 @@ def __init__(self, # number of elements in C self.tensor_c_numel = C.size else: - raise TypeError( - "Unsupported Frontend. Only support numpy and torch") + raise TypeError("Unsupported Frontend. Only support numpy and torch") def sync(self, stream_sync=True): if stream_sync: - err, = cudart.cudaDeviceSynchronize() + (err,) = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) if hasattr(self, "host_D"): - err, = cuda.cuMemcpyDtoH( - self.host_D, self.ptr_D, self.host_D.size * self.host_D.itemsize) + (err,) = cuda.cuMemcpyDtoH( + self.host_D, + self.ptr_D, + self.host_D.size * self.host_D.itemsize, + ) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py b/python/cutlass/backend/c_types.py similarity index 63% rename from tools/library/scripts/pycutlass/src/pycutlass/c_types.py rename to python/cutlass/backend/c_types.py index f625da8f..7212e414 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -31,7 +31,13 @@ ################################################################################################# import ctypes -from pycutlass.library import * + +import cutlass_bindings +from cutlass import ( + DataType, + KernelScheduleType +) +from cutlass.backend.library import DataTypeSizeBytes class GemmCoord_(ctypes.Structure): @@ -51,6 +57,7 @@ class GemmCoordBatched_(ctypes.Structure): Wrapper around a GemmCoord that also contains batch count. This is used for encoding batched GEMM inputs to CUTLASS 3 GEMMs. """ + _fields_ = [ ("m", ctypes.c_int), ("n", ctypes.c_int), @@ -92,37 +99,122 @@ class StrideBatched_(ctypes.Structure): dtype2ctype = { - cutlass.float16: ctypes.c_uint16, - cutlass.float32: ctypes.c_float, - cutlass.float64: ctypes.c_double, - cutlass.int32: ctypes.c_int32 + cutlass_bindings.float16: ctypes.c_uint16, + cutlass_bindings.float32: ctypes.c_float, + cutlass_bindings.float64: ctypes.c_double, + cutlass_bindings.int32: ctypes.c_int32, } -def get_gemm_arguments_3x(epilogue_functor): +class GenericMainloopArguments3x_(ctypes.Structure): + """ + Structure representing the superset of possible mainloop arguments. + This structure should not be passed to kernels directly, but, rather, + be used as an input to one of the more specific schedule arguments, which + will each select those arguments relevant to the particular schedule. + """ + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ] - _EpilogueOutputOpParams = epilogue_functor.epilogue_type - class _GemmArguments(ctypes.Structure): +def get_mainloop_arguments_3x( + kernel_schedule: KernelScheduleType, + element_A, + element_B, + alignment_A: int, + alignment_B: int) -> ctypes.Structure: + """ + Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters. + + :param kernel_schedule: type of kernel schedule to be used in the mainloop + :type kerel_schedule: cutlass.KernelScheduleType + :param element_A: data type of operand A + :param element_B: data type of operand B + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + + :returns: ctypes structure to be used for the 3.x kernel's mainloop parameters + :rtype: ctypes.Structure + """ + class _MainloopArgumentsTma(ctypes.Structure): + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ] + + @staticmethod + def from_generic_mainloop_args(args: GenericMainloopArguments3x_): + return _MainloopArgumentsTma( + args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + ) + + class _MainloopArgumentsMultistage(ctypes.Structure): _fields_ = [ - ("mode", ctypes.c_int), - ("problem_size", GemmCoordBatched_), ("ptr_A", ctypes.c_void_p), ("stride_A", StrideBatched_), ("ptr_B", ctypes.c_void_p), ("stride_B", StrideBatched_), + ] + + @staticmethod + def from_generic_mainloop_args(args: GenericMainloopArguments3x_): + return _MainloopArgumentsMultistage( + args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + ) + + tma_alignment_bytes = 16 + is_tma_aligned_A = ((DataTypeSizeBytes[element_A] * alignment_A) % tma_alignment_bytes) == 0 + is_tma_aligned_B = ((DataTypeSizeBytes[element_B] * alignment_B) % tma_alignment_bytes) == 0 + is_tma_aligned = is_tma_aligned_A and is_tma_aligned_B + + if kernel_schedule == KernelScheduleType.Multistage: + return _MainloopArgumentsMultistage + elif kernel_schedule == KernelScheduleType.ScheduleAuto: + if is_tma_aligned: + return _MainloopArgumentsTma + else: + return _MainloopArgumentsMultistage + else: + if is_tma_aligned: + return _MainloopArgumentsTma + else: + raise Exception(f"Specified a kernel schedule using TMA ({kernel_schedule}), but " + "the provided data types and alignments are not properly aligned for " + "using TMA.") + + +def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _EpilogueOutputOpParams), ("ptr_C", ctypes.c_void_p), ("stride_C", StrideBatched_), ("ptr_D", ctypes.c_void_p), ("stride_D", StrideBatched_), - ("epilogue", _EpilogueOutputOpParams), ] - return _GemmArguments, _EpilogueOutputOpParams + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoordBatched_), + ("mainloop", mainloop_arguments), + ("epilogue", _EpilogueArguments) + ] + return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams -def get_gemm_arguments(epilogue_functor): +def get_gemm_arguments(epilogue_functor): _EpilogueOutputOpParams = epilogue_functor.epilogue_type class _GemmArguments(ctypes.Structure): @@ -157,10 +249,42 @@ class _GemmArguments(ctypes.Structure): return _GemmArguments, _EpilogueOutputOpParams +def get_gemm_arguments_streamk(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoord_), + ("batch_count", ctypes.c_int), + ("epilogue", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("batch_stride_A", ctypes.c_longlong), + ("batch_stride_B", ctypes.c_longlong), + ("batch_stride_C", ctypes.c_longlong), + ("batch_stride_D", ctypes.c_longlong), + ("stride_a", ctypes.c_longlong), + ("stride_b", ctypes.c_longlong), + ("stride_c", ctypes.c_longlong), + ("stride_d", ctypes.c_longlong), + ("lda", ctypes.c_longlong), + ("ldb", ctypes.c_longlong), + ("ldc", ctypes.c_longlong), + ("ldd", ctypes.c_longlong), + ("avail_sms", ctypes.c_int) + ] + + return _GemmArguments, _EpilogueOutputOpParams + + ########################################################################################### # GEMM Grouped ########################################################################################### + def get_gemm_grouped_arguments(epilogue_functor): _EpilogueOutputOpParams = epilogue_functor.epilogue_type @@ -183,10 +307,12 @@ class _GEMMGroupedArguments(ctypes.Structure): return _GEMMGroupedArguments, _EpilogueOutputOpParams + ############################################################################################ # Convolution2D ############################################################################################ + class Conv2DProblemSize(ctypes.Structure): _fields_ = [ ("N", ctypes.c_int), @@ -215,9 +341,7 @@ def __init__(self, problem_size) -> None: class Layout4D(ctypes.Structure): - _fields_ = [ - ("stride", ctypes.c_int * 3) - ] + _fields_ = [("stride", ctypes.c_int * 3)] def __init__(self, tensor_ref): stride = tensor_ref.stride() @@ -247,13 +371,13 @@ def get_conv2d_arguments(epilogue_functor): class _Conv2dArguments(ctypes.Structure): _fields_ = [ - ("problem_size", Conv2DProblemSize), # 0 - ("ref_A", TensorRef_), # 72 - ("ref_B", TensorRef_), # 96 - ("ref_C", TensorRef_), # 120 - ("ref_D", TensorRef_), # 144 - ("output_op", _EpilogueOutputOpParams), # 168 - ("split_k_mode", ctypes.c_int) # 192 + ("problem_size", Conv2DProblemSize), + ("ref_A", TensorRef_), + ("ref_B", TensorRef_), + ("ref_C", TensorRef_), + ("ref_D", TensorRef_), + ("output_op", _EpilogueOutputOpParams), + ("split_k_mode", ctypes.c_int) ] return _Conv2dArguments, _EpilogueOutputOpParams @@ -263,6 +387,7 @@ class _Conv2dArguments(ctypes.Structure): # Reduction ############################################################################################ + def get_reduction_params(epilogue_functor): _EpilogueOutputParams = epilogue_functor.epilogue_type @@ -274,6 +399,7 @@ class _ReductionParams(ctypes.Structure): ("workspace", TensorRef2D_), ("destination", TensorRef2D_), ("source", TensorRef2D_), - ("output_op", _EpilogueOutputParams) + ("output_op", _EpilogueOutputParams), ] + return _ReductionParams, _EpilogueOutputParams diff --git a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py b/python/cutlass/backend/compiler.py similarity index 66% rename from tools/library/scripts/pycutlass/src/pycutlass/compiler.py rename to python/cutlass/backend/compiler.py index 76711391..c5b1caea 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -29,32 +29,31 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# -import pycutlass -from pycutlass import * -import cutlass -from cuda import cuda -from cuda import nvrtc -import tempfile -import os -import ctypes -# +import ctypes import json +import os import sqlite3 +import tempfile +from cuda import cuda, nvrtc +import cutlass_bindings -IncludeTemplate = r'''#include "${include}" -''' +from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH +from cutlass.backend.gemm_operation import GemmOperationUniversal +from cutlass.backend.library import ApiVersion +from cutlass.backend.utils.device import device_cc +from cutlass.backend.utils.software import SubstituteTemplate -# +IncludeTemplate = r"""#include "${include}" +""" class CompilationOptions: - ''' + """ Compilation options. - ''' + """ - # def __init__(self, flags, arch, include_paths=[]): self.includes = [] self.include_paths = include_paths @@ -68,16 +67,15 @@ def get_str(self): options += " " + flag for incl in self.include_paths: - options += ' --include-path=%s' % incl + options += " --include-path=%s" % incl arch_flag = " -arch=sm_%d" % self.arch if self.arch == 90: - arch_flag += 'a' + arch_flag += "a" options += arch_flag return options - # def get(self): options = [] @@ -85,11 +83,11 @@ def get(self): options.append(bytes(str.encode(flag))) for incl in self.include_paths: - options.append(bytes(str.encode('--include-path=%s' % incl))) + options.append(bytes(str.encode("--include-path=%s" % incl))) arch_flag = " -arch=sm_%d" % self.arch if self.arch == 90: - arch_flag += 'a' + arch_flag += "a" options.append(bytes(str.encode(arch_flag))) @@ -97,16 +95,15 @@ def get(self): def convertToBinaryData(filename): - with open(filename, 'rb') as file: + with open(filename, "rb") as file: blobData = file.read() return blobData def CDLLBin(host_binary): tempfile.tempdir = "./" - temp_so = tempfile.NamedTemporaryFile( - prefix='host_func', suffix='.so', delete=True) - with open(temp_so.name, 'wb') as file: + temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True) + with open(temp_so.name, "wb") as file: file.write(host_binary) host_lib = ctypes.CDLL(temp_so.name) return host_lib @@ -118,32 +115,36 @@ class ArtifactManager: """ def __init__(self) -> None: - try: - connection = sqlite3.connect("./compiled_cache.db") - cursor = connection.cursor() - sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)""" - cursor.execute(sqlite_create_table_query) - connection.commit() - cursor.close() - except: - pass + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + # Create the table if it does not already exist + sqlite_create_table_query = """ + CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE, + cubin BLOB NOT NULL, + hostbin BLOB NOT NULL, + op_name TEXT NOT NULL, + op_attrs TEXT NOT NULL) + """ + cursor.execute(sqlite_create_table_query) + connection.commit() + cursor.close() self.nvcc() - self.compiled_cache_device = cutlass.CompileCache() - self.compiled_cache_host = cutlass.CompileCache() - + self.compiled_cache_device = cutlass_bindings.CompileCache() + self.compiled_cache_host = cutlass_bindings.CompileCache() + def nvrtc(self): self.backend = "nvrtc" - self.default_compile_options = [ - '-std=c++17', '-default-device' - ] + self.default_compile_options = ["-std=c++17", "-default-device"] def nvcc(self): self.backend = "nvcc" self.default_compile_options = [ - '-std=c++17', '--expt-relaxed-constexpr', '-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored' + "-std=c++17", + "--expt-relaxed-constexpr", + "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", ] def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): - connection = sqlite3.connect("./compiled_cache.db") + connection = sqlite3.connect(CACHE_FILE) cursor = connection.cursor() sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" @@ -156,11 +157,10 @@ def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): cursor.close() def load_operation(self, op_key, extra_funcs): - connection = sqlite3.connect("./compiled_cache.db") + connection = sqlite3.connect(CACHE_FILE) cursor = connection.cursor() sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" - # try: - cursor.execute(sqlite_fetch_blob_query, (op_key, )) + cursor.execute(sqlite_fetch_blob_query, (op_key,)) record = cursor.fetchall() if len(record) == 0: return False @@ -169,27 +169,26 @@ def load_operation(self, op_key, extra_funcs): op_attr = json.loads(op_attr) err, module = cuda.cuModuleLoadData(cubin_image) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) + raise RuntimeError("Cuda Error: {}".format(err)) - err, kernel = cuda.cuModuleGetFunction( - module, bytes(str.encode(operation_name))) + err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) self.compiled_cache_device.insert(key, kernel) compiled_host_fns = {} host_lib = CDLLBin(host_binary) - func_name = operation_name + '_get_params' + func_name = operation_name + "_get_params" func = getattr(host_lib, func_name) func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) - compiled_host_fns['get_args'] = func + compiled_host_fns["get_args"] = func - func_name = operation_name + '_shared_memory_size' + func_name = operation_name + "_shared_memory_size" func = getattr(host_lib, func_name) - compiled_host_fns['shared_memory_capacity'] = func() + compiled_host_fns["shared_memory_capacity"] = func() for attr in op_attr: if isinstance(attr, str): - func_name = operation_name + '_' + attr + func_name = operation_name + "_" + attr func = getattr(host_lib, func_name) # Set the return type of the function @@ -214,29 +213,33 @@ def emit_compile_(self, operation_list, compilation_options, requires_nvcc_hostl if incl not in includes: includes.append(incl) - includes_host = [ - "builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes + includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes for incl in includes: source_buffer_device += SubstituteTemplate( - IncludeTemplate, {'include': incl}) + IncludeTemplate, + {"include": incl}, + ) for incl in includes_host: if "/device/" not in incl: source_buffer_host += SubstituteTemplate( - IncludeTemplate, {'include': incl}) + IncludeTemplate, + {"include": incl}, + ) # 2. Operations for operation in operation_list: source_buffer_device += operation.emit() source_buffer_host += operation.emit() values = { - 'operation_name': operation.name(), - 'operation_suffix': operation.emitter.operation_suffix + "operation_name": operation.name(), + "operation_suffix": operation.emitter.operation_suffix, } source_buffer_device += SubstituteTemplate( - operation.KernelTemplate, values) - source_buffer_host += SubstituteTemplate( - operation.HostTemplate, values) + operation.KernelTemplate, + values, + ) + source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) if self.backend == "nvrtc": # 3. compile @@ -246,94 +249,104 @@ def emit_compile_(self, operation_list, compilation_options, requires_nvcc_hostl 0, [], []) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) + raise RuntimeError("NVRTC Error: {}".format(err)) # Compile program options = compilation_options.get() err, = nvrtc.nvrtcCompileProgram(program, len(options), options) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - - error_string = 'NVRTC Error: {}\n'.format(err) + error_string = "NVRTC Error: {}\n".format(err) # Get log from compilation err, logSize = nvrtc.nvrtcGetProgramLogSize(program) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) + raise RuntimeError("NVRTC Error: {}".format(err)) - log = b' ' * logSize + log = b" " * logSize err, = nvrtc.nvrtcGetProgramLog(program, log) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) + raise RuntimeError("NVRTC Error: {}".format(err)) - raise RuntimeError( - error_string + log.decode() + source_buffer_device) + raise RuntimeError(error_string + log.decode() + source_buffer_device) # Get data from compilation err, dataSize = nvrtc.nvrtcGetCUBINSize(program) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) + raise RuntimeError("NVRTC Error: {}".format(err)) - cubin_image = b' ' * dataSize - err, = nvrtc.nvrtcGetCUBIN(program, cubin_image) + cubin_image = b" " * dataSize + (err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image) if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError('NVRTC Error: {}'.format(err)) + raise RuntimeError("NVRTC Error: {}".format(err)) else: # with nvcc backend # emit code tempfile.tempdir = "./" temp_cu = tempfile.NamedTemporaryFile( - prefix='kernel', suffix='.cu', delete=True) + prefix="kernel", suffix=".cu", delete=True) temp_cubin = tempfile.NamedTemporaryFile( - prefix='kernel', suffix='.cubin', delete=True) - with open(temp_cu.name, 'w') as file: + prefix="kernel", suffix=".cubin", delete=True) + with open(temp_cu.name, "w") as file: file.write(source_buffer_device) # compile with nvcc - cuda_install_path = os.getenv('CUDA_INSTALL_PATH') - assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}" values = { - "cuda_install_path": cuda_install_path, + "cuda_install_path": CUDA_INSTALL_PATH, "options": compilation_options.get_str(), "srcfile": temp_cu.name, - "tarfile": temp_cubin.name + "tarfile": temp_cubin.name, } cmd = SubstituteTemplate(cmd_template, values) os.system(cmd) # load the cubin image - with open(temp_cubin.name, 'rb') as file: + with open(temp_cubin.name, "rb") as file: cubin_image = file.read() # Set up the host-side library code if requires_nvcc_hostlib_compilation: - cuda_install_path = os.getenv('CUDA_INSTALL_PATH') - assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." - cmd_template = "echo '%s'|${cuda_install_path}/bin/nvcc -x cu -Xcompiler=\"-fpermissive -w -fPIC\" ${options}" % source_buffer_host + cmd_template = ( + "echo '%s'|${cuda_install_path}/bin/nvcc -x cu -Xcompiler=\"-fpermissive -w -fPIC\" ${options}" + % source_buffer_host + ) cmd = SubstituteTemplate( cmd_template, { - "cuda_install_path": cuda_install_path, - "options": compilation_options.get_str() - }) + "cuda_install_path": CUDA_INSTALL_PATH, + "options": compilation_options.get_str(), + }, + ) else: options = compilation_options.get() - cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host - filtered_opts = ['-default-device', '-Xcicc', '-Xllc', '--expt-relaxed-constexpr', '-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored'] + cmd = ( + "echo '%s'|g++ -x c++ -fpermissive -w -fPIC -DCUTLASS_PYTHON_HOST_CC=1" + % source_buffer_host + ) + filtered_opts = [ + "-default-device", + "-Xcicc", + "-Xllc", + "--expt-relaxed-constexpr", + "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", + ] for opt in options: opt = opt.decode("utf-8") - if opt not in filtered_opts and '-arch=sm_' not in opt: - if '--include-path=' in opt: - cmd += " " + opt.replace('--include-path=', '-I') + if opt not in filtered_opts and "-arch=sm_" not in opt: + if "--include-path=" in opt: + cmd += " " + opt.replace( + "--include-path=", + "-I", + ) else: cmd += " " + opt tempfile.tempdir = "./" temp = tempfile.NamedTemporaryFile( - prefix='host_func', suffix='.so', delete=True) + prefix="host_func", suffix=".so", delete=True) - cmd += ' - -shared -o %s -lcudart -lcuda' % temp.name + cmd += " - -shared -o %s -lcudart -lcuda" % temp.name os.system(cmd) host_lib = ctypes.CDLL(temp.name) @@ -344,19 +357,15 @@ def add_module(self, operations, compile_options=None): Insert a new compiled device module """ if compile_options is None: - cutlass_path = os.getenv('CUTLASS_PATH') - assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." - cuda_install_path = os.getenv('CUDA_INSTALL_PATH') - assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." include_paths = [ - cuda_install_path + '/include', - cutlass_path + '/include', - cutlass_path + '/tools/util/include', - cutlass_path + '/tools/library/scripts/pycutlass/src/cpp/include' + CUDA_INSTALL_PATH + "/include", + CUTLASS_PATH + "/include", + CUTLASS_PATH + "/tools/util/include", + CUTLASS_PATH + "/python/cutlass/cpp/include", ] - if pycutlass.DEVICE_CC is not None: - arch = pycutlass.DEVICE_CC + if device_cc() is not None: + arch = device_cc() else: # Find the maximum arch tag among the provided operations and compile for that target. # Since we are compiling to .cubin files, only one architecture may be specified. @@ -374,7 +383,7 @@ def add_module(self, operations, compile_options=None): compiled_kernel = self.compiled_cache_device.at(key) if compiled_kernel is None: - hit = self.load_operation(key, getattr(operation.rt_module, 'extra_funcs', {})) + hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {})) if hit: compiled_kernel = self.compiled_cache_device.at(key) assert compiled_kernel is not None @@ -391,9 +400,9 @@ def add_module(self, operations, compile_options=None): # Creating the Params structures for certain 3.0 kernels currently requires CUDA. For these cases, use NVCC to generate # the PyCUTLASS host-side library. Otherwise, g++ will be used. - if isinstance(operation, pycutlass.gemm_operation.GemmOperationUniversal) and operation.api == pycutlass.library.ApiVersion.v3x: + if isinstance(operation, GemmOperationUniversal) and operation.api == ApiVersion.v3x: if self.backend == "nvrtc": - raise RuntimeError('CUTLASS 3 kernels currently require NVCC for compilation.') + raise RuntimeError("CUTLASS 3 kernels currently require NVCC for compilation.") requires_nvcc_hostlib_compilation = True @@ -403,7 +412,7 @@ def add_module(self, operations, compile_options=None): err, module = cuda.cuModuleLoadData(cubin_image) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) + raise RuntimeError("Cuda Error: {}".format(err)) operation_name = [] operation_attr = [] @@ -420,22 +429,22 @@ def add_module(self, operations, compile_options=None): op_attr = [] # get param size - func_name = operation.name() + '_get_param_size' + func_name = operation.name() + "_get_param_size" func = getattr(host_lib, func_name) param_size = func() - func_name = operation.name() + '_get_params' + func_name = operation.name() + "_get_params" func = getattr(host_lib, func_name) func.argtype = operation.argtype func.restype = ctypes.POINTER(ctypes.c_char * param_size) - setattr(operation, 'get_args', func) - compiled_host_fns['get_args'] = func + setattr(operation, "get_args", func) + compiled_host_fns["get_args"] = func # set shared memory size - func_name = operation.name() + '_shared_memory_size' + func_name = operation.name() + "_shared_memory_size" func = getattr(host_lib, func_name) - setattr(operation, 'shared_memory_capacity', func()) - compiled_host_fns['shared_memory_capacity'] = func() + setattr(operation, "shared_memory_capacity", func()) + compiled_host_fns["shared_memory_capacity"] = func() # set the maximum dynamic shared size operation.initialize() @@ -443,8 +452,8 @@ def add_module(self, operations, compile_options=None): op_attr.append(param_size) if hasattr(operation, "extra_funcs"): - for suffix, ret_type in operation.extra_funcs.items(): - func_name = operation.name() + '_' + suffix + for suffix, ret_type in operation.extra_funcs.items(): + func_name = operation.name() + "_" + suffix func = getattr(host_lib, func_name) if ret_type is not None: func.restype = ret_type @@ -455,6 +464,6 @@ def add_module(self, operations, compile_options=None): operation_attr.append(op_attr) self.compiled_cache_host.insert(key, compiled_host_fns) - for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr): + for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr): self.insert_operation( key, cubin_image, host_file.name, operation_name, operation_attr) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py b/python/cutlass/backend/conv2d_operation.py similarity index 63% rename from tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py rename to python/cutlass/backend/conv2d_operation.py index 0c4713cd..8dc55a25 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +++ b/python/cutlass/backend/conv2d_operation.py @@ -29,78 +29,102 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ -from typeguard import typechecked -from cuda import cuda +# from typeguard import typechecked + +import ctypes from typing import Union -import numpy as np -from typeguard import typechecked +from cuda import cuda +import cutlass_bindings +import numpy as np -from pycutlass import * +from cutlass.backend.arguments import ArgumentBase +from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments +from cutlass.backend.library import ( + ConvKindNames, + ConvKindTag, + DataTypeNames, + DataTypeSize, + DataTypeTag, + IteratorAlgorithmNames, + IteratorAlgorithmTag, + LayoutTag, + MathOperation, + MathOperationTag, + OpcodeClassNames, + OpcodeClassTag, + OperationKind, + ShortDataTypeNames, + ShortLayoutTypeNames, + StrideSupport, + StrideSupportTag, + TensorDescription, + TileDescription, + get_complex_from_real, +) +from cutlass.backend.memory_manager import device_mem_alloc +from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass.backend.tensor_ref import TensorRef +from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate + +if CheckPackages().check_torch(): + import torch # @typechecked class Conv2dArguments(ArgumentBase): """ - Argument wrapper for Conv2d. It encodes problem information and + Argument wrapper for Conv2d. It encodes problem information and user-provide tensors into the kernel's argument. :param operation: the Conv2d operation to take the argument - :type operation: :class:`pycutlass.Conv2dOperation` - + :type operation: :class:`cutlass.backend.Conv2dOperation` :param problem_size: the Conv2d problem size - :type problem_size: :class:`cutlass.conv.Conv2dProblemSize` - + :type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param B: tensor B :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param C: tensor C :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - :param D: tensor D :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray - - :param split_k_mode: conv2d split K mode, defaults to - cutlass.conv.SplitKMode.Serial - :type split_k_mode: cutlass.conv.SplitKMode, optional - + :param split_k_mode: conv2d split K mode, defaults to cutlass_bindings.conv.SplitKMode.Serial + :type split_k_mode: cutlass_bindings.conv.SplitKMode, optional :param output_op: output operator, optional - :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` - + :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` """ - def __init__(self, operation: 'Conv2dOperation', - problem_size: 'cutlass.conv.Conv2dProblemSize', - A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', - B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', - C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', - D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', - split_k_mode: 'cutlass.conv.SplitKMode' - = cutlass.conv.SplitKMode.Serial, **kwargs) -> None: - + def __init__( + self, + operation: "Conv2dOperation", + problem_size: "cutlass_bindings.conv.Conv2dProblemSize", + A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + split_k_mode: "cutlass_bindings.conv.SplitKMode" = cutlass_bindings.conv.SplitKMode.Serial, + **kwargs, + ) -> None: self.operation = operation #: convolution kind - self.conv_kind: cutlass.conv.Operator = operation.conv_kind - self.layout_A: cutlass.layout = operation.A.layout - self.layout_B: cutlass.layout = operation.B.layout - self.layout_C: cutlass.layout = operation.C.layout + self.conv_kind: cutlass_bindings.conv.Operator = operation.conv_kind + self.layout_A: cutlass_bindings.layout = operation.A.layout + self.layout_B: cutlass_bindings.layout = operation.B.layout + self.layout_C: cutlass_bindings.layout = operation.C.layout self.element_A = operation.A.element self.element_B = operation.B.element self.element_C = operation.C.element - if self.layout_C == cutlass.TensorNC32HW32: + if self.layout_C == cutlass_bindings.TensorNC32HW32: B = self.reorder_tensor_B(B, problem_size) super().__init__(A, B, C, D, **kwargs) # preprocessing output ops - - if 'output_op' in kwargs.keys() and \ - split_k_mode != cutlass.conv.SplitKMode.Parallel: - self.output_op = kwargs['output_op'] + + if "output_op" in kwargs.keys() and split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel: + self.output_op = kwargs["output_op"] else: self.output_op = self.operation.epilogue_type(1.0, 0.0) @@ -108,18 +132,17 @@ def __init__(self, operation: 'Conv2dOperation', self.split_k_mode = split_k_mode self.split_k_slices = kwargs["split_k_slices"] else: - self.split_k_mode = cutlass.conv.SplitKMode.Serial + self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial self.split_k_slices = 1 #: problem_size - self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size + self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size self.problem_size.split_k_slices = self.split_k_slices if hasattr(self, "tensor_c_numel"): - c_coord = cutlass.conv.implicit_gemm_tensor_c_extent( + c_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( self.conv_kind, problem_size) - if (self.tensor_c_numel == c_coord.at(3) and - self.tensor_c_numel < c_coord.size()): + if self.tensor_c_numel == c_coord.at(3) and self.tensor_c_numel < c_coord.size(): self.bias = True # @@ -128,15 +151,15 @@ def __init__(self, operation: 'Conv2dOperation', self.initialize() # @typechecked - def reorder_tensor_B(self, tensor_B: 'np.ndarray', - problem_size: 'cutlass.conv.Conv2dProblemSize'): + def reorder_tensor_B(self, tensor_B: "np.ndarray", + problem_size: "cutlass_bindings.conv.Conv2dProblemSize"): """ Reorder tensor_B for interleaved layout :param tensor_B: input tensor B :type tensor_B: numpy.ndarray :param problem_size: Conv2d problem size - :type problem_size: :class:`cutlass.conv.Conv2dProblemSize` + :type problem_size: :class:`cutlass_bindings.conv.Conv2dProblemSize` :return: reordered tensor B :rtype: numpy.ndarray @@ -145,9 +168,8 @@ def reorder_tensor_B(self, tensor_B: 'np.ndarray', tensor_ref_B = self.get_tensor_ref( tensor_B, self.element_B, self.layout_B, problem_size, "b") reordered_tensor_ref_B = self.get_tensor_ref( - reordered_tensor_B, self.element_B, - self.layout_B, problem_size, "b") - cutlass.conv.host.reorder_convK( + reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b") + cutlass_bindings.conv.host.reorder_convK( reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size) return reordered_tensor_B @@ -155,19 +177,19 @@ def reorder_tensor_B(self, tensor_B: 'np.ndarray', def get_tensor_ref( self, tensor, dtype, tensor_layout, problem_size, operand): if operand == "a": - tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent( + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent( self.conv_kind, problem_size) elif operand == "b": - tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent( + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent( self.conv_kind, problem_size) elif operand in ["c", "d"]: - tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent( + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( self.conv_kind, problem_size) else: raise ValueError("unknown operand: " + operand) # Zero stride trick if operand == "c" and self.bias: - tensor_coord = cutlass.Tensor4DCoord(0, 0, 0, 0) + tensor_coord = cutlass_bindings.Tensor4DCoord(0, 0, 0, 0) layout = tensor_layout.packed(tensor_coord) @@ -185,25 +207,16 @@ def get_arguments(self, semaphore): self.c_arguments = self.operation.argument_type( Conv2DProblemSize(self.problem_size), - ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode - ) + ref_A, ref_B, ref_C, ref_D, self.output_op, self.split_k_mode) self.semaphore = semaphore def initialize(self): - """ - Initialize the kernel arguments handling following stuffs - 1. get kernel launch configuration including grid, cta size, - and dynamic shared memory capacity - 2. allocate and initialize device workspace - 3. get kernel params as bytearray for NVRTC input - """ - # get launch configuration + # Get launch configuration self.launch_config = self.operation.rt_module.plan(self) - # allocate and initialize device workspace - device_workspace_size = \ - self.operation.rt_module.get_device_workspace_size(self) + # Allocate and initialize device workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) if device_workspace_size > 0: self.workspace_buffer = device_mem_alloc(device_workspace_size) @@ -213,25 +226,25 @@ def initialize(self): else: workspace_ptr = None - # get kernel params as bytearray + # Get kernel params as a bytearray semaphore = 0 - if workspace_ptr is not None and \ - self.split_k_mode == cutlass.conv.SplitKMode.Parallel: + if (workspace_ptr is not None + and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel): self.ptr_D = workspace_ptr - elif workspace_ptr is not None and \ - self.split_k_mode == cutlass.conv.SplitKMode.Serial: + elif (workspace_ptr is not None + and self.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial): semaphore = workspace_ptr self.get_arguments(semaphore) - params_ = self.operation.rt_module.get_args(ctypes.byref( - self.c_arguments), ctypes.c_void_p(int(self.semaphore))) + params_ = self.operation.rt_module.get_args( + ctypes.byref(self.c_arguments), ctypes.c_void_p(int(self.semaphore))) self.host_workspace = bytearray(params_.contents) self.device_workspace = None def sync(self): """ - Synchronize the arguments. If the input tensor is in host, + Synchronize the arguments. If the input tensor is in host, copy it from device to host. """ return super().sync() @@ -242,7 +255,8 @@ class Conv2dRT(ExecutableOperation): """ Conv2dRT manages the CUTLASS runtime components """ - KernelTemplate = r''' + + KernelTemplate = r""" extern "C" __global__ void ${operation_name}(${operation_name}${operation_suffix}::Params params) { @@ -258,9 +272,9 @@ class Conv2dRT(ExecutableOperation): op(params, *shared_storage); } - ''' + """ - HostTemplate = r''' + HostTemplate = r""" extern "C" { // Get the size of params in bytes int ${operation_name}_get_param_size(){ @@ -286,9 +300,9 @@ class Conv2dRT(ExecutableOperation): } } - ''' + """ - def __init__(self, operation: 'Conv2dOperation'): + def __init__(self, operation: "Conv2dOperation"): super().__init__(operation) self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor) self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p] @@ -296,7 +310,7 @@ def __init__(self, operation: 'Conv2dOperation'): self.operation: Conv2dOperation = operation - self.emitter = EmitConv2dInstance('_type') + self.emitter = EmitConv2dInstance("_type") self.threads: int = operation.tile_description.num_threads @@ -305,7 +319,6 @@ def __init__(self, operation: 'Conv2dOperation'): def emit(self): return self.emitter.emit(self.operation) - # @typechecked def get_device_workspace_size(self, arguments: Conv2dArguments): workspace_bytes = 0 @@ -313,13 +326,14 @@ def get_device_workspace_size(self, arguments: Conv2dArguments): self.conv_kind = self.operation.conv_kind - if arguments.split_k_mode == cutlass.conv.SplitKMode.Parallel: + if arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: problem_size = arguments.problem_size workspace_bytes = DataTypeSize[self.operation.C.element] \ - * launch_config.grid[2] * cutlass.conv.implicit_gemm_tensor_c_size( + * launch_config.grid[2] * cutlass_bindings.conv.implicit_gemm_tensor_c_size( self.conv_kind, problem_size ) // 8 - elif arguments.split_k_mode == cutlass.conv.SplitKMode.Serial and \ + + elif arguments.split_k_mode == cutlass_bindings.conv.SplitKMode.Serial and \ arguments.split_k_slices > 1: workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4 @@ -327,20 +341,20 @@ def get_device_workspace_size(self, arguments: Conv2dArguments): # @typechecked def plan(self, arguments: Conv2dArguments): - tile_size = cutlass.gemm.GemmCoord( + tile_size = cutlass_bindings.gemm.GemmCoord( self.operation.tile_description.threadblock_shape[0], self.operation.tile_description.threadblock_shape[1], - self.operation.tile_description.threadblock_shape[2] + self.operation.tile_description.threadblock_shape[2], ) grid = self.swizzle_functor.get_grid_shape( self.swizzle_functor.get_tiled_shape( - self.conv_kind, arguments.problem_size, + self.conv_kind, arguments.problem_size, tile_size, arguments.split_k_slices ) ) return LaunchConfiguration( - [grid.x, grid.y, grid.z], [self.threads, 1, 1], + [grid.x, grid.y, grid.z], [self.threads, 1, 1], self.shared_memory_capacity) def initialize(self): @@ -349,9 +363,7 @@ def initialize(self): attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, value=self.shared_memory_capacity) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) - -# + raise RuntimeError("Cuda Error: {}".format(err)) class Conv2dOperation: @@ -359,53 +371,56 @@ class Conv2dOperation: CUTLASS Conv2d operation description. :param conv_kind: convolution operator - :type conv_kind: :class:`cutlass.conv.Operator` + :type conv_kind: :class:`cutlass_bindings.conv.Operator` :param iterator_algorithm: Selects among several implementation variants trading off performance with simplicity - :type iterator_algorithm: :class:`cutlass.conv.IteratorAlgorithm` + :type iterator_algorithm: :class:`cutlass_bindings.conv.IteratorAlgorithm` :param arch: GPU compute capability (sm_xx) :type arch: int :param tile_description: tile description - :type tile_description: :class:`pycutlass.TileDescription` + :type tile_description: :class:`cutlass.backend.TileDescription` :param A: tensor A description - :type A: :class:`pycutlass.TensorDescription` + :type A: :class:`cutlass.backend.TensorDescription` :param B: tensor B description - :type B: :class:`pycutlass.TensorDescription` + :type B: :class:`cutlass.backend.TensorDescription` :param C: tensor C description - :type C: :class:`pycutlass.TensorDescription` + :type C: :class:`cutlass.backend.TensorDescription` :param D: tensor D description - :type D: :class:`pycutlass.TensorDescription` + :type D: :class:`cutlass.backend.TensorDescription` :param element_epilogue: element type for computation in epilogue \ - :type element_epilogue: cutlass.int8 | cutlass.int32 | cutlass.float16 | \ - cutlass.bfloat16 | cutlass.float32 | cutlass.float64 + :type element_epilogue: cutlass_bindings.int8 | cutlass_bindings.int32 | cutlass_bindings.float16 | \ + cutlass_bindings.bfloat16 | cutlass_bindings.float32 | cutlass_bindings.float64 :param stride_support: distinguish among partial specializations that \ accelerate certain problems where convolution stride is unit \ - :type stride_support: :class:`cutlass.conv.StrideSupport` + :type stride_support: :class:`cutlass_bindings.conv.StrideSupport` :param epilogue_functor: convolution epilogue functor :type epilogue_functor: :class:`EpilogueFunctor` :param swizzling_functor: threadblock swizzling functor """ - # - - def __init__(self, - conv_kind: cutlass.conv.Operator, - iterator_algorithm: cutlass.conv.IteratorAlgorithm, - arch: int, tile_description: TileDescription, - A: TensorDescription, B: TensorDescription, C: TensorDescription, - stride_support, epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1): - + def __init__( + self, + conv_kind: cutlass_bindings.conv.Operator, + iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm, + arch: int, + tile_description: TileDescription, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + stride_support, + epilogue_functor, + swizzling_functor=cutlass_bindings.IdentitySwizzle1 + ): self.operation_kind: OperationKind = OperationKind.Conv2d self.arch: int = arch self.tile_description: TileDescription = tile_description @@ -427,17 +442,18 @@ def run(self, arguments: Conv2dArguments) -> cuda.CUresult: Launch the cuda kernel with input arguments :param arguments: conv2d arguments - :type arguments: :class:`pycutlass.Conv2dArguments` + :type arguments: :class:`cutlass.backend.Conv2dArguments` """ # launch the kernel err = self.rt_module.run( arguments.host_workspace, arguments.device_workspace, - arguments.launch_config) + arguments.launch_config, + ) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) + raise RuntimeError("CUDA Error %s" % str(err)) return err @@ -446,20 +462,23 @@ def run(self, arguments: Conv2dArguments) -> cuda.CUresult: # def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + """The full procedural name indicates architecture, extended name, tile size, and layout.""" return self.configuration_name() + # def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + """The full procedural name indicates architecture, extended name, tile size, and layout.""" - opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + opcode_class_name = OpcodeClassNames[ + self.tile_description.math_instruction.opcode_class + ] threadblock = "%dx%d_%dx%d" % ( self.tile_description.threadblock_shape[0], self.tile_description.threadblock_shape[1], self.tile_description.threadblock_shape[2], - self.tile_description.stages + self.tile_description.stages, ) if self.stride_support == StrideSupport.Unity: @@ -470,18 +489,18 @@ def configuration_name(self): return SubstituteTemplate( configuration_name, { - 'arch': str(self.arch), - 'opcode_class': opcode_class_name, - 'extended_name': self.extended_name(), - 'threadblock': threadblock, - 'layout': self.layout_name(), - 'alignment': "%d" % self.A.alignment, - } + "arch": str(self.arch), + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment + }, ) # def extended_name(self): - ''' Append data types if they differ from compute type. ''' + """Append data types if they differ from compute type.""" if self.C.element != self.tile_description.math_instruction.element_accumulator and \ self.A.element != self.tile_description.math_instruction.element_accumulator: extended_name = "${element_c}_${core_name}_${element_a}" @@ -492,9 +511,9 @@ def extended_name(self): extended_name = "${core_name}" extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), }) return extended_name @@ -505,27 +524,32 @@ def layout_name(self): # def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + """The basic operation kind is prefixed with a letter indicating the accumulation type.""" - intermediate_type = '' + intermediate_type = "" - if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp: + if self.tile_description.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp: inst_shape = "%dx%dx%d" % tuple( self.tile_description.math_instruction.instruction_shape) if self.tile_description.math_instruction.element_a != self.A.element and \ self.tile_description.math_instruction.element_a != self.accumulator_type(): intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] else: - inst_shape = '' - - return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], - inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + inst_shape = "" + + return "%s%s%s%s_%s" % ( + ShortDataTypeNames[self.accumulator_type()], + inst_shape, + intermediate_type, + ConvKindNames[self.conv_kind], + IteratorAlgorithmNames[self.iterator_algorithm] + ) # def is_complex(self): complex_operators = [ MathOperation.multiply_add_complex, - MathOperation.multiply_add_complex_gaussian + MathOperation.multiply_add_complex_gaussian, ] return self.tile_description.math_instruction.math_operation in complex_operators @@ -545,8 +569,9 @@ def accumulator_type(self): # ################################################################################################### + class EmitConv2dInstance: - def __init__(self, operation_suffix=''): + def __init__(self, operation_suffix=""): self.operation_suffix = operation_suffix self.includes = [ "cutlass/cutlass.h", @@ -586,7 +611,6 @@ def __init__(self, operation_suffix=''): """ def emit(self, operation): - warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] @@ -594,39 +618,38 @@ def emit(self, operation): operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) values = { - 'operation_name': operation.procedural_name(), - 'operation_suffix': self.operation_suffix, - 'conv_kind': ConvKindTag[operation.conv_kind], - 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[operation.A.layout], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[operation.B.layout], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[operation.C.layout], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_vector_length': str(epilogue_vector_length), - 'epilogue_functor': operation.epilogue_functor.emit(), - 'swizzling_functor': operation.swizzling_functor.tag(), - 'stages': str(operation.tile_description.stages), - 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], - 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), - 'stride_support': StrideSupportTag[operation.stride_support], - 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else - MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "conv_kind": ConvKindTag[operation.conv_kind], + "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": operation.epilogue_functor.emit(), + "swizzling_functor": operation.swizzling_functor.tag(), + "stages": str(operation.tile_description.stages), + "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm], + "iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + "stride_support": StrideSupportTag[operation.stride_support], + "math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation], + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), } return SubstituteTemplate(self.template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py b/python/cutlass/backend/epilogue.py similarity index 87% rename from tools/library/scripts/pycutlass/src/pycutlass/epilogue.py rename to python/cutlass/backend/epilogue.py index de6d5391..8cf2c728 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -30,26 +30,24 @@ # ################################################################################ -from ast import Num -from audioop import mul -from pipes import Template -import struct -from pycutlass.library import DataTypeTag -from pycutlass import * -import cutlass -from scipy.special import erf +import ctypes -from pycutlass.c_types import MatrixCoord_ -from pycutlass.frontend import NumpyFrontend +from cuda import cuda, cudart +import cutlass_bindings +import numpy as np +from scipy.special import erf -from cuda import cuda -from cuda import cudart +from cutlass.backend.c_types import MatrixCoord_ +from cutlass.backend.frontend import NumpyFrontend +from cutlass.backend.library import DataTypeTag +from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate dtype2ctype = { - cutlass.float16: ctypes.c_uint16, - cutlass.float32: ctypes.c_float, - cutlass.float64: ctypes.c_double, - cutlass.int32: ctypes.c_int32 + cutlass_bindings.int8: ctypes.c_int8, + cutlass_bindings.float16: ctypes.c_uint16, + cutlass_bindings.float32: ctypes.c_float, + cutlass_bindings.float64: ctypes.c_double, + cutlass_bindings.int32: ctypes.c_int32 } @@ -59,13 +57,15 @@ # ################################################################################################# + class EpilogueFunctorBase: """ Base class for thread-level epilogue functors """ + def __init__(self) -> None: pass - + def emit(self, tag, template_argument): template = """${tag}<${arguments}>""" arguments = "" @@ -75,11 +75,10 @@ def emit(self, tag, template_argument): arguments += ", " values = { "tag": tag, - "arguments": arguments + "arguments": arguments, } return SubstituteTemplate(template, values) - class LinearCombination(EpilogueFunctorBase): @@ -88,8 +87,8 @@ class LinearCombination(EpilogueFunctorBase): D = alpha * accumulator + beta * source :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. + + :param epilogue_vector_length: number of elements computed per operation. Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes when there are not enough data to store @@ -97,25 +96,29 @@ class LinearCombination(EpilogueFunctorBase): :param element_epilogue: data type used to compute linear combination """ + tag = "cutlass::epilogue::thread::LinearCombination" + def __init__( - self, element_output, epilogue_vector_length, - element_accumulator=None, element_epilogue=None) -> None: # TODO bind ScaleType + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: super().__init__() if element_accumulator is None: element_accumulator = element_output if element_epilogue is None: element_epilogue = element_output - + self.element_output = element_output self.element_accumulator = element_accumulator self.element_epilogue = element_epilogue self.epilogue_vector_length = epilogue_vector_length self.template_arguments = [ - DataTypeTag[element_output], str(epilogue_vector_length), - DataTypeTag[element_accumulator], DataTypeTag[element_epilogue] + DataTypeTag[element_output], + str(epilogue_vector_length), + DataTypeTag[element_accumulator], + DataTypeTag[element_epilogue], ] # get epilogue output op type @@ -124,32 +127,32 @@ def __init__( class _EpilogueOutputOpParams(ctypes.Structure): _fields_ = [ - ("alpha_data", ctypes.c_longlong*2), - ("beta_data", ctypes.c_longlong*2), ("alpha", c_element_epilogue), ("beta", c_element_epilogue), ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p) ] + def __init__(self, alpha, beta, *args) -> None: self.alpha = element_epilogue(alpha).storage self.beta = element_epilogue(beta).storage + self.epilogue_type = _EpilogueOutputOpParams - + def emit(self): return super().emit(self.tag, self.template_arguments) class LinearCombinationClamp(LinearCombination): """ - Applies a linear combination operator to an array of elements then clamps + Applies a linear combination operator to an array of elements then clamps the output before converting to the output element type. D = alpha * accumulator + beta * source + uniform :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. + + :param epilogue_vector_length: number of elements computed per operation. Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes when there are not enough data to store @@ -157,15 +160,20 @@ class LinearCombinationClamp(LinearCombination): :param element_epilogue: data type used to compute linear combination """ + tag = "cutlass::epilogue::thread::LinearCombinationClamp" + def __init__( - self, element_output, epilogue_vector_length, + self, element_output, epilogue_vector_length, element_accumulator=None, element_epilogue=None) -> None: # Base constructor super().__init__( - element_output, epilogue_vector_length, - element_accumulator, element_epilogue) - + element_output, + epilogue_vector_length, + element_accumulator, + element_epilogue, + ) + c_element_epilogue = dtype2ctype[self.element_epilogue] element_epilogue = self.element_epilogue @@ -176,10 +184,12 @@ class _EpilogueOutputOpParams(ctypes.Structure): ("alpha_ptr", ctypes.c_void_p), ("beta_ptr", ctypes.c_void_p), ] + def __init__(self, alpha, beta, *args) -> None: self.alpha = element_epilogue(alpha).storage self.beta = element_epilogue(beta).storage - self.epilogue_type = _EpilogueOutputOpParams + + self.epilogue_type = _EpilogueOutputOpParams class FastLinearCombinationClamp(EpilogueFunctorBase): @@ -194,12 +204,14 @@ class FastLinearCombinationClamp(EpilogueFunctorBase): above. :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. + + :param epilogue_vector_length: number of elements computed per operation. Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes when there are not enough data to store """ + tag = "cutlass::epilogue::thread::FastLinearCombinationClamp" + def __init__(self, element_output, epilogue_vector_length, *args) -> None: super().__init__() @@ -207,8 +219,8 @@ def __init__(self, element_output, epilogue_vector_length, *args) -> None: DataTypeTag[element_output], str(epilogue_vector_length) ] - self.element_accumulator = cutlass.int32 - self.element_epilogue = cutlass.float32 + self.element_accumulator = cutlass_bindings.int32 + self.element_epilogue = cutlass_bindings.float32 # get epilogue output op c_element_epilogue = dtype2ctype[self.element_epilogue] @@ -221,18 +233,20 @@ class _EpilogueOutputOpParams(ctypes.Structure): ("alpha_ptr", ctypes.c_void_p), ("beta_ptr", ctypes.c_void_p), ] + def __init__(self, alpha, beta, *args) -> None: self.alpha = element_epilogue(alpha).storage self.beta = element_epilogue(beta).storage - self.epilogue_type = _EpilogueOutputOpParams - + + self.epilogue_type = _EpilogueOutputOpParams + def emit(self): return super().emit(self.tag, self.template_arguments) class LinearCombinationGeneric(LinearCombination): """ - Applies a linear combination operator followed by an activation function + Applies a linear combination operator followed by an activation function to an array of elements. D = activation(alpha * accumulator + beta * source) @@ -240,8 +254,8 @@ class LinearCombinationGeneric(LinearCombination): :param activation_functor: input activation functor :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. + + :param epilogue_vector_length: number of elements computed per operation. Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes when there are not enough data to store @@ -249,21 +263,26 @@ class LinearCombinationGeneric(LinearCombination): :param element_epilogue: data type used to compute linear combination """ + tag = "cutlass::epilogue::thread::LinearCombinationGeneric" + def __init__( self, activation_functor, - element_output, epilogue_vector_length, + element_output, epilogue_vector_length, element_accumulator=None, element_epilogue=None) -> None: super().__init__( - element_output, epilogue_vector_length, - element_accumulator, element_epilogue) - + element_output, + epilogue_vector_length, + element_accumulator, + element_epilogue, + ) + self.template_arguments = [ - activation_functor.emit(),] + self.template_arguments - + activation_functor.emit()] + self.template_arguments + self.activation_functor = activation_functor self.element_epilogue = element_epilogue - + # get epilogue output op self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue) @@ -272,15 +291,17 @@ class ActivationFunctor: """ Base class for frequently used activation functions """ + def __init__(self, element_compute) -> None: pass + @staticmethod def numpy(x: np.ndarray): raise NotImplementedError() def emit(self): return self.tag - + @staticmethod def epilogue_output_op(element_epilogue): c_element_epilogue = dtype2ctype[element_epilogue] @@ -292,60 +313,74 @@ class _EpilogueOutputOpParams(ctypes.Structure): ("alpha_ptr", ctypes.c_void_p), ("beta_ptr", ctypes.c_void_p), ] + def __init__(self, alpha, beta, *args) -> None: self.alpha = element_epilogue(alpha).storage self.beta = element_epilogue(beta).storage + return _EpilogueOutputOpParams + # identity operator class identity(ActivationFunctor): + tag = "cutlass::epilogue::thread::Identity" + def numpy(x: np.ndarray): return x -# ReLu operator, + +# ReLu operator, class relu(ActivationFunctor): tag = "cutlass::epilogue::thread::ReLu" def __init__(self, element_compute): super().__init__(element_compute) + class _Arguments(ctypes.Structure): _fields_ = [ ("threshold", dtype2ctype[element_compute]) ] - def __init__(self, threshold=0.) -> None: + + def __init__(self, threshold=0.0) -> None: self.threshold = element_compute(threshold).storage + self.argument_type = _Arguments - + def emit_visitor(self): return "cutlass::ReLUVisitor" - + @staticmethod def numpy(x: np.ndarray): return np.maximum(x, 0) + # Leaky ReLu operator class leaky_relu(ActivationFunctor): tag = "cutlass::epilogue::thread::LeakyReLU" def __init__(self, element_compute) -> None: super().__init__(element_compute) + class _Arguments(ctypes.Structure): _fields_ = [ ("leaky_alpha", dtype2ctype[element_compute]) ] + def __init__(self, leaky_alpha) -> None: self.leaky_alpha = element_compute(leaky_alpha).storage + self.argument_type = _Arguments - + def emit_visitor(self): return "cutlass::LeakyReLUVisitor" @staticmethod def numpy(x: np.ndarray, leaky_alpha): return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha - + def epilogue_output_op(self, element_epilogue): c_element_epilogue = dtype2ctype[element_epilogue] + class _EpilogueOutputOpParams(ctypes.Structure): _fields_ = [ ("alpha", c_element_epilogue), @@ -354,28 +389,32 @@ class _EpilogueOutputOpParams(ctypes.Structure): ("beta_ptr", ctypes.c_void_p), ("leaky_alpha", c_element_epilogue) ] + def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None: self.alpha = element_epilogue(alpha).storage self.beta = element_epilogue(beta).storage self.alpha_ptr = 0 self.beta_ptr = 0 self.leaky_alpha = element_epilogue(leaky_alpha).storage + return _EpilogueOutputOpParams + # Tanh operator class tanh(ActivationFunctor): tag = "cutlass::epilogue::thread::Tanh" def __init__(self, element_compute) -> None: super().__init__(element_compute) + class _Arguments(ctypes.Structure): - _fields_ = [ - ("tmp", ctypes.c_int) - ] + _fields_ = [("tmp", ctypes.c_int)] + def __init__(self, *args) -> None: self.tmp = 0 + self.argument_type = _Arguments - + def emit_visitor(self): return "cutlass::TanhVisitor" @@ -383,8 +422,10 @@ def emit_visitor(self): def numpy(x: np.ndarray): return np.tanh(x) + def sigmoid_op(x: np.ndarray): - return 1. / (1. + np.exp(-x)) + return 1.0 / (1.0 + np.exp(-x)) + # Sigmoid operator class sigmoid(ActivationFunctor): @@ -394,6 +435,7 @@ class sigmoid(ActivationFunctor): def numpy(x: np.ndarray): return sigmoid_op(x) + # SiLu operator class silu(ActivationFunctor): tag = "cutlass::epilogue::thread::SiLu" @@ -402,14 +444,16 @@ class silu(ActivationFunctor): def numpy(x: np.ndarray): return x * sigmoid_op(x) + # Hardswish operator class hardswish(ActivationFunctor): tag = "cutlass::epilogue::thread::HardSwish" @staticmethod def numpy(x: np.ndarray): - relu6 = np.minimum(np.maximum(x + 3., 0), 6.) - return x * relu6 / 6. + relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0) + return x * relu6 / 6.0 + # GELU operator class gelu(ActivationFunctor): @@ -417,7 +461,8 @@ class gelu(ActivationFunctor): @staticmethod def numpy(x: np.ndarray): - return 0.5 * x * (1 + erf(x / np.sqrt(2.))) + return 0.5 * x * (1 + erf(x / np.sqrt(2.0))) + # reduction operator def reduction_op(tensor, direction, math, factor): @@ -426,7 +471,7 @@ def reduction_op(tensor, direction, math, factor): if direction == "row": num_cta_n = (n + factor - 1) // factor reduction = np.transpose( - np.sum(tensor.reshape(batch, m, num_cta_n, factor), axis=-1), + np.sum(tensor.reshape(batch, m, num_cta_n, factor), axis=-1), axes=[0, 2, 1]).flatten() elif direction == "column": num_cta_m = (m + factor - 1) // factor @@ -434,17 +479,10 @@ def reduction_op(tensor, direction, math, factor): tensor.reshape(batch, num_cta_m, factor, n), axis=-2).flatten() else: raise NotImplementedError - return reduction + return reduction else: raise NotImplementedError -# # GELU operator implemented using the taylor series approximation -# class GELU_taylor(ActivationFunctor): -# tag = "cutlass::epilogue::thread::GELU_taylor" - -# # Computes backwards pass for GELU operator -# class dGELU(ActivationFunctor): -# tag = "cutlass::epilogue::thread::dGELU" ################################################################################ # Epilogue Visitor @@ -457,8 +495,8 @@ class LayerNorm(EpilogueFunctorBase): D = alpha * accumulator + beta * source :param element_output: data type used to load and store tensors - - :param epilogue_vector_length: number of elements computed per operation. + + :param epilogue_vector_length: number of elements computed per operation. Usually it is 128/sizeof_bits, but we use 64 and 32 sometimes when there are not enough data to store @@ -466,6 +504,7 @@ class LayerNorm(EpilogueFunctorBase): :param element_epilogue: data type used to compute linear combination """ + KernelTemplate = """ cutlass::epilogue::threadblock::EpilogueVisitorLayerNorm< @@ -480,18 +519,21 @@ class LayerNorm(EpilogueFunctorBase): ${epilogue_functor}, ${shifted_k}>; """ - headers = ["gemm/gemm_universal_with_visitor.h", - "epilogue/epilogue_visitor_with_layernorm.h"] + headers = [ + "gemm/gemm_universal_with_visitor.h", + "epilogue/epilogue_visitor_with_layernorm.h" + ] + def __init__( self, elementwise_functor, - element_variance=None, element_mean=None, - element_layer_norm_compute=None, shifted_k=True) -> None: # TODO bind ScaleType + element_variance=None, element_mean=None, + element_layer_norm_compute=None, shifted_k=True, ) -> None: super().__init__() self.elementwise_functor = elementwise_functor self.element_compute = elementwise_functor.element_epilogue self.element_output = elementwise_functor.element_output - + if element_variance is None: self.element_variance = self.element_output if element_mean is None: @@ -502,18 +544,19 @@ def __init__( self.shifted_k = "true" else: self.shifted_k = "false" - + # get epilogue output op elementwise_params_type = self.elementwise_functor.epilogue_type - + class _EpilogueVisitorParams(ctypes.Structure): _fields_ = [ ("element_wise", elementwise_params_type), ("ptr_Variance", ctypes.c_void_p), ("ptr_Mean_", ctypes.c_void_p), ("ptr_Shifted_K_", ctypes.c_void_p), - ("extent", MatrixCoord_) + ("extent", MatrixCoord_), ] + def __init__(self, elementwise_params, variance, mean, shift_k, extent) -> None: self.element_wise = elementwise_params if isinstance(variance, np.ndarray): @@ -528,50 +571,52 @@ def __init__(self, elementwise_params, variance, mean, shift_k, extent) -> None: self.host_variance = variance self.host_mean = mean self.host_shift_k = shift_k - + def sync(self, stream_sync=True): if stream_sync: err, = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) - - # if hasattr(self, "host_variance"): + err, = cuda.cuMemcpyDtoH( - self.host_variance, cuda.CUdeviceptr(self.ptr_Variance), + self.host_variance, + cuda.CUdeviceptr(self.ptr_Variance), self.host_variance.size * self.host_variance.itemsize) err, = cuda.cuMemcpyDtoH( - self.host_mean, cuda.CUdeviceptr(self.ptr_Mean_), + self.host_mean, + cuda.CUdeviceptr(self.ptr_Mean_), self.host_mean.size * self.host_mean.itemsize) err, = cuda.cuMemcpyDtoH( - self.host_shift_k, cuda.CUdeviceptr(self.ptr_Shifted_K_), + self.host_shift_k, + cuda.CUdeviceptr(self.ptr_Shifted_K_), self.host_shift_k.size * self.host_shift_k.itemsize) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) - self.epilogue_type = _EpilogueVisitorParams + self.epilogue_type = _EpilogueVisitorParams def emit(self, operation): values = { - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'operation_name': operation.procedural_name(), - 'element_compute': DataTypeTag[self.element_compute], - 'element_variance': DataTypeTag[self.element_variance], - 'element_mean': DataTypeTag[self.element_mean], - 'element_layer_norm_compute': DataTypeTag[self.element_layer_norm_compute], - 'epilogue_functor': self.elementwise_functor.emit(), - 'shifted_k': self.shifted_k + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "operation_name": operation.procedural_name(), + "element_compute": DataTypeTag[self.element_compute], + "element_variance": DataTypeTag[self.element_variance], + "element_mean": DataTypeTag[self.element_mean], + "element_layer_norm_compute": DataTypeTag[self.element_layer_norm_compute], + "epilogue_functor": self.elementwise_functor.emit(), + "shifted_k": self.shifted_k, } return SubstituteTemplate(self.KernelTemplate, values) - class AccumulatorOp: Template = """ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpAccumulator<${element_accumulator}, ${elements_per_access}>; """ counter = 0 + def __init__(self, element_accumulator, elements_per_access) -> None: self.element_accumulator = element_accumulator self.elements_per_access = elements_per_access @@ -579,21 +624,19 @@ def __init__(self, element_accumulator, elements_per_access) -> None: self.instance_name = "AccumulatorOp%d" % AccumulatorOp.counter AccumulatorOp.counter += 1 - class _Arguments(ctypes.Structure): - _fields_ = [ - ("tmp", ctypes.c_int) - ] + _fields_ = [("tmp", ctypes.c_int)] + def __init__(self): self.tmp = 0 - + self.argument_type = _Arguments - + def emit(self, *args): values = { "instance_name": self.instance_name, "element_accumulator": DataTypeTag[self.element_accumulator], - "elements_per_access": str(self.elements_per_access) + "elements_per_access": str(self.elements_per_access), } return SubstituteTemplate(self.Template, values) @@ -609,15 +652,15 @@ class LinearCombinationOp: ${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}>; """ counter = 0 + def __init__(self, element_accumulator, element_compute, elements_per_access, visitor_a, visitor_b) -> None: - # self.element_accumulator = element_accumulator self.element_compute = element_compute self.elements_per_access = elements_per_access self.visitor_a = visitor_a self.visitor_b = visitor_b - + self.instance_name = "LinearCombinationOp%d" % LinearCombinationOp.counter LinearCombinationOp.counter += 1 @@ -628,14 +671,15 @@ class _Arguments(ctypes.Structure): ("visitor_a", self.visitor_a.argument_type), ("visitor_b", self.visitor_b.argument_type) ] + def __init__(self, alpha, beta, visitor_a_arg, visitor_b_arg) -> None: self.alpha = element_compute(alpha).storage self.beta = element_compute(beta).storage self.visitor_a = visitor_a_arg self.visitor_b = visitor_b_arg - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, @@ -649,32 +693,34 @@ def emit(self, operation): } return SubstituteTemplate(self.Template, values) + class VectorAdd: def __init__(self, *args) -> None: class _Arguments(ctypes.Structure): - _fields_ = [ - ("tmp", ctypes.c_int) - ] + _fields_ = [("tmp", ctypes.c_int)] + def __init__(self, *args) -> None: self.tmp = 0 + self.argument_type = _Arguments def emit(self): return "cutlass::VectorAdd" + class VectorMult: def __init__(self, *args) -> None: class _Arguments(ctypes.Structure): - _fields_ = [ - ("tmp", ctypes.c_int) - ] + _fields_ = [("tmp", ctypes.c_int)] + def __init__(self, *args) -> None: self.tmp = 0 + self.argument_type = _Arguments def emit(self): return "cutlass::VectorMult" - + class BinaryOp: Template = """ @@ -687,9 +733,9 @@ class BinaryOp: ${elements_per_access}, ${visitor_a_name}, ${visitor_b_name}, ${binary_op}>; """ counter = 0 - def __init__(self, element_accumulator, element_compute, + + def __init__(self, element_accumulator, element_compute, elements_per_access, visitor_a, visitor_b, binary_op) -> None: - # self.element_accumulator = element_accumulator self.element_compute = element_compute self.elements_per_access = elements_per_access @@ -706,12 +752,14 @@ class _Arguments(ctypes.Structure): ("visitor_a", self.visitor_a.argument_type), ("visitor_b", self.visitor_b.argument_type) ] + def __init__(self, binary_param, visitor_a_arg, visitor_b_arg) -> None: self.binary_param = binary_param self.visitor_a = visitor_a_arg self.visitor_b = visitor_b_arg - + self.argument_type = _Arguments + def emit(self, operation): values = { "instance_name": self.instance_name, @@ -733,14 +781,16 @@ class _Arguments(ctypes.Structure): _fields_ = [ ("alpha", dtype2ctype[element_compute]) ] + def __init__(self, alpha) -> None: self.alpha = element_compute(alpha).storage - + self.argument_type = _Arguments - + def emit_visitor(self): return "cutlass::Mult" + class UnaryOp: Template = """ ${visitor} @@ -750,9 +800,9 @@ class UnaryOp: ${elements_per_access}, ${visitor_name}, ${unary_op}>; """ counter = 0 + def __init__(self, element_accumulator, element_compute, elements_per_access, visitor, unary_op) -> None: - # self.element_accumulator = element_accumulator self.element_compute = element_compute self.elements_per_access = elements_per_access @@ -765,14 +815,15 @@ def __init__(self, element_accumulator, element_compute, class _Arguments(ctypes.Structure): _fields_ = [ ("unary_param", unary_op.argument_type), - ("visitor_arg", self.visitor.argument_type) + ("visitor_arg", self.visitor.argument_type), ] + def __init__(self, unary_param, visitor_arg) -> None: self.unary_param = unary_param self.visitor_arg = visitor_arg - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, @@ -781,36 +832,37 @@ def emit(self, operation): "elements_per_access": str(self.elements_per_access), "visitor_name": self.visitor.instance_name, "unary_op": self.unary_op.emit_visitor(), - "visitor": self.visitor.emit(operation) + "visitor": self.visitor.emit(operation), } return SubstituteTemplate(self.Template, values) - class RowBroadcastOp: Template = """ using ${instance_name} = cutlass::epilogue::threadblock::VisitorOpRowBroadcast< ${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>; """ counter = 0 + def __init__(self, element_accumulator, element_fragment) -> None: self.element_accumulator = element_accumulator self.element_fragment = element_fragment self.instance_name = "RowBroadcastOp%d" % RowBroadcastOp.counter RowBroadcastOp.counter += 1 - + class _Arguments(ctypes.Structure): _fields_ = [ ("broadcast_ptr", ctypes.c_void_p), ("batch_stride", ctypes.c_longlong) ] + def __init__(self, broadcast_ptr, batch_stride=0): self.broadcast_ptr = int(broadcast_ptr) self.batch_stride = batch_stride - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, @@ -827,24 +879,26 @@ class ColumnBroadcastOp: ${element_accumulator}, ${element_fragment}, ${input_tile_iterator}>; """ counter = 0 + def __init__(self, element_accumulator, element_fragment) -> None: self.element_accumulator = element_accumulator self.element_fragment = element_fragment self.instance_name = "ColumnBroadcastOp%d" % ColumnBroadcastOp.counter ColumnBroadcastOp.counter += 1 - + class _Arguments(ctypes.Structure): _fields_ = [ ("broadcast_ptr", ctypes.c_void_p), ("batch_stride", ctypes.c_longlong) ] + def __init__(self, broadcast_ptr, batch_stride=0): self.broadcast_ptr = int(broadcast_ptr) self.batch_stride = batch_stride - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, @@ -861,25 +915,27 @@ class TensorInputOp: ${element_accumulator}, ${input_tile_iterator}>; """ counter = 0 + def __init__(self, element_accumulator) -> None: self.element_accumulator = element_accumulator self.instance_name = "TensorInputOp%d" % TensorInputOp.counter TensorInputOp.counter += 1 - + class _Arguments(ctypes.Structure): _fields_ = [ ("input_ptr", ctypes.c_void_p), ("ldt", ctypes.c_int), ("batch_stride", ctypes.c_longlong) ] + def __init__(self, input_ptr, ldt, batch_stride=0) -> None: self.input_ptr = int(input_ptr) self.ldt = ldt self.batch_stride = batch_stride - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, @@ -888,6 +944,7 @@ def emit(self, operation): } return SubstituteTemplate(self.Template, values) + class TensorOutputOp: Template = """ ${visitor} @@ -896,6 +953,7 @@ class TensorOutputOp: ${element_accumulator}, ${output_tile_iterator}, ${visitor_name}>; """ counter = 0 + def __init__(self, element_accumulator, visitor) -> None: self.element_accumulator = element_accumulator self.visitor = visitor @@ -910,21 +968,22 @@ class _Arguments(ctypes.Structure): ("batch_stride", ctypes.c_longlong), ("visitor_arg", self.visitor.argument_type) ] + def __init__(self, output_ptr, ldt, visitor_arg, batch_stride=0) -> None: self.output_ptr = int(output_ptr) self.ldt = int(ldt) self.visitor_arg = visitor_arg self.batch_stride = batch_stride - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, "element_accumulator": DataTypeTag[self.element_accumulator], "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator", "visitor_name": self.visitor.instance_name, - "visitor": self.visitor.emit(operation) + "visitor": self.visitor.emit(operation), } return SubstituteTemplate(self.Template, values) @@ -939,7 +998,8 @@ class ColumnReductionOp: ${output_tile_iterator}, ${visitor_name}>; """ counter = 0 - def __init__(self, element_accumulator, element_reduction, + + def __init__(self, element_accumulator, element_reduction, element_reduction_accumulator, visitor) -> None: self.element_accumulator = element_accumulator self.element_reduction = element_reduction @@ -955,25 +1015,26 @@ class _Arguments(ctypes.Structure): ("batch_stride", ctypes.c_longlong), ("visitor_arg", self.visitor.argument_type) ] + def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None: self.reduction_ptr = reduction_ptr self.batch_stride = batch_stride self.visitor_arg = visitor_arg - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), "element_accumulator": DataTypeTag[self.element_accumulator], "element_reduction": DataTypeTag[self.element_reduction], "element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator], "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator", "visitor_name": self.visitor.instance_name, - "visitor": self.visitor.emit(operation) + "visitor": self.visitor.emit(operation), } return SubstituteTemplate(self.Template, values) @@ -988,7 +1049,8 @@ class RowReductionOp: ${output_tile_iterator}, ${visitor_name}>; """ counter = 0 - def __init__(self, element_accumulator, element_reduction, + + def __init__(self, element_accumulator, element_reduction, element_reduction_accumulator, visitor) -> None: self.element_accumulator = element_accumulator self.element_reduction = element_reduction @@ -1004,24 +1066,25 @@ class _Arguments(ctypes.Structure): ("batch_stride", ctypes.c_longlong), ("visitor_arg", self.visitor.argument_type) ] + def __init__(self, reduction_ptr, visitor_arg, batch_stride=0) -> None: self.reduction_ptr = reduction_ptr self.visitor_arg = visitor_arg self.batch_stride = batch_stride - + self.argument_type = _Arguments - + def emit(self, operation): values = { "instance_name": self.instance_name, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), "element_accumulator": DataTypeTag[self.element_accumulator], "element_reduction": DataTypeTag[self.element_reduction], "element_reduction_accumulator": DataTypeTag[self.element_reduction_accumulator], "output_tile_iterator": operation.procedural_name() + "_default::Epilogue::OutputTileIterator", "visitor_name": self.visitor.instance_name, - "visitor": self.visitor.emit(operation) + "visitor": self.visitor.emit(operation), } return SubstituteTemplate(self.Template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/frontend.py b/python/cutlass/backend/frontend.py similarity index 85% rename from tools/library/scripts/pycutlass/src/pycutlass/frontend.py rename to python/cutlass/backend/frontend.py index 10ecaeaa..b4e0be07 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +++ b/python/cutlass/backend/frontend.py @@ -30,25 +30,17 @@ # ################################################################################ -import numpy as np from cuda import cuda -from pycutlass.memory_manager import * -from typing import TYPE_CHECKING -try: +import numpy as np + +from cutlass.backend.memory_manager import device_mem_alloc, todevice +from cutlass.backend.utils.software import CheckPackages + +if CheckPackages().check_torch(): import torch - torch_available = True -except ImportError: - torch_available = False - if TYPE_CHECKING: - import torch -try: +if CheckPackages().check_cupy(): import cupy as cp - cupy_available = True -except ImportError: - cupy_available = False - if TYPE_CHECKING: - import cupy as cp class NumpyFrontend: @@ -57,7 +49,7 @@ class NumpyFrontend: """ @staticmethod - def argument(np_tensor: 'np.ndarray', is_output: 'bool') -> cuda.CUdeviceptr: + def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr: """Convert the input numpy tensor to CUDA device pointer :param np_tensor: input numpy nd array @@ -78,7 +70,7 @@ class TorchFrontend: """ @staticmethod - def argument(torch_tensor: 'torch.Tensor') -> cuda.CUdeviceptr: + def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr: """Convert the input torch tensor to CUDA device pointer :param torch_tensor: input torch tensor @@ -100,5 +92,5 @@ class CupyFrontend: """ @staticmethod - def argument(cupy_ndarray: 'cp.ndarray'): + def argument(cupy_ndarray: "cp.ndarray"): return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr)) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py b/python/cutlass/backend/gemm_operation.py similarity index 52% rename from tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py rename to python/cutlass/backend/gemm_operation.py index 75b43862..c10056df 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -30,14 +30,66 @@ # ################################################################################ -import enum import copy +import ctypes +import enum + +from cuda import cuda, cudart +import cutlass_bindings import numpy as np -from typeguard import typechecked -import cutlass -from pycutlass import * -import pycutlass.builder.collective_op_builder as collective_op_builder -from cuda import cuda +import rmm + +from cutlass import KernelScheduleSuffixes, KernelScheduleTag, KernelScheduleType +from cutlass.backend.arguments import ArgumentBase +from cutlass.backend.c_types import ( + GemmCoord_, + GemmCoordBatched_, + GenericMainloopArguments3x_, + StrideBatched_, + dim3_, + get_gemm_arguments, + get_gemm_arguments_3x, + get_gemm_arguments_streamk, + get_gemm_grouped_arguments, + get_mainloop_arguments_3x +) +from cutlass.backend.library import ( + ApiVersion, + ComplexTransformTag, + DataTypeNames, + DataTypeSize, + DataTypeTag, + GemmKind, + GemmKindNames, + LayoutTag, + MathOperation, + MathOperationTag, + OpcodeClassNames, + OpcodeClassTag, + OperationKind, + SchedulerMode, + SchedulerModeTag, + ShortComplexLayoutNames, + ShortDataTypeNames, + ShortLayoutTypeNames, + TensorDescription, + TileDescription, + api_version, + enum_auto, + get_complex_from_real, +) +from cutlass.backend.memory_manager import device_mem_alloc, todevice +from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass.backend.tensor_ref import TensorRef +from cutlass.backend.type_hint import GemmOperation, Tensor +from cutlass.backend.utils.software import ( + CheckPackages, + SubstituteTemplate, + device_sm_count, +) + +if CheckPackages().check_torch(): + import torch ################################################################################ @@ -47,27 +99,26 @@ ################################################################################ -def transpose_layout(layout: cutlass.layout): - if layout == cutlass.ColumnMajor: - return cutlass.RowMajor - elif layout == cutlass.RowMajor: - return cutlass.ColumnMajor +def transpose_layout(layout: cutlass_bindings.layout): + if layout == cutlass_bindings.ColumnMajor: + return cutlass_bindings.RowMajor + elif layout == cutlass_bindings.RowMajor: + return cutlass_bindings.ColumnMajor else: raise ValueError("unsupported Layout {}".format(layout)) -# @typechecked class GemmArguments2x(ArgumentBase): """ - Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and + Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and user-provide tensors into the kernel's argument :param operation: the GEMM operation to take the argument - :type operation: :class:`pycutlass.GemmOperationUniversal` | - :class:`pycutlass.GemmOperationGrouped` - + :type operation: :class:`cutlass.backend.GemmOperationUniversal` | + :class:`cutlass.backend.GemmOperationGrouped` + :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.gemm.GemmCoord` + :type operation: :class:`cutlass_bindings.gemm.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -82,47 +133,46 @@ class GemmArguments2x(ArgumentBase): :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass.gemm.Mode` + :type gemm_mode: :class:`cutlass_bindings.gemm.Mode` :param output_op: output operator, optional - :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` """ def __init__( - self, operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord', - A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor', - gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs): - + self, operation: "GemmOperation", problem_size: "cutlass_bindings.gemm.GemmCoord", + A: "Tensor", B: "Tensor", C: "Tensor", D: "Tensor", + gemm_mode: "cutlass_bindings.gemm.Mode" = cutlass_bindings.gemm.Mode.Gemm, **kwargs): self.operation = operation - self.layout_A: cutlass.layout = operation.A.layout - self.layout_B: cutlass.layout = operation.B.layout - self.layout_C: cutlass.layout = operation.C.layout + self.layout_A: cutlass_bindings.layout = operation.A.layout + self.layout_B: cutlass_bindings.layout = operation.B.layout + self.layout_C: cutlass_bindings.layout = operation.C.layout self.element_A = operation.A.element self.element_B = operation.B.element self.element_C = operation.C.element - if (operation.C.layout in - [cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32]): + if (operation.C.layout in + [cutlass_bindings.RowMajorInterleaved32, cutlass_bindings.ColumnMajorInterleaved32]): # reorder tensor B for interleaved layout output B = self.reorder_tensor_B(B, problem_size) super().__init__(A, B, C, D, **kwargs) if operation.switched: - self.problem_size = cutlass.gemm.GemmCoord( + self.problem_size = cutlass_bindings.gemm.GemmCoord( problem_size.n(), problem_size.m(), problem_size.k()) self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A else: - self.problem_size = cutlass.gemm.GemmCoord( + self.problem_size = cutlass_bindings.gemm.GemmCoord( problem_size.m(), problem_size.n(), problem_size.k()) - + # if the number of elements in C = problem_size.n # C is treated as the bias if hasattr(self, "tensor_c_numel"): - if (self.tensor_c_numel == self.problem_size.n() and - self.problem_size.m() != 1): self.bias = True + if self.tensor_c_numel == self.problem_size.n() and self.problem_size.m() != 1: + self.bias = True # get the leading dimension self.lda = operation.A.layout.packed(self.problem_size.mk()).stride() @@ -134,24 +184,23 @@ def __init__( if self.bias: self.ldc = 0 - if 'output_op' in kwargs.keys() and \ - gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel: - self.output_op = kwargs['output_op'] + if "output_op" in kwargs.keys() and gemm_mode != cutlass_bindings.gemm.Mode.GemmSplitKParallel: + self.output_op = kwargs["output_op"] else: self.output_op = self.operation.epilogue_type(1.0, 0.0) # get number of slices on k dimension self.gemm_mode = gemm_mode - if gemm_mode in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]: - if 'split_k_slices' in kwargs.keys(): - self.batch_count = kwargs['split_k_slices'] + if gemm_mode in [cutlass_bindings.gemm.Mode.Gemm, cutlass_bindings.gemm.Mode.GemmSplitKParallel]: + if "split_k_slices" in kwargs.keys(): + self.batch_count = kwargs["split_k_slices"] else: self.batch_count = 1 self.split_k_slices = self.batch_count - if gemm_mode in [cutlass.gemm.Mode.Batched, cutlass.gemm.Mode.Array]: - if 'batch' in kwargs.keys(): - self.batch_count = kwargs['batch'] + if gemm_mode in [cutlass_bindings.gemm.Mode.Batched, cutlass_bindings.gemm.Mode.Array]: + if "batch" in kwargs.keys(): + self.batch_count = kwargs["batch"] else: self.batch_count = 1 @@ -163,7 +212,7 @@ def __init__( self.batched_stride_C = self.problem_size.n() # support GEMM Mode Array - if gemm_mode == cutlass.gemm.Mode.Array: + if gemm_mode == cutlass_bindings.gemm.Mode.Array: self.ptr_A_array = [] self.ptr_B_array = [] self.ptr_C_array = [] @@ -188,7 +237,7 @@ def __init__( ptr_B_addr += stride_B ptr_C_addr += stride_C ptr_D_addr += stride_D - + self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64) self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64) self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64) @@ -197,15 +246,15 @@ def __init__( if isinstance(self.operation, GemmOperationUniversal): self.initialize() - def reorder_tensor_B(self, tensor_B: 'np.ndarray', - problem_size: 'cutlass.gemm.GemmCoord'): + def reorder_tensor_B(self, tensor_B: "np.ndarray", + problem_size: "cutlass_bindings.gemm.GemmCoord"): """ Reorder tensor_B for interleaved layout :param tensor_B: input tensor B :type tensor_B: numpy.ndarray :param problem_size: GEMM problem size - :type problem_size: :class:`cutlass.gemm.GemmCoord` + :type problem_size: :class:`cutlass_bindings.gemm.GemmCoord` :return: reordered tensor B :rtype: numpy.ndarray @@ -217,12 +266,12 @@ def reorder_tensor_B(self, tensor_B: 'np.ndarray', reordered_tensor_ref_B = self.get_tensor_ref( reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b" ) - cutlass.gemm.host.reorder_column( + cutlass_bindings.gemm.host.reorder_column( tensor_ref_B, reordered_tensor_ref_B, problem_size) return reordered_tensor_B def get_tensor_ref( - self, tensor, dtype, tensor_layout, problem_size, operand): + self, tensor, dtype, tensor_layout, problem_size, operand): if operand == "a": tensor_coord = problem_size.mk() elif operand == "b": @@ -231,7 +280,7 @@ def get_tensor_ref( tensor_coord = problem_size.mn() else: raise ValueError("unknown operand: " + operand) - + layout = tensor_layout.packed(tensor_coord) return TensorRef(tensor, dtype, layout).tensor_ref @@ -239,18 +288,22 @@ def get_tensor_ref( def get_arguments(self): problem_size_ = GemmCoord_(self.problem_size) grid_tiled_shape_ = GemmCoord_( - cutlass.gemm.GemmCoord( - self.grid_tiled_shape.x, self.grid_tiled_shape.y, + cutlass_bindings.gemm.GemmCoord( + self.grid_tiled_shape.x, + self.grid_tiled_shape.y, self.grid_tiled_shape.z ) ) - if self.gemm_mode == cutlass.gemm.Mode.Array: + if self.gemm_mode == cutlass_bindings.gemm.Mode.Array: arguments = self.operation.argument_type( # Arguments from UniversalArgumentsBase - self.gemm_mode, problem_size_, self.batch_count, 0, + self.gemm_mode, + problem_size_, + self.batch_count, + 0, # Remaining arguments self.output_op, - int(self.ptr_A_array_buffer.ptr), + int(self.ptr_A_array_buffer.ptr), int(self.ptr_B_array_buffer.ptr), int(self.ptr_C_array_buffer.ptr), int(self.ptr_D_array_buffer.ptr), @@ -264,9 +317,14 @@ def get_arguments(self): # Arguments from UniversalArgumentsBase self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D, # Remaining arguments - self.output_op, - int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D), - self.batched_stride_A, self.batched_stride_B, self.batched_stride_C, + self.output_op, + int(self.ptr_A), + int(self.ptr_B), + int(self.ptr_C), + int(self.ptr_D), + self.batched_stride_A, + self.batched_stride_B, + self.batched_stride_C, self.lda, self.ldb, self.ldc, self.ldd, self.lda, self.ldb, self.ldc, self.ldd, 0, 0, 0 @@ -278,9 +336,8 @@ def initialize(self): # get launch configuration launch_config = self.operation.rt_module.plan(self) - # get the host and device workspace - device_workspace_size = \ - self.operation.rt_module.get_device_workspace_size(self) + # get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) if device_workspace_size > 0: self.workspace_buffer = device_mem_alloc(device_workspace_size) @@ -291,13 +348,11 @@ def initialize(self): workspace_ptr = None device_workspace = 0 - if (workspace_ptr is not None and - self.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel): - # in GEMM split-K parallel, the D pointer is redirected + if workspace_ptr is not None and self.gemm_mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: + # in GEMM splik-K parallel, the D pointer is redirected # to the workspace self.ptr_D = cuda.CUdeviceptr(workspace_ptr) - elif (workspace_ptr is not None and - self.gemm_mode == cutlass.gemm.Mode.Gemm): + elif workspace_ptr is not None and self.gemm_mode == cutlass_bindings.gemm.Mode.Gemm: # in GEMM split-K serial device_workspace = workspace_ptr @@ -314,17 +369,132 @@ def initialize(self): self.device_workspace = device_workspace self.launch_config = launch_config + +class GemmArguments2xStreamK(GemmArguments2x): + """ + Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass.backend.GemmOperationUniversal` | + :class:`cutlass.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_bindings.gemm.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_bindings.gemm.Mode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + """ + + def __init__( + self, operation: "GemmOperation", problem_size: "cutlass_bindings.gemm.GemmCoord", + A: "Tensor", B: "Tensor", C: "Tensor", D: "Tensor", + gemm_mode: "cutlass_bindings.gemm.Mode" = cutlass_bindings.gemm.Mode.Gemm, **kwargs): + if gemm_mode not in [cutlass_bindings.gemm.Mode.Gemm, cutlass_bindings.gemm.Mode.Batched]: + raise Exception("Unsupporged GEMM mode {}.".format(gemm_mode)) + + super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + def get_arguments(self): + batch_stride_A = self.problem_size.m() * self.problem_size.k() + batch_stride_B = self.problem_size.k() * self.problem_size.n() + batch_stride_C = self.problem_size.m() * self.problem_size.n() + batch_stride_D = self.problem_size.m() * self.problem_size.n() + + arguments = self.operation.argument_type( + self.gemm_mode, + GemmCoord_(self.problem_size), + self.batch_count, + self.output_op, + int(self.ptr_A), + int(self.ptr_B), + int(self.ptr_C), + int(self.ptr_D), + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D, + self.lda, self.ldb, self.ldc, self.ldd, # strides + self.lda, self.ldb, self.ldc, self.ldd, + -1, # avail_sms + ) + return arguments + + def initialize(self): + # get the host and device workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + device_workspace_size = 10 << 20 + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: + # in GEMM splik-K parallel, the D pointer is redirected + # to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == cutlass_bindings.gemm.Mode.Gemm: + # in GEMM split-K serial + device_workspace = workspace_ptr + + arguments = self.get_arguments() + + res_arg = self.operation.rt_module.get_args( + ctypes.byref(arguments), + ctypes.c_void_p(int(device_workspace)), + device_sm_count(), + self.operation.rt_module.occupancy + ) + host_workspace = bytearray(res_arg.contents) + + grid = self.operation.rt_module.get_grid_shape( + ctypes.byref(arguments), + device_sm_count(), + self.operation.rt_module.occupancy + ) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = LaunchConfiguration( + [grid.m, grid.n, grid.k], + [self.operation.rt_module.threads, 1, 1], + self.operation.rt_module.shared_memory_capacity + ) + + class GemmArguments3x(GemmArguments2x): """ - Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and + Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and user-provide tensors into the kernel's argument :param operation: the GEMM operation to take the argument - :type operation: :class:`pycutlass.GemmOperationUniversal` | - :class:`pycutlass.GemmOperationGrouped` - + :type operation: :class:`cutlass.backend.GemmOperationUniversal` | + :class:`cutlass.backend.GemmOperationGrouped` + :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.gemm.GemmCoord` + :type operation: :class:`cutlass_bindings.gemm.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -339,18 +509,18 @@ class GemmArguments3x(GemmArguments2x): :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass.gemm.Mode` + :type gemm_mode: :class:`cutlass_bindings.gemm.Mode` :param output_op: output operator, optional - :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` """ def __init__( - self, operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord', - A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor', - gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs): - if gemm_mode not in [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.Batched]: - raise Exception("Unsupported GEMM mode {}.".format(gemm_mode)) + self, operation: "GemmOperation", problem_size: "cutlass_bindings.gemm.GemmCoord", + A: "Tensor", B: "Tensor", C: "Tensor", D: "Tensor", + gemm_mode: "cutlass_bindings.gemm.Mode" = cutlass_bindings.gemm.Mode.Gemm, **kwargs): + if gemm_mode not in [cutlass_bindings.gemm.Mode.Gemm, cutlass_bindings.gemm.Mode.Batched]: + raise Exception("Unsupporged GEMM mode {}.".format(gemm_mode)) super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) @@ -372,24 +542,36 @@ def get_arguments(self): stride_C = StrideBatched_(self.ldc, bsC) stride_D = StrideBatched_(self.ldd, bsD) - self.arguments = self.operation.argument_type( - self.gemm_mode, - problem_size_, + # Superset of potential mainloop arguments + generic_args = GenericMainloopArguments3x_( int(self.ptr_A), stride_A, int(self.ptr_B), stride_B, + ) + + # Set of mainloop arguments needed for this kernel + mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args) + + epilogue = self.operation.rt_module.epilogue_args( + self.output_op, int(self.ptr_C), stride_C, int(self.ptr_D), stride_D, - self.output_op, ) + self.arguments = self.operation.argument_type( + self.gemm_mode, + problem_size_, + mainloop, + epilogue, + ) + return self.arguments + def initialize(self): - # get the host and device workspace - device_workspace_size = \ - self.operation.rt_module.get_device_workspace_size(self) + # get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) if device_workspace_size > 0: self.workspace_buffer = device_mem_alloc(device_workspace_size) @@ -400,46 +582,52 @@ def initialize(self): workspace_ptr = None device_workspace = 0 - if (workspace_ptr is not None and - self.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel): - # in GEMM split-K parallel, the D pointer is redirected + if workspace_ptr is not None and self.gemm_mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: + # in GEMM splik-K parallel, the D pointer is redirected # to the workspace self.ptr_D = cuda.CUdeviceptr(workspace_ptr) - elif (workspace_ptr is not None and - self.gemm_mode == cutlass.gemm.Mode.Gemm): + elif workspace_ptr is not None and self.gemm_mode == cutlass_bindings.gemm.Mode.Gemm: # in GEMM split-K serial device_workspace = workspace_ptr self.get_arguments() res_arg = self.operation.rt_module.get_args( - ctypes.byref(self.arguments), ctypes.c_void_p(int(device_workspace))) + ctypes.byref(self.arguments), + ctypes.c_void_p(int(device_workspace)), + ) host_workspace = bytearray(res_arg.contents) grid = self.operation.rt_module.get_grid_shape( - ctypes.byref(self.arguments), ctypes.c_void_p(int(device_workspace))) + ctypes.byref(self.arguments), + ctypes.c_void_p(int(device_workspace)), + ) block = self.operation.rt_module.get_block_shape() device_workspace = None self.host_workspace = host_workspace self.device_workspace = device_workspace - self.launch_config = LaunchConfiguration([grid.x, grid.y, grid.z], - [block.x, block.y, block.z], - self.operation.rt_module.shared_memory_capacity) + self.launch_config = LaunchConfiguration( + [grid.x, grid.y, grid.z], + [block.x, block.y, block.z], + self.operation.rt_module.shared_memory_capacity, + ) -def GemmArguments(operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord', - A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor', - gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs): + +def GemmArguments( + operation: "GemmOperation", problem_size: "cutlass_bindings.gemm.GemmCoord", + A: "Tensor", B: "Tensor", C: "Tensor", D: "Tensor", + gemm_mode: "cutlass_bindings.gemm.Mode" = cutlass_bindings.gemm.Mode.Gemm, **kwargs): """ Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments or 3x arguments depending on the `arch` field specified in `operation`. :param operation: the GEMM operation to take the argument - :type operation: :class:`pycutlass.GemmOperationUniversal` | - :class:`pycutlass.GemmOperationGrouped` - + :type operation: :class:`cutlass.backend.GemmOperationUniversal` | + :class:`cutlass.backend.GemmOperationGrouped` + :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.gemm.GemmCoord` + :type operation: :class:`cutlass_bindings.gemm.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -454,25 +642,30 @@ def GemmArguments(operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCo :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass.gemm.Mode` + :type gemm_mode: :class:`cutlass_bindings.gemm.Mode` :param output_op: output operator, optional - :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` """ - ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x + if isinstance(operation.swizzling_functor, cutlass_bindings.ThreadblockSwizzleStreamK): + if operation.api == ApiVersion.v3x: + raise Exception("Stream K is currently only supported in CUTLASS 2.x") + ArgClass = GemmArguments2xStreamK + else: + ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) class GemmGroupedArguments: """ - Argument wrapper for GEMM Grouped. It encodes problem information and + Argument wrapper for GEMM Grouped. It encodes problem information and user-provide tensors into the kernel's argument :param operation: the GEMM Grouped operation to take the argument - :type operation: :class:`pycutlass.GemmOperationGrouped` + :type operation: :class:`cutlass.backend.GemmOperationGrouped` :param problem_size: list of GEMM problem size gemm(M, N, K) - :type operation: list[:class:`cutlass.gemm.GemmCoord`] + :type operation: list[:class:`cutlass_bindings.gemm.GemmCoord`] :param A: list of tensor A :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] @@ -487,14 +680,12 @@ class GemmGroupedArguments: :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] :param output_op: output operator, optional - :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` """ - def __init__( - self, operation: 'GemmOperationGrouped', - problem_sizes: 'list[cutlass.gemm.GemmCoord]', - A: 'list[Tensor]', B: 'list[Tensor]', C: 'list[torch.Tensor]', - D: 'list[Tensor]', **kwargs): + def __init__( + self, operation: "GemmOperationGrouped", problem_sizes: "list[cutlass_bindings.gemm.GemmCoord]", + A: "list[Tensor]", B: "list[Tensor]", C: "list[torch.Tensor]", D: "list[Tensor]", **kwargs): # get number of problems in the group self.problem_count = len(problem_sizes) @@ -521,8 +712,11 @@ def __init__( # get the threadblock threadblock_shape = operation.tile_description.threadblock_shape - self.threadblock_shape = cutlass.gemm.GemmCoord( - threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) + self.threadblock_shape = cutlass_bindings.gemm.GemmCoord( + threadblock_shape[0], + threadblock_shape[1], + threadblock_shape[2], + ) self.threadblock_swizzle = operation.swizzling_functor self.total_tiles = 0 @@ -533,10 +727,9 @@ def __init__( for idx, problem_size in enumerate(problem_sizes): M, N, K = problem_size.m(), problem_size.n(), problem_size.k() temp_argument = GemmArguments2x( - operation=operation, - problem_size=cutlass.gemm.GemmCoord(M, N, K), - A=A[idx], B=B[idx], C=C[idx], D=D[idx], - ) + operation=operation, + problem_size=cutlass_bindings.gemm.GemmCoord(M, N, K), + A=A[idx], B=B[idx], C=C[idx], D=D[idx]) self.gemm_arguments.append(temp_argument) problem_size_host.append( @@ -560,8 +753,8 @@ def __init__( # get number of tiles grid = self.threadblock_swizzle.get_grid_shape( self.threadblock_swizzle.get_tiled_shape( - temp_argument.problem_size, self.threadblock_shape, - temp_argument.batch_count) + temp_argument.problem_size, self.threadblock_shape, + temp_argument.batch_count) ) self.total_tiles += grid.x * grid.y * grid.z @@ -576,22 +769,20 @@ def __init__( self.ldc_buffer = todevice(ldc_host, np.int64) self.ldd_buffer = todevice(ldd_host, np.int64) - if 'output_op' in kwargs.keys(): - self.alpha = kwargs['output_op'].alpha - self.beta = kwargs['output_op'].beta + if "output_op" in kwargs.keys(): + self.alpha = kwargs["output_op"].alpha + self.beta = kwargs["output_op"].beta else: self.alpha = 1.0 self.beta = 0.0 - - if 'output_op' in kwargs.keys(): - self.output_op = kwargs['output_op'] + + if "output_op" in kwargs.keys(): + self.output_op = kwargs["output_op"] else: self.output_op = self.operation.epilogue_type(1.0, 0.0) - # get host problem size - self.host_problem_size_ptr = np.array( - problem_size_host, dtype=np.int32).__array_interface__['data'][0] + self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0] self.arguments = self.get_arguments() @@ -599,20 +790,27 @@ def __init__( def get_arguments(self): return self.operation.argument_type( - self.problem_size_buffer.ptr, self.problem_count, self.total_tiles, - self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr, - self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr, - self.ldb_buffer.ptr, self.ldc_buffer.ptr, self.ldd_buffer.ptr, - ctypes.c_void_p(int(self.host_problem_size_ptr)) + self.problem_size_buffer.ptr, + self.problem_count, + self.total_tiles, + self.output_op, + self.ptr_A_buffer.ptr, + self.ptr_B_buffer.ptr, + self.ptr_C_buffer.ptr, + self.ptr_D_buffer.ptr, + self.lda_buffer.ptr, + self.ldb_buffer.ptr, + self.ldc_buffer.ptr, + self.ldd_buffer.ptr, + ctypes.c_void_p(int(self.host_problem_size_ptr)), ) def initialize(self): # get launch configuration launch_config = self.operation.rt_module.plan(self) - # get the host and device workspace - device_workspace_size = \ - self.operation.rt_module.get_device_workspace_size(self) + # get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) if device_workspace_size > 0: self.workspace_buffer = device_mem_alloc(device_workspace_size) @@ -624,13 +822,14 @@ def initialize(self): if self.operation.precompute_mode == SchedulerMode.Host: device_workspace_ptr = self.operation.rt_module.host_precompute( - self, self.operation.rt_module.get_workspace_size(self)) + self, self.operation.rt_module.get_workspace_size(self),) else: device_workspace_ptr = 0 result = self.operation.rt_module.get_args( - ctypes.byref(self.arguments), self.total_tiles, - ctypes.c_void_p(int(device_workspace_ptr)) + ctypes.byref(self.arguments), + self.total_tiles, + ctypes.c_void_p(int(device_workspace_ptr)), ) host_workspace = bytearray(result.contents) @@ -639,7 +838,7 @@ def initialize(self): self.host_workspace = host_workspace self.device_workspace = device_workspace self.launch_config = launch_config - + def sync(self): err, = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: @@ -652,12 +851,13 @@ def sync(self): # Base class for GEMM runtime module ################################################################################ + class GemmRTbase(ExecutableOperation): """ GemmRT manages the CUTLASS runtime components """ - KernelTemplate = r''' + KernelTemplate = r""" extern "C" __global__ void ${operation_name}(${operation_name}${operation_suffix}::Params params) { @@ -669,48 +869,43 @@ class GemmRTbase(ExecutableOperation): ${operation_name}${operation_suffix}::SharedStorage *shared_storage = reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); - ${operation_name}${operation_suffix} op; - - op(params, *shared_storage); + ${operation_name}${operation_suffix}::invoke(params, *shared_storage); } - ''' + """ - def __init__(self, operation: 'GemmOperation'): + def __init__(self, operation: "GemmOperation"): super().__init__(operation) self.operation = operation threadblock_shape = operation.tile_description.threadblock_shape - self.threadblock_shape = cutlass.gemm.GemmCoord( + self.threadblock_shape = cutlass_bindings.gemm.GemmCoord( threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) self.threadblock_swizzle = operation.swizzling_functor - #: number of threads per threadblock - self.threads: int = operation.tile_description.num_threads + # Threads per threadblock + self.threads = operation.tile_description.num_threads - # def emit(self): return self.emitter.emit(self.operation) - # def can_implement(self, configuration, arguments): raise NotImplementedError() - # def get_host_workspace_size(self, arguments): raise NotImplementedError() - # def get_device_workspace_size(self, arguments): return 0 - # def initialize(self): err, = cuda.cuFuncSetAttribute( self.kernel, attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, value=self.shared_memory_capacity) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) + raise RuntimeError( + f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}" + ) ################################################################################ @@ -722,7 +917,8 @@ class GemmRTUniversal(GemmRTbase): """ GemmRTUniversal manages the CUTLASS runtime components """ - HostTemplate = r''' + + HostTemplate = r""" extern "C" { // Get the size of params in bytes int ${operation_name}_get_param_size(){ @@ -753,29 +949,26 @@ class GemmRTUniversal(GemmRTbase): return output; } } - ''' + """ - def __init__(self, operation: 'GemmOperation'): + def __init__(self, operation: "GemmOperation"): super(GemmRTUniversal, self).__init__(operation) self.emitter = EmitGemmUniversalInstance( - '_type', operation.direct_store, operation.visitor) - + "_type", operation.direct_store, operation.visitor) + self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor) self.argtype = [ - ctypes.POINTER(self.argument_type), + ctypes.POINTER(self.argument_type), ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p ] def plan(self, arguments): - grid = self.threadblock_swizzle.get_tiled_shape( arguments.problem_size, self.threadblock_shape, arguments.batch_count ) gemm_k_size = arguments.problem_size.k() - if (arguments.gemm_mode in - [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]): - # + if arguments.gemm_mode in [cutlass_bindings.gemm.Mode.Gemm, cutlass_bindings.gemm.Mode.GemmSplitKParallel]: alignk = max(max(128 // DataTypeSize[self.operation.A.element], 128 // DataTypeSize[self.operation.B.element]), 1) @@ -783,41 +976,105 @@ def plan(self, arguments): arguments.batch_count + alignk - 1) // alignk) * alignk if gemm_k_size: - grid_z = (arguments.problem_size.k() + - gemm_k_size - 1) // gemm_k_size - grid = cutlass.gemm.GemmCoord(grid.m(), grid.n(), grid_z) + grid_z = (arguments.problem_size.k() + gemm_k_size - 1) // gemm_k_size + grid = cutlass_bindings.gemm.GemmCoord(grid.m(), grid.n(), grid_z) - arguments.grid_tiled_shape = cutlass.dim3(grid.m(), grid.n(), grid.k()) + arguments.grid_tiled_shape = cutlass_bindings.dim3(grid.m(), grid.n(), grid.k()) grid = self.threadblock_swizzle.get_grid_shape(grid) arguments.gemm_k_size = gemm_k_size return LaunchConfiguration( - [grid.x, grid.y, grid.z], - [self.threads, 1, 1], + [grid.x, grid.y, grid.z], + [self.threads, 1, 1], self.shared_memory_capacity) - # def get_device_workspace_size(self, arguments: GemmArguments): workspace_bytes = 0 - if arguments.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel: + if arguments.gemm_mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: workspace_bytes = (DataTypeSize[arguments.operation.C.element] * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8) - elif (arguments.gemm_mode == cutlass.gemm.Mode.Gemm and + elif (arguments.gemm_mode == cutlass_bindings.gemm.Mode.Gemm and arguments.split_k_slices > 1): - # workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y return workspace_bytes +class GemmRTUniversalStreamK(GemmRTUniversal): + """ + Manages the CUTLASS runtime components for 2.x stream K kernels + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + using GemmType = ${operation_name}_base; + + // Get the params as byte array + char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace, + int sm_count, int occupancy) { + GemmType::Params* params; + params = new GemmType::Params(*argument, sm_count, occupancy); + + params->init_workspace(workspace); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(GemmType::Params)]; + for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + // Get the grid shape + dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) { + typename GemmType::Params params(*args, device_sms, sm_occupancy); + return params.get_grid_dims(); + } +} + """ + + def __init__(self, operation: "GemmOperation"): + super(GemmRTUniversalStreamK, self).__init__(operation) + self.extra_funcs = { + "get_grid_shape": GemmCoord_, + } + self._occupancy = None + self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor) + + @property + def occupancy(self): + if self._occupancy is None: + err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + self.kernel, self.threads, self.shared_memory_capacity, + cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError( + "CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: " + f"{cuda.cuGetErrorString(err)[1]}") + return self._occupancy + + ################################################################################ # Runtime module for GEMM Universal within CUTLASS 3 ################################################################################ + class GemmRTUniversal3x(GemmRTUniversal): """ - GemmRTUniversal manages the CUTLASS runtime components + Manages the CUTLASS runtime components for 3.x kernels """ - KernelTemplate = r''' + + KernelTemplate = r""" using Operator = ${operation_name}${operation_suffix}; extern "C" @@ -830,8 +1087,8 @@ class GemmRTUniversal3x(GemmRTUniversal): Operator op; op(params, smem); } - ''' - HostTemplate = r''' + """ + HostTemplate = r""" extern "C" { // Get the size of params in bytes int ${operation_name}_get_param_size(){ @@ -857,6 +1114,15 @@ class GemmRTUniversal3x(GemmRTUniversal): return output; } + // Get the total number of blocks for a persistent kernel + uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) { + auto problem_shape_MNKL = append<4>(problem, Int<1>{}); + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = + cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( + problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{}); + return problem_blocks_m * problem_blocks_n * problem_blocks_l; + } + // Get the grid shape dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) { auto tmp_params = GemmType::to_underlying_arguments(*args, workspace); @@ -868,107 +1134,172 @@ class GemmRTUniversal3x(GemmRTUniversal): return GemmType::get_block_shape(); } } - ''' + """ - def __init__(self, operation: 'GemmOperation'): + def __init__(self, operation: "GemmOperation"): super(GemmRTUniversal3x, self).__init__(operation) self.extra_funcs = { - 'get_grid_shape': dim3_, - 'get_block_shape': dim3_ + "get_grid_shape": dim3_, + "get_block_shape": dim3_, + "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64 } - self.emitter = EmitGemmUniversalInstance3x('_type') - self.argument_type, self.epilogue_type = get_gemm_arguments_3x(operation.epilogue_functor) + self.emitter = EmitGemmUniversalInstance3x("_type") + self.mainloop_args = get_mainloop_arguments_3x( + operation.tile_description.kernel_schedule, + operation.A.element, + operation.B.element, + operation.A.alignment, + operation.B.alignment + ) + self.argument_type, self.epilogue_args, self.epilogue_type = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor) class EmitGemmUniversalInstance3x: - ''' Responsible for emitting a CUTLASS 3 template definition''' + """Responsible for emitting a CUTLASS 3 template definition""" - def __init__(self, operation_suffix=''): + def __init__(self, operation_suffix=""): self.operation_suffix = operation_suffix self.includes = [ "cutlass/cutlass.h", "cute/tensor.hpp", "cute/atom/mma_atom.hpp", "cutlass/numeric_types.h", - "cutlass/gemm/kernel/gemm_universal.hpp", "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/gemm/kernel/sm90_tile_scheduler.hpp", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", "cutlass/epilogue/collective/default_epilogue.hpp", "cutlass/epilogue/thread/linear_combination.h" ] - self.gemm_template = """ + self.gemm_template_kernel = """ using namespace cute; -${collective_op} - -using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t<${layout_c}>, - cutlass::gemm::TagToStrideC_t<${layout_c}>, - ${epilogue_functor} - >; +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stage_count_type}, + ${kernel_schedule} + >::CollectiveOp; // Gemm operator ${operation_name} using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< Shape, - CollectiveOp, - EpilogueOp + CollectiveMainloop, + CollectiveEpilogue >; // Define named type -struct ${operation_name}${operation_suffix} : +struct ${operation_name}${operation_suffix} : public ${operation_name}_base { }; """ - # - def emit(self, operation): + self.gemm_template_device = self.gemm_template_kernel + """ - instance_layout_A, instance_layout_B, instance_layout_C = \ +// Define device-level operator +using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>; +""" + + def emit(self, operation): + instance_layout_A, instance_layout_B, instance_layout_C, = \ (operation.A.layout, operation.B.layout, operation.C.layout) # Support built-in epilogue functors or user-defined functions epilogue_functor = operation.epilogue_functor.emit() - collective_op = collective_op_builder.build(operation) + if operation.tile_description.stages is None or operation.tile_description.stages == 0: + stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout" + else: + stage_count_type = "_" + str(operation.tile_description.stages) + + if operation.emission_type == EmissionType.Kernel: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + schedule = KernelScheduleType.ScheduleAuto + if operation.tile_description.kernel_schedule is not None: + schedule = operation.tile_description.kernel_schedule values = { - 'operation_name': operation.procedural_name(), - 'operation_suffix': self.operation_suffix, - 'collective_op': collective_op, - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[instance_layout_B], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], - 'epilogue_functor': epilogue_functor, - 'element_output': DataTypeTag[operation.epilogue_functor.element_output], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'element_epilogue': DataTypeTag[operation.epilogue_functor.element_epilogue], - 'epilogue_vector_length': str(operation.epilogue_functor.epilogue_vector_length), - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'cluster_shape_m': str(operation.tile_description.cluster_shape[0]), - 'cluster_shape_n': str(operation.tile_description.cluster_shape[1]), - 'cluster_shape_k': str(operation.tile_description.cluster_shape[2]), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment) + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_d": DataTypeTag[operation.epilogue_functor.element_output], + "layout_d": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue], + "epilogue_vector_length": str(operation.epilogue_functor.epilogue_vector_length), + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "stage_count_type": stage_count_type, + "kernel_schedule": KernelScheduleTag[schedule], } - values['epilogue_functor'] = operation.epilogue_functor.emit() - return SubstituteTemplate(self.gemm_template, values) + values["epilogue_functor"] = operation.epilogue_functor.emit() + return SubstituteTemplate(gemm_template, values) ################################################################################################### # Runtime module for GEMM Grouped ################################################################################################### + class GemmRTGrouped(GemmRTbase): """ GemmRTGrouped manages the CUTLASS runtime components """ - HostTemplate = r''' + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + + HostTemplate = r""" extern "C" { // precompute scheduling information @@ -1006,24 +1337,25 @@ class GemmRTGrouped(GemmRTbase): return output; } } - ''' + """ - def __init__(self, operation: 'GemmOperation'): + def __init__(self, operation: "GemmOperation"): super(GemmRTGrouped, self).__init__(operation) - self.extra_funcs = {'precompute': None} + self.extra_funcs = {"precompute": None} - self.emitter = EmitGemmGroupedInstance('_type') + self.emitter = EmitGemmGroupedInstance("_type") self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor) self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p] def host_precompute(self, arguments, workspace_bytes): self.precompute.argtype = [ self.argtype[0], ctypes.c_int, ctypes.c_longlong] - self.precompute.restype = ctypes.POINTER( - ctypes.c_byte * workspace_bytes) + self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes) - problem_info = self.precompute(ctypes.byref( - arguments.arguments), arguments.total_tiles, workspace_bytes) + problem_info = self.precompute( + ctypes.byref(arguments.arguments), + arguments.total_tiles, + workspace_bytes) problem_info_array = bytearray(problem_info.contents) # copy to device memory @@ -1031,8 +1363,10 @@ def host_precompute(self, arguments, workspace_bytes): def plan(self, arguments): return LaunchConfiguration( - [arguments.total_tiles, 1, 1], - [self.threads, 1, 1], self.shared_memory_capacity) + [arguments.total_tiles, 1, 1], + [self.threads, 1, 1], + self.shared_memory_capacity, + ) def get_workspace_size(self, arguments): if self.operation.precompute_mode == SchedulerMode.Device: @@ -1044,54 +1378,46 @@ def get_workspace_size(self, arguments): ################################################################################ -# Runtime module for GEMM Grouped +# Runtime module for GEMM and grouped GEMM ################################################################################ -# + +class EmissionType(enum.Enum): + """ + Tags for whether to emit a kernel- or device-level operation + """ + + Kernel = enum_auto() + Device = enum_auto() + + class GemmOperationBase: """ CUTLASS GEMM operation """ - # def __init__( - self, gemm_kind, arch, tile_description: TileDescription, - A: TensorDescription, B: TensorDescription, C: TensorDescription, - epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1, api=False, **kwargs): - - #: operation kind + self, gemm_kind, arch, tile_description: TileDescription, + A: TensorDescription, B: TensorDescription, C: TensorDescription, + epilogue_functor, swizzling_functor=cutlass_bindings.IdentitySwizzle1, + api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs): self.operation_kind: OperationKind = OperationKind.Gemm - #: compute capability self.arch: int = arch - #: tile description object self.tile_description: TileDescription = tile_description - #: gemm kind self.gemm_kind: GemmKind = gemm_kind self.api = api self.prefix = "3x" if self.api == ApiVersion.v3x else "" + self.emission_type = emission_type - # use deep copy to avoid overwriting the original TensorDescription - if self.api != ApiVersion.v3x and C.layout == cutlass.ColumnMajor: - #: Operand A - self.A: TensorDescription = copy.deepcopy(B) - #: Operand B - self.B: TensorDescription = copy.deepcopy(A) - #: Operand C - self.C: TensorDescription = copy.deepcopy(C) - self.A.layout = transpose_layout(self.A.layout) - self.B.layout = transpose_layout(self.B.layout) - self.C.layout = transpose_layout(self.C.layout) - self.switched = True - else: - #: Operand A - self.A: TensorDescription = copy.deepcopy(A) - #: Operand B - self.B: TensorDescription = copy.deepcopy(B) - #: Operand C - self.C: TensorDescription = copy.deepcopy(C) - self.switched = False + # Optionally swap the TensorDescriptions for operands A and B and transpose their + # layouts. This is needed to mimic the transpose performed by device::GemmUniversal. + # The code below uses deep copy to avoid overwritting the original TensorDescription + self.switched = (self.api != ApiVersion.v3x and + self.emission_type == EmissionType.Kernel and + C.layout == cutlass_bindings.ColumnMajor) + + self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched) self.epilogue_functor = epilogue_functor self.swizzling_functor = swizzling_functor() @@ -1105,17 +1431,50 @@ def __init__( else: self.visitor = False + @staticmethod + def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool): + """ + Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set, + A and B are swapped, and the layout of A, B, and C are transposed. + + :param A: description of operand A + :type A: TensorDescription + :param B: description of operand B + :type B: TensorDescription + :param C: description of operand C + :type C: TensorDescription + + :return: descriptions of operands A, B, and C + :rtype: tuple[TileDescription] + """ + if swap: + A_out = copy.deepcopy(B) + B_out = copy.deepcopy(A) + C_out = copy.deepcopy(C) + A_out.layout = transpose_layout(A_out.layout) + B_out.layout = transpose_layout(B_out.layout) + C_out.layout = transpose_layout(C_out.layout) + else: + A_out = copy.deepcopy(A) + B_out = copy.deepcopy(B) + C_out = copy.deepcopy(C) + return A_out, B_out, C_out + def run(self, arguments: GemmArguments) -> cuda.CUresult: """ Configure and launch the cuda kernel with input arguments """ + if self.emission_type == EmissionType.Device: + raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"') + err = self.rt_module.run( arguments.host_workspace, arguments.device_workspace, - arguments.launch_config) + arguments.launch_config, + ) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) + raise RuntimeError("CUDA Error %s" % str(err)) return err @@ -1123,20 +1482,17 @@ def free(self): if hasattr(self, "workspace_buffer"): del self.workspace_buffer - # def is_complex(self): complex_operators = [ MathOperation.multiply_add_complex, MathOperation.multiply_add_complex_gaussian, - MathOperation.multiply_add_complex_fast_f32 + MathOperation.multiply_add_complex_fast_f32, ] return self.tile_description.math_instruction.math_operation in complex_operators - # def is_planar_complex(self): return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) - # def accumulator_type(self): accum = self.tile_description.math_instruction.element_accumulator @@ -1145,31 +1501,26 @@ def accumulator_type(self): return accum - # def short_math_name(self): if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: return "g%s" % ShortDataTypeNames[self.accumulator_type()] return ShortDataTypeNames[self.accumulator_type()] - # - def core_name(self): - ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + """The basic operation kind is prefixed with a letter indicating the accumulation type.""" - inst_shape = '' - inst_operation = '' - intermediate_type = '' + inst_shape = "" + inst_operation = "" + intermediate_type = "" math_operations_map = { - MathOperation.xor_popc: 'xor', + MathOperation.xor_popc: "xor", } - if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == cutlass.OpClass.WmmaTensorOp: - + if (self.tile_description.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp or + self.tile_description.math_instruction.opcode_class == cutlass_bindings.OpClass.WmmaTensorOp): math_op = self.tile_description.math_instruction.math_operation - math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys( - ) else '' + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else "" if self.tile_description.math_instruction.instruction_shape is not None: inst_shape = "%dx%dx%d" % tuple( @@ -1178,54 +1529,49 @@ def core_name(self): inst_shape = "Default" inst_shape += math_op_string - if self.tile_description.math_instruction.element_a != self.A.element and \ - self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + if (self.tile_description.math_instruction.element_a != self.A.element and + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator): intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) - # def extended_name(self): - ''' Append data types if they differ from compute type. ''' + """Append data types if they differ from compute type.""" if self.is_complex(): extended_name = "${core_name}" else: - if self.C.element != self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: + if (self.C.element != self.tile_description.math_instruction.element_accumulator and + self.A.element != self.tile_description.math_instruction.element_accumulator): extended_name = "${element_c}_${core_name}_${element_a}" - elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ - self.A.element != self.tile_description.math_instruction.element_accumulator: + elif (self.C.element == self.tile_description.math_instruction.element_accumulator and + self.A.element != self.tile_description.math_instruction.element_accumulator): extended_name = "${core_name}_${element_a}" else: extended_name = "${core_name}" extended_name = SubstituteTemplate(extended_name, { - 'element_a': DataTypeNames[self.A.element], - 'element_c': DataTypeNames[self.C.element], - 'core_name': self.core_name() + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), }) return extended_name - # def extended_name_3x(self): - '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + """Generates a string representing the MMA atom. Assumes accumulator type is C type.""" extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}".format( - element_a = DataTypeNames[self.A.element], - element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], - element_c = DataTypeNames[self.C.element], - core_name = self.core_name()) + element_a=DataTypeNames[self.A.element], + element_b=DataTypeNames[self.B.element], + element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_c=DataTypeNames[self.C.element], + core_name=self.core_name()) return extended_name - # def layout_name(self): if self.is_complex() or self.is_planar_complex(): return "%s%s" % ( - ShortComplexLayoutNames[( - self.A.layout, self.A.complex_transform)], - ShortComplexLayoutNames[( - self.B.layout, self.B.complex_transform)] + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] ) return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) @@ -1233,7 +1579,7 @@ def layout_name(self): def layout_name_3x(self): if self.is_complex() or self.is_planar_complex(): return "{}{}{}".format( - ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) else: @@ -1242,81 +1588,129 @@ def layout_name_3x(self): ShortLayoutTypeNames[self.B.layout], ShortLayoutTypeNames[self.C.layout]) - # + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + if self.tile_description.kernel_schedule is None: + return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto] + else: + return KernelScheduleSuffixes[self.tile_description.kernel_schedule] + def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + """The full procedural name indicates architecture, extended name, tile size, and layout.""" opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] if self.api == ApiVersion.v3x and self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}" + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}" return kernel_name_template.format( - p = self.prefix, - ar = self.arch, - op = opcode_class_name, - ex = self.extended_name_3x(), - tbm = self.tile_description.threadblock_shape[0], - tbn = self.tile_description.threadblock_shape[1], - tbk = self.tile_description.threadblock_shape[2], - cm = self.tile_description.cluster_shape[0], - cn = self.tile_description.cluster_shape[1], - ck = self.tile_description.cluster_shape[2], - l = self.tile_description.stages, - s = self.layout_name_3x(), - al = str(self.A.alignment)) + p=self.prefix, + ar=self.arch, + op=opcode_class_name, + ex=self.extended_name_3x(), + tbm=self.tile_description.threadblock_shape[0], + tbn=self.tile_description.threadblock_shape[1], + tbk=self.tile_description.threadblock_shape[2], + cm=self.tile_description.cluster_shape[0], + cn=self.tile_description.cluster_shape[1], + ck=self.tile_description.cluster_shape[2], + l=self.tile_description.stages, + s=self.layout_name_3x(), + al=str(self.A.alignment), + k=self.kernel_schedule_name_3x() + ) else: threadblock = self.tile_description.procedural_name() return "cutlass{p}_sm{ar}_{op}_{ex}_{tb}_{l}_align{a}".format( - p = self.prefix, - ar = self.arch, - op = opcode_class_name, - ex = self.extended_name(), - tb = threadblock, - l = self.layout_name(), - a = str(self.A.alignment)) - - # + p=self.prefix, + ar=self.arch, + op=opcode_class_name, + ex=self.extended_name(), + tb=threadblock, + l=self.layout_name(), + a=str(self.A.alignment) + ) + def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + """The full procedural name indicates architecture, extended name, tile size, and layout.""" return self.procedural_name() class GemmOperationUniversal(GemmOperationBase): def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, - epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + epilogue_functor, swizzling_functor=cutlass_bindings.IdentitySwizzle1, **kwargs): api = api_version(arch, tile_description.math_instruction.opcode_class, A.element) super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, A, B, C, epilogue_functor, swizzling_functor, - api=api, **kwargs) + api=api, **kwargs, ) if api == ApiVersion.v3x: + if swizzling_functor == cutlass_bindings.ThreadblockSwizzleStreamK: + raise Exception("Stream K is currently only supported for CUTLASS 2.x kernels") self.rt_module = GemmRTUniversal3x(self) else: - self.rt_module = GemmRTUniversal(self) + if swizzling_functor == cutlass_bindings.ThreadblockSwizzleStreamK: + self.rt_module = GemmRTUniversalStreamK(self) + else: + self.rt_module = GemmRTUniversal(self) self.argument_type = self.rt_module.argument_type self.epilogue_type = self.rt_module.epilogue_type + def device_op(self): + """ + Returns a new GemmOperationUniversal object that is constructed with emission type + ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, + any swappng performed by the kernel-emitted operation is reversed. + + :return: operation ready for device-level code emission + :rtype: GemmUniversalOperation + """ + A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) + return GemmOperationUniversal(self.arch, self.tile_description, A, B, C, + self.epilogue_functor, type(self.swizzling_functor), + emission_type=EmissionType.Device, direct_store=self.direct_store, + visitor=self.visitor) + class GemmOperationGrouped(GemmOperationBase): def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, - epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + epilogue_functor, swizzling_functor=cutlass_bindings.IdentitySwizzle1, **kwargs): super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description, A, B, C, epilogue_functor, swizzling_functor, **kwargs) - assert "precompute_mode" in kwargs.keys( - ), "missing keyword argument 'precompute_mode'." + assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'." self.precompute_mode = kwargs["precompute_mode"] self.rt_module = GemmRTGrouped(self) self.argument_type = self.rt_module.argument_type self.epilogue_type = self.rt_module.epilogue_type + def device_op(self): + """ + Returns a new GemmOperationGrouped object that is constructed with emission type + ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, + any swappng performed by the kernel-emitted operation is reversed. + + :return: operation ready for device-level code emission + :rtype: GemmOperationGrouped + """ + A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) + return GemmOperationGrouped( + self.arch, self.tile_description, A, B, C, self.epilogue_functor, + type(self.swizzling_functor), emission_type=EmissionType.Device, + direct_store=self.direct_store, precompute_mode=self.precompute_mode, ) + + ################################################################################################### # # Emits single instances of a CUTLASS device-wide operator # ################################################################################################### -# + class EmitGemmUniversalInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """Responsible for emitting a CUTLASS template definition""" - def __init__(self, operation_suffix='', direct_store=False, visitor=False): + def __init__( + self, + operation_suffix="", + direct_store=False, + visitor=False, + ): self.operation_suffix = operation_suffix self.direct_store = direct_store self.visitor = visitor @@ -1334,14 +1728,15 @@ def __init__(self, operation_suffix='', direct_store=False, visitor=False): self.includes += [ "gemm/gemm_universal_with_visitor.h", "epilogue/epilogue_visitor_with_layernorm.h", - "epilogue/epilogue_visitor_generic.h" + "epilogue/epilogue_visitor_generic.h", ] if self.direct_store: self.includes.append( - "cutlass/epilogue/threadblock/default_epilogue_direct_store.h") - self.gemm_template_interleaved = """ + "cutlass/epilogue/threadblock/default_epilogue_direct_store.h" + ) + self.gemm_template_kernel = """ // Gemm operator ${operation_name} -using ${operation_name}_base = +using ${operation_name}_base = typename cutlass::gemm::kernel::DefaultGemmUniversal< ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, @@ -1362,6 +1757,43 @@ def __init__(self, operation_suffix='', direct_store=False, visitor=False): struct ${operation_name}${operation_suffix} : public ${operation_name}_base { }; """ + + self.gemm_template_device = """ +// Gemm operator ${operation_name} +using DeviceKernel = + typename cutlass::gemm::device::GemmUniversal< + // Data type and layout of operand A + ${element_a}, ${layout_a}, + // Data type and layout of operand B + ${element_b}, ${layout_b}, + // Data type and layout of operand C + ${element_c}, ${layout_c}, + // Data type of accumulator + ${element_accumulator}, + // Class of operation + ${opcode_class}, + // Compute capability of the target kernel + ${arch}, + // Threadblock tile shape + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + // Warp tile shape + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + // Instruction shape + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + // Epilogue functor + ${epilogue_functor}, + // Swizzling function + ${swizzling_functor}, + // Number of pipeline stages + ${stages}, + // Alignment of operands A and B + ${align_a}, ${align_b}, + // Type of math operation + ${math_operation}, + // Complex transform types of operands A and B + ${transform_a}, ${transform_b} + >; +""" self.gemm_template_direct_store = """ // Gemm operator ${operation_name} using ${operation_name}_default = @@ -1431,7 +1863,6 @@ def __init__(self, operation_suffix='', direct_store=False, visitor=False): public ${operation_name}_base { }; """ - # def instance_template(self): return """ ${compile_guard_start} @@ -1441,87 +1872,68 @@ def instance_template(self): ${compile_guard_end} """ - # def emit(self, operation): - threadblock_shape = operation.tile_description.threadblock_shape warp_count = operation.tile_description.warp_count - warp_shape = [threadblock_shape[idx] // warp_count[idx] - for idx in range(3)] - - # transpose_layouts = { - # cutlass.layout.ColumnMajorcutlass.layout.ColumnMajor, - # cutlass.layout.RowMajorcutlass.layout.RowMajor - # } + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - # if operation.A.layout in transpose_layouts.keys() and \ - # operation.B.layout in transpose_layouts.keys() and \ - # operation.C.layout in transpose_layouts.keys(): - - # instance_layout_A = transpose_layouts[operation.A.layout] - # instance_layout_B = transpose_layouts[operation.B.layout] - # instance_layout_C = transpose_layouts[operation.C.layout] - - # gemm_template = self.gemm_template - # else: instance_layout_A, instance_layout_B, instance_layout_C = \ (operation.A.layout, operation.B.layout, operation.C.layout) - if self.direct_store: - gemm_template = self.gemm_template_direct_store - elif self.visitor: - gemm_template = self.gemm_template_visitor + + if operation.emission_type == EmissionType.Kernel: + if self.direct_store: + gemm_template = self.gemm_template_direct_store + elif self.visitor: + gemm_template = self.gemm_template_visitor + else: + gemm_template = self.gemm_template_kernel else: - gemm_template = self.gemm_template_interleaved - # + gemm_template = self.gemm_template_device values = { - 'operation_name': operation.procedural_name(), - 'operation_suffix': self.operation_suffix, - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[instance_layout_B], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'swizzling_functor': operation.swizzling_functor.tag(), - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "swizzling_functor": operation.swizzling_functor.tag(), + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], } if self.visitor: - values['epilogue_visitor'] = operation.epilogue_functor.emit(operation) - values['elementwise_epilogue_functor'] = operation.epilogue_functor.elementwise_functor.emit() + values["epilogue_visitor"] = operation.epilogue_functor.emit(operation) + values["elementwise_epilogue_functor"] = operation.epilogue_functor.elementwise_functor.emit() else: - values['epilogue_functor'] = operation.epilogue_functor.emit() + values["epilogue_functor"] = operation.epilogue_functor.emit() return SubstituteTemplate(gemm_template, values) -################################################################################################### - -# - class EmitGemmGroupedInstance: - ''' Responsible for emitting a CUTLASS template definition''' + """Responsible for emitting a CUTLASS template definition""" - def __init__(self, operation_suffix=''): + def __init__(self, operation_suffix=""): self.operation_suffix = operation_suffix self.includes = [ "cutlass/cutlass.h", @@ -1530,9 +1942,9 @@ def __init__(self, operation_suffix=''): "cutlass/arch/mma.h", "cutlass/layout/matrix.h", "cutlass/gemm/kernel/gemm_grouped.h", - "cutlass/gemm/kernel/default_gemm_grouped.h" + "cutlass/gemm/kernel/default_gemm_grouped.h", ] - self.gemm_template = """ + self.gemm_template_kernel = """ // Gemm operator ${operation_name} using ${operation_name}_base = typename cutlass::gemm::kernel::DefaultGemmGrouped< @@ -1556,8 +1968,13 @@ def __init__(self, operation_suffix=''): struct ${operation_name}${operation_suffix} : public ${operation_name}_base { }; """ + self.gemm_template_device = ( + self.gemm_template_kernel + + """ +using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>; +""" + ) - # def instance_template(self): return """ ${compile_guard_start} @@ -1567,52 +1984,53 @@ def instance_template(self): ${compile_guard_end} """ - # def emit(self, operation): - threadblock_shape = operation.tile_description.threadblock_shape warp_count = operation.tile_description.warp_count - warp_shape = [threadblock_shape[idx] // warp_count[idx] - for idx in range(3)] + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] instance_layout_A, instance_layout_B, instance_layout_C = \ (operation.A.layout, operation.B.layout, operation.C.layout) - # # Support built-in epilogue functors or user-defined functions epilogue_functor = operation.epilogue_functor.emit() - + values = { - 'operation_name': operation.procedural_name(), - 'operation_suffix': self.operation_suffix, - 'element_a': DataTypeTag[operation.A.element], - 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], - 'layout_b': LayoutTag[instance_layout_B], - 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], - 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], - 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), - 'warp_shape_m': str(warp_shape[0]), - 'warp_shape_n': str(warp_shape[1]), - 'warp_shape_k': str(warp_shape[2]), - 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), - 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), - 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), - 'epilogue_functor': epilogue_functor, - 'swizzling_functor': operation.swizzling_functor.tag(), - 'stages': str(operation.tile_description.stages), - 'align_a': str(operation.A.alignment), - 'align_b': str(operation.B.alignment), - 'transform_a': ComplexTransformTag[operation.A.complex_transform], - 'transform_b': ComplexTransformTag[operation.B.complex_transform], - 'precompute_mode': SchedulerModeTag[operation.precompute_mode], - 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "epilogue_functor": epilogue_functor, + "swizzling_functor": operation.swizzling_functor.tag(), + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "precompute_mode": SchedulerModeTag[operation.precompute_mode], + "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], } - return SubstituteTemplate(self.gemm_template, values) + if operation.emission_type == EmissionType.Kernel: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + return SubstituteTemplate(gemm_template, values) diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py new file mode 100644 index 00000000..7760f6ed --- /dev/null +++ b/python/cutlass/backend/library.py @@ -0,0 +1,714 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Common data types and string names for them. This file is similar to /tools/library/scripts/library.py, +but uses the Pybind-bound CUTLASS data types as many keys to the dictionary. +""" + +import enum + +import cutlass_bindings +from cutlass import KernelScheduleType + + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + + +ShortDataTypeNames = { + cutlass_bindings.int32: "i", + cutlass_bindings.float16: "h", + cutlass_bindings.float32: "s", + cutlass_bindings.float64: "d", + cutlass_bindings.dtype.cf32: "c", + cutlass_bindings.dtype.cf64: "z", +} + + +DataTypeNames = { + cutlass_bindings.dtype.b1: "b1", + cutlass_bindings.dtype.u4: "u4", + cutlass_bindings.dtype.u8: "u8", + cutlass_bindings.dtype.u16: "u16", + cutlass_bindings.dtype.u32: "u32", + cutlass_bindings.dtype.u64: "u64", + cutlass_bindings.dtype.s4: "s4", + cutlass_bindings.int8: "s8", + cutlass_bindings.dtype.s16: "s16", + cutlass_bindings.int32: "s32", + cutlass_bindings.dtype.s64: "s64", + cutlass_bindings.float16: "f16", + cutlass_bindings.bfloat16: "bf16", + cutlass_bindings.float32: "f32", + cutlass_bindings.tfloat32: "tf32", + cutlass_bindings.float64: "f64", + cutlass_bindings.dtype.cf16: "cf16", + cutlass_bindings.dtype.cbf16: "cbf16", + cutlass_bindings.dtype.cf32: "cf32", + cutlass_bindings.dtype.ctf32: "ctf32", + cutlass_bindings.dtype.cf64: "cf64", + cutlass_bindings.dtype.cu4: "cu4", + cutlass_bindings.dtype.cu8: "cu8", + cutlass_bindings.dtype.cu16: "cu16", + cutlass_bindings.dtype.cu32: "cu32", + cutlass_bindings.dtype.cu64: "cu64", + cutlass_bindings.dtype.cs4: "cs4", + cutlass_bindings.dtype.cs8: "cs8", + cutlass_bindings.dtype.cs16: "cs16", + cutlass_bindings.dtype.cs32: "cs32", + cutlass_bindings.dtype.cs64: "cs64", +} + + +DataTypeTag = { + cutlass_bindings.dtype.b1: "cutlass::uint1b_t", + cutlass_bindings.dtype.u4: "cutlass::uint4b_t", + cutlass_bindings.dtype.u8: "uint8_t", + cutlass_bindings.dtype.u16: "uint16_t", + cutlass_bindings.dtype.u32: "uint32_t", + cutlass_bindings.dtype.u64: "uint64_t", + cutlass_bindings.dtype.s4: "cutlass::int4b_t", + cutlass_bindings.int8: "int8_t", + cutlass_bindings.dtype.s16: "int16_t", + cutlass_bindings.int32: "int32_t", + cutlass_bindings.dtype.s64: "int64_t", + cutlass_bindings.float16: "cutlass::half_t", + cutlass_bindings.bfloat16: "cutlass::bfloat16_t", + cutlass_bindings.float32: "float", + cutlass_bindings.tfloat32: "cutlass::tfloat32_t", + cutlass_bindings.float64: "double", + cutlass_bindings.dtype.cf16: "cutlass::complex", + cutlass_bindings.dtype.cbf16: "cutlass::complex", + cutlass_bindings.dtype.cf32: "cutlass::complex", + cutlass_bindings.dtype.ctf32: "cutlass::complex", + cutlass_bindings.dtype.cf64: "cutlass::complex", + cutlass_bindings.dtype.cu4: "cutlass::complex", + cutlass_bindings.dtype.cu8: "cutlass::complex", + cutlass_bindings.dtype.cu16: "cutlass::complex", + cutlass_bindings.dtype.cu32: "cutlass::complex", + cutlass_bindings.dtype.cu64: "cutlass::complex", + cutlass_bindings.dtype.cs4: "cutlass::complex", + cutlass_bindings.dtype.cs8: "cutlass::complex", + cutlass_bindings.dtype.cs16: "cutlass::complex", + cutlass_bindings.dtype.cs32: "cutlass::complex", + cutlass_bindings.dtype.cs64: "cutlass::complex", +} + + +DataTypeSize = { + cutlass_bindings.dtype.b1: 1, + cutlass_bindings.dtype.u4: 4, + cutlass_bindings.dtype.u8: 8, + cutlass_bindings.dtype.u16: 16, + cutlass_bindings.dtype.u32: 32, + cutlass_bindings.dtype.u64: 64, + cutlass_bindings.dtype.s4: 4, + cutlass_bindings.int8: 8, + cutlass_bindings.dtype.s16: 16, + cutlass_bindings.int32: 32, + cutlass_bindings.dtype.s64: 64, + cutlass_bindings.float16: 16, + cutlass_bindings.bfloat16: 16, + cutlass_bindings.float32: 32, + cutlass_bindings.tfloat32: 32, + cutlass_bindings.float64: 64, + cutlass_bindings.dtype.cf16: 32, + cutlass_bindings.dtype.cbf16: 32, + cutlass_bindings.dtype.cf32: 64, + cutlass_bindings.dtype.ctf32: 32, + cutlass_bindings.dtype.cf64: 128, + cutlass_bindings.dtype.cu4: 8, + cutlass_bindings.dtype.cu8: 16, + cutlass_bindings.dtype.cu16: 32, + cutlass_bindings.dtype.cu32: 64, + cutlass_bindings.dtype.cu64: 128, + cutlass_bindings.dtype.cs4: 8, + cutlass_bindings.dtype.cs8: 16, + cutlass_bindings.dtype.cs16: 32, + cutlass_bindings.dtype.cs32: 64, + cutlass_bindings.dtype.cs64: 128, +} + + +class DataTypeSizeBytes: + """ + Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the + data type key is less than a full byte or a non-integer number of bytes. + """ + + @staticmethod + def __class_getitem__(datatype): + """ + Returns the number of bytes in size the data type is. Raises an exception if the data type + is either less than a full byte or a non-integer number of bytes in size. + + :param datatype: data type to query + + :return: number of bytes the data type occupies + :rtype: int + """ + bits = DataTypeSize[datatype] + if bits < 8: + raise Exception( + "Data type {} is less than one byte in size.".format(datatype) + ) + elif bits % 8 != 0: + raise Exception( + "Data type {} is not an integer number of bytes.".format(datatype) + ) + return bits // 8 + + +ComplexTransformTag = { + cutlass_bindings.complex_transform.none: "cutlass::ComplexTransform::kNone", + cutlass_bindings.complex_transform.conj: "cutlass::ComplexTransform::kConjugate", +} + + +RealComplexBijection = [ + (cutlass_bindings.float16, cutlass_bindings.dtype.cf16), + (cutlass_bindings.float32, cutlass_bindings.dtype.cf32), + (cutlass_bindings.float64, cutlass_bindings.dtype.cf64), +] + + +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + + +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return cutlass_bindings.dtype.invalid + + +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return cutlass_bindings.dtype.invalid + + +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum_auto() + gaussian = enum_auto() + + +class MathOperation(enum.Enum): + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + xor_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_fast_f32 = enum_auto() + multiply_add_complex_fast_f32 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + + +MathOperationNames = { + MathOperation.multiply_add: "multiply_add", + MathOperation.multiply_add_saturate: "multiply_add_saturate", + MathOperation.xor_popc: "xor_popc", + MathOperation.multiply_add_fast_bf16: "multiply_add_fast_bf16", + MathOperation.multiply_add_fast_f16: "multiply_add_fast_f16", + MathOperation.multiply_add_fast_f32: "multiply_add_fast_f32", + MathOperation.multiply_add_complex_fast_f32: "multiply_add_complex_fast_f32", + MathOperation.multiply_add_complex: "multiply_add_complex", + MathOperation.multiply_add_complex_gaussian: "multiply_add_complex_gaussian", +} + + +MathOperationTag = { + MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", + MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", + MathOperation.xor_popc: "cutlass::arch::OpXorPopc", + MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16", + MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16", + MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32", + MathOperation.multiply_add_complex_fast_f32: "cutlass::arch::OpMultiplyAddComplexFastF32", + MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex", + MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex", +} + + +LayoutTag = { + cutlass_bindings.ColumnMajor: "cutlass::layout::ColumnMajor", + cutlass_bindings.RowMajor: "cutlass::layout::RowMajor", + cutlass_bindings.layout.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>", + cutlass_bindings.layout.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>", + cutlass_bindings.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>", + cutlass_bindings.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>", + cutlass_bindings.layout.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>", + cutlass_bindings.layout.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>", + cutlass_bindings.TensorNHWC: "cutlass::layout::TensorNHWC", + cutlass_bindings.layout.TensorNDHWC: "cutlass::layout::TensorNDHWC", + cutlass_bindings.layout.TensorNCHW: "cutlass::layout::TensorNCHW", + cutlass_bindings.layout.TensorNGHWC: "cutlass::layout::TensorNGHWC", + cutlass_bindings.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>", + cutlass_bindings.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>", + cutlass_bindings.layout.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>", + cutlass_bindings.layout.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>", +} + + +TransposedLayout = { + cutlass_bindings.ColumnMajor: cutlass_bindings.RowMajor, + cutlass_bindings.RowMajor: cutlass_bindings.ColumnMajor, + cutlass_bindings.layout.ColumnMajorInterleaved2: cutlass_bindings.layout.RowMajorInterleaved2, + cutlass_bindings.layout.RowMajorInterleaved2: cutlass_bindings.layout.ColumnMajorInterleaved2, + cutlass_bindings.ColumnMajorInterleaved32: cutlass_bindings.RowMajorInterleaved32, + cutlass_bindings.RowMajorInterleaved32: cutlass_bindings.ColumnMajorInterleaved32, + cutlass_bindings.layout.ColumnMajorInterleaved64: cutlass_bindings.layout.RowMajorInterleaved64, + cutlass_bindings.layout.RowMajorInterleaved64: cutlass_bindings.layout.ColumnMajorInterleaved64, + cutlass_bindings.TensorNHWC: cutlass_bindings.TensorNHWC, +} + + +ShortLayoutTypeNames = { + cutlass_bindings.ColumnMajor: "n", + cutlass_bindings.layout.ColumnMajorInterleaved2: "n2", + cutlass_bindings.ColumnMajorInterleaved32: "n32", + cutlass_bindings.layout.ColumnMajorInterleaved64: "n64", + cutlass_bindings.RowMajor: "t", + cutlass_bindings.layout.RowMajorInterleaved2: "t2", + cutlass_bindings.RowMajorInterleaved32: "t32", + cutlass_bindings.layout.RowMajorInterleaved64: "t64", + cutlass_bindings.TensorNHWC: "nhwc", + cutlass_bindings.layout.TensorNDHWC: "ndhwc", + cutlass_bindings.layout.TensorNCHW: "nchw", + cutlass_bindings.layout.TensorNGHWC: "nghwc", + cutlass_bindings.TensorNC32HW32: "nc32hw32", + cutlass_bindings.layout.TensorNC64HW64: "nc64hw64", + cutlass_bindings.TensorC32RSK32: "c32rsk32", + cutlass_bindings.layout.TensorC64RSK64: "c64rsk64", +} + + +ShortComplexLayoutNames = { + (cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.none): "n", + (cutlass_bindings.ColumnMajor, cutlass_bindings.complex_transform.conj): "c", + (cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.none): "t", + (cutlass_bindings.RowMajor, cutlass_bindings.complex_transform.conj): "h", +} + + +OpcodeClassNames = { + cutlass_bindings.OpClass.Simt: "simt", + cutlass_bindings.OpClass.TensorOp: "tensorop", + cutlass_bindings.OpClass.WmmaTensorOp: "wmma_tensorop", + cutlass_bindings.OpClass.SparseTensorOp: "sptensorop", +} + + +OpcodeClassTag = { + cutlass_bindings.OpClass.Simt: "cutlass::arch::OpClassSimt", + cutlass_bindings.OpClass.TensorOp: "cutlass::arch::OpClassTensorOp", + cutlass_bindings.OpClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp", + cutlass_bindings.OpClass.SparseTensorOp: "cutlass::arch::OpClassSparseTensorOp", +} + + +class OperationKind(enum.Enum): + Gemm = enum_auto() + Conv2d = enum_auto() + Conv3d = enum_auto() + + +OperationKindNames = { + OperationKind.Gemm: "gemm", + OperationKind.Conv2d: "conv2d", + OperationKind.Conv3d: "conv3d", +} + + +ArchitectureNames = { + 50: "maxwell", + 60: "pascal", + 61: "pascal", + 70: "volta", + 75: "turing", + 80: "ampere", + 90: "hopper", +} + + +SharedMemPerCC = { + 70: 96 << 10, # 96KB of SMEM + 72: 96 << 10, # 96KB of SMEM + 75: 64 << 10, # 64KB of SMEM + 80: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver + 86: 100 << 10, # 100KB of SMEM + 87: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver + 89: 100 << 10, # 100KB of SMEM + 90: 227 << 10, # 228KB of SMEM - 1KB reserved for the driver +} + + +class GemmKind(enum.Enum): + Gemm = enum_auto() + Sparse = enum_auto() + Universal = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() + Grouped = enum_auto() + + +GemmKindNames = { + GemmKind.Gemm: "gemm", + GemmKind.Sparse: "spgemm", + GemmKind.Universal: "gemm", + GemmKind.PlanarComplex: "gemm_planar_complex", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", + GemmKind.Grouped: "gemm_grouped", +} + + +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + Horizontal = enum_auto() + BatchedIdentity1 = enum_auto() + StridedDgradIdentity1 = enum_auto() + StridedDgradIdentity4 = enum_auto() + StridedDgradHorizontal = enum_auto() + + +SwizzlingFunctorTag = { + cutlass_bindings.IdentitySwizzle1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", + SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", + SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", + SwizzlingFunctor.Horizontal: "cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle", + SwizzlingFunctor.BatchedIdentity1: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle", + SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>", + SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>", + SwizzlingFunctor.StridedDgradHorizontal: "cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle", +} + + +class SchedulerMode(enum.Enum): + Device = (enum_auto(),) + Host = enum_auto() + + +SchedulerModeTag = { + SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", + SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute", +} + + +ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"} + + +ConvKindTag = { + cutlass_bindings.conv.Operator.fprop: "cutlass::conv::Operator::kFprop", + cutlass_bindings.conv.Operator.dgrad: "cutlass::conv::Operator::kDgrad", + cutlass_bindings.conv.Operator.wgrad: "cutlass::conv::Operator::kWgrad", +} + + +ConvKindNames = { + cutlass_bindings.conv.Operator.fprop: "fprop", + cutlass_bindings.conv.Operator.dgrad: "dgrad", + cutlass_bindings.conv.Operator.wgrad: "wgrad", +} + + +IteratorAlgorithmTag = { + cutlass_bindings.conv.IteratorAlgorithm.analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic", + cutlass_bindings.conv.IteratorAlgorithm.optimized: "cutlass::conv::IteratorAlgorithm::kOptimized", + cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "cutlass::conv::IteratorAlgorithm::kFixedChannels", + cutlass_bindings.conv.IteratorAlgorithm.few_channels: "cutlass::conv::IteratorAlgorithm::kFewChannels", +} + + +IteratorAlgorithmNames = { + cutlass_bindings.conv.IteratorAlgorithm.analytic: "analytic", + cutlass_bindings.conv.IteratorAlgorithm.optimized: "optimized", + cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: "fixed_channels", + cutlass_bindings.conv.IteratorAlgorithm.few_channels: "few_channels", +} + + +class StrideSupport(enum.Enum): + Strided = enum_auto() + Unity = enum_auto() + + +StrideSupportTag = { + StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided", + StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity", +} + + +StrideSupportNames = { + StrideSupport.Strided: "", + StrideSupport.Unity: "unity_stride", +} + + +class ConvMode(enum.Enum): + CrossCorrelation = enum_auto() + Convolution = enum_auto() + + +ConvModeTag = { + ConvMode.CrossCorrelation: "cutlass::conv::Mode::kCrossCorrelation", + ConvMode.Convolution: "cutlass::conv::Mode::kConvolution", +} + + +class MathInstruction: + """ + Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel + """ + + def __init__( + self, + instruction_shape, + element_a, + element_b, + element_accumulator, + opcode_class=cutlass_bindings.OpClass.Simt, + math_operation=MathOperation.multiply_add, + ): + """ + :param instruction_shape: size of the [M, N, K] dimensions of the instruction + :type instruction_shape: list or tuple + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_accumulator: data type used in accumulation + :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core) + :type opcode_class: cutlass_bindings.OpClass + :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate) + :type math_operation: MathOperation + """ + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + + +class TileDescription: + """ + Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, + stage count, and math instruction specification + """ + + def __init__( + self, + threadblock_shape, + stages, + warp_count, + math_instruction, + cluster_shape=[1, 1, 1], + kernel_schedule: KernelScheduleType = None + ): + """ + :param threadblock_shape: shape of a threadblock tyle + :type threadblock_shape: list or tuple + :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum + number of stages that can be supported for an operation on a given architecture will be computed at a later time + :type stages: int or None + :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile + :type warp_count: list, tuple, or None + :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands + :type math_instruction: MathInstruction + :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster + :param kernel_schedule: type of kernel schedule to use (only available for SM90+) + :type kernel_schedule: cutlass.backend.KernelScheduleType + """ + self.threadblock_shape = threadblock_shape + self.cluster_shape = cluster_shape + self.kernel_schedule = kernel_schedule + self.stages: int = stages + + self.math_instruction = math_instruction + + # Number of warps along x, y, z directions + self.warp_count = warp_count + + @property + def num_threads(self): + """ + Returns the number of threads in the threadblock + + :return: number of threads in the threadblock + :rtype: int or None (if warp count is None) + """ + if self.warp_count is not None: + threads = 32 + for cnt in self.warp_count: + threads *= cnt + return threads + return None + + def procedural_name(self): + """ + Returns a name identifying the tile description + + :return: name identifying the tile description + :rtype: int + """ + emit_stages = 0 if self.stages is None else self.stages + name = "%dx%dx%d_%dx%d_%dx%d" % ( + self.cluster_shape[0], + self.cluster_shape[1], + self.cluster_shape[2], + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + emit_stages + ) + + return name + + def __str__(self): + """ + Returns a string with containing each of the tile description's values + + :return: contents of tile description + :rtype: str + """ + schedule = KernelScheduleType.ScheduleAuto + if self.kernel_schedule is not None: + schedule = self.kernel_schedule + return f""" +{{ + ClusterShape: {self.cluster_shape} + ThreadblockShape: {self.threadblock_shape} + WarpCount: {self.warp_count} + Stages: {self.stages if self.stages is not None else 'Auto'} + Kernel schedule: {schedule.name} +}}""" + + +class TensorDescription: + def __init__(self, element, layout, alignment=1, + complex_transform=cutlass_bindings.complex_transform.none): + self.element = element + self.layout = layout + self.alignment = min(128 // DataTypeSize[self.element], alignment) + self.complex_transform = complex_transform + + +def CalculateSmemUsagePerStage(operation): + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: cutlass.backend.Operation + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = operation.tile_description.threadblock_shape + + if operation.operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[operation.A.element] * m * k // 8) + + (DataTypeSize[operation.B.element] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception("Unsupported operation kind {}.".format(operation.operation_kind)) + + +def CalculateSmemUsage(operation): + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: cutlass.backend.Operation + + :return: int + """ + return operation.tile_description.stages * CalculateSmemUsagePerStage(operation) + + +class ApiVersion(enum.Enum): + """ + Differentiate between CUTLASS 2.x and 3.x API versions + """ + + v2x = enum_auto() + v3x = enum_auto() + + +def api_version(arch, opclass, datatype): + """ + Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x + or 3.x for code emission. + + :param arch: compute capability of device on which to run + :type arch: int + :param opclass: class of the operation being performed + :type opclass: cutlass_bindings.OpClass + :param datatype: data type to be used in operation (assumes that ElementA and ElementB are the same) + + :return: API version to be used in code emission + :rtype: ApiVersion + """ + if (arch >= 90 and + opclass == cutlass_bindings.OpClass.TensorOp and + (datatype != cutlass_bindings.float64)): + return ApiVersion.v3x + else: + return ApiVersion.v2x diff --git a/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py b/python/cutlass/backend/memory_manager.py similarity index 100% rename from tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py rename to python/cutlass/backend/memory_manager.py index fa554744..7c759e64 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +++ b/python/cutlass/backend/memory_manager.py @@ -30,8 +30,8 @@ # ################################################################################################# -import rmm import numpy as np +import rmm class PoolMemoryManager: diff --git a/tools/library/scripts/pycutlass/src/pycutlass/operation.py b/python/cutlass/backend/operation.py similarity index 80% rename from tools/library/scripts/pycutlass/src/pycutlass/operation.py rename to python/cutlass/backend/operation.py index 9184e514..8a4d57d6 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/operation.py +++ b/python/cutlass/backend/operation.py @@ -31,19 +31,15 @@ ################################################################################ import ctypes -from cuda import cuda -from pycutlass.utils.device import device_cc -from cuda import __version__ as __cuda_version__ -_version_splits = [int(x) for x in __cuda_version__.split('.')] -supports_cluster_launch = device_cc() >= 90 and (_version_splits[0] > 11 or (_version_splits[0] == 11 and _version_splits[1] >= 8)) +from cuda import __version__, cuda +from cutlass.backend.utils.device import device_cc -################################################################################ -# -# Launch configuration -# -################################################################################ +_version_splits = [int(x) for x in __version__.split("rc")[0].split(".")] +supports_cluster_launch = device_cc() >= 90 and ( + _version_splits[0] > 11 or (_version_splits[0] == 11 and _version_splits[1] >= 8) +) class LaunchConfiguration: @@ -53,53 +49,35 @@ def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0): self.shared_memory_capacity = smem -################################################################################ -# -# Base class for an executable operation -# -# ############################################################################## - class ExecutableOperation: - ''' - ''' - def __init__(self, operation): self.operation = operation self.module = None self.kernel = None - # def name(self): return self.operation.procedural_name() - # def emit(self): - return '' + return "" - # def can_implement(self, configuration, arguments): raise NotImplementedError() - # def get_host_workspace_size(self, arguments): raise NotImplementedError() - # def get_device_workspace_size(self, arguments): raise NotImplementedError() - # def plan(self, arguments): raise NotImplementedError() - # def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=cuda.CUstream(0)): raise NotImplementedError() - - # def run_with_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)): - if hasattr(self.operation, 'tile_description') and hasattr(self.operation.tile_description, 'cluster_shape'): + if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"): attr = cuda.CUlaunchAttribute() attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION @@ -122,11 +100,10 @@ def run_with_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0 config.attrs = attrs config.numAttrs = len(attrs) - err, = cuda.cuLaunchKernelEx(config, f=self.kernel, kernelParams=kernel_params, extra=0) + err, = cuda.cuLaunchKernelEx( + config, f=self.kernel, kernelParams=kernel_params, extra=0) return err - - # def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstream(0)): err, = cuda.cuLaunchKernel( self.kernel, @@ -139,11 +116,8 @@ def run_without_clusters(self, launch_config, kernel_params, stream=cuda.CUstrea return err - - # def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)): - cArg = (ctypes.c_char * len(host_workspace) - ).from_buffer(host_workspace) + cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace) packed = (ctypes.c_void_p * 1)() packed[0] = ctypes.addressof(cArg) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/parser.py b/python/cutlass/backend/parser.py similarity index 62% rename from tools/library/scripts/pycutlass/src/pycutlass/parser.py rename to python/cutlass/backend/parser.py index 6eb02bfb..be28bc0c 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/parser.py +++ b/python/cutlass/backend/parser.py @@ -30,16 +30,30 @@ # ################################################################################ +import ast +import ctypes +import inspect +import textwrap from typing import Generic, TypeVar -from treelib import Tree -import numpy as np -from pycutlass import * -import pycutlass +from cuda import cuda, cudart +import numpy as np +from treelib import Tree -import ast -import textwrap -import inspect +from cutlass.backend.epilogue import ( + AccumulatorOp, + BinaryOp, + ColumnBroadcastOp, + ColumnReductionOp, + RowBroadcastOp, + RowReductionOp, + TensorInputOp, + TensorOutputOp, + UnaryOp, +) +from cutlass.backend.frontend import NumpyFrontend +from cutlass.backend.utils.software import SubstituteTemplate +import cutlass.backend as backend ################################################################################ # Type annotation for input arguments @@ -48,9 +62,11 @@ Ttype = TypeVar("Ttype") Dtype = TypeVar("Dtype") + class NDArray(np.ndarray, Generic[Ttype, Dtype]): pass + ################################################################################ # Operations ################################################################################ @@ -59,18 +75,25 @@ class NDArray(np.ndarray, Generic[Ttype, Dtype]): ast.Add: "Add", ast.Div: "Div", ast.Eq: "Equal", - ast.Mult: "Mult" + ast.Mult: "Mult", } + ################################################################################ # AST Node abstractions ################################################################################ class UnaryNode: cnt = 0 + # Concept: this is created by the BinOp Node in python ast - def __init__(self, - element_accumulator, element_compute, elements_per_access, - node, args) -> None: + def __init__( + self, + element_accumulator, + element_compute, + elements_per_access, + node, + args, + ) -> None: if isinstance(node, BinOpNode): self.op = node.op elif isinstance(node, ast.Call): @@ -89,34 +112,46 @@ def __init__(self, self.type = "tensor" - self.epilogue_op = getattr(pycutlass, self.op)(element_compute) + self.epilogue_op = getattr(backend, self.op)(element_compute) # data types self.element_accumulator = element_accumulator self.element_compute = element_compute self.elements_per_access = elements_per_access - + def get_epilogue_node(self, visitors): self.epilogue_node = UnaryOp( - self.element_accumulator, self.element_compute, - self.elements_per_access, *visitors, self.epilogue_op) - + self.element_accumulator, + self.element_compute, + self.elements_per_access, + *visitors, + self.epilogue_op, + ) + def get_argument(self, visitor_args, kwargs): epilogue_ops = [] for arg in self.args: try: epilogue_ops.append(kwargs[arg]) except: - epilogue_ops.append(arg) # direct arguments like constant - self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(*epilogue_ops), *visitor_args) + epilogue_ops.append(arg) # direct arguments like constant + self.argument = self.epilogue_node.argument_type( + self.epilogue_op.argument_type(*epilogue_ops), + *visitor_args, + ) class BinOpNode: cnt = 0 + # Concept: this is created by the BinOp Node in python ast - def __init__(self, - element_accumulator, element_compute, elements_per_access, - node) -> None: + def __init__( + self, + element_accumulator, + element_compute, + elements_per_access, + node, + ) -> None: self.op = operators[type(node.op)] self.tag = "Binary" + self.op + str(BinOpNode.cnt) self.id = self.op + str(BinOpNode.cnt) @@ -125,20 +160,27 @@ def __init__(self, self.type = "tensor" - self.epilogue_op = getattr(pycutlass, "Vector"+self.op)(element_compute) + self.epilogue_op = getattr(backend, "Vector" + self.op)(element_compute) # data types self.element_accumulator = element_accumulator self.element_compute = element_compute self.elements_per_access = elements_per_access - + def get_epilogue_node(self, visitors): self.epilogue_node = BinaryOp( - self.element_accumulator, self.element_compute, - self.elements_per_access, *visitors, self.epilogue_op) - + self.element_accumulator, + self.element_compute, + self.elements_per_access, + *visitors, + self.epilogue_op, + ) + def get_argument(self, visitor_args, kwargs): - self.argument = self.epilogue_node.argument_type(self.epilogue_op.argument_type(self.args), *visitor_args) + self.argument = self.epilogue_node.argument_type( + self.epilogue_op.argument_type(self.args), + *visitor_args, + ) class NameNode: @@ -150,6 +192,7 @@ def __init__(self, node) -> None: self.id = node.targets[0].id self.tag = self.id + class ScalarInputNode(NameNode): # Concept: scalar def __init__(self, node) -> None: @@ -157,10 +200,15 @@ def __init__(self, node) -> None: self.tag = "Scalar:" + self.tag self.type = "scalar" + class AccumulatorNode(NameNode): # Concept: VisitorOpAccumulator - def __init__(self, - element_accumulator, elements_per_access, node) -> None: + def __init__( + self, + element_accumulator, + elements_per_access, + node, + ) -> None: super().__init__(node) self.tag = "Accum:" + self.tag self.type = "tensor" @@ -170,11 +218,14 @@ def __init__(self, def get_epilogue_node(self, visitors): self.epilogue_node = AccumulatorOp( - self.element_accumulator, self.elements_per_access) - + self.element_accumulator, + self.elements_per_access, + ) + def get_argument(self, visitor_args, kwargs): self.argument = self.epilogue_node.argument_type() + class TensorInputNode(NameNode): # Concept: VisitorOpTensorInput def __init__(self, element_accumulator, node) -> None: @@ -182,47 +233,72 @@ def __init__(self, element_accumulator, node) -> None: self.tag = "TensorInput:" + self.tag self.type = "tensor" self.element_accumulator = element_accumulator - + def get_epilogue_node(self, *args): self.epilogue_node = TensorInputOp(self.element_accumulator) - + def get_argument(self, visitor_args, kwargs): self.argument = self.epilogue_node.argument_type( - kwargs[self.id + "_ptr"], kwargs["problem_size"][1], - kwargs["problem_size"][0] * kwargs["problem_size"][1]) + kwargs[self.id + "_ptr"], + kwargs["problem_size"][1], + kwargs["problem_size"][0] * kwargs["problem_size"][1], + ) + class RowBroadcastNode(NameNode): # Concept: VisitorOpRowBroadcast - def __init__(self, element_accumulator, element_fragment, node) -> None: + def __init__( + self, + element_accumulator, + element_fragment, + node, + ) -> None: super().__init__(node) # self.tag = "RowBroadcast:" + self.tag self.type = "tensor" self.element_accumulator = element_accumulator self.element_fragment = element_fragment - + def get_epilogue_node(self, *args): self.epilogue_node = RowBroadcastOp( - self.element_accumulator, self.element_fragment) - + self.element_accumulator, + self.element_fragment, + ) + def get_argument(self, visitor_args, kwargs): - self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1]) + self.argument = self.epilogue_node.argument_type( + kwargs[self.id + "_ptr"], + kwargs["problem_size"][1], + ) + class ColumnBroadcastNode(NameNode): # Concept: VisitorOpColumnBroadcast - def __init__(self, element_accumulator, element_fragment, node) -> None: + def __init__( + self, + element_accumulator, + element_fragment, + node, + ) -> None: super().__init__(node) self.tag = "ColumnBroadcast:" + self.tag self.type = "tensor" self.element_accumulator = element_accumulator self.element_fragment = element_fragment - + def get_epilogue_node(self, *args): self.epilogue_node = ColumnBroadcastOp( - self.element_accumulator, self.element_fragment) - + self.element_accumulator, + self.element_fragment, + ) + def get_argument(self, visitor_args, kwargs): - self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][0]) + self.argument = self.epilogue_node.argument_type( + kwargs[self.id + "_ptr"], + kwargs["problem_size"][0], + ) + class TensorOutputNode(NameNode): # Concept: VisitorOpTensorOutput @@ -234,14 +310,26 @@ def __init__(self, element_accumulator, node) -> None: def get_epilogue_node(self, visitors): self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors) - + def get_argument(self, visitor_args, kwargs): - self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], kwargs["problem_size"][1], *visitor_args, kwargs["problem_size"][0] * kwargs["problem_size"][1]) + self.argument = self.epilogue_node.argument_type( + kwargs[self.id + "_ptr"], + kwargs["problem_size"][1], + *visitor_args, + kwargs["problem_size"][0] * kwargs["problem_size"][1], + ) + class RowReductionNode: # Concept: RowReductionOp - def __init__(self, element_accumulator, element_reduction, - element_reduction_accumulator, id, factor) -> None: + def __init__( + self, + element_accumulator, + element_reduction, + element_reduction_accumulator, + id, + factor, + ) -> None: # self.id = id self.tag = "RowReduction:" + self.id @@ -250,22 +338,36 @@ def __init__(self, element_accumulator, element_reduction, self.element_reduction = element_reduction self.element_reduction_accumulator = element_reduction_accumulator self.factor = factor - + def get_epilogue_node(self, visitors): self.epilogue_node = RowReductionOp( - self.element_accumulator, self.element_reduction, - self.element_reduction_accumulator, *visitors) - + self.element_accumulator, + self.element_reduction, + self.element_reduction_accumulator, + *visitors, + ) + def get_batch_stride(self, problem_size): return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor) - + def get_argument(self, visitor_args, kwargs): - self.argument = self.epilogue_node.argument_type(kwargs[self.id + "_ptr"], *visitor_args, self.get_batch_stride(kwargs["problem_size"])) + self.argument = self.epilogue_node.argument_type( + kwargs[self.id + "_ptr"], + *visitor_args, + self.get_batch_stride(kwargs["problem_size"]), + ) + class ColumnReductionNode: # Concept: ColumnReductionOp - def __init__(self, element_accumulator, element_reduction, - element_reduction_accumulator, id, factor) -> None: + def __init__( + self, + element_accumulator, + element_reduction, + element_reduction_accumulator, + id, + factor, + ) -> None: # self.id = id self.tag = "ColumnReduction:" + self.id @@ -274,28 +376,41 @@ def __init__(self, element_accumulator, element_reduction, self.element_reduction = element_reduction self.element_reduction_accumulator = element_reduction_accumulator self.factor = factor - + def get_epilogue_node(self, visitors): self.epilogue_node = ColumnReductionOp( - self.element_accumulator, self.element_reduction, - self.element_reduction_accumulator, *visitors) - + self.element_accumulator, + self.element_reduction, + self.element_reduction_accumulator, + *visitors, + ) + def get_batch_stride(self, problem_size): return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor) - + def get_argument(self, visitor_args, kwargs): - self.argument = self.epilogue_node.argument_type(kwargs[self.id + '_ptr'], *visitor_args, self.get_batch_stride(kwargs["problem_size"])) + self.argument = self.epilogue_node.argument_type( + kwargs[self.id + "_ptr"], + *visitor_args, + self.get_batch_stride(kwargs["problem_size"]), + ) + ################################################################################ # Epilogue parser function ################################################################################ class EpilogueAST(ast.NodeVisitor): - def __init__(self, epilogue, + def __init__( + self, + epilogue, tile_description, - element_accumulator, elements_per_access, - element_compute, element_output) -> None: + element_accumulator, + elements_per_access, + element_compute, + element_output, + ) -> None: # - + self.tile_description = tile_description self.element_accumulator = element_accumulator self.elements_per_access = elements_per_access @@ -307,7 +422,6 @@ def __init__(self, epilogue, self.ast_tree = ast.parse(self.source) self.epilogue_tree = Tree() - # print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose # input arguments @@ -332,26 +446,48 @@ def visit_Name(self, node): # accum is produced from accumulator node if node.id == "accum": name_node = AccumulatorNode( - self.element_accumulator, self.elements_per_access, node) + self.element_accumulator, + self.elements_per_access, + node, + ) else: # for input nodes if node.id in self.input_args.keys(): type = self.input_args[node.id][0] if type == "tensor": - name_node = TensorInputNode(self.element_accumulator, node) + name_node = TensorInputNode( + self.element_accumulator, + node, + ) elif type == "row": - name_node = RowBroadcastNode(self.element_accumulator, self.element_compute, node) + name_node = RowBroadcastNode( + self.element_accumulator, + self.element_compute, + node, + ) elif type == "column": - name_node = ColumnBroadcastNode(self.element_accumulator, self.element_compute, node) + name_node = ColumnBroadcastNode( + self.element_accumulator, + self.element_compute, + node, + ) elif type == "scalar": name_node = ScalarInputNode(node) else: raise ValueError(type) # for output nodes else: - name_node = TensorOutputNode(self.element_accumulator, node) - self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node, parent=self.stack[-1]) - + name_node = TensorOutputNode( + self.element_accumulator, + node, + ) + self.epilogue_tree.create_node( + name_node.tag, + name_node.id, + data=name_node, + parent=self.stack[-1], + ) + def visit_Assign(self, node): pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id) if pre_assign_node is None: @@ -364,23 +500,36 @@ def visit_Assign(self, node): func_type = node.value.func.value.id else: raise TypeError - if func_type == 'reduction_op': - self.reduction_source[node.value.args[0].id] = [node.value.args[1].value, node.value.args[2].value, node.targets[0].id] + if func_type == "reduction_op": + self.reduction_source[node.value.args[0].id] = [ + node.value.args[1].value, + node.value.args[2].value, + node.targets[0].id, + ] return name_node = TensorOutputNode(self.element_accumulator, node) - self.epilogue_tree.create_node(name_node.tag, name_node.id, data=name_node) + self.epilogue_tree.create_node( + name_node.tag, + name_node.id, + data=name_node, + ) self.stack.append(name_node.id) else: - if node.targets[0].id in self.returns or node.targets[0].id in self.reduction_source.keys(): + if ( + node.targets[0].id in self.returns + or node.targets[0].id in self.reduction_source.keys() + ): self.stack.append(node.targets[0].id) else: - self.stack.append(pre_assign_node.predecessor(self.epilogue_tree.identifier)) + self.stack.append( + pre_assign_node.predecessor(self.epilogue_tree.identifier) + ) self.epilogue_tree.remove_node(node.targets[0].id) - + # get child tag self.visit(node.value) self.stack.pop() - + def visit_Call(self, node): if isinstance(node.func, ast.Name): func_type = node.func.id @@ -393,7 +542,8 @@ def visit_Call(self, node): else: arg_list = [] for idx, arg in enumerate(node.args): - if idx == 0: continue + if idx == 0: + continue if isinstance(arg, ast.Constant): arg_list.append(arg.value) elif isinstance(arg, ast.Name): @@ -401,37 +551,60 @@ def visit_Call(self, node): else: raise TypeError - unary_node = UnaryNode(self.element_accumulator, self.element_compute, self.elements_per_access, node, arg_list) - self.epilogue_tree.create_node(unary_node.tag, unary_node.id, parent=self.stack[-1], data=unary_node) + unary_node = UnaryNode( + self.element_accumulator, + self.element_compute, + self.elements_per_access, + node, + arg_list, + ) + self.epilogue_tree.create_node( + unary_node.tag, + unary_node.id, + parent=self.stack[-1], + data=unary_node, + ) self.stack.append(unary_node.id) self.visit(node.args[0]) self.stack.pop() - + def visit_BinOp(self, node): - binop = BinOpNode(self.element_accumulator, self.element_compute, - self.elements_per_access, node) - self.epilogue_tree.create_node(binop.tag, binop.id, data=binop, parent=self.stack[-1]) + binop = BinOpNode( + self.element_accumulator, + self.element_compute, + self.elements_per_access, + node, + ) + self.epilogue_tree.create_node( + binop.tag, + binop.id, + data=binop, + parent=self.stack[-1], + ) self.stack.append(binop.id) self.visit(node.left) self.visit(node.right) self.stack.pop() - + def visit_Return(self, node): self.stack.append("return") self.visit(node.value) self.stack.pop() - + # # A function definition def visit_FunctionDef(self, node: ast.FunctionDef): # visit args for arg in node.args.args: - if arg.arg == "self": continue + if arg.arg == "self": + continue if isinstance(arg.annotation, ast.Constant): - self.input_args[arg.arg] = [arg.annotation.value, ] + self.input_args[arg.arg] = [ + arg.annotation.value, + ] # visit the assign in the reverse order for idx in range(len(node.body)): - self.visit(node.body[-1-idx]) - + self.visit(node.body[-1 - idx]) + # # Tree optimization pass # @@ -447,29 +620,39 @@ def pass_binary_2_unary(self, tree, nid): if left_type == "scalar" and right_type == "tensor": node.data = UnaryNode( - self.element_accumulator, self.element_compute, + self.element_accumulator, + self.element_compute, self.elements_per_access, - node.data, [lhs_node.data.id,]) + node.data, + [ + lhs_node.data.id, + ], + ) node.tag = node.data.tag tree.remove_node(lhs_node.data.id) self.pass_binary_2_unary(tree, rhs_node.data.id) - + elif left_type == "tensor" and right_type == "scalar": node.data = UnaryNode( - self.element_accumulator, self.element_compute, + self.element_accumulator, + self.element_compute, self.elements_per_access, - node.data, [rhs_node.id,]) + node.data, + [ + rhs_node.id, + ], + ) node.tag = node.data.tag tree.remove_node(rhs_node.data.id) self.pass_binary_2_unary(tree, lhs_node.data.id) - + else: self.pass_binary_2_unary(tree, lhs_node.data.id) self.pass_binary_2_unary(tree, rhs_node.data.id) else: for child in node.successors(tree.identifier): self.pass_binary_2_unary(tree, child) - + # pass 2: inject reduction nodes def pass_inject_reduction(self, tree, nid): node = tree.get_node(nid) @@ -477,14 +660,22 @@ def pass_inject_reduction(self, tree, nid): if node.data.id in self.reduction_source.keys(): direction = self.reduction_source[node.data.id][0] target = self.reduction_source[node.data.id][-1] - if direction == 'row': + if direction == "row": reduction_node = RowReductionNode( - self.element_accumulator, self.element_output, - self.element_accumulator, target, self.tile_description.threadblock_shape[1]) + self.element_accumulator, + self.element_output, + self.element_accumulator, + target, + self.tile_description.threadblock_shape[1], + ) elif direction == "column": reduction_node = ColumnReductionNode( - self.element_accumulator, self.element_output, - self.element_accumulator, target, self.tile_description.threadblock_shape[0]) + self.element_accumulator, + self.element_output, + self.element_accumulator, + target, + self.tile_description.threadblock_shape[0], + ) else: raise ValueError(direction) child_nid = node.successors(tree.identifier)[0] @@ -497,8 +688,16 @@ def pass_inject_reduction(self, tree, nid): # if this output node is also a tensor output, inject reduction as its children else: # get child node - tree.create_node(reduction_node.tag, reduction_node.id, data=reduction_node, parent=node.data.id) - tree.move_node(child_nid, reduction_node.id) + tree.create_node( + reduction_node.tag, + reduction_node.id, + data=reduction_node, + parent=node.data.id, + ) + tree.move_node( + child_nid, + reduction_node.id, + ) child = tree.get_node(child_nid) for grand_child in child.successors(tree.identifier): self.pass_inject_reduction(tree, grand_child) @@ -514,7 +713,7 @@ def pass_inject_epilogue_op(self, tree, nid): visitors = [] for child in node.successors(tree.identifier): visitors.append(self.pass_inject_epilogue_op(tree, child)) - + node.data.get_epilogue_node(visitors) return node.data.epilogue_node @@ -523,19 +722,27 @@ def get_arguments(self, tree, nid, kwargs): visitor_args = [] for child in node.successors(tree.identifier): visitor_args.append(self.get_arguments(tree, child, kwargs)) - + node.data.get_argument(visitor_args, kwargs) return node.data.argument + class EpilogueVisitTree: KernelTemplate = """ ${visitor} using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>; -""" - def __init__(self, elementwise_functor, tile_description, - element_accumulator, elements_per_access, - element_compute, element_output) -> None: +""" + + def __init__( + self, + elementwise_functor, + tile_description, + element_accumulator, + elements_per_access, + element_compute, + element_output, + ) -> None: # # data types self.tile_description = tile_description @@ -545,70 +752,126 @@ def __init__(self, elementwise_functor, tile_description, self.element_output = element_output self.elementwise_functor = elementwise_functor pass - + def initialize(self): - function = EpilogueAST(self, self.tile_description, - self.element_accumulator, self.elements_per_access, - self.element_compute, self.element_output) + function = EpilogueAST( + self, + self.tile_description, + self.element_accumulator, + self.elements_per_access, + self.element_compute, + self.element_output, + ) # tree = function.epilogue_tree self.tree = tree function.pass_binary_2_unary(self.tree, self.tree.root) function.pass_inject_reduction(self.tree, self.tree.root) - function.pass_inject_epilogue_op(self.tree,self.tree.root) + function.pass_inject_epilogue_op(self.tree, self.tree.root) visitor = self.tree.get_node(self.tree.root).data.epilogue_node self.visitor = visitor class _Argument(ctypes.Structure): _fields_ = [ - ("visitor_arg", visitor.argument_type) + ( + "visitor_arg", + visitor.argument_type, + ) ] + def __init__(self, **kwargs) -> None: # process input args _kwargs = {} for input_key in function.input_args.keys(): if input_key == "accum": continue - if function.input_args[input_key][0] == "scalar": + if function.input_args[input_key][0] == "scalar": continue # tensor input else: - setattr(self, "buffer_tensor_" + input_key, NumpyFrontend.argument(kwargs[input_key], False)) - setattr(self, input_key + "_ptr", int(getattr(self, "buffer_tensor_" + input_key).ptr)) - _kwargs[input_key+"_ptr"] = getattr(self, input_key + "_ptr") + setattr( + self, + "buffer_tensor_" + input_key, + NumpyFrontend.argument( + kwargs[input_key], + False, + ), + ) + setattr( + self, + input_key + "_ptr", + int( + getattr( + self, + "buffer_tensor_" + input_key, + ).ptr + ), + ) + _kwargs[input_key + "_ptr"] = getattr( + self, + input_key + "_ptr", + ) # process the return args for ret in function.returns: - setattr(self, "buffer_tensor_" + ret, NumpyFrontend.argument(kwargs[ret], True)) - setattr(self, ret + "_ptr", int(getattr(self, "buffer_tensor_" + ret).ptr)) - _kwargs[ret+"_ptr"] = getattr(self, ret + "_ptr") - setattr(self, "host_tensor_" + ret, kwargs[ret]) - + setattr( + self, + "buffer_tensor_" + ret, + NumpyFrontend.argument(kwargs[ret], True), + ) + setattr( + self, + ret + "_ptr", + int( + getattr( + self, + "buffer_tensor_" + ret, + ).ptr + ), + ) + _kwargs[ret + "_ptr"] = getattr(self, ret + "_ptr") + setattr( + self, + "host_tensor_" + ret, + kwargs[ret], + ) + _kwargs.update(kwargs) function.get_arguments(tree, tree.root, _kwargs) self.visitor_arg = tree.get_node(tree.root).data.argument - + def sync(self, stream_sync=True): if stream_sync: - err, = cudart.cudaDeviceSynchronize() + (err,) = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) - + for ret in function.returns: - err, = cuda.cuMemcpyDtoH( - getattr(self, "host_tensor_" + ret), cuda.CUdeviceptr(getattr(self, ret + "_ptr")), - getattr(self, "host_tensor_" + ret).size * getattr(self, "host_tensor_" + ret).itemsize + (err,) = cuda.cuMemcpyDtoH( + getattr( + self, + "host_tensor_" + ret, + ), + cuda.CUdeviceptr(getattr(self, ret + "_ptr")), + getattr( + self, + "host_tensor_" + ret, + ).size + * getattr( + self, + "host_tensor_" + ret, + ).itemsize, ) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) pass - + self.epilogue_type = _Argument - + def emit(self, operation): values = { - 'visitor': self.visitor.emit(operation), - 'operation_name': operation.procedural_name(), - 'visitor_name': self.visitor.instance_name + "visitor": self.visitor.emit(operation), + "operation_name": operation.procedural_name(), + "visitor_name": self.visitor.instance_name, } return SubstituteTemplate(self.KernelTemplate, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py b/python/cutlass/backend/reduction_operation.py similarity index 62% rename from tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py rename to python/cutlass/backend/reduction_operation.py index d6d2d6a4..0542f6c5 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +++ b/python/cutlass/backend/reduction_operation.py @@ -29,18 +29,28 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ -from pycutlass import * -from pycutlass.c_types import get_reduction_params -import cutlass -from cuda import cuda -try: - import torch - torch_available = True -except ImportError: - torch_available = False -import numpy as np + + from typing import Union -from cuda import cudart + +import ctypes +from cuda import cuda, cudart +import cutlass_bindings +import numpy as np + +from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params +from cutlass.backend.frontend import NumpyFrontend, TorchFrontend +from cutlass.backend.library import ( + DataTypeNames, + DataTypeSize, + DataTypeTag, + TensorDescription, +) +from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate + +if CheckPackages().check_torch(): + import torch class ReductionOperation: @@ -52,12 +62,16 @@ class ReductionArguments: Arguments of reduction """ - def __init__(self, operation: ReductionOperation, - problem_size: 'list[int]', partitions: int, - workspace: cuda.CUdeviceptr, - destination: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', - source: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', **kwargs) -> None: - + def __init__( + self, + operation: ReductionOperation, + problem_size: "list[int]", + partitions: int, + workspace: cuda.CUdeviceptr, + destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + **kwargs, + ) -> None: # tensor_C can be interpreted as the bias with bias=True in keyword args if "bias" in kwargs.keys(): self.bias = kwargs["bias"] @@ -76,10 +90,9 @@ def __init__(self, operation: ReductionOperation, self.host_D = destination self.destination_buffer = NumpyFrontend.argument(destination, True) self.source_buffer = NumpyFrontend.argument(source, False) - self.ptr_destination = cuda.CUdeviceptr( - self.destination_buffer.ptr) + self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr) self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr) - elif torch_available and isinstance(destination, torch.Tensor): + elif CheckPackages().check_torch() and isinstance(destination, torch.Tensor): self.ptr_destination = TorchFrontend.argument(destination) self.ptr_source = TorchFrontend.argument(source) elif isinstance(destination, cuda.CUdeviceptr): @@ -88,15 +101,14 @@ def __init__(self, operation: ReductionOperation, else: raise TypeError("unknown Type") - self.problem_size = MatrixCoord_( - problem_size[0], problem_size[1] - ) + self.problem_size = MatrixCoord_(problem_size[0], problem_size[1]) - self.partition_stride = problem_size[0] * \ - problem_size[1] * DataTypeSize[operation.C.element] // 8 + self.partition_stride = ( + problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8 + ) if "output_op" in kwargs.keys(): - self.output_op = kwargs['output_op'] + self.output_op = kwargs["output_op"] else: self.output_op = self.operation.epilogue_type(1.0, 0.0) @@ -104,49 +116,74 @@ def __init__(self, operation: ReductionOperation, self.get_arguments() @staticmethod - def get_tensor_ref(extent: 'tuple[int]', device_ptr: cuda.CUdeviceptr, layout: cutlass.layout): - if layout == cutlass.RowMajor: + def get_tensor_ref( + extent: "tuple[int]", + device_ptr: cuda.CUdeviceptr, + layout: cutlass_bindings.layout, + ): + if layout == cutlass_bindings.RowMajor: return TensorRef2D_(int(device_ptr), extent[1]) else: raise ValueError("unknown layout type") def get_arguments(self): ref_workspace = ReductionArguments.get_tensor_ref( - extent=[self.problem_size.row, self.problem_size.column], - device_ptr=self.ptr_workspace, layout=cutlass.RowMajor) + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_workspace, + layout=cutlass_bindings.RowMajor, + ) if self.bias: ref_source = ReductionArguments.get_tensor_ref( extent=[0, 0], - device_ptr=self.ptr_source, layout=cutlass.RowMajor) + device_ptr=self.ptr_source, + layout=cutlass_bindings.RowMajor, + ) else: ref_source = ReductionArguments.get_tensor_ref( - extent=[self.problem_size.row, self.problem_size.column], - device_ptr=self.ptr_source, layout=cutlass.RowMajor) + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_source, + layout=cutlass_bindings.RowMajor, + ) ref_destination = ReductionArguments.get_tensor_ref( - extent=[self.problem_size.row, self.problem_size.column], - device_ptr=self.ptr_destination, layout=cutlass.RowMajor) - + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_destination, + layout=cutlass_bindings.RowMajor, + ) self.c_arguments = self.operation.argument_type( - self.problem_size, self.partitions, - self.partition_stride, ref_workspace, - ref_destination, ref_source, - self.output_op + self.problem_size, + self.partitions, + self.partition_stride, + ref_workspace, + ref_destination, + ref_source, + self.output_op, ) - params_ = self.operation.rt_module.get_args( - ctypes.byref(self.c_arguments)) + params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments)) self.host_workspace = bytearray(params_.contents) def sync(self): - err, = cudart.cudaDeviceSynchronize() + (err,) = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) if hasattr(self, "host_D"): - err, = cuda.cuMemcpyDtoH( - self.host_D, self.ptr_destination, self.host_D.size * self.host_D.itemsize) + (err,) = cuda.cuMemcpyDtoH( + self.host_D, + self.ptr_destination, + self.host_D.size * self.host_D.itemsize, + ) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) @@ -161,7 +198,8 @@ class ReductionRT(ExecutableOperation): """ ReductionRT manages the CUTLASS runtime components for reduction """ - KernelTemplate = r''' + + KernelTemplate = r""" extern "C" __global__ void ${operation_name}(${operation_name}${operation_suffix}::Params params) { @@ -177,8 +215,8 @@ class ReductionRT(ExecutableOperation): op(params, *shared_storage); } - ''' - HostTemplate = r''' + """ + HostTemplate = r""" extern "C" { // Get the size of params in bytes int ${operation_name}_get_param_size(){ @@ -200,40 +238,51 @@ class ReductionRT(ExecutableOperation): return output; } } - ''' + """ def __init__(self, operation: ReductionOperation): super().__init__(operation) self.operation: ReductionOperation = operation - self.emitter = EmitReductionInstance('_type') + self.emitter = EmitReductionInstance("_type") self.elements_per_access = self.operation.count - self.argument_type, self.epilogue_type = get_reduction_params(operation.epilogue_functor) + ( + self.argument_type, + self.epilogue_type, + ) = get_reduction_params(operation.epilogue_functor) self.argtype = [ctypes.POINTER(self.argument_type)] def emit(self): return self.emitter.emit(self.operation) def plan(self, arguments: ReductionArguments): - block_shape = [self.operation.shape.column( - ) // self.elements_per_access, self.operation.shape.row(), 1] + block_shape = [ + self.operation.shape.column() // self.elements_per_access, + self.operation.shape.row(), + 1, + ] grid_shape = [ - (arguments.problem_size.row + self.operation.shape.row() - - 1) // self.operation.shape.row(), - (arguments.problem_size.column + self.operation.shape.column() - - 1) // self.operation.shape.column(), - 1 + (arguments.problem_size.row + self.operation.shape.row() - 1) + // self.operation.shape.row(), + (arguments.problem_size.column + self.operation.shape.column() - 1) + // self.operation.shape.column(), + 1, ] - return LaunchConfiguration(grid_shape, block_shape, self.shared_memory_capacity) + return LaunchConfiguration( + grid_shape, + block_shape, + self.shared_memory_capacity, + ) def initialize(self): - err, = cuda.cuFuncSetAttribute( + (err,) = cuda.cuFuncSetAttribute( self.kernel, attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - value=self.shared_memory_capacity) + value=self.shared_memory_capacity, + ) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('Cuda Error: {}'.format(err)) + raise RuntimeError("Cuda Error: {}".format(err)) class ReductionOperation: @@ -244,12 +293,18 @@ class ReductionOperation: r """ - def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription, - element_accumulator, element_workspace=None, - element_compute=None, epilogue_functor=None, - count: int = 1, partitions_per_stage: int = 4) -> None: - """ Constructor - """ + def __init__( + self, + shape: cutlass_bindings.MatrixCoord, + C: TensorDescription, + element_accumulator, + element_workspace=None, + element_compute=None, + epilogue_functor=None, + count: int = 1, + partitions_per_stage: int = 4, + ) -> None: + """Constructor""" self.shape = shape #: epilogue functor (default: LinearCombination) @@ -291,36 +346,38 @@ def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription, def extended_name(self): extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}" - return SubstituteTemplate(extend_name, - { - 'element_workspace': DataTypeNames[self.element_workspace], - 'element_accumulator': DataTypeNames[self.element_accumulator], - 'element_compute': DataTypeNames[self.element_compute], - 'element_output': DataTypeNames[self.element_output] - }) + return SubstituteTemplate( + extend_name, + { + "element_workspace": DataTypeNames[self.element_workspace], + "element_accumulator": DataTypeNames[self.element_accumulator], + "element_compute": DataTypeNames[self.element_compute], + "element_output": DataTypeNames[self.element_output], + }, + ) # def configuration_name(self): - ''' The full procedural name indicates architecture, extended name, tile size''' + """The full procedural name indicates architecture, extended name, tile size""" configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}" threadblock = "%dx%d" % ( self.shape.row(), - self.shape.column() + self.shape.column(), ) return SubstituteTemplate( configuration_name, { - 'extended_name': self.extended_name(), - 'threadblock': threadblock - } + "extended_name": self.extended_name(), + "threadblock": threadblock, + }, ) # def procedural_name(self): - ''' The full procedural name indicates architecture, extended name, tile size''' + """The full procedural name indicates architeture, extended name, tile size""" return self.configuration_name() def run(self, arguments: ReductionArguments) -> cuda.CUresult: @@ -336,16 +393,19 @@ def run(self, arguments: ReductionArguments) -> cuda.CUresult: # launch the kernel err = self.rt_module.run( - host_workspace, device_workspace, launch_config) + host_workspace, + device_workspace, + launch_config, + ) if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) + raise RuntimeError("CUDA Error %s" % str(err)) return err class EmitReductionInstance: - def __init__(self, operation_suffix='') -> None: + def __init__(self, operation_suffix="") -> None: self.operation_suffix = operation_suffix self.includes = [ "cutlass/cutlass.h", @@ -357,7 +417,7 @@ def __init__(self, operation_suffix='') -> None: "cutlass/gemm/device/gemm_universal_adapter.h", "cutlass/gemm/kernel/default_gemm_universal.h", "cutlass/reduction/kernel/reduce_split_k.h", - "cutlass/reduction/thread/reduction_operators.h" + "cutlass/reduction/thread/reduction_operators.h", ] self.template = """ // Reduction kernel instance @@ -376,23 +436,27 @@ def __init__(self, operation_suffix='') -> None: """ def emit(self, operation: ReductionOperation): - - epilogue_vector_length = int(min( - operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + epilogue_vector_length = int( + min( + operation.C.alignment * DataTypeSize[operation.C.element], + 128, + ) + / DataTypeSize[operation.C.element] + ) values = { - 'operation_name': operation.configuration_name(), - 'operation_suffix': self.operation_suffix, - 'shape_row': str(operation.shape.row()), - 'shape_column': str(operation.shape.column()), - 'epilogue_functor': operation.epilogue_functor.emit(), - 'element_output': DataTypeTag[operation.element_output], - 'epilogue_vector_length': str(epilogue_vector_length), - 'element_accumulator': DataTypeTag[operation.element_accumulator], - 'element_compute': DataTypeTag[operation.element_compute], - 'element_workspace': DataTypeTag[operation.element_workspace], - 'count': str(operation.count), - 'partition_per_stage': str(operation.partitions_per_stage) + "operation_name": operation.configuration_name(), + "operation_suffix": self.operation_suffix, + "shape_row": str(operation.shape.row()), + "shape_column": str(operation.shape.column()), + "epilogue_functor": operation.epilogue_functor.emit(), + "element_output": DataTypeTag[operation.element_output], + "epilogue_vector_length": str(epilogue_vector_length), + "element_accumulator": DataTypeTag[operation.element_accumulator], + "element_compute": DataTypeTag[operation.element_compute], + "element_workspace": DataTypeTag[operation.element_workspace], + "count": str(operation.count), + "partition_per_stage": str(operation.partitions_per_stage), } return SubstituteTemplate(self.template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py b/python/cutlass/backend/tensor_ref.py similarity index 83% rename from tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py rename to python/cutlass/backend/tensor_ref.py index 733232c3..9f7aa9da 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +++ b/python/cutlass/backend/tensor_ref.py @@ -30,33 +30,32 @@ # ################################################################################ -from typeguard import typechecked -import numpy as np -try: - import torch - torch_available = True -except ImportError: - torch_available = False from cuda import cuda -try: +import cutlass_bindings +import numpy as np + +from cutlass.backend.utils.software import CheckPackages + +cupy_available = CheckPackages().check_cupy() +if cupy_available: import cupy as cp - cupy_available = True -except ImportError: - cupy_available = False -import cutlass + +torch_available = CheckPackages().check_torch() +if torch_available: + import torch -# @typechecked class TensorRef: """ - Python Wrapper for cutlass.TensorRef + Python Wrapper for cutlass_bindings.TensorRef """ + def __init__(self, tensor, dtype, layout) -> None: if isinstance(tensor, np.ndarray): - ptr = cuda.CUdeviceptr(tensor.__array_interface__['data'][0]) + ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0]) elif torch_available and isinstance(tensor, torch.Tensor): ptr = cuda.CUdeviceptr(tensor.data_ptr()) - elif cupy_available and isinstance(tensor, cp.ndarray): + elif torch_available and isinstance(tensor, cp.ndarray): ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) elif isinstance(tensor, cuda.CUdeviceptr): ptr = tensor @@ -64,7 +63,7 @@ def __init__(self, tensor, dtype, layout) -> None: ptr = cuda.CUdeviceptr(tensor) else: raise NotImplementedError(tensor) - - # the dtype(0) is used to overload between different data types + + # the dtype(0) is used to overload between different data types # with the same layout - self.tensor_ref = cutlass.get_tensor_ref(int(ptr), dtype(0), layout) + self.tensor_ref = cutlass_bindings.get_tensor_ref(int(ptr), dtype(0), layout) diff --git a/tools/library/scripts/pycutlass/build.sh b/python/cutlass/backend/test/__init__.py similarity index 88% rename from tools/library/scripts/pycutlass/build.sh rename to python/cutlass/backend/test/__init__.py index 5dbda5d5..03f54cec 100644 --- a/tools/library/scripts/pycutlass/build.sh +++ b/python/cutlass/backend/test/__init__.py @@ -1,4 +1,4 @@ -################################################################################################# +################################################################################ # # Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause @@ -28,9 +28,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -################################################################################################# +################################################################################ -pip install -U pybind11 -git clone https://github.com/google/googletest.git -python setup.py develop --user -python setup.py rmm +from cutlass.backend.test.conv2d_testbed import * +from cutlass.backend.test.gemm_grouped_testbed import * +from cutlass.backend.test.gemm_testbed import * +from cutlass.backend.test.profiler import * diff --git a/python/cutlass/backend/test/conv2d_testbed.py b/python/cutlass/backend/test/conv2d_testbed.py new file mode 100644 index 00000000..b6c96055 --- /dev/null +++ b/python/cutlass/backend/test/conv2d_testbed.py @@ -0,0 +1,783 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import re +import subprocess +from time import sleep + +from bfloat16 import bfloat16 +import cutlass_bindings +import numpy as np + +from cutlass.backend.compiler import ArtifactManager +from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation +from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport +from cutlass.backend.memory_manager import get_allocated_size +from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass.backend.test.profiler import GpuTimer +from cutlass.backend.utils.software import SubstituteTemplate + + +def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand): + ptr = tensor.__array_interface__["data"][0] + if operand == "a": + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent( + conv_kind, problem_size + ) + elif operand == "b": + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent( + conv_kind, problem_size + ) + elif operand in ["c", "d"]: + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ) + else: + raise ValueError("unknown operand: " + operand) + + layout = tensor_layout.packed(tensor_coord) + + if tensor.dtype == np.float64: + return cutlass_bindings.TensorRefF64NHWC(ptr, layout) + elif tensor.dtype == np.float32: + return cutlass_bindings.TensorRefF32NHWC(ptr, layout) + elif tensor.dtype == np.float16: + return cutlass_bindings.TensorRefF16NHWC(ptr, layout) + if tensor.dtype == bfloat16: + return cutlass_bindings.TensorRefBF16NHWC(ptr, layout) + elif tensor.dtype == np.int32: + return cutlass_bindings.TensorRefS32NHWC(ptr, layout) + elif tensor.dtype == np.int8: + if tensor_layout == cutlass_bindings.TensorNC32HW32: + return cutlass_bindings.TensorRefS8NC32HW32(ptr, layout) + elif tensor_layout == cutlass_bindings.TensorC32RSK32: + return cutlass_bindings.TensorRefS8C32RSK32(ptr, layout) + else: + return cutlass_bindings.TensorRefS8NHWC(ptr, layout) + else: + raise ValueError("unsupported data type") + + +def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand): + tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand) + + if operand == "a": + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_a_extent( + conv_kind, problem_size + ) + elif operand == "b": + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_b_extent( + conv_kind, problem_size + ) + elif operand in ["c", "d"]: + tensor_coord = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ) + else: + raise ValueError("unknown operand: " + operand) + + if tensor.dtype == np.float64: + return cutlass_bindings.TensorViewF64NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.float32: + return cutlass_bindings.TensorViewF32NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.float16: + return cutlass_bindings.TensorViewF16NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == bfloat16: + return cutlass_bindings.TensorViewBF16NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.int32: + return cutlass_bindings.TensorViewS32NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.int8: + if tensor_layout == cutlass_bindings.TensorNC32HW32: + return cutlass_bindings.TensorViewS8NC32HW32(tensor_ref, tensor_coord) + elif tensor_layout == cutlass_bindings.TensorC32RSK32: + return cutlass_bindings.TensorViewS8C32RSK32(tensor_ref, tensor_coord) + else: + return cutlass_bindings.TensorViewS8NHWC(tensor_ref, tensor_coord) + + else: + raise ValueError("unsupported data type") + + +# @typechecked +class Conv2dLauncher: + """ + Launcher that runs the operation on given problem size + """ + + def __init__( + self, + operation: "Conv2dOperation", + seed: int = 2080, + interleaved=False, + verification=True, + profiling=False, + warmup_iterations=500, + iterations=500, + **kwargs, + ) -> None: + self.enable_cached_results = True + self.interleaved = interleaved + + # create the reduction kernel + self.reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, + ) + + #: verify the output result + self.verification = verification + #: profile the kernel's runtime + self.profiling = profiling + + self.timer = GpuTimer() + + self.warmup_iterations = warmup_iterations + self.iterations = iterations + + if "sleep" in kwargs.keys(): + self.sleep_time = kwargs["sleep"] + else: + self.sleep_time = 0 + + # + # Compile the operator + # + + ArtifactManager().add_module([operation, self.reduction_operation]) + + self.operation = operation + + self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element) + self.layout_A = operation.A.layout + self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element) + self.layout_B = operation.B.layout + self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element) + self.layout_C = operation.C.layout + self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element) + self.layout_D = operation.C.layout + + accumulator_size = DataTypeSize[ + operation.tile_description.math_instruction.element_accumulator + ] + element_size = DataTypeSize[operation.A.element] + + if element_size <= 8: + self.scope = 1 + elif element_size == 16: + if accumulator_size <= 16: + self.scope = 2 + else: + self.scope = 4 + else: + self.scope = 7 + + # Seed + self.seed = seed + + self.conv_kind = operation.conv_kind + + # + # Get the host reference function + # + + self.element_compute = operation.epilogue_functor.element_epilogue + + self.host_conv2d = cutlass_bindings.test.conv.host.conv2d + + self.timer = GpuTimer() + + @staticmethod + def numpy_type(type): + if type == cutlass_bindings.float64: + return np.float64 + elif type == cutlass_bindings.float32: + return np.float32 + elif type == cutlass_bindings.float16: + return np.float16 + elif type == cutlass_bindings.bfloat16: + return bfloat16 + elif type == cutlass_bindings.int32: + return np.int32 + elif type == cutlass_bindings.int8: + return np.int8 + else: + raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) + + def print_problem_size(self, p, split_k_mode=1): + print( + "nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d" + % ( + p.N, + p.H, + p.W, + p.C, + p.K, + p.R, + p.S, + p.C, + p.pad_h, + p.pad_w, + p.stride_h, + p.stride_w, + p.dilation_h, + p.dilation_w, + p.split_k_slices, + split_k_mode, + ) + ) + + def uniform_init(self, size, dtype): + if dtype in [np.float32, np.float16, bfloat16, np.float64]: + return np.ceil( + np.random.uniform( + low=-self.scope - 0.5, high=self.scope - 0.5, size=size + ).astype(dtype) + ) + else: + return np.random.uniform( + low=-self.scope - 1, high=self.scope + 1, size=size + ).astype(dtype) + + def eq_gemm_size(self, problem_size): + n = problem_size.N + p = problem_size.P + q = problem_size.Q + k = problem_size.K + r = problem_size.R + s = problem_size.S + c = problem_size.C + h = problem_size.H + w = problem_size.W + if self.conv_kind == cutlass_bindings.conv.Operator.fprop: + return cutlass_bindings.gemm.GemmCoord(n * p * q, k, r * s * c) + elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + return cutlass_bindings.gemm.GemmCoord(n * h * w, c, k * r * s) + else: + return cutlass_bindings.gemm.GemmCoord(k, r * s * c, n * p * q) + + def bytes(self, problem_size, alpha, beta): + mnk = self.eq_gemm_size(problem_size) + + bytes_ = ( + (DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k() + + (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k() + + (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n() + ) + + if beta != 0: + bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n() + + return bytes_ + + def flops(self, problem_size): + mnk = self.eq_gemm_size(problem_size) + + flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2 + flops_epilogue_ = mnk.m() * mnk.n() * 2 + + # Adjust mainloop flop for dgrad stride + if self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + flops_mainloop_ = flops_mainloop_ // ( + problem_size.stride_h * problem_size.stride_w + ) + + flops_total_ = flops_mainloop_ + flops_epilogue_ + + return flops_total_ + + def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + if self.element_compute == cutlass_bindings.float16: + alpha = cutlass_bindings.float16(alpha) + beta = cutlass_bindings.float16(beta) + elif self.element_compute == cutlass_bindings.int32: + alpha = int(alpha) + beta = int(beta) + else: + alpha = alpha + beta = beta + + # if cached result is loaded + cached_result_loaded = False + + if self.enable_cached_results: + # get problem key + cached_test_key = cutlass_bindings.test.conv.host.CreateCachedConv2dTestKey( + self.conv_kind, + problem_size, + alpha, + beta, + getTensorView( + tensor_A, self.layout_A, self.conv_kind, problem_size, "a" + ), + getTensorView( + tensor_B, self.layout_B, self.conv_kind, problem_size, "b" + ), + getTensorView( + tensor_C, self.layout_C, self.conv_kind, problem_size, "c" + ), + ) + + cached_test_result = cutlass_bindings.test.conv.host.CachedTestResult() + + conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % ( + self.operation.arch, + self.seed, + ) + + cached_results = cutlass_bindings.test.conv.host.CachedTestResultListing( + conv2d_result_cache_name + ) + # CachedTestResultListing cached_results(conv2d_result_cache_name); + cached = cached_results.find(cached_test_key) + cached_result_loaded = cached[0] + if cached_result_loaded: + cached_test_result = cached[1] + + if not cached_result_loaded: + # compute the conv2d on host + tensor_D_ref = np.ones_like(tensor_C) + tensor_ref_A = getTensorRef( + tensor_A, self.layout_A, self.conv_kind, problem_size, "a" + ) + tensor_ref_B = getTensorRef( + tensor_B, self.layout_B, self.conv_kind, problem_size, "b" + ) + tensor_ref_C = getTensorRef( + tensor_C, self.layout_C, self.conv_kind, problem_size, "c" + ) + tensor_ref_D_ref = getTensorRef( + tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d" + ) + + self.host_conv2d( + self.conv_kind, + problem_size, + tensor_ref_A, + tensor_ref_B, + tensor_ref_C, + tensor_ref_D_ref, + alpha, + beta, + ) + + tensor_view_D_ref = getTensorView( + tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d" + ) + + if self.enable_cached_results: + cached_test_result.D = cutlass_bindings.test.conv.host.TensorHash( + tensor_view_D_ref + ) + cached_results = ( + cutlass_bindings.test.conv.host.CachedTestResultListing( + conv2d_result_cache_name + ) + ) + cached_results.append(cached_test_key, cached_test_result) + cached_results.write(conv2d_result_cache_name) + else: + return tensor_D_ref + + return cached_test_result.D + + def equal(self, tensor_D, tensor_D_ref, problem_size): + if self.enable_cached_results: + tensor_view_D = getTensorView( + tensor_D, self.layout_D, self.conv_kind, problem_size, "d" + ) + tensor_D_hash = cutlass_bindings.test.conv.host.TensorHash(tensor_view_D) + + return tensor_D_hash == tensor_D_ref + else: + tensor_view_D = getTensorView( + tensor_D, self.layout_D, self.conv_kind, problem_size, "d" + ) + tensor_view_D_ref = getTensorView( + tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d" + ) + return cutlass_bindings.test.conv.host.equals( + tensor_view_D, tensor_view_D_ref + ) + + def run_cutlass_profiler( + self, + problem_size, + split_k_mode=cutlass_bindings.conv.SplitKMode.Serial, + alpha=1.0, + beta=0.0, + ): + if split_k_mode == cutlass_bindings.conv.SplitKMode.Serial: + split_k_mode_ = "serial" + else: + split_k_mode_ = "parallel" + + cutlass_path = os.getenv("CUTLASS_PATH") + assert ( + cutlass_path is not None + ), "Environment variable 'CUTLASS_PATH' is not defined." + + values = { + "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler", + "kernel_name": self.operation.procedural_name(), + "verification_providers": "device", + "provider": "cutlass", + "n": str(problem_size.N), + "h": str(problem_size.H), + "w": str(problem_size.W), + "c": str(problem_size.C), + "k": str(problem_size.K), + "r": str(problem_size.R), + "s": str(problem_size.S), + "p": str(problem_size.P), + "q": str(problem_size.Q), + "pad_h": str(problem_size.pad_h), + "pad_w": str(problem_size.pad_w), + "stride_h": str(problem_size.stride_h), + "stride_w": str(problem_size.stride_w), + "dilation_h": str(problem_size.dilation_h), + "dilation_w": str(problem_size.dilation_w), + "split_k_slices": str(problem_size.split_k_slices), + "split_k_mode": split_k_mode_, + "alpha": str(alpha), + "beta": str(beta), + "warmup": str(self.warmup_iterations), + "profile": str(self.iterations), + } + + cmd_template = ( + "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" + " --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}" + " --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}" + " --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}" + " --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}" + ) + + cmd = SubstituteTemplate(cmd_template, values) + result = subprocess.getoutput(cmd) + + m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) + runtime = float(m.group("runtime")) + + m = re.search(r"Bytes:\s+(?P\d+)", result) + bytes = int(m.group("bytes")) + + m = re.search(r"FLOPs:\s+(?P\d+)", result) + flops = int(m.group("flops")) + + # check if the problem size matches + assert bytes == self.bytes(problem_size, alpha, beta) + assert flops == self.flops(problem_size) + + return runtime + + def run( + self, + problem_size, + split_k_mode=cutlass_bindings.conv.SplitKMode.Serial, + alpha=1.0, + beta=0.0, + ): + assert get_allocated_size() == 0, ( + "%d byte of pool memory is not released in previous run" + % get_allocated_size() + ) + + # + # Initialize input and output tensors + # + tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size( + self.conv_kind, problem_size + ) + tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size( + self.conv_kind, problem_size + ) + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + self.conv_kind, problem_size + ) + + np.random.seed(self.seed) + + tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A) + tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B) + tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C) + tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D) + + # + # Launch kernel + # + + arguments = Conv2dArguments( + operation=self.operation, + problem_size=problem_size, + A=tensor_A, + B=tensor_B, + C=tensor_C, + D=tensor_D, + output_op=self.operation.epilogue_type(alpha, beta), + split_k_slices=problem_size.split_k_slices, + split_k_mode=split_k_mode, + ) + + if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: + implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size( + self.operation.conv_kind, arguments.problem_size + ) + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op=self.reduction_operation.epilogue_type(alpha, beta), + ) + + self.operation.run(arguments) + if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + if self.verification: + if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: + reduction_arguments.sync() + else: + arguments.sync() + + tensor_D_ref = self.host_reference( + problem_size, tensor_A, tensor_B, tensor_C, alpha, beta + ) + + passed = self.equal(tensor_D, tensor_D_ref, problem_size) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, split_k_mode) + + if self.profiling: + sleep(self.sleep_time) + for _ in range(self.warmup_iterations): + self.operation.run(arguments) + if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: + self.reduction_operation.run(reduction_arguments) + + self.timer.start() + for _ in range(self.warmup_iterations): + self.operation.run(arguments) + if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: + self.reduction_operation.run(reduction_arguments) + self.timer.stop_and_wait() + runtime = self.timer.duration(self.iterations) + + # free memory + del arguments + if split_k_mode == cutlass_bindings.conv.SplitKMode.Parallel: + del reduction_arguments + + assert get_allocated_size() == 0, ( + "%d byte of pool memory is not released after current run" + % get_allocated_size() + ) + if self.profiling: + return runtime + return passed + + +######################################################################################################## +# TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +# TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +# Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +# (conv_blacklist_sizes) +############################################################################################################ + + +def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=False): + passed = True + # + # Testbed object + # + + testbed = Conv2dLauncher(operation, interleaved=interleaved) + + # + # Get conv problem sizes to run conv operator + # + + conv_problems = cutlass_bindings.test.conv.TestbedConv2dProblemSizes(64) + + # Vector of conv2d problem sizes to avoid duplicate runs + conv_tested_sizes = [] + + # Flatten 2D problem_vectors into a 1D problem sizes + problem_sizes = conv_problems.conv2d_default_sizes + + problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes + + # Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0) + for conv_problem in problem_sizes: + if conv_problem in conv_tested_sizes: + continue + + # skip channel dimension % 32 != 0 for interleaved case + if interleaved: + if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0: + continue + + # + # Procedurally disable certain cases + # + + # CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ( + operation.conv_kind == cutlass_bindings.conv.Operator.dgrad + and operation.stride_support == StrideSupport.Unity + ): + if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)): + continue + + if not interleaved: + # Fixed channels algorithm requires channel count to match access size + if ( + operation.iterator_algorithm + == cutlass_bindings.conv.IteratorAlgorithm.fixed_channels + ): + if conv_problem.C != operation.A.alignment: + continue + + # Few channels algorithm requires channel count to match access size + if ( + operation.iterator_algorithm + == cutlass_bindings.conv.IteratorAlgorithm.few_channels + ): + if conv_problem.C % operation.A.alignment: + continue + + # CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + # Although strided dgrad works for all stride combinations, we are only going + # to run strided dgrad for non-unity strides + + if ( + operation.conv_kind == cutlass_bindings.conv.Operator.dgrad + and operation.stride_support == StrideSupport.Strided + ): + if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1): + continue + + # + # Test + # + + # push back tested problem size to avoid re-running duplicates + conv_tested_sizes.append(conv_problem) + + passed = testbed.run(conv_problem) + + if not passed: + return False + + if interleaved: + return True + # + # filter the cases for split K + # + + # Small-channels convolution can't run here. + if operation.iterator_algorithm in [ + cutlass_bindings.conv.IteratorAlgorithm.fixed_channels, + cutlass_bindings.conv.IteratorAlgorithm.few_channels, + ]: + return True + + # CUTLASS DGRAD's *stride* specialization does not support split-k mode + if ( + operation.conv_kind == cutlass_bindings.conv.Operator.dgrad + and operation.stride_support == StrideSupport.Strided + ): + conv_problem = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 56, 56, 8), + cutlass_bindings.Tensor4DCoord(8, 1, 1, 8), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, + 1, + ) + passed = testbed.run(conv_problem) + + return passed + + # Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + # a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + # which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep + # alpha and beta for local testing, but only runs one value for alpha and beta. + + conv2d_split_k_test_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 17, 11, 288), + cutlass_bindings.Tensor4DCoord(160, 3, 3, 288), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, + 1, + ) + + split_k_modes = [ + cutlass_bindings.conv.SplitKMode.Parallel, + cutlass_bindings.conv.SplitKMode.Serial, + ] + + split_k_slices = [1, 2, 3, 4, 201] + problem_alpha = [ + 2.0, + ] + problem_beta = [ + 2.0, + ] + + for split_k_mode in split_k_modes: + for split_k_slice in split_k_slices: + for alpha in problem_alpha: + for beta in problem_beta: + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + alpha, + beta, + ) + + return passed diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py b/python/cutlass/backend/test/gemm_grouped_testbed.py similarity index 66% rename from tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py rename to python/cutlass/backend/test/gemm_grouped_testbed.py index 6cf14f32..95f22a0c 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +++ b/python/cutlass/backend/test/gemm_grouped_testbed.py @@ -30,18 +30,20 @@ # ################################################################################################# -import pycutlass -from pycutlass.test.gemm_testbed import getTensorRef, getTensorView, transpose -from pycutlass import * -import numpy as np -import cutlass from bfloat16 import bfloat16 +import cutlass_bindings +import numpy as np + +from cutlass.backend import compiler +from cutlass.backend.gemm_operation import GemmGroupedArguments, GemmOperationGrouped +from cutlass.backend.library import DataTypeSize, ShortDataTypeNames +from cutlass.backend.memory_manager import get_allocated_size +from cutlass.backend.test.gemm_testbed import getTensorRef, getTensorView, transpose class TestbedGrouped: def __init__(self, operation: GemmOperationGrouped, seed: int = 2080) -> None: - - pycutlass.compiler.add_module([operation]) + compiler.add_module([operation]) self.seed = seed @@ -70,21 +72,23 @@ def __init__(self, operation: GemmOperationGrouped, seed: int = 2080) -> None: #: compute type self.compute_type = operation.epilogue_functor.element_epilogue - self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + self.accumulator_type = ( + operation.tile_description.math_instruction.element_accumulator + ) @staticmethod def numpy_type(type): - if type == cutlass.float64: + if type == cutlass_bindings.float64: return np.float64 - elif type == cutlass.float32: + elif type == cutlass_bindings.float32: return np.float32 - elif type == cutlass.float16: + elif type == cutlass_bindings.float16: return np.float16 - elif type == cutlass.bfloat16: + elif type == cutlass_bindings.bfloat16: return bfloat16 - elif type == cutlass.int32: + elif type == cutlass_bindings.int32: return np.int32 - elif type == cutlass.int8: + elif type == cutlass_bindings.int8: return np.int8 else: raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) @@ -93,24 +97,26 @@ def uniform_init(self, size, dtype): if dtype in [np.float32, np.float16, bfloat16, np.float64]: return np.ceil( np.random.uniform( - low=self.scope_min - 0.5, high=self.scope_max - 0.5, - size=size).astype(dtype) + low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size + ).astype(dtype) ) else: return np.random.uniform( - low=self.scope_min - 1, high=self.scope_max + 1, - size=size).astype(dtype) + low=self.scope_min - 1, high=self.scope_max + 1, size=size + ).astype(dtype) def print_problem_size(self, p): problem_size = "problem: %d, %d, %d\n" % (p.m(), p.n(), p.k()) print(problem_size) def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool: - - assert get_allocated_size( - ) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() + assert get_allocated_size() == 0, ( + "%d byte of pool memory is not released in previous run" + % get_allocated_size() + ) # initialize + passed = False np.random.seed(self.seed) # generate the problem sizes @@ -124,59 +130,61 @@ def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool for i in range(problem_count): if self.dtype_A == np.int8: if i == 0: - problem_size = cutlass.gemm.GemmCoord(48, 16, 32) + problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 32) else: - problem_size = cutlass.gemm.GemmCoord( + problem_size = cutlass_bindings.gemm.GemmCoord( + 16 * np.random.randint(0, 64) + 48, 16 * np.random.randint(0, 64) + 48, 16 * np.random.randint(0, 64) + 48, - 16 * np.random.randint(0, 64) + 48 ) else: if i == 0: - problem_size = cutlass.gemm.GemmCoord(48, 16, 8) + problem_size = cutlass_bindings.gemm.GemmCoord(48, 16, 8) else: - problem_size = cutlass.gemm.GemmCoord( + problem_size = cutlass_bindings.gemm.GemmCoord( + 8 * np.random.randint(0, 64) + 24, 8 * np.random.randint(0, 64) + 24, 8 * np.random.randint(0, 64) + 24, - 8 * np.random.randint(0, 64) + 24 ) tensor_As.append( self.uniform_init( - size=(problem_size.m() * problem_size.k(),), - dtype=self.dtype_A) + size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A + ) ) tensor_Bs.append( self.uniform_init( - size=(problem_size.n() * problem_size.k(),), - dtype=self.dtype_B) + size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B + ) ) tensor_Cs.append( self.uniform_init( - size=(problem_size.m() * problem_size.n(),), - dtype=self.dtype_C) + size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C + ) ) tensor_Ds.append( np.zeros( - shape=(problem_size.m() * problem_size.n(),), - dtype=self.dtype_D + shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D ) ) tensor_D_refs.append( np.ones( - shape=(problem_size.m() * problem_size.n(),), - dtype=self.dtype_D + shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D ) ) problem_sizes.append(problem_size) arguments = GemmGroupedArguments( - operation=self.operation, problem_sizes=problem_sizes, - A=tensor_As, B=tensor_Bs, C=tensor_Cs, D=tensor_Ds, - output_op=self.operation.epilogue_type(alpha, beta) + operation=self.operation, + problem_sizes=problem_sizes, + A=tensor_As, + B=tensor_Bs, + C=tensor_Cs, + D=tensor_Ds, + output_op=self.operation.epilogue_type(alpha, beta), ) self.operation.run(arguments) @@ -193,34 +201,65 @@ def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool for idx, problem_size in enumerate(problem_sizes): if self.operation.switched: tensor_ref_A = getTensorRef( - tensor_As[idx], problem_size, "a", transpose(self.operation.B.layout)) + tensor_As[idx], + problem_size, + "a", + transpose(self.operation.B.layout), + ) tensor_ref_B = getTensorRef( - tensor_Bs[idx], problem_size, "b", transpose(self.operation.A.layout)) + tensor_Bs[idx], + problem_size, + "b", + transpose(self.operation.A.layout), + ) tensor_ref_C = getTensorRef( - tensor_Cs[idx], problem_size, "c", transpose(self.operation.C.layout)) + tensor_Cs[idx], + problem_size, + "c", + transpose(self.operation.C.layout), + ) tensor_ref_D_ref = getTensorRef( - tensor_D_refs[idx], problem_size, "d", transpose(self.operation.C.layout)) + tensor_D_refs[idx], + problem_size, + "d", + transpose(self.operation.C.layout), + ) else: tensor_ref_A = getTensorRef( - tensor_As[idx], problem_size, "a", self.operation.A.layout) + tensor_As[idx], problem_size, "a", self.operation.A.layout + ) tensor_ref_B = getTensorRef( - tensor_Bs[idx], problem_size, "b", self.operation.B.layout) + tensor_Bs[idx], problem_size, "b", self.operation.B.layout + ) tensor_ref_C = getTensorRef( - tensor_Cs[idx], problem_size, "c", self.operation.C.layout) + tensor_Cs[idx], problem_size, "c", self.operation.C.layout + ) tensor_ref_D_ref = getTensorRef( - tensor_D_refs[idx], problem_size, "d", self.operation.C.layout) + tensor_D_refs[idx], problem_size, "d", self.operation.C.layout + ) tensor_view_D_ref = getTensorView( - tensor_D_refs[idx], problem_size, "d", self.operation.C.layout) + tensor_D_refs[idx], problem_size, "d", self.operation.C.layout + ) - cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A, - tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + cutlass_bindings.test.gemm.host.gemm( + problem_size, + alpha, + tensor_ref_A, + tensor_ref_B, + beta, + tensor_ref_C, + tensor_ref_D_ref, + init_acc, + ) tensor_view_D = getTensorView( - tensor_Ds[idx], problem_size, "d", self.operation.C.layout) + tensor_Ds[idx], problem_size, "d", self.operation.C.layout + ) - passed = cutlass.test.gemm.host.equals( - tensor_view_D, tensor_view_D_ref) + passed = cutlass_bindings.test.gemm.host.equals( + tensor_view_D, tensor_view_D_ref + ) try: assert passed @@ -229,7 +268,9 @@ def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool del arguments - assert get_allocated_size( - ) == 0, "%d byte of pool memory is not released after current run" % get_allocated_size() + assert get_allocated_size() == 0, ( + "%d byte of pool memory is not released after current run" + % get_allocated_size() + ) return passed diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py b/python/cutlass/backend/test/gemm_testbed.py similarity index 56% rename from tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py rename to python/cutlass/backend/test/gemm_testbed.py index ab3ae5ad..3790f170 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +++ b/python/cutlass/backend/test/gemm_testbed.py @@ -30,31 +30,50 @@ # ################################################################################################# +import os +import re +import subprocess from time import sleep -import pycutlass -from pycutlass import * -import pycutlass.utils.datatypes as datatypes -import cutlass -from cuda import cudart -from cuda import cuda + from bfloat16 import bfloat16 -from .profiler import GpuTimer -import subprocess +from cuda import cuda, cudart +import cutlass_bindings +import numpy as np + +from cutlass.backend import compiler +from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass.backend.library import ( + DataTypeSize, + DataTypeSizeBytes, + MathOperation, + ShortDataTypeNames, +) +from cutlass.backend.memory_manager import get_allocated_size +from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass.backend.test.profiler import GpuTimer +from cutlass.backend.utils.datatypes import to_cutlass +from cutlass.backend.utils.software import SubstituteTemplate def transpose(layout): - if layout == cutlass.RowMajor: - return cutlass.ColumnMajor - elif layout == cutlass.ColumnMajor: - return cutlass.RowMajor - elif layout == cutlass.ColumnMajorInterleaved32: - return cutlass.RowMajorInterleaved32 - elif layout == cutlass.RowMajorInterleaved32: - return cutlass.ColumnMajorInterleaved32 - - -def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: cutlass.layout, batch_offset: int = 0): - ptr = tensor.__array_interface__['data'][0] + if layout == cutlass_bindings.RowMajor: + return cutlass_bindings.ColumnMajor + elif layout == cutlass_bindings.ColumnMajor: + return cutlass_bindings.RowMajor + elif layout == cutlass_bindings.ColumnMajorInterleaved32: + return cutlass_bindings.RowMajorInterleaved32 + elif layout == cutlass_bindings.RowMajorInterleaved32: + return cutlass_bindings.ColumnMajorInterleaved32 + + +def getTensorRef( + tensor: np.ndarray, + problem_size: cutlass_bindings.gemm.GemmCoord, + operand: str, + layout: cutlass_bindings.layout, + batch_offset: int = 0, +): + ptr = tensor.__array_interface__["data"][0] if operand == "a": tensor_coord = problem_size.mk() batch_stride = problem_size.m() * problem_size.k() @@ -67,20 +86,20 @@ def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, opera else: raise ValueError("Unknown operand: " + operand) - elt_size = DataTypeSizeBytes[datatypes.to_cutlass(tensor.dtype)] + elt_size = DataTypeSizeBytes[to_cutlass(tensor.dtype)] ptr += batch_offset * batch_stride * elt_size - if layout == cutlass.RowMajor: - layout = cutlass.RowMajor.packed(tensor_coord) + if layout == cutlass_bindings.RowMajor: + layout = cutlass_bindings.RowMajor.packed(tensor_coord) layout_tag = "RowMajor" - elif layout == cutlass.ColumnMajor: - layout = cutlass.ColumnMajor.packed(tensor_coord) + elif layout == cutlass_bindings.ColumnMajor: + layout = cutlass_bindings.ColumnMajor.packed(tensor_coord) layout_tag = "ColumnMajor" - elif layout == cutlass.ColumnMajorInterleaved32: - layout = cutlass.ColumnMajorInterleaved32.packed(tensor_coord) + elif layout == cutlass_bindings.ColumnMajorInterleaved32: + layout = cutlass_bindings.ColumnMajorInterleaved32.packed(tensor_coord) layout_tag = "ColumnMajorInterleaved32" - elif layout == cutlass.RowMajorInterleaved32: - layout = cutlass.RowMajorInterleaved32.packed(tensor_coord) + elif layout == cutlass_bindings.RowMajorInterleaved32: + layout = cutlass_bindings.RowMajorInterleaved32.packed(tensor_coord) layout_tag = "RowMajorInterleaved32" else: raise ValueError("unsupported layout") @@ -97,13 +116,18 @@ def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, opera elif tensor.dtype == np.int32: ref_name = "TensorRefS32" + layout_tag else: - raise ValueError("unsupported datatype %s" % - ShortDataTypeNames[tensor.dtype]) + raise ValueError("unsupported datatype %s" % ShortDataTypeNames[tensor.dtype]) - return getattr(cutlass, ref_name)(ptr, layout) + return getattr(cutlass_bindings, ref_name)(ptr, layout) -def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: str, batch_offset: int = 0): +def getTensorView( + tensor: np.ndarray, + problem_size: cutlass_bindings.gemm.GemmCoord, + operand: str, + layout: str, + batch_offset: int = 0, +): tensor_ref = getTensorRef(tensor, problem_size, operand, layout, batch_offset) if operand == "a": @@ -115,13 +139,13 @@ def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, oper else: raise ValueError("Unknown operand: " + operand) - if layout == cutlass.RowMajor: + if layout == cutlass_bindings.RowMajor: layout_tag = "RowMajor" - elif layout == cutlass.ColumnMajor: + elif layout == cutlass_bindings.ColumnMajor: layout_tag = "ColumnMajor" - elif layout == cutlass.ColumnMajorInterleaved32: + elif layout == cutlass_bindings.ColumnMajorInterleaved32: layout_tag = "ColumnMajorInterleaved32" - elif layout == cutlass.RowMajorInterleaved32: + elif layout == cutlass_bindings.RowMajorInterleaved32: layout_tag = "RowMajorInterleaved32" else: raise ValueError("unsupported layout") @@ -140,18 +164,29 @@ def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, oper else: raise ValueError("unsupported datatype") - return getattr(cutlass, ref_name)(tensor_ref, tensor_coord) + return getattr(cutlass_bindings, ref_name)(tensor_ref, tensor_coord) class GemmUniversalLauncher: - def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interleaved=False, - verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None: + def __init__( + self, + operation: "GemmOperationUniversal", + seed: int = 2080, + interleaved=False, + verification=True, + profiling=False, + warmup_iterations=500, + iterations=500, + **kwargs, + ) -> None: # create the reduction kernel self.reduction_operation: ReductionOperation = ReductionOperation( - shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment), - C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor, - count=operation.C.alignment + shape=cutlass_bindings.MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, ) self.math_operation = operation.tile_description.math_instruction.math_operation @@ -180,7 +215,7 @@ def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interl # Split K via Python is currently only supported for pre-SM90 kernels op_list.append(self.reduction_operation) - pycutlass.compiler.add_module(op_list) + compiler.add_module(op_list) self.operation = operation @@ -189,7 +224,9 @@ def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interl self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element) self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element) - accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator] + accumulator_size = DataTypeSize[ + operation.tile_description.math_instruction.element_accumulator + ] element_size = DataTypeSize[operation.A.element] if element_size == 1: @@ -213,32 +250,39 @@ def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interl #: compute type self.compute_type = operation.epilogue_functor.element_epilogue - self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + self.accumulator_type = ( + operation.tile_description.math_instruction.element_accumulator + ) def print_problem_size(self, p, mode, batch_count): - if mode == cutlass.gemm.Mode.Gemm: + if mode == cutlass_bindings.gemm.Mode.Gemm: mode = "Gemm" - elif mode == cutlass.gemm.Mode.Batched: + elif mode == cutlass_bindings.gemm.Mode.Batched: mode = "GemmBatched" - elif mode == cutlass.gemm.Mode.GemmSplitKParallel: + elif mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: mode = "GemmSplitKParallel" problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % ( - p.m(), p.n(), p.k(), batch_count, mode) + p.m(), + p.n(), + p.k(), + batch_count, + mode, + ) print(problem_size) @staticmethod def numpy_type(type): - if type == cutlass.float64: + if type == cutlass_bindings.float64: return np.float64 - elif type == cutlass.float32: + elif type == cutlass_bindings.float32: return np.float32 - elif type == cutlass.float16: + elif type == cutlass_bindings.float16: return np.float16 - elif type == cutlass.bfloat16: + elif type == cutlass_bindings.bfloat16: return bfloat16 - elif type == cutlass.int32: + elif type == cutlass_bindings.int32: return np.int32 - elif type == cutlass.int8: + elif type == cutlass_bindings.int8: return np.int8 else: raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) @@ -247,22 +291,25 @@ def uniform_init(self, size, dtype): if dtype in [np.float32, np.float16, bfloat16, np.float64]: return np.ceil( np.random.uniform( - low=self.scope_min - 0.5, high=self.scope_max - 0.5, - size=size).astype(dtype) + low=self.scope_min - 0.5, high=self.scope_max - 0.5, size=size + ).astype(dtype) ) else: return np.random.uniform( - low=self.scope_min - 1, high=self.scope_max + 1, - size=size).astype(dtype) + low=self.scope_min - 1, high=self.scope_max + 1, size=size + ).astype(dtype) def reorder_tensor_B(self, tensor_B, problem_size): reordered_tensor_B = np.empty_like(tensor_B) tensor_ref_B = getTensorRef( - tensor_B, problem_size, "b", self.operation.B.layout) + tensor_B, problem_size, "b", self.operation.B.layout + ) reordered_tensor_ref_B = getTensorRef( - reordered_tensor_B, problem_size, "b", self.operation.B.layout) - cutlass.gemm.host.reorder_column( - tensor_ref_B, reordered_tensor_ref_B, problem_size) + reordered_tensor_B, problem_size, "b", self.operation.B.layout + ) + cutlass_bindings.gemm.host.reorder_column( + tensor_ref_B, reordered_tensor_ref_B, problem_size + ) return reordered_tensor_B def host_reference(self, problem_size, batch_count, tensor_A, tensor_B, tensor_C, alpha, beta): @@ -278,40 +325,88 @@ def host_reference(self, problem_size, batch_count, tensor_A, tensor_B, tensor_C for i in range(batch_count): if self.operation.switched: tensor_ref_A = getTensorRef( - tensor_A, problem_size, "a", transpose(self.operation.B.layout), batch_offset=i) + tensor_A, + problem_size, + "a", + transpose(self.operation.B.layout), + batch_offset=i, + ) tensor_ref_B = getTensorRef( - tensor_B, problem_size, "b", transpose(self.operation.A.layout), batch_offset=i) + tensor_B, + problem_size, + "b", + transpose(self.operation.A.layout), + batch_offset=i, + ) tensor_ref_C = getTensorRef( - tensor_C, problem_size, "c", transpose(self.operation.C.layout), batch_offset=i) + tensor_C, + problem_size, + "c", + transpose(self.operation.C.layout), + batch_offset=i, + ) tensor_ref_D_ref = getTensorRef( - tensor_D_ref, problem_size, "d", transpose(self.operation.C.layout), batch_offset=i) + tensor_D_ref, + problem_size, + "d", + transpose(self.operation.C.layout), + batch_offset=i, + ) else: tensor_ref_A = getTensorRef( - tensor_A, problem_size, "a", self.operation.A.layout, batch_offset=i) + tensor_A, problem_size, "a", self.operation.A.layout, batch_offset=i + ) tensor_ref_B = getTensorRef( - tensor_B, problem_size, "b", self.operation.B.layout, batch_offset=i) + tensor_B, problem_size, "b", self.operation.B.layout, batch_offset=i + ) tensor_ref_C = getTensorRef( - tensor_C, problem_size, "c", self.operation.C.layout, batch_offset=i) + tensor_C, problem_size, "c", self.operation.C.layout, batch_offset=i + ) tensor_ref_D_ref = getTensorRef( - tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i) + tensor_D_ref, + problem_size, + "d", + self.operation.C.layout, + batch_offset=i, + ) if self.math_operation in [MathOperation.multiply_add_saturate]: - cutlass.test.gemm.host.gemm_saturate( - problem_size, alpha, tensor_ref_A, tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + cutlass_bindings.test.gemm.host.gemm_saturate( + problem_size, + alpha, + tensor_ref_A, + tensor_ref_B, + beta, + tensor_ref_C, + tensor_ref_D_ref, + init_acc, + ) else: - cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A, - tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + cutlass_bindings.test.gemm.host.gemm( + problem_size, + alpha, + tensor_ref_A, + tensor_ref_B, + beta, + tensor_ref_C, + tensor_ref_D_ref, + init_acc, + ) return tensor_D_ref def equal(self, tensor_D, tensor_D_ref, problem_size, batch_count): for i in range(batch_count): tensor_view_D = getTensorView( - tensor_D, problem_size, "d", self.operation.C.layout, batch_offset=i) + tensor_D, problem_size, "d", self.operation.C.layout, batch_offset=i + ) tensor_view_D_ref = getTensorView( - tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i) + tensor_D_ref, problem_size, "d", self.operation.C.layout, batch_offset=i + ) - if not cutlass.test.gemm.host.equals(tensor_view_D, tensor_view_D_ref): + if not cutlass_bindings.test.gemm.host.equals( + tensor_view_D, tensor_view_D_ref + ): return False return True @@ -321,10 +416,11 @@ def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0): n = problem_size.n() k = problem_size.k() - bytes = \ - (DataTypeSize[self.operation.A.element] * m // 8) * k + \ - (DataTypeSize[self.operation.B.element] * n // 8) * k + \ - (DataTypeSize[self.operation.C.element] * m // 8) * n + bytes = ( + (DataTypeSize[self.operation.A.element] * m // 8) * k + + (DataTypeSize[self.operation.B.element] * n // 8) * k + + (DataTypeSize[self.operation.C.element] * m // 8) * n + ) if beta != 0: bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n @@ -342,10 +438,13 @@ def flops(self, problem_size, batch_count=1): return flops_ - def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): - - cutlass_path = os.getenv('CUTLASS_PATH') - assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." + def run_cutlass_profiler( + self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0 + ): + cutlass_path = os.getenv("CUTLASS_PATH") + assert ( + cutlass_path is not None + ), "Environment variable 'CUTLASS_PATH' is not defined." values = { "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler", @@ -355,28 +454,29 @@ def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, bet "m": str(problem_size.m()), "n": str(problem_size.n()), "k": str(problem_size.k()), - 'split_k_slices': str(batch_count), - 'alpha': str(alpha), - 'beta': str(beta), - 'warmup': str(self.warmup_iterations), - 'profile': str(self.iterations) + "split_k_slices": str(batch_count), + "alpha": str(alpha), + "beta": str(beta), + "warmup": str(self.warmup_iterations), + "profile": str(self.iterations), } - cmd_template = \ - "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \ + cmd_template = ( + "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" " --providers=${provider} --m=${m} --n=${n} --k=${k}" + ) cmd = SubstituteTemplate(cmd_template, values) result = subprocess.getoutput(cmd) m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) - runtime = float(m.group('runtime')) + runtime = float(m.group("runtime")) m = re.search(r"Bytes:\s+(?P\d+)", result) - bytes = int(m.group('bytes')) + bytes = int(m.group("bytes")) m = re.search(r"FLOPs:\s+(?P\d+)", result) - flops = int(m.group('flops')) + flops = int(m.group("flops")) # check if the problem size matches assert bytes == self.bytes(problem_size, alpha, beta) @@ -385,61 +485,86 @@ def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, bet return runtime def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): - assert get_allocated_size( - ) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() + assert get_allocated_size() == 0, ( + "%d byte of pool memory is not released in previous run" + % get_allocated_size() + ) np.random.seed(self.seed) # Assign an actual batch count in cases where we are not running in batched mode. # This is to differentiate between the number of split K slices and the batch count, # which are overloaded within the single `batch_count` variable. - true_batch_count = batch_count if mode == cutlass.gemm.Mode.Batched else 1 + true_batch_count = ( + batch_count if mode == cutlass_bindings.gemm.Mode.Batched else 1 + ) tensor_A = self.uniform_init( - size=(problem_size.m() * problem_size.k() * true_batch_count,), dtype=self.dtype_A) + size=(problem_size.m() * problem_size.k() * true_batch_count,), + dtype=self.dtype_A, + ) tensor_B = self.uniform_init( - size=(problem_size.n() * problem_size.k() * true_batch_count,), dtype=self.dtype_B) + size=(problem_size.n() * problem_size.k() * true_batch_count,), + dtype=self.dtype_B, + ) tensor_C = self.uniform_init( - size=(problem_size.m() * problem_size.n() * true_batch_count,), dtype=self.dtype_C) + size=(problem_size.m() * problem_size.n() * true_batch_count,), + dtype=self.dtype_C, + ) tensor_D = np.zeros( - shape=(problem_size.m() * problem_size.n() * true_batch_count,), dtype=self.dtype_D) + shape=(problem_size.m() * problem_size.n() * true_batch_count,), + dtype=self.dtype_D, + ) # # Launch kernel # arguments = GemmArguments( - operation=self.operation, problem_size=problem_size, - A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + operation=self.operation, + problem_size=problem_size, + A=tensor_A, + B=tensor_B, + C=tensor_C, + D=tensor_D, output_op=self.operation.epilogue_type(alpha, beta), - gemm_mode=mode, split_k_slices=split_k_slices, batch=batch_count + gemm_mode=mode, + split_k_slices=split_k_slices, + batch=batch_count, ) - if mode == cutlass.gemm.Mode.GemmSplitKParallel: + if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: reduction_arguments = ReductionArguments( - self.reduction_operation, problem_size=[ - problem_size.m(), problem_size.n()], + self.reduction_operation, + problem_size=[problem_size.m(), problem_size.n()], partitions=split_k_slices, workspace=arguments.ptr_D, destination=tensor_D, source=tensor_C, - output_op=self.reduction_operation.epilogue_type(alpha, beta) + output_op=self.reduction_operation.epilogue_type(alpha, beta), ) self.operation.run(arguments) - if mode == cutlass.gemm.Mode.GemmSplitKParallel: + if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: self.reduction_operation.run(reduction_arguments) passed = True if self.verification: - if mode == cutlass.gemm.Mode.GemmSplitKParallel: + if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: reduction_arguments.sync() else: arguments.sync() tensor_D_ref = self.host_reference( - problem_size, true_batch_count, tensor_A, tensor_B, tensor_C, alpha, beta) + problem_size, + true_batch_count, + tensor_A, + tensor_B, + tensor_C, + alpha, + beta, + ) passed = self.equal(tensor_D, tensor_D_ref, problem_size, true_batch_count) try: @@ -451,13 +576,13 @@ def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, be sleep(self.sleep_time) for _ in range(self.warmup_iterations): self.operation.run(arguments) - if mode == cutlass.gemm.Mode.GemmSplitKParallel: + if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: self.reduction_operation.run(reduction_arguments) self.timer.start() for _ in range(self.iterations): self.operation.run(arguments) - if mode == cutlass.gemm.Mode.GemmSplitKParallel: + if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: self.reduction_operation.run(reduction_arguments) self.timer.stop_and_wait() @@ -465,45 +590,56 @@ def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, be # free memory and clear buffers del arguments - if mode == cutlass.gemm.Mode.GemmSplitKParallel: + if mode == cutlass_bindings.gemm.Mode.GemmSplitKParallel: del reduction_arguments - assert get_allocated_size( - ) == 0, "%d byte of pool memory is not released after current run" % get_allocated_size() + assert get_allocated_size() == 0, ( + "%d byte of pool memory is not released after current run" + % get_allocated_size() + ) if self.profiling: return runtime return passed -def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): - +def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"): passed = True minimum_operand_element_size = min( - DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) + DataTypeSize[operation.A.element], DataTypeSize[operation.B.element] + ) opcode_class = operation.tile_description.math_instruction.opcode_class - if opcode_class == cutlass.OpClass.Simt: + if opcode_class == cutlass_bindings.OpClass.Simt: alignment = 1 else: alignment = 128 // minimum_operand_element_size # int8_t gemm alignment constraints - if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 and operation.A.layout == cutlass.ColumnMajor: + if opcode_class == cutlass_bindings.OpClass.Simt and operation.A.element == cutlass_bindings.int8 and operation.A.layout == cutlass_bindings.ColumnMajor: alignment_m = 4 else: alignment_m = alignment - if opcode_class == cutlass.OpClass.Simt and operation.B.element == cutlass.int8 and operation.A.layout == cutlass.RowMajor: + if ( + opcode_class == cutlass_bindings.OpClass.Simt + and operation.B.element == cutlass_bindings.int8 + and operation.A.layout == cutlass_bindings.RowMajor + ): alignment_n = 4 else: alignment_n = alignment - if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 \ - and operation.B.element == cutlass.int8 \ - and (operation.A.layout == cutlass.RowMajor or operation.B.layout == cutlass.ColumnMajor): - + if ( + opcode_class == cutlass_bindings.OpClass.Simt + and operation.A.element == cutlass_bindings.int8 + and operation.B.element == cutlass_bindings.int8 + and ( + operation.A.layout == cutlass_bindings.RowMajor + or operation.B.layout == cutlass_bindings.ColumnMajor + ) + ): alignment_k = 4 else: alignment_k = alignment @@ -511,35 +647,55 @@ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): threadblock_k = operation.tile_description.threadblock_shape[2] if testcase == "interleaved": - if operation.A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]: + if operation.A.layout in [ + cutlass_bindings.ColumnMajorInterleaved32, + cutlass_bindings.RowMajorInterleaved32, + ]: interleavedk = 32 else: raise ValueError("Unknown layout") + # Split K mode via Python is currently only supported pre-SM90, and when stream K is not used. + # Stream K enables split-k functionality with mode `Gemm` and a non-unit batch count. + supports_split_k = operation.arch < 90 and not isinstance( + operation.swizzling_functor, cutlass_bindings.ThreadblockSwizzleStreamK + ) if testcase == "interleaved": - modes = [cutlass.gemm.Mode.Gemm, ] - problem_size_m = [interleavedk, 512+interleavedk] - problem_size_n = [interleavedk, 512+interleavedk] - problem_size_k = [interleavedk, threadblock_k * - operation.tile_description.stages + interleavedk] + modes = [ + cutlass_bindings.gemm.Mode.Gemm, + ] + problem_size_m = [interleavedk, 512 + interleavedk] + problem_size_n = [interleavedk, 512 + interleavedk] + problem_size_k = [ + interleavedk, + threadblock_k * operation.tile_description.stages + interleavedk, + ] problem_alpha = [1.0] problem_beta = [0.0] - batch_counts = [1, ] + batch_counts = [ + 1, + ] elif testcase == "multistage": - modes = [cutlass.gemm.Mode.Gemm, ] + modes = [ + cutlass_bindings.gemm.Mode.Gemm, + ] problem_size_m = [16, 528] problem_size_n = [16, 528] - problem_size_k = [threadblock_k, threadblock_k * operation.tile_description.stages + - operation.tile_description.math_instruction.instruction_shape[2]] + problem_size_k = [ + threadblock_k, + threadblock_k * operation.tile_description.stages + + operation.tile_description.math_instruction.instruction_shape[2], + ] problem_alpha = [1.0] problem_beta = [0.0] - batch_counts = [1, ] + batch_counts = [ + 1, + ] else: # universal - modes = [cutlass.gemm.Mode.Gemm] + modes = [cutlass_bindings.gemm.Mode.Gemm] batch_counts = [1, 2, 3, 5, 7] - if operation.arch < 90: - # Split K kernels via Python are currently only supported pre-SM90 - modes.append(cutlass.gemm.Mode.GemmSplitKParallel) + if supports_split_k: + modes.append(cutlass_bindings.gemm.Mode.GemmSplitKParallel) problem_size_m = [alignment_m, 512 - 3 * alignment_m] problem_size_n = [alignment_n, 512 - 2 * alignment_n] @@ -550,12 +706,12 @@ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): problem_size_k = [ alignment_k, threadblock_k * stages_for_k_calc - alignment_k, - threadblock_k * stages_for_k_calc * 3 - alignment_k] + threadblock_k * stages_for_k_calc * 3 - alignment_k, + ] problem_alpha = [1.0] problem_beta = [2.0] - testbed = GemmUniversalLauncher( - operation, interleaved=(testcase == "interleaved")) + testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved")) for mode in modes: for m in problem_size_m: @@ -566,27 +722,35 @@ def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): for beta in problem_beta: # skip very small K problems if testcase == "universal": - if (k // batch_count < 2 * threadblock_k): + if k // batch_count < 2 * threadblock_k: continue - problem_size = cutlass.gemm.GemmCoord(m, n, k) + problem_size = cutlass_bindings.gemm.GemmCoord(m, n, k) - if operation.arch < 90: + if supports_split_k: split_k_slices = batch_count else: split_k_slices = 1 overridden_mode = mode - if mode == cutlass.gemm.Mode.Gemm and batch_count > 1: - overridden_mode = cutlass.gemm.Mode.Batched + if ( + mode == cutlass_bindings.gemm.Mode.Gemm + and batch_count > 1 + ): + overridden_mode = cutlass_bindings.gemm.Mode.Batched passed = testbed.run( - overridden_mode, problem_size, batch_count, split_k_slices, alpha, beta) - - err, = cudart.cudaDeviceSynchronize() + overridden_mode, + problem_size, + batch_count, + split_k_slices, + alpha, + beta, + ) + + (err,) = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError( - "CUDA Error %s" % str(err)) + raise RuntimeError("CUDA Error %s" % str(err)) if not passed: return False diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py b/python/cutlass/backend/test/profiler.py similarity index 91% rename from tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py rename to python/cutlass/backend/test/profiler.py index 66a4e960..31d14b3d 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +++ b/python/cutlass/backend/test/profiler.py @@ -30,24 +30,23 @@ # ################################################################################################# -from cuda import cuda -from cuda import cudart +from cuda import cuda, cudart class GpuTimer: def __init__(self) -> None: self.events = [ cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], - cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], ] def start(self, stream=cuda.CUstream(0)): - err, = cuda.cuEventRecord(self.events[0], stream) + (err,) = cuda.cuEventRecord(self.events[0], stream) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) def stop(self, stream=cuda.CUstream(0)): - err, = cuda.cuEventRecord(self.events[1], stream) + (err,) = cuda.cuEventRecord(self.events[1], stream) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) pass @@ -55,11 +54,11 @@ def stop(self, stream=cuda.CUstream(0)): def stop_and_wait(self, stream=cuda.CUstream(0)): self.stop(stream) if stream: - err, = cuda.cuStreamSynchronize(stream) + (err,) = cuda.cuStreamSynchronize(stream) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) else: - err, = cudart.cudaDeviceSynchronize() + (err,) = cudart.cudaDeviceSynchronize() if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/utils.py b/python/cutlass/backend/test/utils.py similarity index 64% rename from tools/library/scripts/pycutlass/src/pycutlass/test/utils.py rename to python/cutlass/backend/test/utils.py index f1a25f92..1489a4aa 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/utils.py +++ b/python/cutlass/backend/test/utils.py @@ -30,22 +30,27 @@ # ################################################################################################# -import cutlass -from pycutlass import library, SubstituteTemplate +import cutlass_bindings + +from cutlass import KernelScheduleSuffixes +from cutlass.backend import library +from cutlass.backend.utils.software import SubstituteTemplate class Layout: """ Utility class to map transpose and non-transpose terminology to row- and column-major terminology """ - T = cutlass.RowMajor - N = cutlass.ColumnMajor + + T = cutlass_bindings.RowMajor + N = cutlass_bindings.ColumnMajor class LayoutCombination: """ Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs """ + NNN = (Layout.N, Layout.N, Layout.N) NNT = (Layout.N, Layout.N, Layout.T) NTN = (Layout.N, Layout.T, Layout.N) @@ -56,9 +61,22 @@ class LayoutCombination: TTT = (Layout.T, Layout.T, Layout.T) -def get_name(layouts, alignments, element_output, - element_accumulator, element_epilogue, cluster_shape, - threadblock_shape, stages, element_a, element_b, arch, opclass, suffix=""): +def get_name( + layouts, + alignments, + element_output, + element_accumulator, + element_epilogue, + cluster_shape, + threadblock_shape, + stages, + element_a, + element_b, + arch, + opclass, + kernel_schedule=None, + suffix="", +): """ Generates a procedural name for a test case. @@ -76,34 +94,38 @@ def get_name(layouts, alignments, element_output, :param arch: compute capability of kernel being generated :type arch: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass + :type opclass: cutlass_bindings.OpClass + :param kernel_schedule: kernel_schedule type + :type kernel_schedule: cutlass.KernelScheduleType :param suffix: additional string to add to the suffix of the name :type suffix: str :return: str """ - name_format = 'test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${suffix}' - return SubstituteTemplate(name_format, + name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${suffix}" + return SubstituteTemplate( + name_format, { - 'arch': str(arch), - 'eA': library.DataTypeNames[element_a], - 'eB': library.DataTypeNames[element_b], - 'eC': library.DataTypeNames[element_output], - 'lA': library.ShortLayoutTypeNames[layouts[0]], - 'lB': library.ShortLayoutTypeNames[layouts[1]], - 'lC': library.ShortLayoutTypeNames[layouts[2]], - 'opclass': library.OpcodeClassNames[opclass], - 'acc': library.DataTypeNames[element_accumulator], - 'cM': str(cluster_shape[0]), - 'cN': str(cluster_shape[1]), - 'cK': str(cluster_shape[2]), - 'tbM': str(threadblock_shape[0]), - 'tbN': str(threadblock_shape[1]), - 'tbK': str(threadblock_shape[2]), - 'stages': str(stages) if stages is not None else 'auto', - 'aA' : str(alignments[0]), - 'aB' : str(alignments[1]), - 'aC' : str(alignments[2]), - 'suffix': '' if suffix is None else suffix - } + "arch": str(arch), + "eA": library.DataTypeNames[element_a], + "eB": library.DataTypeNames[element_b], + "eC": library.DataTypeNames[element_output], + "lA": library.ShortLayoutTypeNames[layouts[0]], + "lB": library.ShortLayoutTypeNames[layouts[1]], + "lC": library.ShortLayoutTypeNames[layouts[2]], + "opclass": library.OpcodeClassNames[opclass], + "acc": library.DataTypeNames[element_accumulator], + "cM": str(cluster_shape[0]), + "cN": str(cluster_shape[1]), + "cK": str(cluster_shape[2]), + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "stages": str(stages) if stages is not None else "auto", + "aA": str(alignments[0]), + "aB": str(alignments[1]), + "aC": str(alignments[2]), + "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], + "suffix": "" if suffix is None else suffix, + }, ) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py b/python/cutlass/backend/type_hint.py similarity index 89% rename from tools/library/scripts/pycutlass/src/pycutlass/type_hint.py rename to python/cutlass/backend/type_hint.py index d767a042..d1e8ba91 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +++ b/python/cutlass/backend/type_hint.py @@ -30,10 +30,6 @@ # ################################################################################ -from typing import Union -from typeguard import typechecked +GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]" - -GemmOperation = 'Union[GemmOperationUniversal, GemmOperationGrouped]' - -Tensor = 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]' +Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]" diff --git a/python/cutlass/backend/utils/__init__.py b/python/cutlass/backend/utils/__init__.py new file mode 100644 index 00000000..3d71d4da --- /dev/null +++ b/python/cutlass/backend/utils/__init__.py @@ -0,0 +1,41 @@ +################################################################################ +# +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +from cutlass.backend.utils.datatypes import * +from cutlass.backend.utils.device import check_cuda_errors, device_cc +from cutlass.backend.utils.reference_model import ReferenceModule +from cutlass.backend.utils.software import ( + CheckPackages, + SubstituteTemplate, + device_sm_count, + get_memory_pool, +) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/datatypes.py b/python/cutlass/backend/utils/datatypes.py similarity index 58% rename from tools/library/scripts/pycutlass/src/pycutlass/utils/datatypes.py rename to python/cutlass/backend/utils/datatypes.py index f4cc56ba..834a6071 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/utils/datatypes.py +++ b/python/cutlass/backend/utils/datatypes.py @@ -34,88 +34,96 @@ Utility functions for converting between frontend datatypes and CUTLASS datatypes """ -from typing import Union, Tuple +import cutlass_bindings -import cutlass +from cutlass.backend.utils.software import CheckPackages -import pycutlass.library as library +numpy_available = CheckPackages().check_numpy() +if numpy_available: + import numpy as np + numpy_to_cutlass_dict = { + np.float16: cutlass_bindings.float16, + np.float32: cutlass_bindings.float32, + np.float64: cutlass_bindings.float64, + np.int8: cutlass_bindings.int8, + np.int32: cutlass_bindings.int32, + np.dtype('float16'): cutlass_bindings.float16, + np.dtype('float32'): cutlass_bindings.float32, + np.dtype('float64'): cutlass_bindings.float64, + np.dtype('int8'): cutlass_bindings.int8, + np.dtype('int32'): cutlass_bindings.int32, + } -try: - import numpy as np - numpy_available = True -except ImportError: - numpy_available = False def numpy_to_cutlass(inp): + numpy_available = CheckPackages().check_numpy() if numpy_available: - if inp == np.float16: - return cutlass.float16 - elif inp == np.float32: - return cutlass.float32 - elif inp == np.float64: - return cutlass.float64 - elif inp == np.int8: - return cutlass.int8 - elif inp == np.int32: - return cutlass.int32 - return None + return numpy_to_cutlass_dict.get(inp, None) -try: + +cupy_available = CheckPackages().check_cupy() +if cupy_available: import cupy as cp - cupy_available = True + cupy_to_cutlass_dict = { - cp.float16: cutlass.float16, - cp.float32: cutlass.float32, - cp.float64: cutlass.float64 + cp.float16: cutlass_bindings.float16, + cp.float32: cutlass_bindings.float32, + cp.float64: cutlass_bindings.float64, } -except ImportError: - cupy_available = False + def cupy_to_cutlass(inp): + cupy_available = CheckPackages().check_cupy() if cupy_available: - if inp == cp.float16: - return cutlass.float16 - elif inp == cp.float32: - return cutlass.float32 - elif inp == cp.float64: - return cutlass.float64 - return None + return cupy_to_cutlass_dict.get(inp, None) -try: + +torch_available = CheckPackages().check_torch() +if torch_available: import torch - torch_available = True + torch_to_cutlass_dict = { - torch.half: cutlass.float16, - torch.float16: cutlass.float16, - torch.float: cutlass.float32, - torch.float32: cutlass.float32, - torch.double: cutlass.float64, - torch.float64: cutlass.float64 + torch.half: cutlass_bindings.float16, + torch.float16: cutlass_bindings.float16, + torch.float: cutlass_bindings.float32, + torch.float32: cutlass_bindings.float32, + torch.double: cutlass_bindings.float64, + torch.float64: cutlass_bindings.float64, } -except ImportError: - torch_available = False + def torch_to_cutlass(inp): if torch_available: return torch_to_cutlass_dict.get(inp, None) + try: import bfloat16 + bfloat16_available = True + numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = cutlass_bindings.bfloat16 except ImportError: bfloat16_available = False + def bfloat16_to_cutlass(inp): if bfloat16_available: if inp == bfloat16.bfloat16: - return cutlass.bfloat16 + return cutlass_bindings.bfloat16 def to_cutlass(inp): - for cvt_fn in [bfloat16_to_cutlass, cupy_to_cutlass, numpy_to_cutlass, torch_to_cutlass]: + for cvt_fn in [ + bfloat16_to_cutlass, + cupy_to_cutlass, + numpy_to_cutlass, + torch_to_cutlass, + ]: out = cvt_fn(inp) if out is not None: return out - raise Exception('No available conversion from type {} to a CUTLASS type.'.format(inp)) + raise Exception( + "No available conversion from type {} to a CUTLASS type.".format(inp) + ) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py b/python/cutlass/backend/utils/device.py similarity index 100% rename from tools/library/scripts/pycutlass/src/pycutlass/utils/device.py rename to python/cutlass/backend/utils/device.py diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py b/python/cutlass/backend/utils/reference_model.py similarity index 64% rename from tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py rename to python/cutlass/backend/utils/reference_model.py index 041143a2..2e25dcce 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +++ b/python/cutlass/backend/utils/reference_model.py @@ -30,24 +30,39 @@ # ################################################################################################# -import numpy as np -import cutlass -from pycutlass.library import TensorDescription from typing import Union + from bfloat16 import bfloat16 -try: +import cutlass_bindings +import numpy as np + +from cutlass.backend.library import TensorDescription +from cutlass.backend.utils.software import CheckPackages + +torch_available = CheckPackages().check_torch() +if torch_available: import torch - torch_available = True -except ImportError: - torch_available = False + class ReferenceModule: - def __init__(self, A: TensorDescription, B: TensorDescription, C: TensorDescription) -> None: + def __init__( + self, A: TensorDescription, B: TensorDescription, C: TensorDescription + ) -> None: self.layout_A = A.layout self.layout_B = B.layout self.layout_C = C.layout - - def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0, bias=False, batch=1): + + def run( + self, + A: np.ndarray, + B: np.ndarray, + C: np.ndarray, + problem_size: cutlass_bindings.gemm.GemmCoord, + alpha: float = 1.0, + beta: float = 0.0, + bias=False, + batch=1, + ): """ Compute the reference result on CPU Args: @@ -57,19 +72,19 @@ def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass """ M, N, K = problem_size.m(), problem_size.n(), problem_size.k() if isinstance(A, np.ndarray): - if self.layout_A == cutlass.RowMajor: + if self.layout_A == cutlass_bindings.RowMajor: A_row = np.reshape(A, newshape=(batch, M, K)) else: A_col = np.reshape(A, newshape=(batch, K, M)) A_row = np.transpose(A_col, axes=(0, 2, 1)) - - if self.layout_B == cutlass.RowMajor: + + if self.layout_B == cutlass_bindings.RowMajor: B_row = np.reshape(B, newshape=(batch, K, N)) else: B_col = np.reshape(B, newshape=(batch, N, K)) B_row = np.transpose(B_col, axes=(0, 2, 1)) - if self.layout_C == cutlass.RowMajor: + if self.layout_C == cutlass_bindings.RowMajor: if bias: C_row = np.reshape(C, newshape=(batch, 1, N)) else: @@ -80,49 +95,56 @@ def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass else: C_col = np.reshape(C, newshape=(batch, N, M)) C_row = np.transpose(C_col, axes=(0, 2, 1)) - + if A_row.dtype == bfloat16: # numpy's einsum doesn't support bfloat16 - out_row = np.einsum("bik,bkj->bij", A_row.astype(np.float32), B_row.astype(np.float32)) * alpha + C_row * beta + out_row = ( + np.einsum( + "bik,bkj->bij", + A_row.astype(np.float32), + B_row.astype(np.float32), + ) + * alpha + + C_row * beta + ) out_row = out_row.astype(C_row.dtype) else: out_row = np.einsum("bik,bkj->bij", A_row, B_row) * alpha + C_row * beta - if self.layout_C == cutlass.ColumnMajor: + if self.layout_C == cutlass_bindings.ColumnMajor: out = np.transpose(out_row, axes=(0, 2, 1)) else: out = out_row - + return out.ravel() elif isinstance(A, torch.Tensor): - if self.layout_A == cutlass.RowMajor: + if self.layout_A == cutlass_bindings.RowMajor: A_row = A.view((M, K)) else: A_col = A.view((K, M)) A_row = torch.permute(A_col, (1, 0)) - - if self.layout_B == cutlass.RowMajor: + + if self.layout_B == cutlass_bindings.RowMajor: B_row = B.view((K, N)) else: B_col = B.view((N, K)) B_row = torch.permute(B_col, (1, 0)) - if self.layout_C == cutlass.RowMajor: + if self.layout_C == cutlass_bindings.RowMajor: C_row = C.view((M, N)) else: C_col = C.view((N, M)) C_row = torch.permute(C_col, (1, 0)) - + out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta - if self.layout_C == cutlass.ColumnMajor: + if self.layout_C == cutlass_bindings.ColumnMajor: out = torch.permute(out_row, (1, 0)) else: out = out_row - - return torch.flatten(out) + return torch.flatten(out) ##################################################################################################### @@ -130,17 +152,31 @@ def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass ##################################################################################################### if torch_available: + import torch + class Conv2dReferenceModule: - def __init__(self, A: TensorDescription, B: TensorDescription, C: TensorDescription, kind: cutlass.conv.Operator.fprop) -> None: + def __init__( + self, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + kind: cutlass_bindings.conv.Operator.fprop, + ) -> None: self.layout_A = A.layout self.layout_B = B.layout self.layout_C = C.layout self.kind = kind - - def run(self, + + def run( + self, A: Union[np.ndarray, torch.Tensor], B: Union[np.ndarray, torch.Tensor], - C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0, bias=False) -> np.ndarray: + C: Union[np.ndarray, torch.Tensor], + problem_size, + alpha=1.0, + beta=0.0, + bias=False, + ) -> np.ndarray: """ Compute the reference result on CPU """ @@ -170,86 +206,112 @@ def run(self, if isinstance(A, np.ndarray): # the pytorch activation layout is NCHW # weight layout is Cout Cin Kh Kw (also NCHW) - if self.layout_A == cutlass.TensorNHWC: + if self.layout_A == cutlass_bindings.TensorNHWC: A_nhwc = np.reshape(A, newshape=(n, h, w, c)) A_torch_nhwc = torch.from_numpy(A_nhwc).to("cuda") A_torch_nchw = torch.permute(A_torch_nhwc, (0, 3, 1, 2)) - - if self.layout_B == cutlass.TensorNHWC: + + if self.layout_B == cutlass_bindings.TensorNHWC: B_nhwc = np.reshape(B, newshape=(k, r, s, c)) B_torch_nhwc = torch.from_numpy(B_nhwc).to("cuda") B_torch_nchw = torch.permute(B_torch_nhwc, (0, 3, 1, 2)) - - if self.layout_C == cutlass.TensorNHWC: + + if self.layout_C == cutlass_bindings.TensorNHWC: C_nhwc = np.reshape(C, newshape=(n, p, q, k)) C_torch_nhwc = torch.from_numpy(C_nhwc).to("cuda") C_torch_nchw = torch.permute(C_torch_nhwc, (0, 3, 1, 2)) - + elif isinstance(A, torch.Tensor): - if self.kind == cutlass.conv.Operator.wgrad: - if self.layout_A == cutlass.TensorNHWC: + if self.kind == cutlass_bindings.conv.Operator.wgrad: + if self.layout_A == cutlass_bindings.TensorNHWC: A_nhwc = A.view((n, p, q, k)) A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2)) - - if self.layout_B == cutlass.TensorNHWC: + + if self.layout_B == cutlass_bindings.TensorNHWC: B_nhwc = B.view((n, h, w, c)) B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) - - if self.layout_C == cutlass.TensorNHWC: + + if self.layout_C == cutlass_bindings.TensorNHWC: if bias: C_nhwc = C.view((1, 1, 1, c)) else: C_nhwc = C.view((k, r, s, c)) C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) - elif self.kind == cutlass.conv.Operator.dgrad: - if self.layout_A == cutlass.TensorNHWC: + elif self.kind == cutlass_bindings.conv.Operator.dgrad: + if self.layout_A == cutlass_bindings.TensorNHWC: A_nhwc = A.view((n, p, q, k)) A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2)) - - if self.layout_B == cutlass.TensorNHWC: + + if self.layout_B == cutlass_bindings.TensorNHWC: B_nhwc = B.view((k, r, s, c)) B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) - - if self.layout_C == cutlass.TensorNHWC: + + if self.layout_C == cutlass_bindings.TensorNHWC: if bias: C_nhwc = C.view((1, 1, 1, c)) else: C_nhwc = C.view((n, h, w, c)) C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) else: - if self.layout_A == cutlass.TensorNHWC: + if self.layout_A == cutlass_bindings.TensorNHWC: A_nhwc = A.view((n, h, w, c)) A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2)) - - if self.layout_B == cutlass.TensorNHWC: + + if self.layout_B == cutlass_bindings.TensorNHWC: B_nhwc = B.view((k, r, s, c)) B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) - - if self.layout_C == cutlass.TensorNHWC: + + if self.layout_C == cutlass_bindings.TensorNHWC: if bias: C_nhwc = C.view((1, 1, 1, k)) else: C_nhwc = C.view((n, p, q, k)) C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) - if self.kind == cutlass.conv.Operator.fprop: - D_torch_nchw = alpha * torch.nn.functional.conv2d( - A_torch_nchw, B_torch_nchw, stride=(stride_h, stride_w), - padding=(pad_h, pad_w), dilation=(dilation_h, dilation_w), groups=groups) + beta * C_torch_nchw - elif self.kind == cutlass.conv.Operator.dgrad: - D_torch_nchw = alpha * torch.nn.grad.conv2d_input( - (n, c, h, w), B_torch_nchw, A_torch_nchw, padding=(pad_h, pad_w), stride=(stride_h, stride_w) - ).to(torch.float32) + beta * C_torch_nchw - elif self.kind == cutlass.conv.Operator.wgrad: - D_torch_nchw = alpha * torch.nn.grad.conv2d_weight( - B_torch_nchw, (k, c, r, s), A_torch_nchw, padding=(pad_h, pad_w), stride=(stride_h, stride_w) - ).to(torch.float32) + beta * C_torch_nchw - - - if self.layout_C == cutlass.TensorNHWC: + if self.kind == cutlass_bindings.conv.Operator.fprop: + D_torch_nchw = ( + alpha + * torch.nn.functional.conv2d( + A_torch_nchw, + B_torch_nchw, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dilation_h, dilation_w), + groups=groups, + ) + + beta * C_torch_nchw + ) + elif self.kind == cutlass_bindings.conv.Operator.dgrad: + D_torch_nchw = ( + alpha + * torch.nn.grad.conv2d_input( + (n, c, h, w), + B_torch_nchw, + A_torch_nchw, + padding=(pad_h, pad_w), + stride=(stride_h, stride_w), + ).to(torch.float32) + + beta * C_torch_nchw + ) + elif self.kind == cutlass_bindings.conv.Operator.wgrad: + D_torch_nchw = ( + alpha + * torch.nn.grad.conv2d_weight( + B_torch_nchw, + (k, c, r, s), + A_torch_nchw, + padding=(pad_h, pad_w), + stride=(stride_h, stride_w), + ).to(torch.float32) + + beta * C_torch_nchw + ) + + if self.layout_C == cutlass_bindings.TensorNHWC: if isinstance(A, np.ndarray): - D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1)).detach().cpu().numpy() + D_torch_out = ( + torch.permute(D_torch_nchw, (0, 2, 3, 1)).detach().cpu().numpy() + ) elif isinstance(A, torch.Tensor): D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1)) - + return D_torch_out.flatten() diff --git a/python/cutlass/backend/utils/software.py b/python/cutlass/backend/utils/software.py new file mode 100644 index 00000000..86bcc411 --- /dev/null +++ b/python/cutlass/backend/utils/software.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import re +import sys + +from cutlass.backend.memory_manager import PoolMemoryManager + + +class CheckPackages: + def __init__(self) -> None: + pass + + def check_cupy(self): + if "cupy" in sys.modules: + return True + else: + try: + import cupy + + cupy_available = True + except ImportError: + print("cupy is not loaded.") + + def check_numpy(self): + if "numpy" in sys.modules: + return True + else: + try: + import numpy + + numpy_available = True + except ImportError: + print("numpy is not loaded.") + + def check_torch(self): + if "torch" in sys.modules: + return True + else: + try: + import torch + + torch_available = True + except ImportError: + print("torch is not loaded.") + + +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + + +# this._device_sm_count = None +def device_sm_count(): + # Query the number of SMs, if needed + # if this._device_sm_count is None: + from cuda import cuda + + _device = 0 + err, _device_sm_count = cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, _device + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception( + "Failed to retireve SM count. " + f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}" + ) + + return _device_sm_count + + +def get_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): + memory_pool = PoolMemoryManager( + init_pool_size=init_pool_size, max_pool_size=max_pool_size + ) + return memory_pool diff --git a/tools/library/scripts/pycutlass/src/cpp/compiler.h b/python/cutlass/cpp/compiler.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/compiler.h rename to python/cutlass/cpp/compiler.h diff --git a/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp b/python/cutlass/cpp/cutlass_bindings.cpp similarity index 98% rename from tools/library/scripts/pycutlass/src/cpp/cutlass.cpp rename to python/cutlass/cpp/cutlass_bindings.cpp index 9e471882..c5becc57 100644 --- a/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp +++ b/python/cutlass/cpp/cutlass_bindings.cpp @@ -62,10 +62,10 @@ namespace py = pybind11; -PYBIND11_MODULE(cutlass, m) { +PYBIND11_MODULE(cutlass_bindings, m) { // module doc - m.doc() = "cutlass C++ binding"; + m.doc() = "CUTLASS C++ binding"; // // Bind data type diff --git a/tools/library/scripts/pycutlass/src/cpp/include/arch.h b/python/cutlass/cpp/include/arch.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/arch.h rename to python/cutlass/cpp/include/arch.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h b/python/cutlass/cpp/include/conv/conv_problem_size.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h rename to python/cutlass/cpp/include/conv/conv_problem_size.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h b/python/cutlass/cpp/include/conv/convolution.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h rename to python/cutlass/cpp/include/conv/convolution.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h b/python/cutlass/cpp/include/conv/host.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/conv/host.h rename to python/cutlass/cpp/include/conv/host.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_generic.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_generic.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_generic.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/binary_ops.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_accumulator.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_binary.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_broadcast.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_column_reduction.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_linear_combination.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_broadcast.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_row_reduction.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_input.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_tensor_output.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h b/python/cutlass/cpp/include/epilogue/epilogue_visitor_with_layernorm.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_with_layernorm.h rename to python/cutlass/cpp/include/epilogue/epilogue_visitor_with_layernorm.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h b/python/cutlass/cpp/include/gemm/gemm.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h rename to python/cutlass/cpp/include/gemm/gemm.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h b/python/cutlass/cpp/include/gemm/gemm_universal_with_visitor.h similarity index 99% rename from tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h rename to python/cutlass/cpp/include/gemm/gemm_universal_with_visitor.h index 64b65a03..73c5f962 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h +++ b/python/cutlass/cpp/include/gemm/gemm_universal_with_visitor.h @@ -430,6 +430,16 @@ struct GemmUniversalwithEpilogueVisitor { return can_implement(args.problem_size); } + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmUniversalwithEpilogueVisitor op; + op(params, shared_storage); + } + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h b/python/cutlass/cpp/include/gemm/host.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h rename to python/cutlass/cpp/include/gemm/host.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h b/python/cutlass/cpp/include/layout/layout.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h rename to python/cutlass/cpp/include/layout/layout.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h b/python/cutlass/cpp/include/layout/matrix.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h rename to python/cutlass/cpp/include/layout/matrix.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h b/python/cutlass/cpp/include/layout/tensor.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h rename to python/cutlass/cpp/include/layout/tensor.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h b/python/cutlass/cpp/include/swizzling.h similarity index 93% rename from tools/library/scripts/pycutlass/src/cpp/include/swizzling.h rename to python/cutlass/cpp/include/swizzling.h index 970cd6d3..27994ded 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h +++ b/python/cutlass/cpp/include/swizzling.h @@ -61,7 +61,7 @@ void bind_identity_swizzle(py::module & m, std::string name) { &T::get_tiled_shape, py::const_ ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), R"pbdoc(Returns the shape of the problem in units of logical tiles - + :param problem_size: gemm(M, N, K) :type problem_size: :class:`cutlass.gemm.GemmCoord` )pbdoc") @@ -70,7 +70,7 @@ void bind_identity_swizzle(py::module & m, std::string name) { &T::get_tiled_shape, py::const_ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), R"pbdoc(Returns the shape of the problem in units of logical tiles - + :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) :type problem_size: :class:`cutlass.gemm.GemmCoord`) )pbdoc") @@ -79,7 +79,7 @@ void bind_identity_swizzle(py::module & m, std::string name) { &T::get_tiled_shape, py::const_ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), R"pbdoc(Returns the shape of the problem in units of logical tiles - + :param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC) :type problem_size: :class:`cutlass.gemm.GemmCoord`) )pbdoc") @@ -100,18 +100,27 @@ void bind_swizzle(py::module & m, std::string name, std::string doc) { &T::get_tiled_shape, py::const_ ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), R"pbdoc(Returns the shape of the problem in units of logical tiles - + :param problem_size: gemm(M, N, K) :type problem_size: :class:`cutlass.gemm.GemmCoord` )pbdoc") .def("get_grid_shape", &T::get_grid_shape, - py::arg("tiled_shape"), + py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") .def("tag", [](const T & swizzle){ return demangle(typeid(T).name()); }, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc"); } +template +void bind_swizzle_streamk(py::module & m, std::string name, std::string doc) { + py::class_(m, name.c_str(), doc.c_str()) + .def(py::init<>()) + .def("tag", [](const T & swizzle){ + return demangle(typeid(T).name()); + }, R"pbdoc(Returns the c++ name of the swizzling for code emission)pbdoc"); +} + template void bind_dgrad_swizzle(py::module & m, std::string name) { py::class_(m, name.c_str(), @@ -122,13 +131,13 @@ void bind_dgrad_swizzle(py::module & m, std::string name) { &T::get_tiled_shape, py::const_ ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), R"pbdoc(Returns the shape of the problem in units of logical tiles - + :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) :type problem_size: :class:`cutlass.gemm.GemmCoord`) )pbdoc") .def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) { return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); - }, py::arg("tiled_shape"), + }, py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") .def("tag", [](const T & swizzle){ return demangle(typeid(T).name()); @@ -153,6 +162,8 @@ void bind_threadblock_swizzle(py::module &m) { bind_swizzle(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc"); bind_swizzle(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc"); + bind_swizzle_streamk(m, "ThreadblockSwizzleStreamK", R"pbdoc(Threadblock swizzling function using Stream K feature)pbdoc"); + bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle1"); bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle4"); bind_dgrad_swizzle(m, "StridedDgradHorizontalSwizzle"); diff --git a/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h b/python/cutlass/cpp/include/tensor_coord.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h rename to python/cutlass/cpp/include/tensor_coord.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h b/python/cutlass/cpp/include/tensor_ref_view.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h rename to python/cutlass/cpp/include/tensor_ref_view.h diff --git a/tools/library/scripts/pycutlass/src/cpp/include/types.h b/python/cutlass/cpp/include/types.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/include/types.h rename to python/cutlass/cpp/include/types.h diff --git a/tools/library/scripts/pycutlass/src/cpp/library.h b/python/cutlass/cpp/library.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/library.h rename to python/cutlass/cpp/library.h diff --git a/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h b/python/cutlass/cpp/test/conv/conv_problems.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h rename to python/cutlass/cpp/test/conv/conv_problems.h diff --git a/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h b/python/cutlass/cpp/test/conv/convolution.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h rename to python/cutlass/cpp/test/conv/convolution.h diff --git a/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h b/python/cutlass/cpp/test/conv/host.h similarity index 99% rename from tools/library/scripts/pycutlass/src/cpp/test/conv/host.h rename to python/cutlass/cpp/test/conv/host.h index ca15ce6d..142c4688 100644 --- a/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h +++ b/python/cutlass/cpp/test/conv/host.h @@ -56,7 +56,7 @@ template>); + Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>); m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); } diff --git a/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h b/python/cutlass/cpp/test/gemm/gemm.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h rename to python/cutlass/cpp/test/gemm/gemm.h diff --git a/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h b/python/cutlass/cpp/test/gemm/host.h similarity index 100% rename from tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h rename to python/cutlass/cpp/test/gemm/host.h diff --git a/tools/library/scripts/pycutlass/test/frontend/run_test.sh b/python/cutlass/emit/__init__.py similarity index 94% rename from tools/library/scripts/pycutlass/test/frontend/run_test.sh rename to python/cutlass/emit/__init__.py index 072f60b5..52200ca7 100644 --- a/tools/library/scripts/pycutlass/test/frontend/run_test.sh +++ b/python/cutlass/emit/__init__.py @@ -1,6 +1,6 @@ ################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -30,4 +30,4 @@ # ################################################################################################# -CUPY_CACHE_DIR=./ python test_frontend.py +from cutlass.emit.pytorch import pytorch diff --git a/python/cutlass/emit/common.py b/python/cutlass/emit/common.py new file mode 100644 index 00000000..4d1dd4cd --- /dev/null +++ b/python/cutlass/emit/common.py @@ -0,0 +1,182 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Common utilities for emitting CUTLASS kernels +""" + +import cutlass + +# Strings used for printing information about the generation of emitted scripts +_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass.__version__} Python interface (https://github.com/nvidia/cutlass/python)" + + +_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR} +""" + + +_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR} +""" + +_CUTLASS_KERNEL_ARGS_2x = """ + typename DeviceKernel::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, // problem size + 1, + {alpha, beta}, + A, B, C, D, + 0, 0, 0, 0, // batch strides + DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda + DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc + DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd + }; +""" + +_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """ + typename DeviceKernel::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, // problem size + 1, + {alpha, beta}, + A, B, C, D, + 0, 0, 0, 0, // batch strides + DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda + DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd + -1 // avail_sms + }; +""" + +_CUTLASS_KERNEL_RUN_GEMM_2x = """ +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +cutlass::Status ${name}_kernel_run(int M, int N, int K, + const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta) { + ${args} + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.initialize(arguments, + workspace.get(), + nullptr); // CUDA stream + + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op(); + return status; +} +""" + +_CUTLASS_KERNEL_RUN_GEMM_3x = """ +using StrideA = typename DeviceKernel::GemmKernel::StrideA; +using StrideB = typename DeviceKernel::GemmKernel::StrideB; +using StrideC = typename DeviceKernel::GemmKernel::StrideC; +using StrideD = typename DeviceKernel::GemmKernel::StrideD; + +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +cutlass::Status ${name}_kernel_run( + int M, int N, int K, int L, + const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) { + + typename DeviceKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, L}, // problem size + A, // ptrA + make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A + B, // ptrB + make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B + { + C, // ptrC + make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C + D, // ptrD + make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D + {alpha, beta}, + }, + hw_info + }; + + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.run(arguments, + workspace.get(), + nullptr); // CUDA stream + + return status; +} +""" + + +_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """ +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +int threadblock_count = DeviceKernel::sufficient(); + +cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes, + DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D, + int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd, + ElementCompute alpha, ElementCompute beta) { + + typename DeviceKernel::Arguments arguments { + problem_sizes, + problem_count, + threadblock_count, + {alpha, beta}, + A, B, C, D, + lda, ldb, ldc, ldd + }; + + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.initialize(arguments, + workspace.get(), + nullptr); // CUDA stream + + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op(); + return status; +} +""" diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py new file mode 100644 index 00000000..61cc5d94 --- /dev/null +++ b/python/cutlass/emit/pytorch.py @@ -0,0 +1,639 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel. +If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method. + +Example usage with JIT compilation: + +.. highlight:: python +.. code-block:: python + + plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) + op = plan.construct() + mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) + + # Generate inputs for the GEMM + A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] + + # Run the module + D = mod.run(A, B, C) + + +Example usage without JIT compilation: + +.. highlight:: python +.. code-block:: python + + plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) + op = plan.construct() + cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output') + +After this call, the directory ``output`` contains ``setup.py``, +``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from +within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``. + +The module can later be used in Python via: + +.. highlight:: python +.. code-block:: python + + import torch + import cutlass_gemm + + # Generate inputs for the GEMM + A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] + + # Run the module + D = cutlass_gemm.run(A, B, C) +""" + +import logging +import os + +import cutlass_bindings + +from cutlass import CUTLASS_PATH, logger, swizzle +from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal +from cutlass.backend.library import ApiVersion +from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate +from cutlass.emit import common + +torch_available = CheckPackages().check_torch() +if torch_available: + import torch + + +_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" + +${includes} +${declaration} +${impl} +""" + +_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f); + +// C++ interface +at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f) { + return ${name}_kernel(A, B, C, alpha, beta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", py::overload_cast, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); +} +""" + +_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f); + +// C++ interface +std::vector ${name}(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f) { + return ${name}_kernel(A, B, C, alpha, beta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", py::overload_cast&, const std::vector&, at::optional>, float, float>(&${name}), + py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); +} +""" + +_PYTORCH_GEMM_INCLUDES = { + ApiVersion.v2x: """ +#include "cutlass/gemm/device/gemm_universal.h" +""", + ApiVersion.v3x: """ +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/util/packed_stride.hpp" +""", +} + +_PYTORCH_GROUPED_GEMM_INCLUDES = """ +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +""" + +_CUTLASS_TYPE_TO_TORCH_TYPE = { + cutlass_bindings.float16: "torch::kF16", + cutlass_bindings.float32: "torch::kF32", + cutlass_bindings.float64: "torch::kF64", + cutlass_bindings.int8: "torch::I8", + cutlass_bindings.int32: "torch::I32", +} + +_PYTORCH_GEMM_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_GEMM_2x + + """ +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { + int M = A.size(0); + int N = B.size(1); + int K = A.size(1); + + typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->contiguous().data_ptr()); + at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); + + cutlass::Status status = ${name}_kernel_run(M, N, K, + reinterpret_cast(A.contiguous().data_ptr()), + reinterpret_cast(B.contiguous().data_ptr()), + ptrC, + reinterpret_cast(D.contiguous().data_ptr()), + ElementCompute(alpha), ElementCompute(beta)); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + +_PYTORCH_GEMM_IMPL_TEMPLATE_3x = ( + common._CUTLASS_KERNEL_RUN_GEMM_3x + + """ +bool hw_info_queried = false; +cutlass::KernelHardwareInfo hw_info; + +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { + int M = A.size(0); + int N = B.size(1); + int K = A.size(1); + int L = 1; + + // Query hardware info if we haven't already + if (!hw_info_queried) { + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->contiguous().data_ptr()); + at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); + + cutlass::Status status = ${name}_kernel_run(M, N, K, L, + reinterpret_cast(A.contiguous().data_ptr()), + reinterpret_cast(B.contiguous().data_ptr()), + ptrC, + reinterpret_cast(D.contiguous().data_ptr()), + ElementCompute(alpha), ElementCompute(beta), + hw_info); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + + +_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = ( + common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x + + """ +std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C, float alpha, float beta) { + size_t num = A.size(); + + // To avoid performing many small cudaMallocs and host-to-device copies, + // we serialize the grouped GEMM arguments on the host, allocate one + // large chunk of device memory, and perform a single cudaMemcpy to + // copy the host data to the device. Allocation overheads could be + // avoided by using a memory pool. + + // Calculate the total size of the data to be copied from host to device + size_t total_size = sizeof(cutlass::gemm::GemmCoord) + + sizeof(DeviceKernel::ElementA*) + + sizeof(DeviceKernel::ElementB*) + + sizeof(DeviceKernel::ElementC*) + + sizeof(DeviceKernel::ElementC*) + + sizeof(int64_t) + + sizeof(int64_t) + + sizeof(int64_t); + total_size *= num; + + // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple + // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system). + // To ensure that we don't end up having misaligned loads in the kernel, + // we pad to the nearest multiple of 8. + // + // Note that, even on a 32-bit system (for which sizeof(X*) will not equal + // sizeof(int64_t)), only padding between the list of GemmCoords and the + // list of ptr_As is sufficient because the set of four equal-length lists of pointers + // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always + // start on a multiple of 8. + int64_t padding = 8 - (total_size % 8); + total_size += padding; + + uint8_t* host_data = new uint8_t[total_size]; + cutlass::DeviceAllocation device_data(total_size); + + uint8_t* start = host_data; + cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast(start); + + // Apply the padding after the list of GemmCoords + start += num * sizeof(cutlass::gemm::GemmCoord) + padding; + + int64_t ptr_A_offset = start - host_data; + DeviceKernel::ElementA** ptr_A_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementA*); + + int64_t ptr_B_offset = start - host_data; + DeviceKernel::ElementB** ptr_B_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementB*); + + int64_t ptr_C_offset = start - host_data; + DeviceKernel::ElementC** ptr_C_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementC*); + + int64_t ptr_D_offset = start - host_data; + DeviceKernel::ElementC** ptr_D_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementC*); + + int64_t lda_offset = start - host_data; + int64_t* lda_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + int64_t ldb_offset = start - host_data; + int64_t* ldb_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + int64_t ldc_offset = start - host_data; + int64_t* ldc_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + std::vector D(num); + + bool need_C = (C != at::nullopt) && (beta != 0.f); + for (size_t i = 0; i < num; ++i) { + int M = A[i].size(0); + int N = B[i].size(1); + int K = A[i].size(1); + *(problem_sizes_host + i) = {M, N, K}; + *(ptr_A_host + i) = reinterpret_cast(A[i].contiguous().data_ptr()); + *(ptr_B_host + i) = reinterpret_cast(B[i].contiguous().data_ptr()); + + if (need_C) { + *(ptr_C_host + i) = reinterpret_cast(C->at(i).contiguous().data_ptr()); + } + else { + *(ptr_C_host + i) = nullptr; + } + + D[i] = B[i].new_empty({M, N}, ${torch_type_C}); + *(ptr_D_host + i) = reinterpret_cast(D[i].contiguous().data_ptr()); + + *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0); + *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0); + *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0); + } + + device_data.copy_from_host(host_data); + + cutlass::Status status = ${name}_kernel_run( + num, + reinterpret_cast(device_data.get()), + reinterpret_cast(device_data.get() + ptr_A_offset), + reinterpret_cast(device_data.get() + ptr_B_offset), + reinterpret_cast(device_data.get() + ptr_C_offset), + reinterpret_cast(device_data.get() + ptr_D_offset), + reinterpret_cast(device_data.get() + lda_offset), + reinterpret_cast(device_data.get() + ldb_offset), + reinterpret_cast(device_data.get() + ldc_offset), + reinterpret_cast(device_data.get() + ldc_offset), + ElementCompute(alpha), ElementCompute(beta)); + + delete[] host_data; + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + + +_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='${name}', + ext_modules=[ + CUDAExtension('${name}', [ + '${name}.cpp', + '${name}_kernel.cu', + ], + include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'], + extra_compile_args=['-std=c++17'] + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) + +""" + + +def _generate_setup(name: str, sourcedir: str): + """ + Generates a setup.py file for the extension + + :param name: name of the module to generate + :type name: str + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + """ + setup_py_file = os.path.join(sourcedir, "setup.py") + setup_source = SubstituteTemplate( + _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH} + ) + with open(setup_py_file, "w") as outfile: + outfile.write(setup_source) + + +class _ArchListSetter: + """ + Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST`` + environment variable when building a PyTorch CUDA module. + + ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch + CUDA module should be compiled. + + For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of + ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the + compilation of the module. + + This utility wraps the building of a PyTorch CUDA module with a setting of this environment + variable according to the current compute capability being targetted. + + Example usage: + + .. highlight:: python + .. code-block:: python + + # Temporarily set TORCH_CUDA_ARCH_LIST="8.0" + with _ArchListSetter(80): + # Perform JIT compilation and loading of the module + mod = torch.utils.cpp_extension.load(...) + + :param cc: compute capability + :type cc: int + """ + + _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST" + + def __init__(self, cc: int): + self.cc_str = ".".join(list(str(cc))) + + def __enter__(self): + """ + Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc`` + """ + self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST) + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str + + return self + + def __exit__(self, exc_type, exc_val, traceback): + """ + Restores the old value of TORCH_CUDA_ARCH_LIST + """ + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list + + +def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): + """ + JIT compiles and loads a PyTorch CUDA extension. + + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param cpp_file: path to file containing extension's C++ interface + :type cpp_file: str + :param cuda_file: path to file containing extension's CUDA interface + :type cuda_file: str + + :return: loaded PyTorch module + """ + + from torch.utils.cpp_extension import load + + extra_cuda_cflags = ["-std=c++17"] + if cc == 90: + # PyTorch does not currently add the sm_90a target when compute capability + # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target. + extra_cuda_cflags.append("-gencode=arch=compute_90a,code=sm_90a") + + with _ArchListSetter(cc): + jitmodule = load( + name, + [cpp_file, cuda_file], + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=[ + os.path.join(CUTLASS_PATH, "include"), + os.path.join(CUTLASS_PATH, "tools/util/include"), + ], + verbose=(logger.level == logging.DEBUG) + ) + return jitmodule + + +def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + extra_kw = {} + if op.api == ApiVersion.v3x: + impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x + else: + impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x + if isinstance(op.swizzling_functor, swizzle.ThreadblockSwizzleStreamK): + extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K + else: + extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x + impl_template = ( + _PYTORCH_GEMM_IMPL_TEMPLATE_3x + if op.api == ApiVersion.v3x + else _PYTORCH_GEMM_IMPL_TEMPLATE_2x + ) + cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_GEMM_INCLUDES[op.api], + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + _PYTORCH_GEMM_CPP_TEMPLATE, + {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def _pytorch_grouped_gemm( + op, name: str, cc: int, jit: bool = False, sourcedir: str = "" +): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if op.api != ApiVersion.v2x: + raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x") + + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_GROUPED_GEMM_INCLUDES, + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE, + {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + The result of this method is files within ``sourcedir`` that can be used for building + a PyTorch module. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module (if ``jit=True``) or None + """ + device_op = op.device_op() + if isinstance(op, GemmOperationUniversal): + return _pytorch_gemm(device_op, name, cc, jit, sourcedir) + elif isinstance(op, GemmOperationGrouped): + return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir) + else: + raise Exception( + f"Operation type {type(op)} is not currently supported for PyTorch emission." + ) diff --git a/python/cutlass/epilogue.py b/python/cutlass/epilogue.py new file mode 100644 index 00000000..6355a071 --- /dev/null +++ b/python/cutlass/epilogue.py @@ -0,0 +1,107 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Registry of elementwise epilogues + +Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via +code like the following for GEMM: + +.. highlight:: python +.. code-block:: python + + plan = cutlass.op.Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + plan.activation = cutlass.epilogue.relu +""" + +from cutlass.backend import epilogue + +gelu = epilogue.gelu +hardswish = epilogue.hardswish +identity = epilogue.identity +leaky_relu = epilogue.leaky_relu +relu = epilogue.relu +sigmoid = epilogue.sigmoid +silu = epilogue.silu +tanh = epilogue.tanh + + +_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh] + + +def get_activations() -> list: + """ + Returns a list of available activation functions + + :return: list of available activation functions + :rtype: list + """ + return _activations + + +def get_activation_epilogue( + activation, + element_output, + elements_per_access, + element_accumulator, + element_compute, +): + """ + Return an epilogue corresponding to the activation function, data types, and alignment + used in the kernel + + :param activation: elementwise activation function to use + :param element_output: data type of the output + :param elements_per_access: alignment of operand C of the kernel + :type elements_per_access: int + :param element_accumulator: data type of the accumulated output C + :param element_compute: data type in which compute operations should be performed + + :return: epilogue functor + """ + if activation not in _activations: + raise Exception( + f"Unsupported activation type {activation}. Available activations are: {_activations}" + ) + + if activation == identity: + return epilogue.LinearCombination( + element_output, elements_per_access, element_accumulator, element_compute + ) + else: + return epilogue.LinearCombinationGeneric( + activation(element_compute), + element_output, + elements_per_access, + element_accumulator, + element_compute, + ) diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py new file mode 100644 index 00000000..f72ca394 --- /dev/null +++ b/python/cutlass/library_defaults.py @@ -0,0 +1,445 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Classes containing valid operations for a given compute capability and data types. +""" + +import logging +from cuda import __version__ + +# Strip any additional information from the CUDA version +_cuda_version = __version__.split("rc")[0] + +# Imports from CUTLASS profiler generator and manifest scripts +import generator as prof_generator +import manifest as prof_manifest + +import cutlass +from cutlass.utils.check import valid_stage_count +from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op, has_binding_type + + +_generator_ccs = [50, 60, 61, 70, 75, 80, 90] + + +class KernelsForDataType: + """ + Container class for keeping track of kernels that correspond to a particular combination + of data types for operands A, B, and accumulator + """ + + def __init__(self, datatype_comb: tuple, layout_comb: tuple): + self.datatype_comb = datatype_comb + self.layout_comb = layout_comb + + # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment + # constraint for the data type combination + self.kernels_by_alignment = {} + + def add(self, operation): + """ + Add an operation to the list of supported kernels + """ + alignment = operation.A.alignment + if alignment not in self.kernels_by_alignment: + self.kernels_by_alignment[alignment] = [] + self.kernels_by_alignment[alignment].append(operation) + + @property + def alignments(self): + """ + Returns an unsorted list of alignments supported by this data type combination + + :return: unsorted list of alignments supported by this data type combination + :rtype: list + """ + return list(self.kernels_by_alignment.keys()) + + @property + def all_operations(self): + """ + Returns a list of all operations supported by this data type combination + + :return: list of all operations supported by this data type combination + :rtype: list + """ + ops = [] + for _, alignment_ops in self.kernels_by_alignment.items(): + ops.extend(alignment_ops) + return ops + + def operations(self, alignment: int): + """ + Returns operations satisfying the alignment constraint indicated by `alignment` + + :param alignment: alignment constraint of operations to return + :type alignment: int + + :return: list of operations + :rtype: list + """ + if alignment not in self.kernels_by_alignment: + raise Exception( + f"No operations of alignment {alignment} found for data type and layout " + f"combination {self.datatype_comb} {self.layout_comb}" + ) + return self.kernels_by_alignment[alignment] + + def find_alignment(self, shape: tuple, layout: cutlass.LayoutType) -> int: + """ + Returns the most preferable alignment for a given shape and layout + + :param shape: extent of each dimension of the tensor + :type shape: tuple + :param layout: layout of the tensor + :type layout: cutlass.LayoutType + + :return: maximum alignment supported by the data type combination and tensor size + :rtype: int + """ + # Determine the leading dimension of the shape + if layout == cutlass.LayoutType.RowMajor: + ld = shape[0] + elif layout == cutlass.LayoutType.RowMajor: + ld = shape[1] + else: + raise Exception(f"Unexpected or unsupported layout {layout}") + + for alignment in sorted(list(self.kernels_by_alignment.keys()), reverse=True): + if ld % alignment == 0: + return alignment + + # Default to alignment of 1 if no others match + return 1 + + def sort(self): + """ + Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape + """ + key = lambda op: ( + op.tile_description.threadblock_shape[0] + * op.tile_description.threadblock_shape[1] + * op.tile_description.threadblock_shape[2] + ) + for alignment in self.kernels_by_alignment.keys(): + self.kernels_by_alignment[alignment].sort(key=key, reverse=True) + + +class ArchOptions: + """ + Structure for keeping track of kernels available on a given compute capability + + :param target_cc: compute capability of the device on which kernels will be run + :type target_cc: int + :param kernel_cc: compute capability of the kernels to generate + :type kernel_cc: int + :param operation_kind: type of operation to register + :type operation_kind: cutlass.OperationKind + :param gemm_kinds: types of GEMM operations that can be included + :type gemm_kinds: list + :param allowed_math_operations: types of primitive math operations allowed + :type allowed_math_operations: list + """ + + def __init__( + self, + target_cc: int, + kernel_cc: int, + operation_kind: cutlass.OperationKind, + gemm_kinds: list, + allowed_math_operations: list = [ + cutlass.MathOperation.multiply_add, + cutlass.MathOperation.multiply_add_saturate, + ] + ): + self.cc = kernel_cc + + # Dictionary with following structure: + # Key: OpcodeClass + # Value: Dictionary with the following structure: + # Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType), + # representing ((element_a, element_b, element_accumulator), (layout_a, layout_b)) + # Value: KernelsForDataType + self.operations_by_opclass = {} + self.op_class = None + self.allowed_math_operations = allowed_math_operations + + # Identify the method within CUTLASS generator script that generates kernel + # descriptions for the target CC + generate_function_name = "GenerateSM" + str(kernel_cc) + if not hasattr(prof_generator, generate_function_name): + cutlass.logger.warning(f"No generator found for architecture {kernel_cc}") + return + generate_function = getattr(prof_generator, generate_function_name) + + # Initialize a default manifest and populate it with valid kernel descriptions + # for the target CC + args = [ + "--kernels=all", + f"--log-level={logging.getLevelName(cutlass.logger.level)}" + ] + manifest_args = prof_generator.define_parser().parse_args(args) + manifest = prof_manifest.Manifest(manifest_args) + generate_function(manifest, _cuda_version) + + if operation_kind not in manifest.operations: + # No kernels generated for this architecture, this could be because the CUDA + # toolkit is insufficient to support operations in this CC + cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}") + return + + # Iterate through the available operations for this operation kind and + # find available opclasses and data types + for name, op_list in manifest.operations[operation_kind].items(): + for op in op_list: + if op.gemm_kind not in gemm_kinds: + continue + + mi = op.tile_description.math_instruction + if mi.math_operation not in self.allowed_math_operations: + continue + + datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) + + # Skip any data types that do not currently have conversions via cutlass_bindings + if False in [has_binding_type(elt) for elt in datatype_comb]: + continue + + # Prune operations that don't fit in shared memory + td = td_from_profiler_op(op) + if not valid_stage_count(target_cc, td)[0]: + continue + + if mi.opcode_class not in self.operations_by_opclass: + self.operations_by_opclass[mi.opcode_class] = {} + + datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) + layout_comb = (op.A.layout, op.B.layout) + + # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations + if datatype_comb == (cutlass.DataType.tf32, cutlass.DataType.tf32, cutlass.DataType.f32): + # TF32 kernels only supported on SM80 and beyond + if self.cc < 80: + continue + elif self.cc == 90: + if (op.A.element != cutlass.DataType.f32 + or op.B.element != cutlass.DataType.f32 + or op.C.element != cutlass.DataType.f32): + continue + + datatype_comb = (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32) + + opclass_dict = self.operations_by_opclass[mi.opcode_class] + key = (datatype_comb, layout_comb) + if key not in opclass_dict: + opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb) + opclass_dict[key].add(op) + + # Set the default opclass to TensorOp, if available. Otherwise default to SIMT + if cutlass.OpcodeClass.TensorOp in self.operations_by_opclass: + self.op_class = cutlass.OpcodeClass.TensorOp + else: + self.op_class = cutlass.OpcodeClass.Simt + + # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels. + # Here, we generate additional versions via a generic TileDescription. + if cutlass.OpcodeClass.Simt not in self.operations_by_opclass: + self.operations_by_opclass[cutlass.OpcodeClass.Simt] = {} + + types = [ + (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8), + (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32), + (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16), + (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32), + (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32), + (cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64), + ] + + layouts = [ + (cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor), + (cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor), + (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor), + (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor), + ] + alignment = 1 + epilogue_functor = cutlass.EpilogueFunctor.LinearCombination + swizzling_functor = cutlass.SwizzlingFunctor.Identity8 + for type_comb in types: + for layout_comb in layouts: + comb = (type_comb, layout_comb) + if comb in self.operations_by_opclass[cutlass.OpcodeClass.Simt]: + continue + + A = cutlass.TensorDescription(type_comb[0], layout_comb[0], alignment) + B = cutlass.TensorDescription(type_comb[1], layout_comb[1], alignment) + C = cutlass.TensorDescription(type_comb[2], cutlass.LayoutType.ColumnMajor, alignment) + math_inst = cutlass.MathInstruction( + [1, 1, 1], + type_comb[0], + type_comb[1], + type_comb[2], + cutlass.OpcodeClass.Simt, + cutlass.MathOperation.multiply_add + ) + + td = cutlass.TileDescription( + [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024) + + # Prune operations that don't fit in shared memory + if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]: + continue + + new_operation = prof_manifest.GemmOperation( + cutlass.GemmKind.Universal, td.minimum_compute_capability, + td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) + + new_kernels = KernelsForDataType(type_comb, layout_comb) + new_kernels.add(new_operation) + self.operations_by_opclass[cutlass.OpcodeClass.Simt][comb] = new_kernels + + # Sort all operations + for oc in self.operations_by_opclass.keys(): + for comb in self.operations_by_opclass[oc].keys(): + self.operations_by_opclass[oc][comb].sort() + + def opclass_supports_combination( + self, op_class: cutlass.OpcodeClass, datatype_comb: tuple, layout_comb: tuple + ) -> bool: + """ + Returns whether the provided operation class supports the provided data type and layout combination + + :param op_class: operation class to consider + :type op_class: cutlass.OpcodeClass + :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator) + :type datatype_comb: tuple[cutlass.DataType] + :param layout_comb: tuple of data types for (layout_A, layout_B) + :type layout_comb: tuple[cutlass.LayoutType] + + :return: set of operation classes that support the provided data type and layout combination + :rtype: set + """ + if op_class not in self.operations_by_opclass: + raise Exception(f"Unexpected or unsupported operation class {op_class}") + + return (datatype_comb, layout_comb) in self.operations_by_opclass[op_class] + + def supporting_opclasses( + self, + element_a: cutlass.DataType, + element_b: cutlass.DataType, + element_accumulator: cutlass.DataType, + layout_a: cutlass.LayoutType, + layout_b: cutlass.LayoutType, + ) -> set: + """ + Returns a set of operation classes that support the provided data type combination + + :param element_a: data type of operand A + :type element_a: cutlass.DataType + :param element_b: data type of operand B + :type element_b: cutlass.DataType + :param element_accumulator: data type of accumulator + :type element_accumulator: cutlass.DataType + :param layout_a: layout of operand A + :type layout_a: cutlass.LayoutType + :param layout_b: layout of operand B + :type layout_b: cutlass.LayoutType + + :return: set of operation classes that support the provided data type combination + :rtype: set + """ + supporting_op_classes = set() + datatype_comb = (element_a, element_b, element_accumulator) + layout_comb = (layout_a, layout_b) + + for op_class in self.operations_by_opclass.keys(): + if self.opclass_supports_combination(op_class, datatype_comb, layout_comb): + supporting_op_classes.add(op_class) + return supporting_op_classes + + def operations( + self, + op_class: cutlass.OpcodeClass, + element_a: cutlass.DataType, + element_b: cutlass.DataType, + element_accumulator: cutlass.DataType, + layout_a: cutlass.LayoutType, + layout_b: cutlass.LayoutType, + ) -> KernelsForDataType: + """ + Returns whether the provided operation class supports the provided data type combination + + :param op_class: operation class to consider + :type op_class: cutlass.OpcodeClass + :param element_a: data type of operand A + :type element_a: cutlass.DataType + :param element_b: data type of operand B + :type element_b: cutlass.DataType + :param element_accumulator: data type of accumulator + :type element_accumulator: cutlass.DataType + :param layout_a: layout of operand A + :type layout_a: cutlass.LayoutType + :param layout_b: layout of operand B + :type layout_b: cutlass.LayoutType + + :return: container of kernels by alignment supported by the provided combination of parameters + :rtype: KernelsForDataType + """ + datatype_comb = (element_a, element_b, element_accumulator) + layout_comb = (layout_a, layout_b) + if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb): + raise Exception( + f"Data type layout combination {datatype_comb}, {layout_comb} " + f"is not supported by opcode class {op_class} on CC {self.cc}." + ) + return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)] + + +class OptionRegistry: + """ + Container of all architecture-specific options + + :param target_cc: compute capability of the device on which operations will be run + :type target_cc: int + """ + + def __init__(self, target_cc: int): + self.registry = {} + + gemm_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x] + # Construct options for each CC + for kernel_cc in _generator_ccs: + self.registry[kernel_cc] = ArchOptions(target_cc, kernel_cc, cutlass.OperationKind.Gemm, gemm_kinds) + + def options_for_cc(self, cc: int) -> ArchOptions: + return self.registry.get(cc, None) diff --git a/tools/library/scripts/pycutlass/build_doc.sh b/python/cutlass/op/__init__.py similarity index 90% rename from tools/library/scripts/pycutlass/build_doc.sh rename to python/cutlass/op/__init__.py index 3fad0808..59b02a36 100644 --- a/tools/library/scripts/pycutlass/build_doc.sh +++ b/python/cutlass/op/__init__.py @@ -1,6 +1,6 @@ ################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -30,7 +30,6 @@ # ################################################################################################# -pip install enum-tools -pip install sphinx-toolbox -pip install m2r2 -sphinx-build -b html docs/source/ docs/build/html +from cutlass.op.gemm import Gemm +from cutlass.op.gemm_grouped import GroupedGemm +from cutlass.op.op import OperationBase diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py new file mode 100644 index 00000000..e33843ae --- /dev/null +++ b/python/cutlass/op/gemm.py @@ -0,0 +1,696 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run + GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass.op.Gemm(A, B, C, D) + plan.run() + + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32, + # layout=cutlass.LayoutType.RowMajor) + plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) + + A0 = torch.rand((128, 256), device='cuda') + B0 = torch.rand((256, 64), device='cuda') + C0 = torch.zeros((128, 64), device='cuda') + D0 = torch.zeros((128, 64), device.'cuda') + plan.run(A0, B0, C0, D0) + + A = torch.rand((32, 128), device='cuda') + B = torch.rand((128, 256), device='cuda') + C = torch.zeros((32, 256), device='cuda') + D = torch.zeros((32, 256), device.'cuda') + plan.run(A1, B1, C1, D1) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor) + plan.compile() + + # Do other work... + + plan.run(A0, B0, C0, D0) + + # Do other work... + + plan.run(A1, B1, C1, D1) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor) + plan.activation = cutlass.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor) + args = plan.run() + + # Do other work... + + args.sync() +""" + +import cutlass_bindings + +import cutlass +from cutlass import epilogue, swizzle +from cutlass.backend import compiler +from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass.backend.library import TensorDescription, TileDescription +from cutlass.op.op import OperationBase +from cutlass.utils import check, datatypes + + +class Gemm(OperationBase): + """ + Constructs a ``Gemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime -- + these are not to be changed after a ``Gemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation. All operands are row major. + + # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts + # for operands to the same values. + Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + + # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. + Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32, + element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type and layout as those passed in here). + # A, B, C, and D are row-major torch.Tensor objects of type torch.float32 + Gemm(A=A, B=B, C=C, D=D) + + # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is + # the same as that for D, at present) + Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor, + layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor) + + # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types + # and layouts will inherit those passed in via the generic ``element`` and ``layout`` + Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor, + element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + + The order of precedence for the setting of the data type and layout for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor + 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``, ``layout``) + + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass.DataType + :type layout_A: layout of operand A + :param layout_A: cutlass.LayoutType + :type layout_B: layout of operand B + :param layout_B: cutlass.LayoutType + :type layout_C: layout of operand C + :param layout_C: cutlass.LayoutType + :type layout_D: layout of operand D + :param layout_D: cutlass.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc) + self.name = "gemm" + self.compiled = False + + elements = [] + layouts = [] + + # Check that at least one of the following is set for each tensor (illustrated assuming tensor A): + # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout`` + for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D], + [layout_A, layout_B, layout_C, layout_C], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if lay is not None and tens is not None: + raise Exception(f'Must not specify both layout_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + if lay is None and tens is None and layout is None: + raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.') + + elt_to_set = None + lay_to_set = None + if tens is not None: + elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + lay_to_set = lay if lay is not None else layout + + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(datatypes.library_layout(lay_to_set)) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + self.epilogue_functor = None + self.op_class = None + + self._reset_operations() + + self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1 + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b) + + if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass.OpcodeClass.TensorOp + elif cutlass.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass.OpcodeClass.Simt + else: + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(epilogue.identity) + + def _reset_epilogue_functor_activation(self, activation): + if self.epilogue_functor is None: + if self.op_class == cutlass.OpcodeClass.Simt: + elements_per_access = 1 + else: + elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] + else: + elements_per_access = self.epilogue_functor.epilogue_vector_length + + if not self.specified_kernel_cc: + if self.current_cc == 90 and activation != epilogue.identity: + # CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation, + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity): + # SM80 fallback kernels are currently used. Since an identity activation is requested, + # we can switch back to using SM90 kernels. + self._reset_options(90) + self._reset_operations(reset_epilogue=False) + else: + if self.current_cc == 90 and activation != epilogue.identity: + raise Exception("Epilogues with elementwise fusion are not currently supported " + "in the Python interface for 3.x kernels. To use 2.x kernels " + "with fused elementwise epilogues, do not set the `kernel_cc` " + "parameter when constructing the Gemm object.") + + self.epilogue_functor = epilogue.get_activation_epilogue( + activation, + datatypes.binding_type(self._element_c), + elements_per_access, + datatypes.binding_type(self._element_accumulator), + datatypes.binding_type(self._element_accumulator), + ) + + def _reset_epilogue_functor_alignment(self, alignment): + if self.epilogue_functor is None or not hasattr(self.epilogue_functor, 'activation_functor'): + activation = epilogue.identity + else: + activation = type(self.epilogue_functor.activation_functor) + + self.epilogue_functor = epilogue.get_activation_epilogue( + activation, + datatypes.binding_type(self._element_c), + alignment, + datatypes.binding_type(self._element_accumulator), + datatypes.binding_type(self._element_accumulator), + ) + + @property + def activation(self): + """ + Returns the type of the current activation function used + """ + return type(self.epilogue_functor.activation_functor) + + @activation.setter + def activation(self, act): + """ + Sets the type of the activation function to use + """ + self._reset_epilogue_functor_activation(act) + + @property + def opclass(self) -> cutlass.OpcodeClass: + """ + Returns the opcode class currently in use by the GEMM + + :return: opcode class currently in use + :rtype: cutlass.OpcodeClass + """ + return self.op_class + + @opclass.setter + def opclass(self, oc: cutlass.OpcodeClass): + """ + Sets the opcode class to use in the GEMM. If the opcode class is not supported under + the given compute capability and element/layout combinations of the GEMM, an exception is raised. + """ + if oc in self.possible_op_classes: + self.op_class = oc + else: + raise Exception( + f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' + f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' + f'layout combination ({self._layout_a}, {self._layout_b}).') + + # Changing the op class changes the elements per access in the epilogue. Reset this. + if self.op_class == cutlass.OpcodeClass.Simt: + elements_per_access = 1 + else: + elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] + + if self.epilogue_functor is not None: + self._reset_epilogue_functor_alignment(elements_per_access) + + # Changing the op class also changes the possible operations available. Reset these. + self.possible_operations = self.options.operations( + self.op_class, self._element_a, self._element_b, + self._element_accumulator, self._layout_a, self._layout_b) + + @property + def swizzling_functor(self): + """ + Returns the type of the swizzling functor currently being used by the GEMM + + :return: swizzing functor type + """ + return self._swizzling_functor + + @swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if swizzling_functor == swizzle.ThreadblockSwizzleStreamK: + if self.op_class == cutlass.OpcodeClass.Simt: + raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') + + if self.current_cc == 90: + raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90') + self._swizzling_functor = swizzling_functor + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + # Check stage count based on the CC to which we are compiling (self.cc), rather + # than the CC from which we find kernels (self.current_cc) + valid, msg = check.valid_stage_count(self.cc, td) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + valid, msg = check.valid_kernel_schedule(self.current_cc, td.kernel_schedule) + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + return [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: + """ + Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass.backend.GemmOperationUniversal + """ + alignment_pref_A = min(128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments)) + alignment_pref_B = min(128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments)) + alignment_pref_C = min(128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments)) + alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) + alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C) + + self._reset_epilogue_functor_alignment(alignment_C) + + tensor_A = TensorDescription( + datatypes.binding_type(self._element_a), + datatypes.binding_layout(self._layout_a), + alignment_A + ) + tensor_B = TensorDescription( + datatypes.binding_type(self._element_b), + datatypes.binding_layout(self._layout_b), + alignment_B + ) + tensor_C = TensorDescription( + datatypes.binding_type(self._element_c), + datatypes.binding_layout(self._layout_c), + alignment_C + ) + + if tile_description is None: + op = self.possible_operations.operations(alignment_A)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + operation = GemmOperationUniversal( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + print_module: bool = False) -> cutlass.backend.GemmOperationUniversal: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: operation that was compiled + :rtype: cutlass.backend.GemmOperationUniversal + """ + self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, layout = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type or layout != ref_layout: + raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' + f'does not match the expected type and ' + f'layout of ({ref_type}, {ref_layout}).') + + def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): + """ + Verifies the following properties: + 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) + 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions + set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + + :return: valid tensor object to use + :rtype: numpy/cupy/torch array/tensor object + """ + if tensor is None: + if ref_tensor is None: + raise Exception(f"Tensor {name} must be set.") + return ref_tensor + + self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) + return tensor + + def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): + """ + Verifies the following properties: + 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) + 2) If ``scalar`` is not ``None``, its datatype must match matches the current version + set by the plan (i.e., those in ``ref_dtype``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. + + :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type scalar: numpy/cupy/torch scalar + :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_scalar: numpy/cupy/torch scalar + :param ref_dtype: data type for the scalar that this object was initialized to + :param name: identifier of the scalar to verify. Used in raising exceptions + :type name: str + + :return: valid scalar to use + :rtype: numpy/cupy/torch scalar + """ + if scalar is None: + if ref_scalar is None: + raise Exception(f"Scalar {name} must be set.") + return ref_scalar + dtype = datatypes.library_type(scalar.dtype) + if dtype != ref_dtype: + raise Exception( + f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." + ) + return scalar + + def run(self, A=None, B=None, C=None, D=None, + alpha=None, beta=None, batch_count: int = 1, + sync: bool = True, print_module: bool = False) -> GemmArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in this call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param batch_count: number of GEMMs in the batch + :type batch_count: int + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: arguments passed in to the kernel + :rtype: cutlass.backend.GemmArguments + """ + if batch_count < 1: + raise Exception(f"Invalid batch count {batch_count}. Value must be an integer >= 1.") + + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a) + alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b) + alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c) + self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + problem_size = cutlass_bindings.gemm.GemmCoord(A.shape[0], B.shape[1], A.shape[1]) + + if batch_count == 1: + mode = cutlass_bindings.gemm.Mode.Gemm + kwargs = {'split_k_slices': 1} + else: + mode = cutlass_bindings.gemm.Mode.Batched + kwargs = {'batch': batch_count} + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=self.operation.epilogue_type(alpha, beta), + gemm_mode=mode, + **kwargs + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py new file mode 100644 index 00000000..b8261fc1 --- /dev/null +++ b/python/cutlass/op/gemm_grouped.py @@ -0,0 +1,270 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run + grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS grouped GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects + plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) +""" + +import cutlass_bindings + +from cutlass.backend.gemm_operation import ( + GemmGroupedArguments, + GemmOperationGrouped, +) +from cutlass.backend.library import ( + DataTypeSize, + SchedulerMode, + TensorDescription, + TileDescription, +) +from cutlass.op.gemm import Gemm +from cutlass.utils import check, datatypes + + +class GroupedGemm(Gemm): + """ + Constructs a ``GroupedGemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime -- + these are not to be changed after a ``GroupedGemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. Please see the constructor + for ``Gemm`` for examples of these. + + :param cc: compute capability of device to generate kernels for + :type cc: int + :param A: tensor representing data type and layout of operands A + :param B: tensor representing data type and layout of operands B + :param C: tensor representing data type and layout of operands C + :param D: tensor representing data type and layout of operands D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass.DataType + :type layout_A: layout of operand A + :param layout_A: cutlass.LayoutType + :type layout_B: layout of operand B + :param layout_B: cutlass.LayoutType + :type layout_C: layout of operand C + :param layout_C: cutlass.LayoutType + :type layout_D: layout of operand D + :param layout_D: cutlass.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, + ): + super().__init__( + A=A, B=B, C=C, D=D, + alpha=alpha, beta=beta, + element_accumulator=element_accumulator, + element=element, layout=layout, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + cc=cc + ) + + # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 + if self.current_cc == 90: + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + + self.name = "grouped_gemm" + + @Gemm.swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + raise Exception('Grouped GEMM does not currently support different swizzling functors') + + def construct(self, tile_description: TileDescription = None, + alignment_A: int = None, + alignment_B: int = None, + alignment_C: int = None) -> GemmOperationGrouped: + """ + Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass.backend.GemmOperationGrouped + """ + alignment_preference = max(self.possible_operations.alignments) + alignment_A = check.alignment_or_default(alignment_A, alignment_preference) + alignment_B = check.alignment_or_default(alignment_B, alignment_preference) + alignment_C = check.alignment_or_default(alignment_C, alignment_preference) + + self._reset_epilogue_functor_alignment(alignment_C) + + tensor_A = TensorDescription( + datatypes.binding_type(self._element_a), + datatypes.binding_layout(self._layout_a), + alignment_A + ) + tensor_B = TensorDescription( + datatypes.binding_type(self._element_b), + datatypes.binding_layout(self._layout_b), + alignment_B + ) + tensor_C = TensorDescription( + datatypes.binding_type(self._element_c), + datatypes.binding_layout(self._layout_c), + alignment_C + ) + + if tile_description is None: + op = self.possible_operations.operations(alignment_A)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + operation = GemmOperationGrouped( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + precompute_mode=SchedulerMode.Device) + + return operation + + def run(self, A, B, C, D, + alpha=None, beta=None, sync: bool = True, + print_module: bool = False) -> GemmGroupedArguments: + """ + Runs the kernel currently specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: list of tensors representing data type and layout of operand A + :type A: list + :param B: list of tensors representing data type and layout of operand B + :type B: list + :param C: list of tensors representing data type and layout of operand C + :type C: list + :param D: list of tensors representing data type and layout of operand D + :type D: list + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: arguments passed in to the kernel + :rtype: cutlass.backend.GemmGroupedArguments + """ + if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): + raise Exception("Lengths of A, B, C, and D lists must be equal") + + problem_sizes = [] + As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4)) + for i in range(len(A)): + As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A") + Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B") + Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C") + Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D") + problem_sizes.append(cutlass_bindings.gemm.GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1])) + + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a) for A in As)) + alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b) for B in Bs)) + alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c) for C in Cs)) + self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + arguments = GemmGroupedArguments( + operation=self.operation, + problem_sizes=problem_sizes, + A=As, B=Bs, C=Cs, D=Ds, + output_op=self.operation.epilogue_type(alpha, beta) + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/python/cutlass/op/op.py b/python/cutlass/op/op.py new file mode 100644 index 00000000..cb76b3ed --- /dev/null +++ b/python/cutlass/op/op.py @@ -0,0 +1,116 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) +""" + +from bisect import bisect_left + +from cutlass import option_registry +from cutlass.backend.utils.device import device_cc +from cutlass.epilogue import get_activations +from cutlass.library_defaults import _generator_ccs +from cutlass.swizzle import get_swizzling_functors + + +class OperationBase: + """ + Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) + """ + + def __init__(self, cc: int = None, kernel_cc: int = None): + """ + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + """ + self.cc = cc if cc is not None else device_cc() + self.specified_kernel_cc = kernel_cc is not None + self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) + self.tile_description = None + + self.options = option_registry.options_for_cc(self.current_cc) + + if self.options is None: + raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") + + def _find_closest_cc(self, cc: int) -> int: + """ + Returns the closest CC in _generator_ccs less than or equal to `cc` + + :param cc: compute capability to query + :type cc: int + + :returns: closest CC in _generator_ccs less than or equal to `cc` + :rtype: int + """ + if cc in _generator_ccs: + return cc + + # Find closest CC lower than this CC + idx = bisect_left(_generator_ccs, cc) + if idx == 0: + raise Exception(f'No valid CC to fall back to for {cc}') + return _generator_ccs[idx-1] + + def activations(self) -> list: + """ + Returns possible activation functions that can be used + + :return: list of activation functions that can be used + :rtype: list + """ + return get_activations() + + def swizzling_functors(self) -> list: + """ + Returns possible swizzling functions that can be used + + :return: list of swizzling functions that can be used + :rtype: list + """ + return get_swizzling_functors() + + def _reset_options(self, cc: int): + """ + Resets the kernel options based on cc + + :param cc: compute capability to reset to + :type cc: int + """ + if cc != self.current_cc: + if cc not in _generator_ccs: + raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') + self.current_cc = cc + self.options = option_registry.options_for_cc(self.current_cc) diff --git a/tools/library/scripts/pycutlass/docs/Makefile b/python/cutlass/swizzle.py similarity index 60% rename from tools/library/scripts/pycutlass/docs/Makefile rename to python/cutlass/swizzle.py index 9581604b..479fafdb 100644 --- a/tools/library/scripts/pycutlass/docs/Makefile +++ b/python/cutlass/swizzle.py @@ -1,6 +1,6 @@ ################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -30,23 +30,37 @@ # ################################################################################################# -# Minimal makefile for Sphinx documentation -# +""" +Registry of swizzling functions +""" + +import cutlass_bindings + +IdentitySwizzle1 = cutlass_bindings.IdentitySwizzle1 +IdentitySwizzle2 = cutlass_bindings.IdentitySwizzle2 +IdentitySwizzle4 = cutlass_bindings.IdentitySwizzle4 +IdentitySwizzle8 = cutlass_bindings.IdentitySwizzle8 +HorizontalSwizzle = cutlass_bindings.HorizontalSwizzle +BatchedIdentitySwizzle = cutlass_bindings.BatchedIdentitySwizzle +ThreadblockSwizzleStreamK = cutlass_bindings.ThreadblockSwizzleStreamK +StridedDgradIdentitySwizzle1 = cutlass_bindings.StridedDgradIdentitySwizzle1 +StridedDgradIdentitySwizzle4 = cutlass_bindings.StridedDgradIdentitySwizzle4 +StridedDgradHorizontalSwizzle = cutlass_bindings.StridedDgradHorizontalSwizzle -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +_swizzling_functors = [ + IdentitySwizzle1, + IdentitySwizzle2, + IdentitySwizzle4, + IdentitySwizzle8, + HorizontalSwizzle, + BatchedIdentitySwizzle, + ThreadblockSwizzleStreamK, + StridedDgradIdentitySwizzle1, + StridedDgradIdentitySwizzle4, + StridedDgradHorizontalSwizzle, +] -.PHONY: help Makefile -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +def get_swizzling_functors(): + return _swizzling_functors diff --git a/python/cutlass/utils/__init__.py b/python/cutlass/utils/__init__.py new file mode 100644 index 00000000..27c11413 --- /dev/null +++ b/python/cutlass/utils/__init__.py @@ -0,0 +1,40 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass.utils.check import ( + alignment_or_default, + calculate_smem_usage, + calculate_smem_usage_per_stage, + valid_cluster_shape, + valid_kernel_schedule, + valid_stage_count, +) diff --git a/python/cutlass/utils/check.py b/python/cutlass/utils/check.py new file mode 100644 index 00000000..3cd4dd1d --- /dev/null +++ b/python/cutlass/utils/check.py @@ -0,0 +1,192 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for checking constraints on kernels and calculating kernel attributes +""" + +import ctypes + +import cutlass_bindings +import cutlass +from cutlass.backend.library import DataTypeSize, TileDescription + + +def calculate_smem_usage_per_stage(tile_description, operation_kind): + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = tile_description.threadblock_shape + + if operation_kind == cutlass.OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[tile_description.math_instruction.element_a] * m * k // 8) + + (DataTypeSize[tile_description.math_instruction.element_b] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}") + + +def calculate_smem_usage(operation): + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :return: number of bytes of shared memory consumed by the operation + :return: int + """ + _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind) + return _per_stage * operation.tile_description.stages + + +def valid_stage_count(cc: int, td: TileDescription) -> tuple: + """ + Checks whether a device with `cc` supports the number of stages within `tile_description`, both + based on raw limits on the number of stages and based on shared memory capacity + + :param cc: compute capability of device in question + :type cc: int + :param td: tile description to check + :type td: TileDescription + + :return: tuple with the first element indicating whether the provided tile description is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + if cc == 90 and (td.stages is None or td.stages == 0): + # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically + # determines the stage count to use. Thus, all settings are valid in these scenarios. + return (True, "") + + if td.stages <= 0: + return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") + + if cc < 80 and td.stages != 2: + return (False, f"Tile description has stage count of {td.stages}, " + f"but only 2 stages are supported on SM{cc}.") + + smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm) + smem_arch = cutlass.SharedMemPerCC[cc] << 10 + if (smem_per_stage * td.stages) > smem_arch: + return ( False, + "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" + f"Details: configuration uses {smem_per_stage} bytes of shared memory per stage, and " + f"{td.stages} stages for a total of {smem_per_stage * td.stages} bytes.\n" + f"The maxmium amoung of shared memory that can be used per block on CC {cc} is {smem_arch}.") + + return (True, "") + + +def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: + """ + Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`. + + :param cc: compute capability of device in question + :type cc: int + :param cluster_shape: dimensions of thread block cluster shape to check + :type cluster_shape: list + + :return: tuple with the first element indicating whether the provided cluster shape is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + + if cc < 90: + if cluster_shape != [1, 1, 1]: + return (False, + f"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of " + f"{cluster_shape} for SM{cc}.") + else: + return (True, "") + + if len(cluster_shape) != 3: + return (False, + f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}") + + if cluster_shape[2] != 1: + return (False, + "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " + f"Received cluster shape of {cluster_shape}.") + + # The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster + # as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters). + # Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions, + # so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total. + blocks_in_2d = cluster_shape[0] * cluster_shape[1] + if blocks_in_2d > 8: + return (False, + f"Thread block clusters with more than 8 thread blocks are currently unsupported on SM{cc}. " + f"Received cluster shape {cluster_shape}, which has {blocks_in_2d} thread blocks.") + return (True, "") + + +def valid_kernel_schedule(cc: int, kernel_schedule: cutlass.KernelScheduleType) -> tuple: + """ + Checks whether a device with ``cc`` supports ``kernel_schedule``. + + :param cc: compute capability of device in question + :type cc: int + :param kernel_schedule: kernel schedule type + :type KernelScheduleType: cutlass.KernelScheduleType + + :return: tuple with the first element indicating whether the provided kernel schedule is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + if kernel_schedule != cutlass.KernelScheduleType.ScheduleAuto and cc < 90: + return (False, "Non-default kernel schedules are only supported on SM90 and beyond") + return (True, "") + + +def alignment_or_default(alignment_provided: int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py new file mode 100644 index 00000000..98984e3b --- /dev/null +++ b/python/cutlass/utils/datatypes.py @@ -0,0 +1,339 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for converting between frontend datatypes and CUTLASS datatypes +""" + +import cutlass_bindings + +import cutlass +from cutlass.backend.library import ( + DataTypeSize, + MathInstruction, + MathOperation, + ShortLayoutTypeNames, + TileDescription, +) + +try: + import numpy as np + + numpy_available = True + _library_to_numpy_dict = { + cutlass.DataType.f16: np.float16, + cutlass.DataType.f32: np.float32, + cutlass.DataType.f64: np.float64, + cutlass.DataType.s8: np.int8, + cutlass.DataType.s32: np.int32, + } +except ImportError: + numpy_available = False + _library_to_numpy_dict = {} + + +def numpy_library_type(inp) -> cutlass.DataType: + if numpy_available: + if inp == np.float16: + return cutlass.DataType.f16 + elif inp == np.float32: + return cutlass.DataType.f32 + elif inp == np.float64: + return cutlass.DataType.f64 + elif inp == np.int8: + return cutlass.DataType.s8 + elif inp == np.int32: + return cutlass.DataType.s32 + return None + + +def numpy_type(inp): + return _library_to_numpy_dict.get(inp, None) + + +try: + import cupy as cp + + cupy_available = True + _library_to_cupy_dict = { + cutlass.DataType.f16: cp.float16, + cutlass.DataType.f32: cp.float32, + cutlass.DataType.f64: cp.float64, + cutlass.DataType.s8: cp.int8, + cutlass.DataType.s32: cp.int32, + } +except ImportError: + cupy_available = False + _library_to_cupy_dict = {} + + +def cupy_library_type(inp) -> cutlass.DataType: + if cupy_available: + if inp == cp.float16: + return cutlass.DataType.f16 + elif inp == cp.float32: + return cutlass.DataType.f32 + elif inp == cp.float64: + return cutlass.DataType.f64 + return None + + +def cupy_type(inp): + return _library_to_cupy_dict.get(inp, None) + + +try: + import torch + + torch_available = True + _torch_to_library_dict = { + torch.half: cutlass.DataType.f16, + torch.float16: cutlass.DataType.f16, + torch.float: cutlass.DataType.f32, + torch.float32: cutlass.DataType.f32, + torch.double: cutlass.DataType.f64, + torch.float64: cutlass.DataType.f64, + } + + _library_to_torch_dict = { + cutlass.DataType.f16: torch.half, + cutlass.DataType.f16: torch.float16, + cutlass.DataType.f32: torch.float, + cutlass.DataType.f32: torch.float32, + cutlass.DataType.f64: torch.double, + cutlass.DataType.f64: torch.float64, + } +except ImportError: + torch_available = False + _torch_to_library_dict = {} + _library_to_torch_dict = {} + + +def torch_library_type(inp) -> cutlass.DataType: + return _torch_to_library_dict.get(inp, None) + + +def torch_type(inp): + return _library_to_torch_dict.get(inp, None) + + +try: + import bfloat16 + + bfloat16_available = True +except ImportError: + bfloat16_available = False + + +def bfloat16_library_type(inp) -> cutlass.DataType: + if bfloat16_available: + if inp == bfloat16.bfloat16: + return cutlass.DataType.bf16 + + +def bfloat16_type(inp) -> bfloat16.bfloat16: + if bfloat16_available: + if inp == cutlass.DataType.bf16: + return bfloat16.bfloat16 + + +# Mapping from library data type to Python-bound CUTLASS data type +library_to_binding_dict = { + cutlass.DataType.s8: cutlass_bindings.int8, + cutlass.DataType.s32: cutlass_bindings.int32, + cutlass.DataType.f16: cutlass_bindings.float16, + cutlass.DataType.bf16: cutlass_bindings.bfloat16, + cutlass.DataType.f32: cutlass_bindings.float32, + cutlass.DataType.f64: cutlass_bindings.float64, + cutlass.DataType.tf32: cutlass_bindings.tfloat32, +} + +# Mapping from Python-bound CUTLASS data type to library data type +binding_to_library = { + cutlass_bindings.int8: cutlass.DataType.s8, + cutlass_bindings.int32: cutlass.DataType.s32, + cutlass_bindings.float16: cutlass.DataType.f16, + cutlass_bindings.bfloat16: cutlass.DataType.bf16, + cutlass_bindings.float32: cutlass.DataType.f32, + cutlass_bindings.float64: cutlass.DataType.f64, + cutlass_bindings.tfloat32: cutlass.DataType.tf32, +} + + +def binding_library_type(inp): + if inp in binding_to_library: + return binding_to_library[inp] + return None + + +def has_binding_type(inp: cutlass.DataType): + return inp in library_to_binding_dict + + +def library_to_binding(inp: cutlass.DataType): + if not has_binding_type(inp): + raise Exception(f"No available conversion from library type {inp} to Python-bound CUTLASS type") + return library_to_binding_dict[inp] + + +def library_type(inp): + if inp in cutlass.DataTypeSize.keys(): + return inp + + for cvt_fn in [ + bfloat16_library_type, + cupy_library_type, + numpy_library_type, + torch_library_type, + binding_library_type, + ]: + out = cvt_fn(inp) + if out is not None: + return out + + raise Exception(f"No available conversion from type {inp} to a library type.") + + +def library_layout(layout): + if layout in cutlass.LayoutTag.keys(): + return layout + + # Convert Python-bound CUTLASS layout to profiler library layout + if layout == cutlass_bindings.RowMajor: + return cutlass.LayoutType.RowMajor + elif layout == cutlass_bindings.ColumnMajor: + return cutlass.LayoutType.ColumnMajor + else: + raise Exception(f"No conversion available for layout {layout} to library layout.") + + +def binding_type(inp): + if inp in DataTypeSize.keys(): + return inp + + libtype = library_type(inp) + return library_to_binding(libtype) + + +def binding_layout(layout): + if layout in ShortLayoutTypeNames.keys(): + return layout + elif layout == cutlass.LayoutType.RowMajor: + return cutlass_bindings.RowMajor + elif layout == cutlass.LayoutType.ColumnMajor: + return cutlass_bindings.ColumnMajor + else: + raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.") + + +def _tensor_from_numpy(np_tensor): + dtype = library_type(np_tensor.dtype) + if np_tensor.flags.c_contiguous: + layout = cutlass.LayoutType.RowMajor + elif np_tensor.flags.f_contiguous: + layout = cutlass.LayoutType.ColumnMajor + return (dtype, layout) + + +def _tensor_from_torch(pt_tensor): + dtype = library_type(pt_tensor.dtype) + return (dtype, cutlass.LayoutType.RowMajor) + + +def get_datatype_and_layout(tensor): + if (numpy_available and isinstance(tensor, np.ndarray)) or ( + cupy_available and isinstance(tensor, cp.ndarray) + ): + return _tensor_from_numpy(tensor) + elif torch_available and isinstance(tensor, torch.Tensor): + return _tensor_from_torch(tensor) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +def binding_opclass(opclass: cutlass.OpcodeClass): + if opclass == cutlass.OpcodeClass.TensorOp: + return cutlass_bindings.OpClass.TensorOp + elif opclass == cutlass.OpcodeClass.Simt: + return cutlass_bindings.OpClass.Simt + else: + raise Exception(f"Unable to convert opcode class of type {opclass} to Python-bound CUTLASS opcode class.") + + +_math_operation_value_map = {x.value: x for x in MathOperation} + + +def backend_math_operation(math_op: cutlass.MathOperation): + if math_op.value not in _math_operation_value_map.keys(): + raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.") + return _math_operation_value_map[math_op.value] + + +def construct_backend_td(td: cutlass.TileDescription, + kernel_schedule: cutlass.KernelScheduleType) -> TileDescription: + mi = td.math_instruction + backend_mi = MathInstruction( + mi.instruction_shape, + binding_type(mi.element_a), + binding_type(mi.element_b), + binding_type(mi.element_accumulator), + binding_opclass(mi.opcode_class), + backend_math_operation(mi.math_operation) + ) + return TileDescription(td.threadblock_shape, td.stages, td.warp_count, + backend_mi, td.cluster_shape, kernel_schedule) + + +def td_from_profiler_op(op) -> TileDescription: + """ + Converts the profiler's TileDescription in ``op`` into the backend TileDescription + + :param op: profiler Operation + + :returns: backend TileDescription + :rtype: cutlass.backend.TileDescription + """ + schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None + return construct_backend_td(op.tile_description, schedule) + + +def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription: + """ + Converts the profiler's TileDescription into the backend TileDescription + + :param td: profiler TileDescription + :type td: cutlass.TileDescription + + :returns: backend TileDescription + :rtype: cutlass.backend.TileDescription + """ + return construct_backend_td(td, kernel_schedule=None) diff --git a/tools/library/scripts/pycutlass/docker/Dockerfile-cuda11.8-pytorch b/python/docker/Dockerfile-cuda11.8-pytorch similarity index 96% rename from tools/library/scripts/pycutlass/docker/Dockerfile-cuda11.8-pytorch rename to python/docker/Dockerfile-cuda11.8-pytorch index c36e0e2e..c573dfe7 100644 --- a/tools/library/scripts/pycutlass/docker/Dockerfile-cuda11.8-pytorch +++ b/python/docker/Dockerfile-cuda11.8-pytorch @@ -1,6 +1,6 @@ ################################################################################################# # -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/python/docker/Dockerfile-cuda12.0-pytorch b/python/docker/Dockerfile-cuda12.0-pytorch new file mode 100644 index 00000000..a9a84bf3 --- /dev/null +++ b/python/docker/Dockerfile-cuda12.0-pytorch @@ -0,0 +1,38 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +FROM nvcr.io/nvidia/pytorch:23.01-py3 + +RUN chmod ugo+rwx /home +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH +ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH +ENV CUDA_INSTALL_PATH=/usr/local/cuda diff --git a/python/docs_src/Makefile b/python/docs_src/Makefile new file mode 100644 index 00000000..92dd33a1 --- /dev/null +++ b/python/docs_src/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/tools/library/scripts/pycutlass/docs/make.bat b/python/docs_src/make.bat similarity index 94% rename from tools/library/scripts/pycutlass/docs/make.bat rename to python/docs_src/make.bat index 061f32f9..954237b9 100644 --- a/tools/library/scripts/pycutlass/docs/make.bat +++ b/python/docs_src/make.bat @@ -7,10 +7,8 @@ REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) -set SOURCEDIR=source -set BUILDDIR=build - -if "%1" == "" goto help +set SOURCEDIR=. +set BUILDDIR=_build %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( @@ -25,6 +23,8 @@ if errorlevel 9009 ( exit /b 1 ) +if "%1" == "" goto help + %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end diff --git a/python/docs_src/source/_static/cutlass-logo-small.png b/python/docs_src/source/_static/cutlass-logo-small.png new file mode 100644 index 0000000000000000000000000000000000000000..6c2a313fde38fd6c691d9595fa6a79060236661d GIT binary patch literal 1488 zcmaJ>Yd8}M7+#k-BIIzWX`?N=h-#M6NG_QXrpIv$rOw7SmW?JxiZms9a$6=_lOY

6p^!^+*7)heKvw7}2+Cic&3B_OXtEDJ&dGl(5-t8Fy1@9GN%wY36 zWjgz#V(b-zCW_0`(xK%#(5)Q7eMKaK?4V<1~%K2N+*4!Z=obe~PHH`Anv? z&#p~xrG6tgc(0OKsvO&$Gdefa+!pYDjt1ehGVONzN?sN^8yOo{P=@^JOS#%6_fz~3 zljVeYaDJSgjjRWku}dbvzbghVkX*$pr@zzH6m)VC%m+@wBh6IrI!SIBCv!EU^ZoqyIpXo7$dzm&U68{^QA4Oy4eho0lTJ?fKnK#I zyKm$>U6@JWQsbNce)lS$fMGWA3+pCt3{P+Ww8>&(6J9!_MEF*cGnib+#%OGkD-vzurWCje>v6m z%D=DurvVXa1o*S|>YqSAhs46=no;zF7bVLm;k1>-+ZPhQmT#^H1u*7i)1gmKC@|_6 z=Tc%^T6h+rJcH1j^^3J&4gOxL;rxO+bNzNtoLm!%U)1!p5JE;bz=RrMYDj1-GLliS zFo?LBtu8u3KQGdq)-O=a`@q5z4HDMJ(>^AO#nz)z5=7>v%bTFwaSxOLR?qFIShOYE ztQ6A~l!sJw6$(LOE9>O1%OjuL=#FWgiT&(5p^v*(AUrMy`(KM)fF?R&kgG}F)zP+1 z8DD~=()1GgF0v-5PAhPMZ4%wmIxWAd7?ruc8ws5{lG^oIEzd=r_jsk!DAF=ndN%G* zqOE|}=hhz9Y499eVN&C?;5#!wcFjUGAbw3R zy2{~T_UD&%&=a@UDOvwB7JnxN%R&SZM@nLS=G=Dh`}9ck2jc#QX{)JYT5r*H=5A!P zcm1dy$s7qKJ~J20o~D3KxZDErltw#}OPAu;aS7ziCXGjRi-({v_@sjJ(QYQQC(dwP z9OY$h%95v|;D=^HfKD=_M+CuDuK>{^S40)mJ(?Lh5UdtueWQ*3KiQS%VLBsSu!2JV zQvR)>ptRg{D%u$=P&9>J>{M!8#3~73qe9pdM_oz0Ndz$rq<$CAC}|%~au7e&j}B)| zOiYx{ml{kw+;O@C7XX(&@TpKNX+elX`APk%G~F7umg?fLy<2YoXU!a0MzxFsJq|`0 u2Z_tqEC}ID;VIBj2^TSo1qEwpNP*dL2WS`X5nW|p7=St9WK(2yHs%jSv%?nv literal 0 HcmV?d00001 diff --git a/python/docs_src/source/_static/logo-dark-mode.png b/python/docs_src/source/_static/logo-dark-mode.png new file mode 100644 index 0000000000000000000000000000000000000000..6b005a283ba6b7299a08cda1d37ceac8f693f535 GIT binary patch literal 50546 zcmeFZc|6qp_dom^CRB_S$BWKTk5o3U?0!d0QgQe;a?B}=kzlZr~3 zk!;zC8IdLXexK>{{e16!_kZ`l_npUeJ?b(uXI`(fJkL4LbKcj#r;PPjS$4A^gjn^D zA2mUUJqw{NUQ7(|o8SW8CHRBc^SA{LA+c?=e=&DBw+#vyqQOIB2{I6&H^_%~y;%`j-R}=rg z1&T2(Lxi+-qhkCk2J6DE$Fv{#6GBfNq@k|g;`9#3k=BCn#o=!qbG<)0+~zks8b-{m z-~S1~#K(t{V1Vp;p7|3gV|k@kZB*sQ$I5@(Fqty(L4mn{ytVNX7f8)`d0=GK&9rt( z^z=7B`mq06g7)VU6YXoC?|fa{89;A#c>cg@XvxfiCpz=>&pfeo0NzD>{crO)*-!rH z+kqgMYya;L(eGa}2C-d`m#~Zp7P`Uw!*OKZ&insJAv7E~CHz`t)7!;+#qDyXoWTEn z8lizqTI%4wckN9CuGUfhQ1AHGKVK!#UR67o>g{fF$WMN2>x~YfKVH^er@hSQ#B#{8 z_4mNdOHrV9Ta$`rmPIzL{`(bRuldjd`+rAN_gwo;WzXo&C;N zEF^nI?ARajPzOyfF@A#;p4`4CZvEj)XgWj$j|K{{&rX%;-m(!HQds}=r_5;{+RL7A zb3N6e>e%zA%xQVrOB_d^5TADltq)3kb2+JO$^FMW2ukpV%=7iQ0gn}#s2E<6hbMGa zxW7;+H_JDHKIA-a?`is~YC#mod z9z+O|V-`T=RW3gQ_(EIg&$?e{ZF)78_U24G1`WiR^>6Yvdi|U|vj6eRs=?@3UaMBR zAo<=Hqb&~@F@=e)OC@!S@swDr>s{~05mAF=F;HCa{K77T%xOiW-$5g+<~2UE@T+50 zydoDkvcaV`$ub2zW*6rDlwRH_F2Bpu@zRH{v8eN8?u7rx4b~(?5UAP_WqrLPxuw!v z04=4qb9rv<6K6nnv0&!%3JJ}{IirE;x8V9j?&_d;N6x1tkMn+SRL`|47+C}e(fO!Y z&|TIrdub^PvgkAF_B+hmd%5A8*%9NRIENyZo?DCmV$jl6nl$EcF`{FE{K_laBu|I^ z2zd(3y53l39pCjN)$X;@w!xa(l;u9U)S(Ak(1xssGTs4lr zv3+?jLO5CkJbOz)8LfV&MIOeqU=V@tqibxJ!np8biOV9gnpClB%DF>IJSHE;JLOrC z@|;gtHTzfVReEjR1~0c4bbe%G%58)=I5}xWyeBLB$psJ-z| z-$_YugYrbh(EdlEd-@Z`vP;zWGw)<8mL>M3?bqS1)?-0sRSSOW>}b`oe2WeEV^11f zb1BIpcrsR;AoW}ST2%EA^^;0DT2&Hhn_X3R_v-mdZOtS#y(&O@J)BPlL2E;vi zTQSIgJj?(#rM`|Vtd~wa+Ksx)(;ebsmc4JN@Y6v|#@$W|Q%^P#*%Mya5%a!P!Et2Q zmm7X(GtFd-#-e=~5dTns#TkvdfX%Hz2%Tcgvh!Vd1Z_`n_I6b??VfD;s7ZQ~(~BL0 zCs z!b6Ub*1@RJx@2b?Gdh%Oa(-i`x|*kU|J1(=7-Z*APvkG(X!;2L2qz_GFsZ&Qu3p1k zo}1d#>a|EB%iYVy#Kia)DcTw7(s4LLY@TdQNPJi|B{Y?tmiqXE$NN{xN;xs*&m&dd ztMN5tl?Zyq@W(gx|D+!Xx_PLkoJvDWnQ|`%@y@%gR0z#f$N?)r{OQFxUUS5zfcw!)-?-|J;CPuBl!G*Mo zU5j>T`RVf%wA;bg<>SxZF)z=bWX&sy`=+m?%C3jDa=SEd`c;D8k z?b#EpS|4u=wNB2|o16rSxoWmb;nz}^l$WFVy@pM*EYblc5i<9JQ0Cnb0tW~;O)$A+ zI-tGKgHVh{K4LqGdD8MuGjxG9isM?v5l#S*!vWdFP<6<6wxsA_}*&2GvcbSmf0bxAijgcEF z^k&+(C?mZY)AV!o;KtBFr{y)LSuKRfUshklIXC;y!mK8#op%S6^qc~J!qcWkg zu)%+L-aw+^vY<@@s6NK2w*@nhm2luW9k-W@1R@gQnwbvulq2tFn9?ZYDlfgKc|4Ac zH8f}DQ8PMbiX?G^_a~5#kPA2S|kk&nVo^FgYo}SR)(3rH9aTLI1P!IAzra(QF``%-T zWRo3q$_C)j7ojSc001HN`x$7*8b}l$;zdM%nhNdfDZL!OaxRfLXM!tsJ+G-sAI4Yc zp4XgxBmzMqHT8lmLr^9tej@qEwX>7@%-@uZ&X|GA!O7GUOo9teA6)f9h@R^(BY}ZA zLrGTe?f91v-t!{2`YJ|dKh*++6L$7n7)qR+Q4XH8_OUGxj_f-a&GPhq#gpIn>|{rT zjL!J%OHUK-=`d$sCzFkeg=cfTwdkxa^+2`m(UG3WGvEvR!s;BkUU z(mKIH1v$s8v!$Fe&=2*)Al_(WDms;AtC}Is{j{$@fK+@(TV3yM$k;7SDZN!d%_HaV zV0I0OJ=WyvRQLB;8i7#{#5lepPJe{NK`(g zKUF0zmRWRYCqh@rP=QH$RlEzX&}GKVrHuKr+1?#$5m?Vx3+&FPyzjuB$dZDdi#!nY zH10rZOG~LkI-lgQU&4YToMXY}NK*8p9q8+&<oKqJdSU;@OqfI~T9dra*EXiViMI!Zs^N9}i3&zI30=8YwyN?NS8p$7TAy*Q(4GA8{sq;a=Yn zzjX`P|NQMCuj!sDYUN8V{Q4*2Am04P8wDdxJv2aGDf^UEM$8^X^)|V%)eIgpGyUBnYRo8b!VXr zLEJx?KYc{R+Y=Qn$~^lJ+NKS$?)wTWb&WKBXIQUe7QAVRS#o;+&Axx&ixJ^ij>+|W zsua1abkj+fS&q#je5)ZMq^NgqM+4L;5wT4B_EX^W91xZQR_^k)?x*Xc?2T%DG_=M7 z+1j{{c_idR0jhn~L|rvgENAJ!RK)try5=nW(}}Af2nErh>5M61=Ohg^MA%qA$&BpO zKxO{?pA4U!z6SVnm_mL)X_8J)c3K(<^u1TJrHPV|Y>4GIfdb^bmuyZo)BBJQuAS^# z1{QzcuPWK(mx7Ua4zgpfufh+h)Ze#4g!C~mF<*I)h_$;{d`ytCr^QwN%Jd`bCP#Z@ z%gSYwgURUkWf?S*O(s81tI(~5;N+6BDNbU}ea}H_iG~Z{j<;*rDj`aqfqTgUcAIy7 znK50eW|IY{Gh9-aV#JGHlM6ia+Y29v4Cx*ME130Z2&hK)2b$2u4G6gkTAigs3a`XZ z<>_EL?@aAd9X<^H7sm`a*&u0>BIfp(U%D}VdO44uT(YB{YSh9_OUplN&Ty=^S{-hm z!DV52gMv+`8f?K#W>l(9HRT@8VdY!LKCw;^m(HFki zx9#`U+WS;|v{n=HuN8XIB^+tNph-In4;ZphG&ue2 zu$)S0B&pb&NQT-&=|7P#UzrgCA<#F!GHQV|XM5$I(;@N?wOc}}$d#)-=Zws4ETSgR z&Qt$oy*kq9u4s)FNuD`_zV7y6-Wp~IXUJ@vM>lz=0@q*4e$MCrG8KmBTW4#HdN1zfRO&1J6ABu@?BIk=3jDT_k!$jR$qCz5XBbtZvt2Ed3Q+ z@owFFHrx0GUKifKdyu$X7f>*sT*O|Tyj33&Y`Wmvk3`~jMMgmjp2i8`)o)(;z3+`q zr6}AZ=s2er&+lvW3Y)?$F0nn>dU~g#X1ylkkwQ)IksEM` z>LR?BtbOqmQ$2;qNlQkf=?2pMTg4Lrgd|6a%l#ErdA1E2sRzmyUSHdbh-hLfnm$=8 z#e4H$Cqe?WW4wdR=#STOvaP3Rv!=_I?W?EyRHlpClH{-D>qDg1Oh|3q_p1n1IsU%= z4wnK+7sTwtwqTZcf?RICtz<%(qKBzy+m!c>fU(f@vkvD#PAteNw{@@Gei`zl#BYsA zni_zUunKsF%lex{j4lJF9TSORBla)F+3kZ$5@?eZp-47jp-kCv!$q9ZiAgEFxB5%` zx#rhnR(}cTIObCt7mq6wS!mV0rbNt!p6a0QN1el^>E{wNugHi%BHat(!a9#Qocer; z9^nr&2@r&>-R1S(YVg~AI}@k3wDtF45Nc1PQfA`}8}Xt%F2Lr^4oT7_H+_%)fFdc*K?F}CHT zB(OPW|4}iYKbYoOb*Uijh4O=5O`I0gzB+g$G8Dwr;YXMKGHe?nV*q5LiXurqTKye2 zK+eBAIQ^_!#*2d+snR+y7b~=?oO-75kSjFbd!Wu^x^G<^U7`;5Jtt4_;Gz^ za|^H-E#dm-UHQq1ZW=D_6sgw*vb1^oVa80u%8V%)fS1t9bbwEgf}{k@DV&C$P36N?cc>;VPfrlq{0a=n;>ps5L*@r;OQ&?9quLO)7^l1Gcz>~ZeR|*v zL|GK=Qn5|amwHKli3OBTdzyaVt zY^b|J&sS>HdNCcS%!44tt5f!>5){dj2))t=2Yd_;*qD9elu4|0HPnZop&eV6w!-C# zgFVs?1HX)vmqJ=}W}!_sjnLfHpLgh-q)scXZyCslx&;t0&_89Iv!hOwe(t9}^FW=| zMHWQxZfY%g`BEsww(;@4DWAHYPtQxr0YTio zu9Khcfh!)B%=vWH4;#k1T5TP7MU2S1`x`XkDENI%g17f=c5S~qTm%X|$`rrLg~$G?0sWkv?qO*d`eEy(+!CcApC~55;!l=dzLXmsFLXAJ!Bf z?BE>pKG<(I_bEN^5X6aSg2`m7LK=gN&Vmd>wY{eEgEwy=Q#gd7%;ze-Rq|`yxJVp4 zGAB*5;CZm1Wc26iyxE@59vdgi^fwy>5C5o(YHiC8BSE!qO&XNb5g?7Y+^BtCmGZ6!_~AcG{y|6QL_K@dm1QCJmHqG-q!< z=D26mCObH{oUA&2ukQ)8W&g(9HPn9Q&j8TnU5+3_b#`coJm%aSXjN#=*|BOM;7#C& znu8Jo*?P<7Jg7Fryn3E+r+!(+zkuTTR2XR=goK%UeUei3F>6(2DoBPhV#{)PYN8{5 z?C@A`rLFsIAKilC`&B(d7o{O!C!Qu8IN7T394nwdh=c_nFq&pvscl5aO!(#&a_J-N zKmZ0&Z$koWBd-UatB+e7Q#q0KCZHtpOI1?Rluq*1uQLyf+To(axOCB4TRP>GTN%loA-;V0mg}!p}h)7c?|LOb6?J0 zQvy3qUJPlA(qw-f*@;uWSkBW+;#8gLNjCb+S`Ioa|%&|@%rDJ(k8e!njAND6xK5iqN;^7{o@!osM# zSZ-Vh#Wy`I55U-G6&h7L55s*6dldT?ZXH1(Z=Kv~dxkG7U~`Oc@7l4FcqGYz}b%}j;}Rj^S_XHa}Y`}ly?_%%*^1vX_lr8 zu4oiUKMhFs6PL$+JiANPv`Hemm|dUDc>W4YYT{DoQKb|lX|L>vKlMU8aDCEHJSZ;2 z{nxKgOSI8JsX|ZqbC;BAC!fn?LvFca$5P)of!aFVY9@lYwfQ2Q$7g>vX`NY2aiW#2 zl>p?0H^lZ@9+vZ!eKKxg=S6V3yphwb68wIie>qt@&@Az?hAdBo=i0$x3qa9Qh099# zXpEWw$$%jLr6dphxH~`Dx(9noX%=7dOLp@_Z#sZ#raPS8&3&v-Qn@TXSflaLYCs+7 z%g9c@*SOMYX{##i7lF0>GnG9Lt=!otX4BP7wBpusw->2V%`Ag|&)eXRvUA{+g_juQ zPM*21ZgeIt#(YCXN`Ky}njW1}0cce$#$Vo<&`WY*e%I20n^W}P;8%T9NZ;G?+>hsI z!H|l|wXJ;MELXPq4gBcN(q7?%6l6{6r6mlvO&e2#qDi`3|8%p6AJrSfn+?U#e>2il z@vj)aA3Gm?5KvZ`PC3-kWlr@cES1^ke%4ye)fbW8^AFMie_^*>)MmnED(gr+7cXD$ zZ%5&4)4|6eWGC{+C9@{3pAzm_6#LPYQKxuL1y7^n4Oz)Z$Tr$`KLfS$$%>QXWr;D#h&r zlIbyJf>!@6gp}=`eyK{heE;tA`CqTYGAheoaiL5Tn*1T4(IcG~!C=FrCU5&U8vON- zPIutGkBDrupHuDG(h?u`k2*WJ@yBBmQIm3Z(fl7rlpHRYU3*M}OXbK}F*JH{W9TGL zS|uvk^}*%<(oTc8UV15Z!u8A*!R5d>eN^*-vKJak+sbsU=&gLzNV!1M$)+4R{!Mj% zpehI@H%q5}ggiT`&zvb$072;asVdQKr?v*}CD&QX$c@b~bzqCu!U5DWN+U1dNZ#nJ z-v3(hTcGjih)P9}8TGzg;Gy*7CVSI8SEg}}^bBk$ z%_{fIWJ*dW)W=^#ZEhDLQd4Fd9<2+er=liuZ*@_SnzGyoZK;q++_JsK7%9-!I&#$tXVl51ecZiSvl{spTIU z&$OUmU%a?8eX6u(A4XdZ=UXsdAVyd%IP~tW4`bTY_KS{4*-hYZh)sXhF*b^t+`eK<#wZcD;YO-D2fl3g)(QavUQyU%42P?7g}($Hd@E3-x0 zL+!8p?EFf>Oz)0HYHvVXA;l=tscQL<8{pCI%{{C0RTHbb=aXw#>zuySt0k4}B5vH_ zIx8MFq&0vvI?q=vs$tbp>t&Nvn4x_7ZHSX60l_S0^cj5n|gySd@|` z#{w$MqlOzoY#KCx^hkZ|vN0mq|87{j^~T~Zwp+&n8A5s*~j1}gzEIfG`U#-UW<4!{6*r8hi zyHL{9(JXvk0j9ykK!3DXGGxJT*LbO7=ShbDQD;$PE7pB+RytDOniM{*GI5bPJ8bKf zLJgfcmWE&&O5sPYG|qW0Jz!;RUmk7~?@N1xJwme8+$`;?+CtEpVqV*A;CEVJff=pb zO~^D@zPqX6T5>F1E$BM9B=mIyRZTyG_B9)>hGNN5!3>e4L0(9dGfb`*+d6y7mOuM# zcnxbrLz%6eyy3aj6U62ku$yX;64JBId5iZvO?6EG(tVDhxYGh|@49xP*`K+}rQR|7Ae?yUj?GnFaNo30tPYwt62r6Uzyg#-j)&~(VYhRd~ci%d>Yl1xd{3IBG zHBQAt%+xvMNEa&XOD7oqm`>k9P8S!;#a)?yS9jFxonko?9UAb5c)DsyR}N?)MzA4u zDs5zWS3aCN#=aV*b2Ihxm_!Ob&oyF#?TIp`$j&1q|Y-K)t30*|f9r zt2GpZU)>yo$y_*YW3L288W-cY4^s3(%296EWtkdOu`+3FR2HN!lnL7Hb3W+G+_Ohl z3ZHC4p8Udi0>;Yywg{HYJ}p8B{p4>Oaku_1S;e|VDH+?sE;!7#6N9#W2Dz6C>Ceer zB>0Mn$t}{MjM>pUszLPCh7#{q`0Q$yzdkshQN3_2plM3()$4b#wZaJaPVIp@nQ3^> zAn-nU+p6N`%LZW_Bo`-^JFfj*5!*7(+iwB+xFJVMHBA0XxYU&HdP|lEqLUl7LBu}P zSAupYf&ndkq)7(d4LFxNPBg#^4kBQA*df%avwtkGRH9_oJmbg`e^jZ}nwQDHMobR{ zwQ9FORpH;wEl^umz^brOv)$)?xX$bj5IXRtpEL^)ztqd*>PdK^;%7N_-Z{JO)s4_R zA84btf;QM{#+ZIO28Gd4SL5fo6hp#eXw55Wko*n~952`(0epAx)~gnLzpN!3_fX@< zh3p{5;z@lqDmk0oTIn_sZ*ik0ur5C`=TNEQo>R^(VeX*5gS zS86k>?Ap%6qYItONhX0zQ4c!m#7PuV=d~dXoG$&+-Eoys&(+PHw9|IiA4BNq+r@mR z_9%6Bh0)l>jlE*j&Jcr+y%{MiesogTo+_dJx_5#;@S7{^T{j;8J zzrW)Ie8=|6!-57VBtt$O*fRUOJpNC{ResI$yKmZN2(8RGSavcYLJY@IihbeLk;;}i zrBEkRYSNp`pD#;+sM?>kKL}f$n1G0DUALPTuO-hXl$;eMJ>BXDH6GABJYYN*OTt)FR@Xz z%AMW90+Xp3t|6g_yx_Miu@?2#!8-jchFIMdUOlXO*=+Js=Q#!#41nRXX@)yEcly`- zenRd1l8uac;c0B%rTb-DS~|pH+)w;I>s(QPGsEgE%NcwF#y0nQ2^Qh)m`^RUGEbdH zg2Q2@P=`mCf<*-!)ycpMlv4_B0G7G-qm+_jhup{Y=epfL4t#*_UBF!z60? zgtb!BR%S06n3~O598|60FARk$O~%{h>bx`}8Ux<$7TNE&4}*6I!pqjFvj=1s$u)ZO zEsiX-zUWa2#r7(#_B-v5pg!MzRLfU=D`fELUaq6bUj&+*h5M4pEic6NnZd7xq28Nw zVit1ZX=km`Ec;Bqq!bkEqxwVb(*{lyC^kcG3@Q}|Z$bhO*~JO&ZtO~FXUJi&Dizob>-Y5*4>nHx+8b(s}+b%*xk0e9MQ z>`2d8w{l>siiTc@aG>Nr1odf$DTU@EXEv8C>YAA*x{!7jB<#oFCaKww4?4D%mmkB=qj8 zTW>t}UsLw1qjjiWbv+ksM$az2aLn(h3DO{?tI)o0SUTGMrNK?ucVJ)8Q!HH_Dl7~powRMV=2)MenP8NMCQ?8qdQXj{jxCM7O<8<|9X*ROnM04f zG-edC`g*!(Vk(@SR{Io6gUVQ>b*yFIHF9a&`kfdL4SjI>DkW@|KVGJ}zv__`bH5eo zc7c=595?erLA7)z2e`yFp%vb!9}A~-#QEEk6NpYRD%_}Zq)IEghHe~6Dg3GYw!!2P z>G7P2Z0D<60WnJfR!;50%tTd3*^p7iuvmrpzEfTNA&oJP!JC#gpxrzbx` zz2*oTV~Y>5y?j4vnuLiP)XyCXjAhL^NITUvoFDXT!^nFE?C2P#MQg&p>M{uo*wInj zIJ9j&PV~o)S$}uzNslBc&A(8qoSq;?`kf1Vtn9m=9?2DoT^wDH>f>$ol(UZdl!dgr zAG`bt%JoqrI^%}&SlaRpFI(QNUOAg$|AHPBa)SBjfy9n3bZ5nqJ9YuFktyERc{WaP z5I5|rB@x*Y{ZCA7U-7HpbDjg0dH_XHPxu1+k5#WV32Qst8;05;nhrG=+uHTA$#Dgp zY8Rb)X*FpmY9)t3g{&a_>VC|AK8wbti0bDzDLaG#Yt2~)XLmp1I3@_Qh(OY9p_Qy* zWm^4BBqzy;^j)F4oFF|T^YkvK%z@XPLYd1xI&5{q7G?C?h<#_;#&1`AsXLTB|LRRe zFe0!*$HK2&M!o!9?9}^?#$vg_s9}Xjl54^c=dp~$3xGo78b>0i(Q(84Y@8T$^KN*seMu3>|fdKN0s6rS_tk`YraDqm(~5= zh7Ua6E9M+14Gs1apNNvf8CD`&?lC?+y%po(njO_?%Ye+2IaG-aPy;``idH%riko(n z-WDDD9ueuTE=yve=NPb#b!t#%fQnAlZ%%@Bg%Ud7@}`ysAbG$EX&SYDUpt(UHyCa& zy~;cKT6Kh>{g}#w7cL))VOEd_yHIvVK;tUbB*kz|wImo++Hd#WD%(+9)RhJNQcG{Z zHoEHPAsyxi1=d1ne8aGp^YzaSvw1AmUdIbr^_f#fRt1qiyUrY2#R(SPd&>nTO&;9R zo}AAPu_NLwNLz{`1RDWiTZalopo7K&<18d9ed~IH8$>yd+Ujy0)xvM1c-{hoo@Fz_ zn0g;CWR}_LeD+9m6^1a}v8|nTeI|d)&5xSL(o$Sf21L%1)Dbb4CRON_W!iqi##m9X zQj(I_dAjE}2Q3PHx22B*hpoh~8ik`ZvD+H6o@yqV99y@t4PLRtnC+azcVmW)J} zM=h>)va8tp!HC+&MOZZP=im2)r_X6cg!gnA#Ek*U46DqtY@)~4Yo19vZ{@J0ks~aV z`9E^6gIrO;L6U7i_PPATRWzv4Gpp$#1NoR6m_#g~LZXSh$YFIEon*o{f$(El0 zk4jtT`HK%s^F&b7ELHVaHCIwz1?v?}RSH?aHrdTadvDdkh;Y zz_+}n$Mvt1@{6lmxwKD87WMJlZTrO>K@Is_Ac4PATKiHFcyA4uoyNarH+2-vlNp&q$n!?lfkIu1rJq1maS3qwnc(T{ z(Oe`>_?&`m=}vdBo-X>e_H5hg3-)?;z^pZ)N5_<=l0XY4AM;r3jIrpM{ZuAt8 z@V@E8Fa}7sb($StnVp~oz(hkctiY@_Kj$LTGoq@y2?NUY5~GI^p}+7m$~^Y@$k|GF zs(<`pB|J#g?f_pe9nKtOOX3i!D==cgd?@)oZh_K{?47kroE_J2JLYkKB>5)aLHQ{l zb(m$g6sF-IH>_33G14JZ&T_g;Nmc%HFDM7wx*EH8%>MVBN5x6K%J>C7Iw-+N-qlv>8H%C8U*FKSdbS_cB$tD(}>0L<6FEt!e{8rv1@; zi~_BFu5AK1A#{qzDI4X|3Ta1I5y|0t7DQmDS?88Jb+x?bE6j^cAWC#m+PBUng|Y>z zE($60+E0MyLvQSX!O7~Ng*R0%Djg>2QJxBHwOhTnD4C@HCPU}u{rNcUe44cLH5n9E zh>I;5ANS6on4vV59I>Mxeg~oTHpQBm8e_=4GWhVAPh$@w zI(rd@m@HSef5WO?*sKp6bZ*xj-G&G<0GMvDIH8!F!QU zOt-@nEFFe`t1JxqR{iWcdUGmy&$Np8{F?`jBRlm`Z*X^$IXENW(dR8Db8j_6w9*r_ zYvd^yG|y3fIenHnZ>^}@eCU>Qm-3b|KZSgfP$>b zF?)&=TH~5~9)1GQ^q~EbB9(XkM_$_p14QJ+AISd}ll;A+r%W1_iW8vWo1z`g!NESr zeKdsJJLCWyh`3|sI~CFh-#F*D-Yd0fiLcEWQ05DtYFLd*e=6TA#z^n?9Ey%nRYdUf zHrqnT?xcV56{NCuF4r*9di4!jHWWX6PLb|ri6nDotEx^bH*|C73}>&1JE-h}qvd_D zmr&iiqb2?k)-InOJ-d^bX)Uut_sQ0rFVCR$Woh|)U}XhqM{{(rcSue}3nEgl>sjQ@ z%j5WAEXIUbj~kqNa$7dnVq>CR_euZ_MYW)%>};9mS8?0#X3TqVD6pM#Lojk~EY_VW zxS7t;+1u!1v03=@o?0ctant45BemwyJ6!jM;r_ZPH?bz`SSyF~OZ^+dWx)3jt*#Wjo<~ zL_J5b>Ci2OqCq_E*tVB?t&eE6VI2?$g$H}dSFjWBuQ4_-U5yF?=9jW8XgUZi!N|eZcg?7QFD2Nd(-yU07ueHy3qfPASnl0*<`U-zA*TVH0hE;E1g#Z)E4rMwfvkq5%=%Ko)x+?)l+&XB-u+EfC)|O0@G$njoRu`1{}J|=pi^)ROT74wwGLM(Zxeo^m7s( z!IlBdu*xumFP3zh>{2d%x~P<#O1b6UtI>b^tGyqM(sQLnY~0YT9Sw)J)eARWssVca z)W3Py=Gv`*y$Xk)38;QRy&a7@)Z?;}_PDAXRHV zr_EbQg_3jBfVVR>Sf4HAMKMAxseB=13oyCY36oPQYaQIlvH4BHhB(kA`YI`*Q=(vK zlIp%>z2!fW8jC9mrK5B9z~xVy(0vLI;rOaCx~hwM%VE^&d~bbO)8E79!xy7%xLt^N z4ra`DqMxstT~2vy!JQYYIy7_!5tJZ6*&Z9tFW`^FGcn%UXlPK-Fcke1o@}{Sr}qeL zUiPtYtWz-`^46vG0xWO$*AxDMhkiKA&?ECVu%%q6nD+1Uq{IQzW0#MYzdpD~+Jy!f zp^8#vV)AfZ;v2(1U9W0teyv)$piZ2X%B|AE%>8UE^AoFb2t2OF&b zu*(&^&}FCohc+^qr?dBnfhYQdXXHYA*X@C2+NBrS#J17E?t;I5@gE_|FL3hcWo>{>mK*zNN^{U9R3#%1OnwtwAR zmzq(sz?!KEU7+A_jdep2b{tFq;?_Hj8=x8`)=pI#AWquS=O=rZgWr4nVfPk7iI`Y! zQZZ@DDfTOG&|~<&kApoEIXL48{8#c7-F4w&`V9iVuRjv*xsn=CNPvIu+(;=A3YG1)tl_n+-c~`gcwc!h*n|8wZ z1J+#EJzouW8ViH}Yxhn_*$1CdW2CG5MPMhud$O?}oj?qO9?N!Zt5hvw_8t$Tx+0oZ9^xt9k$*Xh69yVD(nQ0FM8W z0z8Xh{v4f>=V5Pe+7fVCF#>RiNChAH;c(Udhz1^G^ZGg0+HH3#FU99`uLKq!zn4(^ zbE%fEACA)a<;4@h`mY&}-A82ra#d2ljHp$lVH@`sKtIwf6@j&qeSoElnd4?FK*ffh4&J4@YI1RK~l1X2^*p*0rz-IIJVXeH)VwDhX(%)WM2o$4qzV zpSG-S3c>_g1v*;7H|ar1v~>Luw(Xb3rq9B|?qMr=B)1>*Rmsfk0R$PwANbnug}avY zvg6r5U1QY1!4Qv2yWjWbZ%8}qPD6^yf%Vlie5{9W{zoYKkTz+;;%jr?TggKT4b8wZ zHPb@$o;x)aCzW=C?IwG!&{9n#ZVNdc&}^Oe_u?_RLG2~itAd0fjR!W%)=*y&IF+l; z$p&p7|6o9VK$|5UgFUlTT6~J#Ev(i`_+fgHA#b(GFWU!ufGXvKT#MG=DIe=;nGH=g ze3|%f18U$Z?0WIjHi@%?KlMlVA-khAmudx3R7?d-v__t z#CpZNnWueNDJpgIe-PnF7d+3CZIyCPHa1RC4jyj1_D@$31B!y@(8w|4D&FNY?uLr8SGaZya`ayv zvRw0-O3%66z5@=Yp+yHXMYR6jfUG`64a+gOQ5$d)@PcDbs^UI|1Q`3e(aJ zT4&u1h4gPvKKHeG#J|elln%#(v|j+s?at;~8Rb8YYYCshuxGHf+d!>l=Wc{0D+$rc z4>@d5jy5a*VMwo;exFC$0KgZp>#6fJlw^XBfBQbr?TbIs!P>h-c*ph6CN+$*CzSWu zzi35~bnHKk2B-g%NrJn=ru|gDJex3N26C~9fyuRJ75pArG${FGKfO88+w488xRZh0 z^tC!UE3xmm-ou=kh)o)44)`-3}3HBC@h3{GtIz{wZ_^vux!*1p8j z$LwDVYShL~JB8Yt>@iCZXJ?^>b$s9Wg8t~}DSGNCYo=~Zb;k__2qz1!C6{OXdO&qB#;#i_|( zOqnDl65lvTIqI)x01x?QO23i2`5QnR?{Mu;UgSDCa_TFC3U&u%60n?LwdcKfh@Rvb z43C4&K0hz8SfQHkfp+NcW}P_1>|*g5V^@s`eYkQ8bP&Xcok3+*7R>357&NMY2syO1 z27-IrEX`&s_Q4NK!@2eCi+|jKa~v#q!bHQFrJF)BQk25a3olBy`%8!|FQv}XR;kh4 zNBi&_6l?0NqnNX|`)^ItW=h)+I~RMarX^qR*FRLqG<}?oda(R=DMJiPf0vjU@F643 zjOxDy;)r6CU{}F^yo!l!pWEH_QP@F(P*@S!r@p@Sp13kK;pRF~GVw{xbQfSP5JgpY1r? z(5BFf^-Q_}``^CYH#yJ2$u|y?{S)!z0M?+}!OCt?5nzuJ5Aa4OgKZ~R_nnPnsM5E6w%nPy_4#L+_fEwu%Hk zJfnAf1LxeU@%Vt&kjL~Bg%j$tMQfvQ&G_uuJHU7e)$Zkn*iK3$7}`egNuS8)KtcRW zEay&?ai&Ti?%>Z^KdYvyk2H47O!UVc`~>w$6@LRPoStM7HEY!P8)MD!Z2lzF>q7%x zfhhUQfLSHBMYbk#Fl+M`Y+enF3MnT|M}8nHW+0*5w~!~4s-|!M6v+?~126L+!rv#y5Y`6}7fR5wK=Z#i zvJD2a`Wml4PAd005Pq?FMUKMG1c{(g!21aJ5R1y>rp`5ciJW(40*$z}a@K*~<~T3k z(3g5u;1}gl3cpKcC!FTnZn@A-DnK24#sxs?>VPjRDB8UfoOFimR@&0K3gjrZZXp}5 z1B*d|J2DjTk(H4l&2wI4uwip9y^fQi%&i|gMxFs&0tCo1G^}tR^;g7l$DaOc@T;FV zl1mHsG$V|aPcG;UfV}Tl7Uv6RXvQC}TObkVmYbyUquJrZZRB`lqi~lBG@dRXlU9n` z0TYFj8n(X{6ps~Ou!K4oOsUZoCHF5+v9NsWC@0{P1Pe4Vpv$y$;ZAz%z}k_f$ci8v z4EcITK4o=dJzZ|^-L0Q3+(|>+le4Gz`}aeQV((%C`_V%ITW(o3+{~4u0|nH}E#`d>+h=oR z8x-VW(L z0x2bz31M5gES(X=Zt^s()s7A!`vHBWEUv1b6?3gT)|5~#SnyoX3Xw`n>shFkfGZo6 z-jRMd9@6+WIZ9U?=^}Uo^xP6M$y35uuW&juSub~}1MW#caFrSJ98iBrnMGq{o z|5)ZTl~(ugR&IxLJGyQ>r1viy1I{g1Fk_ZhzLHe&J8v^Wok!qiXZD{My!LTJkTBIc zl@2YD*#oo_>zsaq*xI9(<_TJBc+9W!eLs7p*N4QAik>^;( zup-NvT>clreqE~1MmFjvS7wPYGy`&BR-;#6Jc&z;l2LCfB0-rCq8WrHx9@gb5w%}1 z9kheJvcg_kKd#FZb%HO+(52oGe;7F!Oz%kfno?FE1-`l-CpG zA4fwjcbpB*U$Z}6J-LegIB9&a5?Xall+uTaET=UwOZxqPGdr(cW?B-^w}u^(a&#UJ zuH;dvzCnBIkuF}^zoi|lq7R2icz04E=~++|A4RfK%Ngwci2cp+W*w)w6fCy34^s7ad3FX?)tdrI%HRujPMBt2yt4c&^85hbjege5Ny38*F?RLn{IsE*;vBj!N3*iG8Wu ze5o+?tDCDU;0d`weZ{lvrc`rN+pmAoO9dU0x>CCr!I}as!}ZhM1)^KoPOH$m-C46@ z*NxLtrW-c({729;PusD~*s|EiqD?HF*=eb{7vB;;e4q+L@ILh2NNGkJ<@>2Dm$3%l zBTzZ=0iXr^?}sDyRffi96>CK9y)_&6gp{S<7YXFO#I3&OV0OQxX0sP45KF!$1K#3( zDXrd>NY|-?8E&Mjwtxzm!$_N3MMs8=U6=zxW}>Yn?CjYlB@a|^I}L#i3GWcIJK!Pw zz#K9z5O^w#A01z2^~K=&Tb4ZMPjiD0vLVZO)lGa{Y6^UHq|U-{pQY14LzNwMPNfnD zf>lHhY;Yo2Ks#umZ(W}1Nm*r(A4av~)iQ0P2Uz$;b4fj${KbTU5e-Sc2 zP;pmrdm#7oeY6)!xf+TT7By~8+7}Y{`5!z9h7V2A;`ClMq9OgCU!=X z$bOm1267{fGJ15m#}2BbOjLT3&-tRd*L_XNkUnGOi<=AcKb|~(`C!*=h~20cRc@}` zuR54+1im=L)r4xQB-h(9Hxes(>GN8gzXI=!WEVeIY2}`P$isX1an1| z(@RP5jfT%cq5_3l7C*^uur^6kA^5)FY!4fTEH_;H#ZHvXv%*eO;4PmDUKnUdwn+mc zVUFOH(4*(`sP7*A1020AeSA6q77kS!ef1*W+&DDZsNN#g&)oBN|La{f`lZh(AJywJP=m*M;Vm?P4m9Y^(TGAR*4&%pRVLu1b0^dYgdNF5BeDd|#; z6QW#&K~VlG>OmhKA)3EYQNNb9cagQ{F<+qHPz{@IoZ}GPkg0Fb4gb$nvl-Yqh2Bjk zOrI&}IIq5$587{vmC%Y*T;~P{*5yY%{o?PZE@C!v4C z{Jh10l;SB%Fntn|FP}2Cq*yvV4_}$GLhvU6f%Ur$4iP!}JoCk%8>jc}uol#13n>rl z7fxL-im*tsys~-64I=gy|Libq6z6{z!3mm8*l>TYRQY-s^RK&9k0N3VjNxwV+Tq zmwpyfbr?n}ier0a5bQ{Tn{}=f8V%j;m`GQTvXpJ7K#bA;?@W3+;ycN}wt+^f?g`Jj z9Aa*#5(TM@fwLyIBcBidM}3Lk2pg*@==P1}#} zu7>K8`j2=&tLtwYgHl_M*F#C%mdWbky^mpcHZAk}^%|(V?75_D41i;`6u?l1;Ib|{ zoEs-z*n8gAUlYtxuCMEm<2reYMb&dl_%-(d$%TnEky34xeE)VFNS6m6;lzg0)6LarlryP6R-or@z zf@sc{nqWR`ODE{aJ(?5(Gs1~r4}yl-x{DuEjft0`%NvuX&#O8QRJagyqkH?&MOs?< z(nvLhzUTGpYzQJ?w)!BjP;6UPdHiU*{BfkKzL6dGa^fVucIX6E{8Zac!FNRr0yB-o* zvo}1cexoS3j|Qo;>@Vso?kmIB=D1uE5@^$NycP5cA((@vkrkU551~UGSaqAD_-0L` z>y`qDY!zAU|GF|&ubqukG-dX-pBB!4fbcdU)`{Elfjz^(&$MuM6Z=Jg+VM$eGtbpi zOhw*3mXln9zvMJWjig@%D|7pimS*u(M{Z#LxA{u(0|E7omAbUd;N-oUt{U-e^IJRl z{&GyqzM80r0C36Gz~3c5dWO%LXl?IeB=!L(tQaELIk(eK*|Fj2VL*#tP@ax81V^#n zq9LZYzvymYF?X_UR;f(exzViidtcpJO(}jHf%3th<$G0BX>~{#ncMe-ABEu^pkmIlB9QvS_S$`XyZw;zASFT+2X6kFH1zpn zid665yy@T6vcGveqe-hO^QM>lDlmzD?ynL&R{eR$hg{X&ekoY#URVT!VSL!u`Iwa2 z;<4)MHGn6fah3Aay&@S9JpOY!s&yGT7G4?QxpgxgB#xo zb$r32e0GL%ER={Ds)M9GuwD9@A#@BoK@Hy8$A370@O|;4Jc>Qzm`|w zG9b*P8vNv2iX=qZv8z+VEvpo>Z$Kg2q@Ca1v-AqHAGw*53jB%AdT_O9xq4PiJPlXA zgKMDA``tNEa~^Y<^g~r@gjj#+;i_8j*K<-c==_n}6-7S%7Nf*IR^L1A~*bOOtoPqG}I5!LSwEwS%dk{huYAxxs4=%ED?!*Xflou+$ z&@UXvT&Bw*nyxNJmZgE1_3vRg>-(*)?U(?7b?VV2vxa7CYK^@YcUvtm;v8{~TG-<|9vYlF>{! zj1;xEigq0!wWq~8`FT?!jqKV~HpSq0y}x(cNOH;r8J7g(7#t7QrlQh^^LhQD-~3Ro z*n6DU+4!eW_)ba~7_{xr9|Jg~L$c;m0P?H3lIzjiBN`>oFOm5rMG-a$(-o+ZKBeEu zB}m0_sQ}=yX_GY24^{m7+ZP#`~`Lem%3F>pUK}xti2vX)Y`$6Fr`!gqQ zn*O6mm*|)ZMd>NE7%IF=XE(S_NHQu=I-h*9wJjG}nHA>*gg;ffHY{ZE{f>oEQXk4= zVW7?YRUKaUh!HV#yn;2WjrX8V^l`<#imjLr5Cr`QJ^DzuTJ<)?*d;o8!ewX~M#%1< z$yM!`_7A4Wv%DOiZ_g6iz?t*=x)qLVGfynWm5dOLJ|jJ?%MkL94?fu2w*fH6?63m= zOW^XCq1H&HF9>G1s3eAPdD?xx7_%Fok4t*Q5lh*t$a>@blZWGkHa>!<;S$vIJm?#< zf~r*bj*vq5&gk}i^)lgAiO_pF`j8Dm><&q3RrOFtvdnMr$}zS8XKK2Y>Zv7DNI{QO z0z*YOdS?a@u8&{z_#w)wK8_{T^F@>VKQrW)_k^ttPw!}QVTc7h9q-vHFEMp@ zZ{e7Ed%*;I?4lff@9#I)m@1lF;-^IHA5G}YoF?WILf4zP2zE>enF}r~dp4JCc-J5BXhmOL2u`q9p&k{5K0}`fSQz zteeKj9LJ+Tji7fp4dyaQ<;>y_ry*U$PrTus@$H`b6!VPn3_2`K;3L`4{Ca>NHxIeB zkuZ#O33vu0QctwW^C{`H!+BJ=R|(X5e&W-+>;7jZNlP-dW-wl zCYr2UYOXTpuvD{53>kn(pEV+fYk55M_wD@dt2U{eUsd~VkPbnn?mVf@JbJPJa%NK| zjDkWmX~PYC`59OzvN$x~qq;JENCUv~o{xmPxl5|%Osorw$A@F~=H^gh&1%#-3< zF7d)YnShSH{Os~MYv;q;=&Wsf$YL)Yx4u^l(&!M4cpEbOUQWM0r;TXNz)+;^atkYp z@x9Dq7N&U2jv*FmS)-~#QEaMrGv=|vT7fRgmAyc)pdXWeX7=X zT?VqwGfpfD@c&RqyAy3XRdS&W{PuQMcW-M+mv_X@qIpHscjs|z30-X-bu;rkJ3grf zff9b|><}eVwE;a-6=(kOE`pC|>jC{aj2k`prbZ+c*`hazTZ~`4EylPPBpP^~bf)V0 ze6yZD9iRF2X~Fa7N(jGz-Bp~T;{YB|TVhDlGiFE05s!>rqP5JJ&j0S;{yVUn23DvTB*O?_2bT@tifxn^x>&R?wPS9wy0}8{g>Ta!>oh*eRVvJ zO*ASM7(rGp@krM|_%||0LvT_*c`jGa!kfG}aO^y)zhTrVeQwy)8H(|i2;o{81Semd zJZejU_=6N<7pS`30>DcWeG}Si!>7YFSr~D6vOTK*q^w$dvm*t={bPn}4PkQ+mK5F2 zPJ^Sjm#Ibeu6gNO-ZO?LTV%gF?i^80e!k0Hq>N8ZyrnKzl83pFjfw3eh%7Ih<=iDY z8MSGV_{ghZ=0+qhlUVE4h~KNCqS#E6ZS*zO+Q*6$88U;k;U#rB&Yj?_@TNOX7Ntf$(pdawDc-)_mmoAgOUzySf<@GfA=uP zn>+TU8#QtiGLk=XsR(hKSEp_9%+rDMGML@KyNgk4X@EX&OkMogbN`2O1O3y=B^0Zy zUw>xLO)nky9SDZ|WsqlVu`yhTe4y4G@gnr}D{+*o&Yd>MBL3?= z^Z9JzAUZQuQ`EDUV?yZMPNAqc`EA52HHzFZB>@DBPBJ-K92;3>_yL;dljP2}T&Mkn zUuIQzyaqu3seXa3_rn26)roF!ixq?X|wGccU_*&st4_af3VY#$_VlMrMaGMD4`fkxs!WxwMpt1l+F?M z8%i?geUnE`C$^bEvl-1vTL7}13p2o~_UbnK+Wok}oK;7clhg8kF#CI?Y4T#43GAjT z#fe+>w&IV@vMI&L>X#}}2)28si$~L}1USD*Wuz|7iN)n7!hG&JasKIZATM3Vuk^c) zY)BVv+VD~ZA$UE>!h}MOF={_1;$4E+DieOWi-{jePXOoF9wzouznQunSbU)VHdlA^ zzI=tR?)P7515fmg7m?SMp;gA$-)5KYb)sv7x#sVJTcWbhl+vTY%C%Aji`mG-r(Rj| z`O$=r?_eqfP}kxSxWQdaT1#=q8S;Kd-WRBri(x9L;O3g}cLX$fZW?zi`5NU>$VM8R zrQZl}w#kxUAI2b@&mx!Xk?IN=>a}1&3}KiB@skId)rrLKu!4h{@ToHCQz;Lb@746< zf=#8y4(It1Vuv%FR`_MVtNmueO5Pb0R#+#yB=yK$^YG%gYN{3#X$^6RjUbGxBJ2mV zNlFsh(4osSJ8I=A;V}7^z=^f+JD1^vQT_u;I8^{Z6{b^CYswKu7MH<`!bx;T2s2G)a^(hHcQYP~eZ1QbgJ zXSUnW(yV-DV$>|UJAo&$#4(nsMpmuYGoiOyYac^=!geo#vIESvu{kwDFj(5J64Wh>c5`GoKYA zCk&q(k-=RQPYiemu{10?(BNtc{v4EWwYN2`(9a==eA!AUyKNd)#<<_pV%mu0?EuSInX%{UeLXQBkv_@D3Ptp{A+T-K1^@&AvUZT zuXjU6L&P!292#rH*tbTRQ8 zpKshjz#|fb3Hk;<(j(=p{U>r4mH0a@NqT9cPWEjzoukuw^cjtP(138}YJ_BX+VueG z3cHmQ40)i@=?pY7T7g0@g*H{4hM*NylGSF4oLXD66V5TPw}9uJ@x*)Cox3AOB(E~w z4p4mreuoy7=DzT)22}<$3k)T4gu9H`;)uHwMH$xkryPBR0*6%f8f| z7iaIuJ*FT}IM5)1v!Xv+DG?tO-L1Rib2HL{Y!1^pDtZoq;XVg5G=v) ze+ZJ~V}zBBMZ;dW-z733SYh>;#c;=yEPO>&c25l$@#VoQ7{mwe2oa9o(VQMR`H+0l zIXof$G+mP0Ue)P6N~08{vyvoiCum~8J63+|%U9Z@i;H>jkC&6HR+zffp**0>IAeqL zUe!g`<{aeCX;AE%4ZI=iw$lpWLwwbhr*9M|Eb>JK8F10K2Db57zY`8O)mU`cGXZ48 z(!j+RatLR%eMUwTU@Qdgg6{}jUnN1g9|4#4uT?;V;UiMO<^*ccswcs1rtZ#BA8n7< z(5mHViot3}PY;vG>JnNKi#cHs-^v#Ajn+P!Dy1@kYBg{&0iGGxhT8N<7)S&G9H5A0 zkSag~fh&y%v+{Ib>-P^8C$IDQ@>&~Q!OaE81yT`fuaTA1I8BMGtc2`5ZUvXfDG- zTGjsIDLQNSo2Fi)Lt+rK@HW$T$n6mHZ0GVBl*whOx6WYjyz~5%vW)rgH#ES>#QVbO z&J1RKk&bWr>?v%(RA8|6edAz!1I2Pxk{^O;4m$4hJ3+!d#KAbg6vqNUNA~wm{`AJ} zu0}wy0tRJ2NvP>OHfZpfNQ-PiH5n3l>MF^Z4y&&H56RMoTOGM3H^*j@Mz|jMy9$My z$}BqnbDzThL_#o!9qKZXt}p%%4kBC)5)|XEg{emwb*$)N?gGe}-?e|-Q{DTPlfs?g zcc0hkH%Pi^5~}@Sw&u7dVY-zM5)8a9xwmW0*U+>Uh&HSP`L&sY;DJjA^KS;LQt}d2x994HK zVpXRFPi2;snKQ<*R||IOX_I?N2DmK?^`LeM$O{Ct0|N5BO6YJrwlGnk>qJ_ABxjb= z(T_V4_0I}}A)H*=abu1Ka7=dlt}7abl9J)oOoaJ4v3#8*1V0b103i|X^l+;!cZBHY z&z}O0Ms5Wu9Nm7ZqIbPIF{*qy>Ye&#xqVEBsF(7T5bsle74%~MtdCm8P+W85sWiyH?wkDzT>&)#>dd)6o?k;+r zZ&d81ZyDr}-D7o9@qm^2Qy>-tajS*U-it~Jmq&#Hz_!eoA_L`lQ;L-EgAtbxMKYE+ zzFQdU+OYGYDWaWn$dCvJ*l7uhLhKAr)722s^jud;B0D1308rCjw8}di3J(x*z-Phl zoPdq>5F&IONqd==VKeS+(QB%6Q>4ty=VW=uLa(%pI%s z7s*nxwds2#x*764K#!3kK()lf7Zw?(LxE3wk3%gUl;k`G+}K`l?I|F|qzs>6zcg!9 z>e)c*b1n%v=`7Z;cPY&eZ2bLCBSJEe2YX(jqWbjpI5ep^ILm?G1b%23k<{N+lZnJ7 zL!_T~H=l^`oUI?cO0`v(`sLn>LlkvOO;&Oe%VHq1WONAYhw++3z_W;XQp$wc7~(|- zrku23l?Am&fFx8L0WSw~$CCuPad{cTE@q6qm57wqF+li{aFnIwYw7abJ_WX!~3&Q<{(7ntlmM&Zh`(qh{2lqFxWs(HqFw{&ol@HFlx|uvk1lSoPi1S z4R6^h+0=_&>Mh_>I?>BO9Mv}|H6Ubs4p@5+f2r*inYHW&;pp1Zgy%*YFjrR+#?C+^ z9wJac)176$r8>=X&d3l`YMcb4>8EOzpXMkw#t}uo@R`oX);gv^MDU|0Od}<#Z82uK!Ir1 z^WCQk*WzC6wAxZ@M_9=57!I}z_Imu7uxy-YoyNZ@( z;_1*2UOe(xQ+oW`qHsj=-Q4~f`lFx~*TPvfi!FiHp|V7%kcZnYf@#xs!1vEwxM`R; z@bhbj8xNhOxA`9GdgfXg%*8dQK{Vm>O~Fqm6&0T^PM=lC_GngVHEOt~RCO6)Nb9hE zRwjTfvYTqcVbe3w1 zdcrM~f)sq!f1?nb&_-2Z@JvFsPR(AMqN_>M&&p!qJ6UiAvbT2~rs>0LLt4PnHV+dk zM97r$A%kDw)YB;+cjmP6k8I6|;F@xBU(7d8?feoDNo6MVph{W$BVn;ZrB$f`AM%0& zY8IT#u=JO88#0pJ2X<3c_iZPE=T!Aq7-@djii4AKY32+DH@*k1DP!mkUcvu_N0g!D zVJb)Rz~+nc=CP(xGY9G(j7>DE67GJhNi+#&tK|T+-f;Hw9R9zQ8 z&2c!ZVcUgO;0?)O)5tO}L8v$nu(VX1rWse^8nL<<6U0h>(Rn($tf6#W-Of;Q;eMm9 zq{zy*Lh3FYenIA2w1jZ`4OV*m-w7|Cv+Tp!tW5y>#tSdii3ssoomC! zMq5aIbciYnsv>tE({#sy2tT+~}S7ISYoGped3t!68q=YC! z_qCQ7r`WYl5+UL;Db#e*qkjY3P9Y*pR-#iH?1&SWht%SbX?0ZG3D zD*0#u&3z!mIn_%~f?&wbmEX#`x{4BAPe;Db3UDcB19hK~6dH_NTfG`UfQ4hU@wV%c zO^0Ak`^MmjcU_5q{WX?)N-*Y52X7u0A3K*BO8UU@V98%n4A@4_7_*z^JxT(mo7Mf{ zviJ+;?ZiFz-|nThN10);c!*Id2-1`YrzXN*bCRjG`>wSZUZ$MM|D2;a&X3%ti&<)s zWt{Ylyz>5c6jQ=i8vo?qLH*25i$|m` z+=fdv6UQ?TOG#>3d{`zgDx7ziIU(udwLCH-LrGd^vv2eCZs!qio*slY*RN-(@RFcX zG-=T+vd@sIjO~lOQ)W>d6TflRqU166`j{Hx&eO9Knqqh&@%r*o;u zPx@K4=R1sAG&UPIe+o7=3C21t-t5|f!Ql6(@Nc*BsgSO?!M`H*Dq20qHmm3jWPP!t z#1nq;i<}3~Kp)BK+eH~KB~^2M)CL=jWL3S5zY~gj@Y%YbDc2IY#N2Nx2}U;P>syjc zOD0G~A3MKPza!d>^^dKNdg0pw5o3MOx1QpXEoG7%eMQ;~4wK88xk5p3MkyuFT!fP8T{g?v zfnXRQ8&{@vO$AR6bA)_%tSM^VJyt4H*^Dn*IXk+s{U)k!wJ`Nai`rL(FLgSWaZaCy z*>ug|HgctrNF%VX%9c7GTIM*8<%G9c%!-~bS5a0?@JyWhIKF;!Qv--J5Bi{YJ~_H| zmrw#H5dQQVlvKQ;A8D-}rJN;O`T7Az4mJr6#b`{cF^w>@taY1jOZ`}e8Aarg#zLBI z-|eT<+x;}V5Q*kA-2G3}zqPHG`U<-E`KP{Qb z14jK5yRvvxqv#!>8U*pXS|@i>m^6a=%48%C>Wf{FB0sk@t88O~vo-4(`-~)+PtKa% zaEZ(8eKLvD(%4+j6xy&ctabUF>jOnNH$X>~CX|szPO$w=)iaUISC`dW85(cAhkpiE zU(2$schRVCX@jeawK_e7)L^}%w-*znCB-DXZC z9KB;JL+hoo0@f355tr$*dX-*+UPfF~)ZBX3*P!-Li9VOopIlZX$NchYePb+ZDAd!``Nm4cL@`~#%;I;#lvd>lOo;}5fMY9tkc4JcJQ4;Yz#eDhqw%_k(be+R%ez_T4 zia4R(syouf6^?~`0*n!W_Pfz?4nd0V!UQ9ggtJz9=x}d5t#S;rE$+<}C?Qu7JoAU| z03=Pas;+$-Joq#DM@QO7UCz4v$S*;SUn7^F^oDUQdaleihp&1n1#{yK07OSbfb;%3 zv0M&{hs~GQL?;c$xj(K;Ny>r+KTtmpnt+T|is}09ZwIAw9~BuGdJ+CwQjDPvUo>s= zipSiH`3;xG{Kju7X)U+sBx+MHcRYCc{YK93!TE4VZzNLssetT`%W&CuU;DE0B00WG zEK4Fp(4Scw8}LJIWm}yANW@YsPm|Si4JL;?LJ4nV*u}R>y;ivEMMYR_pFLNzemBf3 z&;&zR1yVz$tJc>BD8Niz(<+_y8lwsdwo{VO7c_l`Jj!(TiRXee9RK~W;#?OcHQ{LL9`Biz_1T6;X9Lu z@Y#1Lf>S`<{&K(=6E?s56w*XZxi#Ac@*`bTCZB)2Bx>A2B1KSIB2ix}Fd2dEhrw`EJE)Qz7^L-}h2fg(yab`4-^M(4?x*Nc zk6(TlR8_r*Gi5-`#xl46gvZ#P8Y?}ccLJa^-7y1YEBw!9ru&*Z8zd?sPEKUO)1(x{jr^lFUr#DkgI0>+V(iCYO(ss`Z0 z-8C<({XK6a+597Bu7VetLZM{U50`kAM4rNlE$Zm9H+x_aWNwjXw7-WH1NIfrEq9|0 zjj*7Up#$vH*fqHyVbtZ}`Cf_DO{cA1N<2PGmBv@yiy6|yj7cP>3Iv#)Qoyv37o}f{ z5akI+hjSbrrJ7eQaE1O9`BK$Buc8_16iC^&mjuzv1K-E$Fy zae0T$WYz3OIq!1{9PA@F<&mj$LflulIe}qI%jbJ@DUDLhJxX7RWF5j7w`_m6l-j=6 zRQ>nrtm1~6gkz>Fv*uf4U@uqaBg#mGy@^VcKTOSn2^!D(z|ea2oX(|^PEGqElj6^8 zT*gF4Xc1jAP3nkz4h#G`vop-U%Z!j0Mt6Q_A#|lUm5-Y36 z0{%jh(QmT1!dI}n^rmx;O`yyz+tkNifj9sODiS?!J9A~vH^k7{94D$u-{$5N7f-N% z9$MsiOr40QC4S3mRVNXs!4iQQM+Myf}LQ*9Cq@|^| z;alxfb?9kJ`{YEkOjtCHJO{O?)-M`hjD97BN# zk;D@M=RZe~TL|6={v}Gr8!ktb3Ci~ZQ6Xymg|O@G-H$;^fAL)CH&OEqAEoTuU-;4H zn%?EZ(d+jQz4S2I4j<2TUgwzpiD8Qz2?`TP&j5^w4L$nad`jsd*sy(Yb{%<(brh|~ zM8rG}VW<~xL3MwCoj)D%$`i8oj?3lW7}p~U1zzHcF0v!`&WyF4=|ApAgfFb{Z6#o2 zuW10mY;N?b55~KXY2-0uL^UhxCUN|1zDb{8#R@)P5=k)ztw-mK;7=#o=R)%hzIKgQFv9NpI1&cp3DsroPb`N06I!MXXtrw$r2UqKQs~QkrWm z((ekQ%I(Q&U&F)W<0Z?SlGOV1OeH`4u94YMB5%G3)ItQrfV7-5?D+<|y{*YUqm;(4 zpP@o*^F(pjvrtulf)+WjRh2>F8_yhMmp1pL$2zdYv`D|<=xv*~d(21UHKtLV-0Uvs zYhmb3B<=|+ZZA%aEpwU?D@RmyK%h%sT@*BdFTv}JDua!}dy@q4PhmbhK`kwf@;WyX zN!H}n7rsA~%k3Sk?){`FLou>ZhAE_ItNQz7c|%!Am=CeqPowZmJ9+wRXXZTbOmhrgXCe649WQ;sep5|?g8E@e$w|ptjYfw!$>ds= zIWvnGQe*Xu2hWD62A9KA3PP_gSgsQL^Gjn~-}XG$~P2ocYR0_(1Cb2#9JS z(*q^K84sfhp)0{)GJx)-yd zQ7(hQ9zYkpID%Vh!{M+@y3=rS!eiWREomx$6{u^t z;bw|?=*{@aGZx(-3JYuUvsI?PwAhG)AW39~n-I;9F<=3aPY5&TlBwwb04<`!Zdb?*>gG>!$n41M|H|(dx-h;dO z*CU(H6+>cUYmaw#{42a!NIjV~mr|Df3Vc$G#eOSc2h>Kl+UB#Ei9ovblCor>LN{bIZ1VW@0;!Vgvb|2%{z zMhOCJ9hNETZHq;({E-{`K^hLczL=w9W6|S8WvyX7;f8AUe;#~OSu_OKH8*->HW)=b zMZbivpTjY0en;OxT_-Qr$b4FPDNtcVHE+yB?cYCw%*!`L!{oJ z)`9b{L>11z9|8kPQO*;ooqfZvr&j)Ra1b1}>+rv<#$lEC;NS5@n>{#p&Pk8v=V-|G zKhFnUAajLzaYRx`W_+{w$(v#KIy#!%qyBU42y!3+eZQzrQQMP}PS~ z%Epc`^ooj%YW(W=T~ON}TT_)|N7n(;`#+DBCcKB>Xu8?_<^1g@$4*@M4>iGxn23Oj z6Iy#OL*P=><&A4LwSGZ6oyLv-`#nM;D7_j8S#Q6OMIJKGY0J8}+y_Aa|2+7YsCa79 z*PYr9( z?`1dO5W&x*2KP*vks%Sd)qg)&IaD|htA|Mww_mZ$3v&MN{UbY7|C?UuZ|u-}`LE{< zy8r)rE}*CWujdHR^1q*ZK>GiB@Tgq;*Q1Hrj{kau_WqBM|Leb}puF&3&!6r5ha7*X z^B-FL;hq1`;{P|2^`9gE`L{n5_(Op|6!=4dKNR>wfj<=Ze@TJy-z(#IG*Hs)`rSWE zcjo`)+JDyiLxDdO_(Op|6!=4dKNR>wfj<=ZLxKN$3QQG$ibfEs;gh;rf4tNGd-DDH k#UBd%p}_yADKJN&uDL(qkzWpUIwHtPJyYE>ZEW=a0qVmQf&c&j literal 0 HcmV?d00001 diff --git a/python/docs_src/source/_static/logo-light-mode.png b/python/docs_src/source/_static/logo-light-mode.png new file mode 100644 index 0000000000000000000000000000000000000000..c07d6848c98d3084b6df4ef4f21fd5d8fd32b2bc GIT binary patch literal 48816 zcmeFZc|4Wt_dk4Xvt&$>F|;d`C_@NahKdYPQsz{YIrF?DLzxwtyiQ@SdtL8qt@j%5ZW-zyXWhcL1tG+G@`R=l zLhMfwqIYFtfbW#fHd?^{m|ahtcSi`n75fi&kCSB>ptrurG5?uoZJ9Q+FL&PLuZ zCQQnouRjR z_=CV71pXlK2Z8@@1hP8VH4q}3J9>T5)U@4=h+F0@_^WLG&j+oUgLcv&bS3>(WDa~= zqMaZ9_Mfi@gP9`#`Sbt$e~@sK?Z3Z8;Cf2?-`}~TN&W9{pxvBG|F7?y{3GUn5cz}6 ze@x?#-TViMKR)#zB+&nFiQ+de@5qRYjC0t`%#FD|MsfB^4F2`5Ba!ic!GR(1^Hm!r zgnYuqY!=NlWsj@QuRpR5P~Y4beOGg?j17$&j`01T-zO^FYQ>?lZ+FQnXRL!vIxjr4 zIjUvbX&+k~v{C={O4D&OsFYF#thRR&)%6fwTe&!oap#16gyXAQlS`0h2g*n%$ z8kgth9ljfoeqF{!*Z=!vt_M7yZ**~1NszfMF0lFi*#G_r{_tcsJl7OABg4sQu{!de zHp1cAPK#}p+W-C?B9kB@$%nY-$|x#UZS6GNR_pPv9~0ZL=(E3T-;Nzm+!gFm&MpG*eAV5J?sW-x8wh_3s7c>p&NJv|L=zk{yvq$^^n8813%*DxI+F5 zkWB1u0_r&4YcsVB4ee3)+h)1)ujim?Lre`8#h9Izc@EFwB2`6i@A@z4GMFNpyN*UI zbEW?~upbSFo3OsbU9hFaMxE}z?r~8J?$N{`e~TyQe|@A^8vxGkxWa?K`NFXHz^FyE?!@@y4Q7N{re;nV9P`a zR+};Tn-sE|htnMAm*biqb)u1Bv!R)?|buq8$fmf9Le

wT^+ z8h0$i6U`j!#^CV(`$fKuX{W_kJ6p;TPQ%_-3=dFy^#m&J9?nWAXhzLJ=VT5kf0u^<$T8MMLcZ(7?{t;*!JX-%J8iIlE2>1y0!e04THqt;j`eWI^8j3-OQ&! z4(&|qNJJu}zYGp=JcHx1w?;4U4yU#V=o^pMiC;6%EA`F^>PMZ2 z+>N&4(38H|)b7QeRa!#T(fc0{cRcR6i+-oYt_sI=h~u}u;1F?dxBdpd$SKNo;f3sW zNs*OH%T30f`Fn80o`vyxf(XY}~aeKYM>TAj5aM0E$^;O1Poz0|&Ftjd3==+z z>(%Pi9w?qs77dCh{RH;i()n)V#+;L809&w*&#q*Z)wT-XkEpn}!BlmeZ7H=%`Q&x6 zh7zu=+KwtltY`i3Rd)`fOFqbzhsXg&b`5ZRDWF*nW$A3AU5Ptml>MG@1#^v+O6^8K$ z$tSZ5WR};+A(S- zA|7Z^?^ta1j{(<3<6eD(yylruyOE`^)K{l)7k!d`YxtTw*hO3Iy;JBtMK`OIX012D zvFbgSCF{2$YvACegN(-u+D+LJ!VXr$EO+R>a9BQj;)mp-s^HMEhPsg{zjsb`UYcck z$*G*$5VH9m3cnGlQd3ktk{{=+YTF{2b5f@@%npx;QXM!^y1#JNQsUYE`N;wuw{a+Z zFHni*&1uB}lVg!9oMRj3S1I!I_bdJu^tPE3I0mjj=$o~^F&7k{AIsGCzQR;EnjMky zW*!oh&`~LI9Dm$c!6AgA^?>gf6vB3-)?UQBY>TuUnEACFus+OXOdX(*{WD7LhGEA1 z>G{&O6Q#^m7A-a^ukGY_{Pq1!*uqYBL>O@2dcy-PE&Fqq#7FlbRKEw*3gt5TAzVA^ zsr&WdTMt!yGz44OAD-J*#sxapN(4urtn1p?iZ~UY>t4*JL)HfXtEk1@#nA(|F@-Jte z^_U3_%HKiSkEVGXgxU@{(;{o@V^k!~G&F0&jE@^RqB0fS_u$v8aF4I>o4S!jhS+pb zK#A79r_)T%QsJp_rrm5)?zcf5-8n6U*0br$1O<_^k?AZ+=o>F9vNi`{t%^glPoHFL zde$lC8ENiTcz$@E`SO-$R%e00>@Vx{*HYyqZTem>+HqXClR#qG3$g!uP3bfKah&vu zKj$?!S$SH-o5-PrJam=L+zR6jXt_PC|=+%cxAMY{u%wA*#T zLho}SWF@froI{(@dTWRI3(unrh)l$^bE^%t5Fgre?dP2gtGv?m)MLA<%)U}CWGsF> z3lJ0QM|1Ixssr~b;%2g!pS&m!{BA*mz+;F)1Kx(u1lXkL5pOJjC9AVk$%5~d*W-^B zxzk@Yc6i;hOp|S|1n~6a37gsF4*k))eQLHTlGmP3jfJW*pnM^SG?(zCi^rX~k@ZpR z;&_Kq*TTQlxfc70&DTz&#=nO$6a&Eb3d^%mri%4cj{|3cBaEQ zjPa3O;*3Zzze9&*I4=#>JBK*{Ds84zh34ijb3c8`-pL5$wkE%*k@)(C@)9#AUHA5} zPASKb3sFjo2gGsRd-_EXL2S4#(e;QRnyhIddF8Fr;}Gw=O>rD*n}U1lJw`k!_sH#b zWS_5a&F9q*G*l0%_XqB5TrYd)xkc*hjak)C@PpbIO_1)Note{E`P$_?b5<)*;#0Uz z=|bDyT51O^vY%IBC_em|2JuD%W@jc`Zv~VsUC$N?p}3Cc^R#l4YKli_BTj5^Fy3Lj~*w6YZN~)j3+sy?7TMKc(vywL=M)jL|Q`5 zJ+9VW4iK$EUAHMANBb?-LUp1Sb96M1rY<@+gwUs#^~DNty8Zp6c_zY&eZ&$^k7@ zbl6z|_Cy~=bI~_~ReHu&SD!H;)XJwoYHn#RS=Z4amVAi54YtQ$xCasRK(Q>KNVL8U zI4V);FnvcKTr&9U^EMnQlYYW()WKs`J@`h_TF1Y00&}YAMx4FOM^qrf9Gj zX<`TC)OLw7!y<-uRqoJCQW{-ffyjfL%lPROgiIslaH!-?Jdw>*fdNsUfR*mI-=eoA zwX-gV(kf*$$56f<0;j#f2@oSzVC)a9|kBp{yq$vo9DmxERj0R>0-sl9{?F%a|M`4PQ2DJ#IhXL$ieQA)YZEj?#X|uPRlXu5 zh~3D;Q$z0j#--kmoK<*KBV^#|@BVpu_TqX|A(uzvELuBrng)?%8p?855SdG0;e^XQ zxZSyv5afhB$~0BYzvK6l94F-hQW9vwcSrYgl%@fz#c*nFs?NHn9&BoM(wTH8Fr(Hq z?hq>4KO1!N6Cq-f4-wh03t)#MWK5UG81#s9zJ}E47pPGCTLs%(5ua5HEoS1w-FRxq z;y%>+w@=#$H)>jUzR};M%ZggX0l=jU@qi%Ky@D<;-W>?2R-SUYLf;QcF#T5IK(k!k zqe1VonUX0xTt|F7$yWeRO4_e4W7mp4kP)OG9GaUi)95VJ7?BnBTx#SlB^1Z;*K@m%N0|Rd zi(7U0TpQyt6CI{h+S!vL(Uw(XCUaIISd}R&-Ebz{pnSQ+O1`n+K{j2pB(R6a2TM^D zBpODPvMqw;eI<;WrjD448kDB)Z8x=G9&)!<6wi>ZFluwC*uI_)*G8C})xErAXcCg! zqn7>l2IHjrE;!9><%bG{1_HeaQ?Venl?Ma1VTkMMmOU4k&Rbz}RK~zc>)wYiT}WQM zD|V5`S(gU7^C@QGF`=n4nd9e7XiR?EeL06v+fmEynv8?5CQe@FKF^HMiU=s>oiV*E z=XmXc874}(K|DHlDzDX7_;SVezEy}{Sqby+;>vM6OX7hJs~C5LMw`)p!7ZIKp<(#@ zH^7*)F9ohH>vpvkm2lNk2rB>3#<(o~7jFp=(WaES)gfw1T&MMbg2HRr^=qpFnlva{ z)(cZc>{xlUyu8jQI2t1yEnGgV#mYs^P6dCn<%ZJpvam0OMs0R}c4e6FW+m)e>1i5T z$!IRz%_vxO5UK*{uC)&vVyNH$`0mTrWG~1W86h>`?ckO!lJT&>gvevW;U?)9_<42T z*Bt8JcDON#ONGmrpoDnj6y?ngM!+E0aOE-@j9H|OuE*U+C>c&=eZ;#o)ojqT^W>6% zed|9)15!S?y}^_rRg|Z5tsOrd;;bT$V}V?{5^`z8;3=CzKN#ap8@&&-BdlV|-0o>n za8sHZ?&RJ)s54q%85)1@dhZBbuVUdw>(rMjd=S>88uj9t_bpfG8CZJ$LTA*Q&3{7Hezx zd@Hx+z;7t(DrFq={m#ad#=Y3Oxi~pnz-6XKkDTKWw63fcF7q{_K}0X?Bw^wKzZPaa zCn5PPUl$luC`~(e#qap<2$A(x#!+eOY+>#<%@5{t2`hexn9@*jekTea+TpR}FAyJ( zLnJ+5CRxa@1WKoUjbhF|`MkP>4`7?#19C% z`%l1;ez{;8-p~!9quUU&06Qa{2RrMzQxwZp+S5{KgaSV0P!k@8P(vV0l1FE?@ErOR zrY1B(7)f>vAf!r3V$q@LUA@=q)W>tep*T`PoKb2Aap%82*`@s7y{+y=wE6zPyJC|%~2t7xk z03YyvQh&PM2(6#cL8+=@^{P;OK$sw1b5^rDdn@ujzu&{h%e=9v$bOmQ$pIeCQDnIl&WT zzvOcMKC;?k_h_RnZJDNDiQ~IeQts$_xscT9&;I;dBzzWhG}o>(fnntrboR7164m_K z90XN@OV00H3Qzwha!pSs9ocyfCp(ZxYc-28{3K?h5+PQvxcq(;Qe?u&zT1C=jQOrO z7PLHaApj2vVRUiXz#Ug;zq7#*?PbPWq+8~FPlF4zpQ#R*>234q(dAO+-20@KIfpsOD=>QZvWLbGqsfF z(h4Xk?IksiK)FX2&E;HJL9(La>IvpbUl<%h!K|0SJ*}OLBOAYeR??@T)~Q%QN(=SN z!LROnh@V|3fj?dtbcHqF$SKpG(klrL@)|PP3l{DBt>x6-dt9WPtymFAzBqiz_A*hv zEA?FJ^+gt2sDFii?D0S*!wa)#6UyjYql~uBm_N@?YF6m&yOUDZC6g|8wiu@bl}|p2NM1ah!T(qA`$;d4z~jkS^fJ(Sy?nSZDl!>J9Zk^NUHa?YYaL455bMNm~^lc7aLH@kL{+ODk6SeZZL&J_Zc1Nr*jL^b!B z*WTh#yN3tyBVxpoGcQSBJlmy>I`=WvcrPO~C(oS16g@R`#405F*^@;Ys-TK;2<6O3 zYR{MX!}BMMt)9kpEiCU#!Kg<>tlwi{!gdNGOi_#!HVrcMoN!#*YG}84D)7_Z;Mf`P z^?}`xLS(Zgo-p?5xJXv=GCmT&kotJPkA8DT&lhj+(r1NL+#->djfI@mMc~AviHxhh z&97E|F!1-gp9#rOJVJ-aOFIbzr=0W$*2dfLU~GhBh{H+ELtKi$IuE#p7G~>H6ITb9 z<@W1*y))+JVKe7vp6M;!qssI9+`Zan2M=ZbG&Z848bN+TY^RuvohYV!QlH1oIZb(JuAy?WaDI!wZ}(uUAne0H z;xKI+q5Ff$rgVGxdgWWYcR7@3P2Q?4|0Ljv7G`QJoNO_W@{T^Fyf@|Ni;An%?a=mp z6wFc&%??DohCB9#`&fD)ZoSGS&z$od ztj3X51tdi%*HRyE(;V2?xeu~AgpN8Fm0Qf*^%j@Jg)UsJ@&z$ho``<$B-ZFpt%am? z-KWV<&ZcXR4XmlH2*>2AVd-aLAHC!G27Od@i}HQGW7eVTD5DFP_CTAbUvQzt>cM{` zOR>)t?hy7)Oo=Dn(o63@j`HW;{LHc13~%=qNso2gi3aRqUg^JIEmVfy2>Lk#%rq#O zfgYb+_wzEgk5NjEkX2!qZ+xrpWbrP0;=Zqi!^s6$?$G{r&*PGH=vPSzjMAdD$wV$k ziJIBY)TVP;Z8G)UyS%qJxKw>9O%wdxe@w3G{59aToUA5)@TtkMJ7`?j#_5^%fQn(( zlE7ovtq!LVYKM3>sA1e?Vtq|T>qPzvMcrqIkhEa_>7xd(wrS>(CJod*CWR1*ZTXe$ zSa7|Id}`rXuT=dvkF!Tzx~2O{aNW>8KwCv?W*Wwv_uJ+kJ7DYJzMnY1YWZ`v&wJ>j z-53lJ2K_f+c;M3l9zxgJxtnM?Rd@&5A^I&EMw~vqzMF4alwZ_e`wQzNL6YXQs3iwo zq|@I^hSugg^Ts;I1L~iI%Ly|2FHjAPLWat7UORyW<*QD2b`w2uC8Gp7w4N)|z1Vp# zsz6Q*ON<4<0DQi3mlfNWq})e-2`%ba`G2#;Mg~ZJc4Ki5>w`p5x-XLUiqEu;p`jisdIbE)`-}GY}`bamFT{dy3Asl2R6bCTlVBPdg7PctRI=0};6lK1ouciyW)6yLG9T?N%q3&1kOh=J#9FX~MB32Kd%LKezFCdXP} zP-!K0lTEs;4*wT7LE*HOf#MW>tK`+qT8iOo@h5MLZ@st3e2scaJcogkWa5|Hw>^rX ztepY2whH)_{UO5G6JYI2V<~>X1 zfw;!24~_gF=?MObo$9O4ojQKZzeVKL>sUXSUV4Ssh)LHTnxuoxiN3L@1sP8CIo(fbivw2*jeZ&Xn_uXXf zJ#$*;m--mUFzU~K_oD4SDZWHL!Jk_Y$N%U-JZaet9A;zn9B*5Kccmnsvq0R#4vB& zeJ*M+jwoCcA^5r38|*eIp-$yzJP)tC%%~kZBJ)6e*g7Z`-B>}_T6%R#5$u=O*s$|f zQ6mOvAiplmpK>&l7N<8Vok*$k-x~S^-(CvbK7V)v5tZU2mqz-Uc%Upx(lQFU)}VD@ zP5FFys0c(2C>&bL#g9HXbJ3HAI$k;-Q1(vO)IMbB$)Yo(xgk{Ger3G6j$PHoP6&x^ zKzDU9HdKXn4ljPgF7YQs3*ygulbXtB%QtCj(5Maj#bo`Mb^nJ}Sd^Rxl|6pC-jb!i zu28Bl%m!NDz!e_$TW(pLz37>TJ7IE-V!X7*w@ zqDW+QV2>0uhsd%N8TlV)g{XFc3vS+f#nl%bw>Ty8lM%@oFK{NJd9Pf2Q8Ft~OCc+RSBS*F=2o9`>M8L8-;rc=o%wl{lq$Dc_C}1< zTbfR%hqdJN@>PlDJBUFZ+`dD>H5!J+^a%As<03)hh}F*gH7^;a1&xIRjc$Eo1~X@^yoh=H(Gud-vdEgEZ!9WSo;2fJ{3{p|du zwwU?iz)`ior@~tP$v^sE%`Oc0-+x>(mh3GK`IY;^MaLsgxQ88MB~^d@5~$0OQ<*vc z(-udpm=Fz!kj=jo%{Yy(&Gzj6evflX55su_vn?jvTr)Gtp{K57J&9(BvcC41G1Jwz z+<{%uIcjC>my5EH)aLs7O;@GEMKv%s&R`Krr2C@a@)Zkpl?87Djv7F=o0%`eZGEeH zJEGqA3~_tT(L}R!E%;>M-`blN<1Ld%_{O0igc6d3J67=+?ljkF^ zvw40NbL$4uRiUXdVS0=_+*7^Zn1G>69U2kRT>EDIdDrY~-n|6VfZje|qY&p6k+~d1 zsG|GQAead=?8=skGs!6mYc7dni899g*N>cnM(F9yOgc2r72fj%tz4h%9;G-Xi(Z8G zW8SRl^#+ki+umbi2VqDBiw|FMOw$=JS`JFcM9I!CI(~hids7BRs+?a1DCt{=o8B|5 z`W8D)@}!EboI!)D0ei1KVIvO&rUw=1`wnkz6vmNuI_JQKVG@TDA&f0?;B8Ji-(yaK zRfa{4edpt41s_X>)n((}9sGI5`&5Lb^lN_9(0!o<{({SLYQ$qrn6i(figmn`^+J;s zYA2Z=+0|vLZb$Z^Q<^bfab+iU44?C{H8!d;v_PLi0gOB0q4U^w!oUNxvNov{(bG~Y z?Kiz9w%LC#w`D{@eeSpGJkZ4|o@u)nB5u?g>|pd)w=~Is9tC~&PD6vF*bCch*pC9Z z^N`(SeXu~ulT4(R-i*xYr}Qb?dt5c2(6!Xy-w!CH@5S5r@%RWg%|+`)+|AcBlj;U7 z^2X_5+*34OZF=0<%p@eEff@R&0TGuf*+?$w+RyI|M9fpXu=LCn#^eM%UF)k|L{*o^ zpN${uzkIE~8q-`oN!{AK__Tf;t!Q5}`nvJp`>?$ZRt#RojLEt}`~nKJrJ)va_*{sDb z;rot9ft=dH8-0eg^-<;K^cdWV0g4QWSa*vW=c^9{Dk~{|gTjXFy(~vo4Tn96oU2-a zw#nBo*9Z^1DymDljEG_I1V_g$<9iaio*Zl&z1ji=8;_IU_Pn9lZ0@LS+QLtEtz<*o zcWQjY(F%!=h0LCb{>akBF1jq%!YT+KaZqdDcDv!GOjNzAQ@=aNz^4Q%R>*krMR0T1 z%GCW*7M_TGf5{cKO?4nvdWB5)#SR}??PRtC8VKax3>UG_rA=-{HUZ9Sy@k+1*ft=t zvy?XaUCc+%mu|_mOCJhZk@Qg!4pRzR6e!=gKrFx z-(ou5*_XaC3adVkzR%L2q=y_zKCddBRJ|Z)cma=`T8pFqQKy@Ye65d|hh$e6kYQ+zPE_@8zMhv>c zS3GT@wICyhH9yD!5aV~q(|ggkj_)FFn40Muo9Vxld2@f-_ngtqT`T+Gx`rQqD-opR z&+sDaGq9H?`*ZD+o@}7oo03JQl~WKh@d3ef1giqBCC4kg!e1B|^$~0%FbP_#J4fe6 zp%o1$#fUoY2Zq0g9Ji$`G8fT$jg`*QJ)2y+`m(D%ZubBR={>^)zq<(}NuZm-e1s`n zdY=n6*F4@SX(!))%~^jGAx{k3NaW{9*DCDLs84TB(yLexU>bAf%oeN?aNs zGi}Itq551s*xB8Ek~fueMF(1;ZlRvE$nQ?Al6FoCKR^E$Hbw)`QWLbq(NNnt=hn|F z>v#VxiVY8nK0Zx%t()ZlPV}|R@MhXf+o~!UncX zmvCY_I@Z}T+ZgAH=6n4YaOeu8CDMnV_WgWvGcdR<*7#;gB{A9TO2nQ>RRnwpHAGMN zZ$6AZ!5d?F2~YiWn)&5}c!Cu(+N#m@$yA8s+3_%1X4ELav0eqCXVB44x(nz2u-_Sxe;~#v#ep_e2 z9x6ZVUthw*vcke-+Y68B!=`pup@=}lR4~t^M6EC|KRv=VK4DBE!~&ZpSFV2GQy8dJ4RS|5c=obqp$XVKgm27g~+Xt zi;Iclum!bwiR3kb1-YFKG-#_j*n5x`vph6XqUc&FF8*u;>pdY#j|nLr@q5U&d4Gsil70 zdIbL~UU1M`U`0GgmaNV{A1my~uTfT(eE_T9t0953-nnDDiD>$QoH6+n;kwX29m)L@ z5|O#V^Uii*3y*6PC1Fs1csS5uzWcCy;SDn@CS)CLlS@x>8O}^mUVC^uu2{T70f&MJ z7)wy9dP1@y`1S2ZyevC*gA~ELs$Nwf2Db&W+ZOHXey8Ix_9B27UR&!F`e|m7uYM!? z4Vv{@-6Dw42|WA2INsWq=22?Lal-X_y9+3pb;(F`Q|aDUUzfF_?@K4)kcjt&?-i?i z&=-`vQO$t{h1p)ylC;4^&M0ksO9`yh`#txa0=!IP4_@fF8$akob8l9)1;SW+gI>#_ z`MDe{Ay^-uz_?=eg3}Q?c=cuJ$&8}PYm)(XMtSGP(NlrD5V__y7cYhO%pOVgV*PHx zgcCMX!iZc7Ua|jD#mB;|ci$bA{2|(AdxhL}+%G8AZH*6f*fP|I=VlDFuLI+~c#q38 zlX3l)DhDE)(hdh$b1dBFQ&?J-X851bT9A@p?Q2ZBtu_zj(&DUSNK@j~*6Gw=I?B;+ zal0qUWI`8c71_=Y?Z)Rw z_u~hzY&rR|;V50}i{`T&frZucd^=L+ivx>n1DRkOS>kDA&{1(5brlg`a^vkSf%JmI zXKL#T;=b82pjJsBy^(K&uaCYZt2vS%M;Pv2Q5HSs>OCfNui8YJ0glYd8JZZ5^Y#i25w|LKPm!wV!kKe(T-tE3ACk^!Wh?h-*iW2A!i(66OBF?rt%! zoyv&Xp2I|o-~E|eRe=mAY?#7=fFc~xvy1w=GM$c-rj4DGkcoHJu)^}`k^ZyLW+=Lh z^MIcIbT*ww`aX390i9R<)LFO%_UL7t<@0>CTHdEZ3sL-xIOK60%%0vjzc}k%0bAlz z{Zng^TG>51P@nUOjt#sXm~WFQ;ZQ>p-S!5=?7mWx`!j%YO|B z41S9PbBL~H19J#GjU33^PZ#%*^D~otRGY}G z`B%lWpedOM-QBpkv)t>GN|Vs9#-SwN{dshR=Z;&?*j3g*jSDH!1tS{rYQ)2sfDc|^9uTFnq zsGvcNY>Q2Vhr+qDA=qs= zmTmRX9Qolx0+}2VR>1~L6S=dhT+sAT**gjA%Y$YRp7?=&gkj7_&W9-4xN?)P2;1Bt z_h1Tvxt={LuC94xdQeRNfaV*awjr)_#TR1G%HNZUVF+e*WN+Zta5?W)Ww&hX=Vf z`1hKj&rENg!si*b3)OlNm{#L5)tSdzbXsGw>7pQjxX_}dr)hL3n;i3gzCf+JEbB7H z_#3bcZ%u0~{VMBSTC5HS=Z?7|&Z5b;%;UMaux=6T{}~2ZFYSJHLo(OCJ&+$6@-Pmg z{37mEDS;F5HZxv< ztX3Idt4)XE&W1=OiM^6sk%hm}qV?^b`3EoP`Q5=%Ar}zE@5a9((c9+sEV=5SI4)T! zHt0hwcbn&-vD;I8xViT3Vs8lCX89*?3zaA536Fl7{RHum7D=bbmeOP;trO-YbgUiQ zaG1Sg?U(!mj};OVdOdQI18YNMyS*dmJUjOibD^DwWcPtdHYHLn-+!(O)g5XIUe4t@ zMip)k|6TJgJ%&mryf?i1_A{@Hn80`H^d7+Iwt+L(%QU&#`fh>=IGMeH^;)6334^c- zSuvRZ#TQaNco_r{o(rU$%d^89K8}Ty2Q6s;&i>9XEMccDhTeh7 z5y|`eONt%9Cq}x~7=v|nNC#z|7UR=PHrFVQnk^z-rwWiG5oRJu>ZPUs{ zg^PGRX;yVwiQJMTy0WP#xtm!1d3Icd9uZgzxJXXJg#yRmVZw-LQ*Oa4y*ZL8K$tC0 zqia41?&k`!3>jIZ8~HXgE5vu3%7&FHW)VL>7fwSbo)*+O8wsT_F?9?E(Jq$u2*&j3La7>m6==Ccr=Bll*E5Fi8ZH!*aVEgeR zaopVHo-z#P1BZ?i_%pYU_Wpd{0`x(Ib;7I*r)ix0r7?(41H^5OPE?lb{1x&Gx6KnwZRM(e@AqeU@CFF>CXdji<}}NWJwoT=38~PA z-U11LPWcNzH)oB``XctM{Szu-SOXu1Ap!gb*bb}V0r&5!kR`21z*!&^hP=ECDa@{# zAlMJy2{+`~BmVY!Z@3WWj)EfgqIsR}3Rh>rVKy*-0VsoAN8blZZ*Ej%)BXN@)*anT zKhMS!1Q{{Y<-H5twf7mjOi$l>Z<^IsUBr&PZPVk_q+%}0OfgYM7W{zbVExJbfP!DW(M?xp`rD{N z9!B}(RNn$F={xCkEBhez7$!-O7|t^XsD|#-X-$V_JlP%`QDX1z$|IS7DN+fm#K@ur zC#v_5c?rygy2L~O5m63B(%bG|lmag5Gtdeq1H>(+oi7rrGG--oNO3gC{stU}e}C`( z4^C7zrKR-1-Dj;iI=833a$ur#z@Y7$cLXM>Rm$P-z!N&eY6wK=(Y<)mH7({Vc2&pK z=#ilcTn7Ua^TE$AJTN#TJ=sQ*Xs9QP74{OV&G`*Ph-`?^;LeOkeD~r9H9D%RB5)|$ z3T5dT(M&D8g@tzGIfJL`>^z;#+75w+%k z_RMrT-INZ%06~&qn=OW>E!di0C8{(vitCWTX0o8|MG(7rlgCb-Ec$c-$5Olq17Fyp zvT=7{1y2?QjZo`=zy;<&P!LflN8Pk8!aG)_ssPW@IH@)&;vtpDI%a zX`MBW`)~RVJUS*_x&zELMg&_ECvJ2tbw*^S3|m#jfTp?iMbxt#Ds<Ud z$&H2HrOiPs$eIZR#3)WV1(@CIY3&Q&z?D9Irnc6v7R(x<&~^+ih@y$rRHR|WVlLU1 zl{Tm#?PK*mfUG$warO2BuKl_DZrITxk_iZMD)@FuQ9k{uRX5|Q$@NUegQswF={L`E z5Ue6Yfs0W1V0fzPZv0Z?eJ_-B^OvKLL_g{Ws*z6 zEcX7`{wsjV{c1n{mAOUc(GHlIiO^s?JKX#_f#mDFu^}c2jZ-UNJi=6VUt7+;8y99p z{BPrkzpz2uC#UvCFqwttfn=A?gGaZGNd^nQf;m9$=5JEtZ7Bp$%u^m?p2Mm-#a-qQ9ufIW2Rr0afGO~r#33`-M2pF=~r@-xE1z(r+rGUM` zs{;Ff!D% z&07#N_xXk|q(Aw$vgzJUiTnO%7hvTOB6fqwq85hu5mWBgYE?&Jm$2(LH&xw!>!pPP zamH%FXTeyAL*;=PfgP(kBiKcGe$+9P^FLl{98Wrl(XAy0k8?#%_r;cDz2UimfwlEo z+V|jHq1zmUylAv?$d=ey0)BwNpYAYhE~m=k0@>lOXWm`V>buS5SpPK3sUVdXQ6GRa zF@E-cY`DV%*e;KKcj?o^Ne)aS9izL6s%jW(`@lO33?{H~{IMktatw&tcWN(aHJh6k z<-qF|3y>~Z0NYU`v{J(aq>kBlcSbz1WOBSrA>GmHUlpf=`UPH)Eo~`s`O#Bc5NDN2 zOWrK>R`}>NZpur@d(&S{7gHprPp?-j7E71cq{1<|FE)5i!-&pD|LjXEZwmhJ?MxxN z)GQI&v0MCY+IGB;Qp-dKY)qk6QSf2yts+HfT=z`_OC3ZStQsO#kb94i_vckxe#@g` z7!i=qnMZUU%vrJ_ItnBynHq(=P7z3>Jh=7UZ%qi@E#cq6+cZp`5J z+p6v$2oG=&3?xnqLe~imu7sNXNZ9E8c+RC7jzbLlpp$Q(r<2d51-VK*#Gn{JQ6rpi z;q7g0diH+Wq!PA5flYB zY8CpAF5plz^hE5NINsUt$ul@!Tfa1aPy+4QOH{?-=2qe&hNZ6Gl!JdkK$@jHP)BEt2jr%8VzjW$*8L z1*suwl|I;LR+2^+Wt*3!&kt^T0b8Ju8hU(-Y3K2qj0b`Xgll6^+=`rQqXlJ#8#>uS z?44`c7T9}bFY*lHNq!D~NP4;Af38j&k2jTVxhZ9A5^{~A>@I{_w_--XdZQSOpo`z? zZp=Yjx}>)D8)mtk^KM&kb9{oo6osTJ(##4$nP--TlIrt+-3~Gl^lmx`-^k}qd4XwH z&9D9Y8pGfkdnzBgifTkQFMU@&so*gP&GoOzZAbJF~Dx~Q|qT75rG>}>xSrSQ*kwX7pG{FhDH)oj&EN2BvD?h9y)+Cuba zA{h`+)rB6LDw+BUys-DoMd&o5XwUUrx-+{ZZ%o0c6q(Onrx&qnV6Rq&_7UdozwUm% zyLQbbfLUkb5|avzxr2kHV4#pb4fU*q%tf>wRU9=8pW_{_2&>4k2@W21kk5mT%V{8F zDJoN?<9X&C1{wuQ-=5KhIatyOdy-)bR`***m>ei#fpHl+xxyS{(5wXqQdq_%1FX zbL;-h_HRpvw4pmqiklE6>RWtjeGyP1*TkXiy#}7+ZBe&Reg0l*4E;x}pdmc-P`8pl z*psbsI^-HuGw2$;>YmT#&fwRPF|NYPEiW3(FBWo#d`<@tlcPNHmNeHr>QJS$R@LKL z2hLr#UXfDa#JBP7US4MaAlQ*XKDHY#y?A&_%UH+9C~$_~x<#=@UY&@Def3I?ZeL8^ z6#k$jQWbB2y{N8m0g%ZM-EVO1zrft0*+xCf z{4$xn`4%D+0|A&^VOn8IGZ?-U@GWUgi%)5uX1qu+Bq#ck4=&obSoL{$qKtsIaH=sc zguta4zT?a&cVI0MWGZ-;`*0?V~az(nKVC`1oT zlx0Zm*De-beyejwcJ9sRgK_q8%wPFu&*rjU3p(HI#lXRYOcCf5M)d-dq!*nzsAGGC-opCBF2mLfFCDhxEG}NLPE60s~pO^to zT$x7KJ}Y{|k8<37H!>6y@j!GxGKSO-4v+I)Wjad}(;V24ocRB=_vZ0Xw(TGAWhqP9 zg{-M0OP1_QBWod4_ClnR?1n5^hGZ$rog&FDS+Z8LGf}k2k{DSBWiZ8*eeFFi_j5n@ z@AjR??nd!9M*_xlDZNOCOIM*MzWUFNV^vFy$%=Ui`WER3EI3!e~#K$+=Xe#SXuN*X+jk^DcyY^ zc(xdGxwS!VEx-$0Ok|hI1s4qGtAH>K9DbD(aQDjjXY46p;wC{~9oLz%IF?sh)jgq0 zXQ>}PbvmyES(_q#oM2_Fn*^kbgLAl^wCa2BK&4tLuaOl6IDRZ~YWe{GC!3wa&<T85z`Et3fY+9kzR(6R;)^rGaT>sL^Yxtc<9EH$(0q+`}5PYkWICjZY5_?cniW(I%3uywpMnr^YrDZsF zN?iTw;?q*@_J~`D+idIUhm4Q5sYt5|7T@V0=af{!uz(lXVE`-)uD7QgK#*AgLfBW% zH(~M5AcTRu!MrGfzYoH;lK27JEv0rNgHvHgxm7sI(E<&i5ZLJL9E( z>ZMd*S&0l}A>=)O0DjI8Bf;^L`YMv;tExg>3|LxS34?ORrS$e}(6Qh~5k(b6o>a^Q zKH5p6`EryYr-g%5FqS77drA5C!=9OAf@xKvOJXSHK7SYq98TbVB2#R|dHS4|`u9Ni7XCLUBLmuPEb>bx9JUlb!Jqo{ z(cA>#wC~_baSCMG+>`I)1NT_v71K5zNNe*-qek4Iz)C!!P1m2po@;39(Szry3uveJ zXkPHSA!tnqZNH$eP&9FC_bw|@hFlux>mZoI#eRK$l!~wff+3F_`4wJ&`MJ>8daDg1 zLXiQr5+V|-^e~uE$KWU1dfQ;bfK}j$sa6%Kn~9`CwP29I;DXmJN%D?Z1m$VgN)zh) z0RxvYq~p2MKKQJ^;HBx??8SilnHNLwoNyV{$*iB`mkcPcv>y`KbfOq0GbfCZVAmF# z%Z)`o!w~VSw+*-5p^52CVNw3IP2E`i#JaaJz`YV)Eq0T9M4SyZ$TO$kkyz2E+p7Y>s|Ztx*HOjp(_d znOQy}hEDs$cn-TIoLsH*vR(1fMhDDM(cVISKk6xkvh+xq+Bt}!xcnRPXw3~Vdxulbsvw@GP_Px z({ue+&p=+^f5bP1UFKOFyYdzl!nI$53O@tIBe0b` zbbVj`9!^7a*FIid@Op|i=C^=Tvv(`Ykk$iIj8E{_uFd)@cJ1XOi=VGsnLOmCLWnBG zZ^)IeAG992V5ndRk=@la?t?W~cJ!UNDKM|P7)#^$N|g_>=LA)%uJf1k(L(aElhj!P*EGgepu42itrN-wQxzHg*y#PWp02)`iJ8Nc&tg9A{ufsry zB(IN*DeDjKk*GZ5J)BAvM&VK*F^%Vm*nE=R^Zk~vyt&!dKG+g?0b;r7g?P-G{(nJeN&1Bp_x!?_9q&nF?}8siHQ!{1^R#$ z)h1YP#VTkn45rtPT-k3V;WbNfO*8Qm^p+|jI4JuyTat_T9o)dPlzX;RN=OVJeMdOM zl>P)-x_=}A$gMMnQv_cvL!w4b0b&AoX18!Ms&S}a|Bh8~C>;FJ(1p_zh_CCPp{O_h z=dN5!A3l)t{8(GG+eHE#WCSh?d4V7 zjIWF?DzZ27xaf+wC)Q$&TiN5btwF!B_VWB%LISyMXqtJxUGylmX5$_oN+y69a84qo}#yOi!~#&Z1Q^WDB1%D|xMAL$;9%%U)j< zQ}41_l=U8B$ig<=DV=HsSLT$@@Do~fJWbSMmjfem)tYgcgQ`#U_j8+!yw@u-%Rm4< zlmJZGqI_H&JmFy*$NFf`%wcB&J*pq8(v1!1cgm`0Ar0PkUHGdm#SW6zJ%I9Fp*1Bh zene2;_^J51x7oQLI1~4+X*iYu z;H+V(l_N=+-|y$qwDUY)beFM8d^6_z>$Ae%^)1u_f)rm3Ao?RE3v-i-rMXlH`66;| z;vqYNl|jlc0U9jb49qNWeLOPPPz*=%9uN&(^}$1HMqX^+=YMwV8d_-Jo!rVeTpNM>a7EaAp0DihHljO=jdBUKB2QFX_LJP9z?>QVp|NR*{P za4vfhi$nisLx+5abeRP>^8fNF-SVzaVvP;FYV)Q>BG7{I%5~-0aovLJsJH&{GN`KBMEMpAM{wk`5 zQB>#d7aLMpsQsp~ci)Mi@*$}8>fWb{fs{vQXV(1XT6tn9tEYT4JoNU)Kg{8j4EU(v zqpM^L&Uxj@xGh_CJ#Kv@%iPX?xHN|r&jPb!tRB%II;ceC3LiIz@J5lVQ3xA8HKFX= zz%A&hG&cVbZg0qIf?v;wTKV-gMf&VVRPG;%6qSQS2-0ls96270qhqateJ6Hm=I3%ZPwQ~zWu2!G2ioU2=c&6wj9lVHs@86sgR5}hUFy=VT)+A=%1<$A?F&} zOqdJ_IYQeZR~zFoPB!WHlppI)eQKP1pIKk^)t=-|{-n z?e)>~u#Ca3Z0I7JV)KFV6zaH0*BxEXS3?ZOvgLuDA0S}t*;!{HbMd3Zs|1Vg6}_R3 zoyZE>yqwR0Pd&x1({%0o-q0Qx+SDc-i7%$CZv5?E`w#j^yr=As%HfnhdLms5O^FNE z8p9s*7=F91ACGrHa8AV@KbFf{Aj>rh>5JnE--T(A=ip@9i_wPVgRI>VroFmhL_QW9 z!mFBhHj16wwsXf7l2SzHEww>!$}a)lNBPFt2SEF)7ai8Dd=e-#=xcF|8D&nL=pFXM zXZ!CV1zxi|{|=(=?4wsNJh>M>sTUIsxm)R$GD-fr-P#4F<|@` z9bwmbr3p&&oL~{(hg?c;W&VCiq(l#$NaZcou?*MN!-@+GC~7DgI5T@_s22FC<|Htr3(PNo^~b_l)0&=R zQDaQEw5QL!QK*~SmID%~e8KuASljn$0{%X45V~`L@y2DPAA_;##CnWvs>Oi02KA8&; zbRP()6f8S?u0OM7;Tgzq(yw|Nq|}3IAid`!8NnZf5W}h-rE^*1;7_ydRZ;4w?v?&k z7o(B>3Hw^()k5mzU+mib>ORV4c7zkm(r2B`y|x}#tcqJtfR5|=R-SPBDd$fLCS$IV z8I8{k2v4b}z(8ZZ!0pN{_1!@5X7D=&&0V+7S++XU=Zq?H8R~FORAkJ}JpJuKzcW;k zpq(MI712Jb;j$E0sk>4iP@C!40hjvZY;&Vy+*x%lR zXgCa@9z~!_5j1#-X)dg{Ix;FjxUuz%`!{xLb~$vWC}b|PlQD%mHGf^@NXPj;0_4^u z%lf9w17CcIovRl_W5nB!J?9di-(*MWl3K+A%h^){r+l7d6>Lp4m{~iiT4x8c`Dw8t zZoW0n>&uFtfFfz-6)Q%G`-Ram_FP((x<5$ehkFzCZTMQ(`&o=A6<}2%#KoF;V<{6E zsE7aO8HRC&xYXTa(gG**vAop!Lo$b=N~=eBnIzCuj8wBKL)BC8&LH#Vh$XDo?|?XG z;hY&jxP?~y?)@?!Lyv#E@5w9;QI%r_Q0KTOW$+m^)H&`;+bt>@9Wq+o-cl`DMD>FM zb0O{rC5P>*^K77;2Z;4JLv7#mi2J1VbkSHHHFC|9aoBD3{26K@2Z;f}sw4Pb^oGDy z{7TzB$N!)FmYavk3+B84=34 zeEFYqB!DSESnA6t%+Y9hA6a>=LX#DYMk}(A^*`MpJayje;?eM_gt$(AIE^oy<^+10 zmhIEjqJ4=aEZ%kQPqjf6*@vBtxIwk?epd&s6MBUD(8xlD3m74FUo=KVbD^y@zjW#} z!`OT;93`0q#jdO{vIR^9oOSeN9il>#59(HPvQkJO@ zcgSiTH{HHj%!smrg##5FLLUy1-d6t#kT(POlClgQNrYZFTO^LUj9mM4U!Sk$#=C*~ zwt(Zz3T1pI;NP(IRP<=exOkwreRE3+&Sl#kMbhoG|%F4BO@i@39Y zGnr`Fkd_}>TD*oar_O88U{ zyN-}w`=moVS1;~oFk|L(L6MGOm~hqg)+V{uu8jjy8Do`47WJDpep|=zhr*ze;sduj z_uk}ArFx()K~rvrR`1Zf&hh8aQRBrMemYkENO6p__1QKpvH4hqkLJL+0&Jl z+@kFS8bGhqg?vf6hzYHsv>&`hQ7IJ(bMBNwikJC`>fjDQ^9t?>Lr&*@hMWWq_gmd# zA(^T>ME6*yv%@NrX@~GYF>>Bb+Z#3$z1zIf64(-K)N0ZE{1DFNJhJ6k0yDDFMJ8bp z=#0*(xB;o!knY91QKck7NamHG)G7<00I{N31q;l#CPV$=pr+~WY$eh7X7!s+UsUW5 za|YZlIGfcmY26)&;IB+}qdV+lIH+)} zhnW`t!2;Axf%|-IDjKs(pgUin1ub|`D17yP#|XkH?4T0zvhLO#@!dQ@f~&z9F42F9 zc^A3Z)nzZ1C zL}nj@8}c;wz?$S)owUf~dpP31Wv|aEy6)~SH?i#;-k1Bg4mHFM$uwCjfeF;GW0BBI z!dxI<7Ab7j1P#GQ2vX@7?(kJ7eX+qJFfJsd_BwQC0(=K~dN3Npr@8R3^|_2{$@JPu zTcP~Y<*c1T20E@$tIX$8NWb3j(*Zb==~e$)P~+s_Vf}4wsPVR?%rs4#WcN7EMbAFw zyCy<;NajL_Qs{(T7YHe**ZNymBlq>F_y)U15j8jxrg;S8`D%*>WOBq?3?&jPK5#Ru zX7=>kkiPMPQ9TB}Kw-MetH5{#1h>36xX#;gj-PCm)jH3O*t~C%MsQC-tjn*3Pd!(G zo2YW1+&ieU{}-#?XawIb$rP{mENN<2ol6a28YZx_KMZRze?RuBF?C9i0pxFdZu3SM z_FBE2&qsQj3<3;TK}E(a`#iR{K&P&>FQ-%NfF*QO@JEfM+Y@OwT{o*v$Tcpd?TL9e;6mmx>PE4D45 z=cJ!cMr;3sLDVw$)#C@|Gx#>(oGG`-u``dyG=e=A`5hrYNHCjzKS~Zc=mI)IMhNk2 zSu_X7jd%{su?*TZ2;u6T=o02C6F&0lX6CJllvJXM5-TRZvZix&4>)I9XTrrnwWGIu z^)2SR=}lkBay>X4@izHQF?0o1ZspS=ls4AeeR($o^Gmr6Z1E|bP$bwrF%qY*2l82m zF>cu9$)t7)x5VvN)X)<!0i5nn=49xH8MEV_5R$H zV;{qwOXau)*k_!Rc&Wbqf+Hi;x6iknYone6lj@m!=@?plv9mZ$TzgzI34WcuzE&~p3<`D1I$E1W^8n6*k6RiOfnKIP;1UqI zflJkV8_F+=wCM?Moq4dZbA|cg$^Xzq#$=K*4p(QdVv~#akMILTs1bwoS?aBmW7a~* z>^*ImTd;yw0c|u{gRevB`CzmiL6$w#Pv1nitNto?`uE@yr?O!Nz1znJ;SeL@&z3da zJBmNt2F}6f*LUvY48bR;?t#?$-(*RW3ZGgp_fm^O^T2H9gZ zTg4t7yQB|%BG5b+_&UGpMip21=4Ua?{)>){4h|8W&X*fMj-ef)r^^WCPP7T&N5J4^ zA5A*U_Ab5qC3YZ`eoSqHhVl&CF2|4SCxv_9Yg!Jv>_O4Rem1aTFB*|CAHTqjoHBpH0Moo|XzA^{0o` zXD|D|c>i44E&Y743KcFCm77LviTRG2fO7$Z!+x9#trN7f#S)R z3CBF2qvCA#T22=4v=rVxQO@~@axpf!;|NRy#dV@uMyZtZK((^dY^klsl=qZ-UloNWL%I#0R4d4 zZ+E7Sr4MfKK6(sPf|1hcS3jv?Vl^#ySlRcaEOA{1nU2DfjGcB=NahB*_jK5)_P+N! zRyT~%a|8In7i1(Rxz;!Mt8!Bq;oL6x^|~Bd&q_;F4J4A)34wSg0L+hH-e3IY>1tsO zHqQ4jLeT_GnbOZ1{_^1pc;z%3WKXm}uBt#@dkehj71B}%fdqp0{8&co34M(~(c?Wk zh!V9~1n9?-xSNEp*@;|5-(ApdPrp{9vY?)yec%{Al_Ar#!DI{zC)u*NkMN5Eo4k_K z!}|gF8NS!5mIP@hq<$T;BKCs+@t|ZD?mZTMah(-UYtdSNq8D;bX&#SvWk98Cqxdk@ z#tZ~oNC9!6V^M&DHdv~_&(HJP-S19}13fV31AMku!0!y21ys6|;XJzN#Y^o5J zv&p3AdC}ldU7kA2Qf?-)-|kjt;TKj+)e@|iqQ}6p_R6wt2Z|CPnc*@a+B7&r%})Mg!OPJpyEpI9&_^I*x!TrHg?;fr)1ikCps2*&Z1*jw+?w!f`yRoKc z7_ap%j`)LjEps!6mXd%nrF}MSSJ)^$()UJ?HlC_ev+A|qYL)s1DW$XNhsLJU-uT?& zAw=>G`pSW1$0>a9=@{L9pd(ZIIjW3lVABs?+~=u$VBNb>59(1^IMUvNk=r-|$_&}{ zf=Qa4CL~!4DOa>+@Z3#QbDOnEJAlau9_*wn|A=Zb!Kxu#Q1C*yu1q!lIMVjPYi3Ou zD(P|(Y!w2{Bx{v|GlUbo9ogH`I1l;4csii05h75n2TcSigjW?^IhB~@q;Y;~jbjx% zA7s8tMR56$>cs8fRUNMV@`7J1qB|@AvU(W8V%jxZpJP?4kkqB9>oO40{3ue6cTps` z;T7sB4N^<{@D*J}FP>%kI=`9|2|0vVeANIy?lNz4R36z@M(X74=+BF1;?;&{!n*HA z`}h7#OJ$OF)*ssmQ%UiO_sgqx&!N~;N$231ha5y+fGInj{#ilXoAdI!qriHxp-`Rk z%C}T%oHcq!%>sTE0^znW&AdZ*qibw&pFE>Fk(Eui4SNlcp9d~-Xu&=a!-bhZru`@~ z2Y0!|=$0K-k^L{hmjm^`qpQ%Ia6-`T?b)75C>P@gE{WijQAGHzGmuEGwqsl#cX4E) z-U2J~@f}kp)+8yz&m{Yu03z4LAb^O=((W-NSl-AUdFC_s9UuIHTSsX9;YLgTyw>VCG)~E zg@_j$#mwJ1%)e}AL;%%@yYiLNb(H*#$`e5{N+R~rd)9$o>sZD&)gP=pwyrl>e}ZO8 zB+}{i3H7=BF+#(#4~)&HkzoT9UGSOJQ_@!B{iFVrR!m3O0=rfrbiuQgwWUA8eKy%% z{7_PPRC{0>03_70(3VC$wbzGC$hHLdDiN5ogCh7Q~cP8;CA zkA&EAKO_t!9w-WSKdK=kSS34}F|Rm2>xy4*A_daoy3uzsWrb)L1a#avKYMTgICy)C zKB+rquu#JT@K_0=TKTul`Pn~c3zuruSWH{|7abmk5ht@#3*k4xbg0bHjNUL0pZq*9 zj2T}2VyIKE?$8+eb&RQiHZW`IjsoI1X7y%lsBgz6e)(BViGE@8>Rd+YxS`dk$3){Hx%IJm zH500B_{`OIFYX$0?odEQ`^J(N^Wz7~epc5jSUN5DkDKRX&q1g%zNJ(HV%QZ)VAXQ6 zMOl&(p-_ORQc#2n*eja`V}tiy_m^?lBdq{>4ajI*>cpvR+p&^2%D~W)0tK-1v1_&c73utK7?+n_z`B(tLt%bRrz*1T zLpynWlR`cc!t&5#PWiXLrYyqJgI+?uL}2e(8=n&i(MN+a@tgdH@=pE^`iy+7aEOgl zReSq#P@$5m5v8T3BmqV%z?`~TUzVWYB=mtLCts~~_&QbQay~iMEscg;^>*wsE255C zWQ+0YJ3m;@suk)bEqTqNuNpLOyHcI%714Nz3OW6YF>2hk@F|uxw?|?*si#?GeUy%FIrn;Y7?0_J5kwNysv94Bp>W6)|Ard# zo%WZOuFcFB(Ry;P_x#Xj3Vzu_9U5t>Qy(<{+wH*kk4Oy{$7*U%X%KxYG28A(_q6t< zFCO-4e3(0^Vnhz~K)?^T_jyx!h)E5lEXM^C=rA+|7Xb1Mp38ynw|pL9$Jt45Hsc(h z$5Q+R^|oy0$PvFoM$U84ldcxTUgdN}S(tfk@kW4mkS%wg>L+g>154x5M{EC6+suH4CK8x70rW8(lcU)sY-OzGYaQwSa+b#3Ux5r{H zE+*i576$BrGFfuqp7_@HnqLDXJ_6@EFM50l@Dom20OFpliMX5kx)z#;(<_pA1h?el ztt4Kd!pwCoc9ysNAe1{0o_C57M{xlMfv@IiVq4&{yd@noug8%H>{Y(nei$0!A)qF{ zDpKkV?5>yHfW)Pm)^dECj*Ogy5_v^<(}R7i@3&WL>?!(y!|vRJS8hw(oL>5#+sL{$ z_Ab*`?o%c>%kDJz2y|l?n6f8^7QMXh)?Ulho)+<`vCB=EJ4V|vc9RIP7?LDra9kUO zN#W=OwDxhPJaF1Iqi%^zK*M42*Y$=%i)n_?IlXCs!hzrKS&2p>Z{3BOr)*qt^iPzg zf3?Pxs|4ke4@5m`fz>`w6H}A*&|ryZECRXy$p`nd2CzG21lN1q1|CV5ouhlq*ZNyQ zTJxRKMnIo1GK9uO)$dq!j%fJ>9eC69Y9acL5$o83u9#Q1$Xh=e!rqRsN9Dp0-y9z{ z=%2`o^8$fd?%L${=USKb1u@MPGM#K$db}ej^1NzrLhv!l75eQCAV~gcTPgSjr)kkA zj0U@Mg)rtnsHWF!b9sj- z_9D!s!G$pV87=GYd=^P-(>8wN=C|UH-%YBVJVy9GT0nW28l)RQT~6PQ(k>Z`H{YBP zyRbp{5Zbu=R>;ek+6&>H)z1X+5wsA%BUahtswwZ=pfSa!?m8zCNycMNrIT3gYbQR` z_dmI`71XcEL}i8IH|*uLc{cOlF5-lr@Hu|#3!RIitlfW42?*Q4vAwfzx7klupzb*F zwu|5xlf4+&*>_tsr|K+njc$z2X}_GI!zY4%Kyvw1l}Hdo^{YGWb$k}VwAX&0KR5(P zsF;J$1J%pu3UuLo_)q?N*~Zok_Gjtl9q%Qw4;8;&eIp8KZOIPQkxv0U?4YM93vS2M z+WZHILPqwFbGk#z7J^qBPAfPAr5~ixD>|ARlD1*6leQXKey1U6iC5+?TNXc9d_=6o z(LApTh+X3_?a=odPPeMVPc%pqU zy_r4uPTnD5t9{t;YBz(55pJgY)w3gpyxGR9t(`j%C9vPf6@`fI3$+;UetBx5Cig2W zi4#s$J?cWAZJL1kOMMY7zNai0v#7xI__6ggjlF>?+@J}hy&=mlr;3c$OkHneP;!E!+Ka#O zRZ2n=FXgIH)kpADOQ*U3e+hR)3>5z`@a(kIloflzX;eQoq9kIi(M2 z1A8`b#EGA}eyE4eT+c8A6NDVft&>+cjILEwIV!y0Zk+#M1sCl#=&(1)aEKrkZ(i;B zROi#_Rob{z?``6bQRzccQbC5w4>v<^i$X+)i8fs5h&u>33V}Snqh68<$qJr(HD5yJ z%%b9hisu<>&mq}V&UvNGns0j-R62TBBhP(Po7sijL$NkE5!KoKC(`POk*xcV-vj*c z%`fTz#6>=wZl*Dm42`8mx-weI`<{=`oJ(xR7?z|Z5CdGaI`3BAFqYi4<%;dsdp&)7 zv?^f`a*})rvQUCB2F;r@9n0L6Y%<(oGS&(tLnLFQdL(UuS8-P$z=M(L#-YBN!MUS{9>fF3o4>l&Y2?G|V!8TRLVf z!xBb=d@cb&v9uP$C6%|Gom*w}+>mWlpRYU~L%u-&JahWFKNlV2>A4EoBY zEb&a35l?Bg!a*$~f$20B^?p7VRzGoi<;a2HYWc@s#)l#$Z=qX-BWHBfHJp_U?^E?!Fk%+GIN#nJ%BaZQT=j)>y^S6Qqqy6ii*{wcudFs>VHO|X~ ztAmIWF;|=I;NE~6p5LoQVpBxByRvk@C}`~ZD(ndZG{n3@V$2z&cT-p2ZmD6O>uOsVq#8ziAYBp5GJ&s8D1XvUL=Nn=Tez!L zg<4f4BNrSXd=J5~tVP4=8?tf+y%S&j(IH;`CxYt_%S3<0Ig07N7V8*(|KL{*CNv|c z?Oclr^Y)B1%VI%KHR-Cx* zN2T9C_&*kKz9!!rFAozM$_fF!U2v7Yw6nR<}&>qx|XL@5T@Tf6h~$=QEDljk@IQTKCoT7VG|G-U;2~yYl#9irzauK>eGChUw z^XU$VxHpfk*;)NI;T)$S=%=y%sm<28s^p#a*Nv*V+zLmIP?-e`kvA4M zFIcVD^_NthLhkffg}?dLxXV{i?zx?Z-1i&6d^HoBYbCJiliR;*=2vne#HNtglMso! z4v6Tn^mK>qAnTHr(*Wv~tRG9$|C?g^hMj_cIiZjc%%gmzIfQ3r%XhqzdZdlA@nsTe zUFI#l<+$NClTlt56CLVUTR3I3tUUK^2*|&78T>ZPU}g@JX~-3?Ar-w_Lf>;!S<`7G zKaN#n>4IU?+fPb&14pZL#-3cIpDQwz1W!kPn8rbYA!T>$ro>KWM$-$p&)ylB&CkK& zzM-$V|NFsj2r_ja=#UMJY0I5cY^^c2#WCYe=^@P+Qn2fjP(usNU_Ml&!?b?xZh+QL$KAmaSB-Q=G#Z2mQceus> zJOynj!fX5nX*7`jPle&Oh<$a)@q;JnB4xgTxBj2U{(6~L#;L~FWEDetc#8-Q{4}Cw zxDQdHfjAU`>czjGEGf6No3m9`9ljZ5!;U)vRR{-*J^2^&V5Np`&y9vh-v8Y&a=PhL z@QX9ciE)rD~j(@$z`>5gVa{ziKA9aKDo>`=%3|LoMB zEj7g*v5vt8*-5G)3NjG8+UtQ+FqM9=KlA7*_NL~0i*_D%+luAy?+l%-O;EhJKy^uE z|HzF{I>byDsDJwIh)=ts)_rEj>4xN1)oIWh|9S8)@2~LRQIh>W@cBHnQxqKpKcT=% z!(cR1uL&zy96vqz&nKapSrAdck4r+iK=QTM@=sb_X@dmT!Dq6TgZ*LdcK}gwNMA?O ztcg$V-|s->7iI7A)0Srf`F=VuzBrt2&I(xIf1gt(=>P9;v(=k&dmV+dtT*fpR5~5) z$pK#Oe;=H64D{Yxe7|-sP$wul_}28P8FBfl&3~^H(*6eMIniN+I$8!>D0G7tZRXZYr^{tlw#7y$a{sQ=y1FO zrS4z<&sx1G9Z$z?iVdvr=Pd(%5Yk-9yrq7>XPW|^`vUw5J@G$}RV>Q#xE;>pWQI|X z^`3-Ay_hoXC0T~He?J2!IEbF$e9vowFr+a5eOyO0;M@Q6;C7)O{hvo4==VR*@qgw0 zKY#xVW&i(qwtw+|{_p8x;= literal 0 HcmV?d00001 diff --git a/python/docs_src/source/_templates/layout.html b/python/docs_src/source/_templates/layout.html new file mode 100644 index 00000000..31c1aaaa --- /dev/null +++ b/python/docs_src/source/_templates/layout.html @@ -0,0 +1,94 @@ +{% extends "!layout.html" %} + {% block sidebartitle %} {{ super() }} + + + {% endblock %} + + {% block footer %} {{ super() }} + + + + {%- if nvidia_analytics_id %} + + {%- endif %} + + {% endblock %} diff --git a/python/docs_src/source/conf.py b/python/docs_src/source/conf.py new file mode 100644 index 00000000..57cd633d --- /dev/null +++ b/python/docs_src/source/conf.py @@ -0,0 +1,100 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath('../../media/docs')) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'CUTLASS Python interface' +copyright = '2023, NVIDIA' +author = 'NVIDIA' +release = '3.1.0' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'myst_parser', + 'nbsphinx', + 'nbsphinx_link', + 'sphinx_copybutton', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosectionlabel', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.extlinks', + 'sphinx.ext.ifconfig', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_inline_tabs', + ] + +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +autodoc_typehints = 'description' + +pygments_style = "sphinx" +pygments_dark_style = "monokai" + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# Ignore errors when converting notebooks +nbsphinx_allow_errors = True + +language = 'en' +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_static_path = ['_static'] + +html_title = "CUTLASS Python" +html_baseurl = 'docs' +html_theme = 'furo' +html_theme_options = { + "light_logo": "cutlass-logo-small.png", + "dark_logo": "cutlass-logo-small.png", + "light_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "dark_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/NVIDIA/cutlass", + "html": """ + + + + """, + "class": "", + }, + ], +} diff --git a/python/docs_src/source/contribute.md b/python/docs_src/source/contribute.md new file mode 100644 index 00000000..42475252 --- /dev/null +++ b/python/docs_src/source/contribute.md @@ -0,0 +1,9 @@ +# Contributing + +Thank you for your interest in contributing to the CUTLASS Python interface. Based on the type of contribution, it will fall into two categories: + +1. You want to report a bug, feature request, or documentation issue + - File an [issue](https://github.com/NVIDIA/cutlass/issues/new/choose) describing what you encountered or what you want to see changed. + - The CUTLASS team will evaluate the issues and triage them, scheduling them for a release. If you believe the issue needs priority attention, comment on the issue to notify the team. +2. You want to implement a feature or bug-fix + - We welcome contributions from the community. We recommend that you contribute via a [pull request](https://github.com/NVIDIA/cutlass/pulls). If you have questions about CUTLASS, consider asking a question via the [Discussions](https://github.com/NVIDIA/cutlass/discussions) tab. Please be sure to search through both existing issues and discussions to see whether your question has already been answered. diff --git a/python/docs_src/source/cutlass.emit.rst b/python/docs_src/source/cutlass.emit.rst new file mode 100644 index 00000000..3e65d407 --- /dev/null +++ b/python/docs_src/source/cutlass.emit.rst @@ -0,0 +1,18 @@ +Emitters +======== + +Common +------ + +.. automodule:: cutlass.emit.common + :members: + :undoc-members: + :show-inheritance: + +PyTorch +------- + +.. automodule:: cutlass.emit.pytorch + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs_src/source/cutlass.op.rst b/python/docs_src/source/cutlass.op.rst new file mode 100644 index 00000000..3b8a2b7e --- /dev/null +++ b/python/docs_src/source/cutlass.op.rst @@ -0,0 +1,26 @@ +Operations +========== + +GEMM +---- + +.. automodule:: cutlass.op.gemm + :members: + :undoc-members: + :show-inheritance: + +Grouped GEMM +------------ + +.. automodule:: cutlass.op.gemm_grouped + :members: + :undoc-members: + :show-inheritance: + +Operation +--------- + +.. automodule:: cutlass.op.op + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs_src/source/cutlass.rst b/python/docs_src/source/cutlass.rst new file mode 100644 index 00000000..a65c2518 --- /dev/null +++ b/python/docs_src/source/cutlass.rst @@ -0,0 +1,36 @@ +CUTLASS +======= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 1 + + cutlass.emit + cutlass.op + cutlass.utils + +Epilogue +-------- + +.. automodule:: cutlass.epilogue + :members: + :undoc-members: + :show-inheritance: + +Library Defaults +---------------- + +.. automodule:: cutlass.library_defaults + :members: + :undoc-members: + :show-inheritance: + +Swizzle +---------- + +.. automodule:: cutlass.swizzle + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs_src/source/cutlass.utils.rst b/python/docs_src/source/cutlass.utils.rst new file mode 100644 index 00000000..58e56e56 --- /dev/null +++ b/python/docs_src/source/cutlass.utils.rst @@ -0,0 +1,18 @@ +Utilities +========= + +Checks +------ + +.. automodule:: cutlass.utils.check + :members: + :undoc-members: + :show-inheritance: + +Data Types +---------- + +.. automodule:: cutlass.utils.datatypes + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs_src/source/examples.rst b/python/docs_src/source/examples.rst new file mode 100644 index 00000000..3cea3621 --- /dev/null +++ b/python/docs_src/source/examples.rst @@ -0,0 +1,9 @@ +Examples +================== + +.. toctree:: + :maxdepth: 5 + + Basic GEMM + Epilogue + PyTorch Extension diff --git a/python/docs_src/source/externals/00_basic_gemm.nblink b/python/docs_src/source/externals/00_basic_gemm.nblink new file mode 100644 index 00000000..b3841985 --- /dev/null +++ b/python/docs_src/source/externals/00_basic_gemm.nblink @@ -0,0 +1,3 @@ +{ + "path": "./../../../../examples/python/00_basic_gemm.ipynb" +} diff --git a/python/docs_src/source/externals/01_epilogue.nblink b/python/docs_src/source/externals/01_epilogue.nblink new file mode 100644 index 00000000..14503a1e --- /dev/null +++ b/python/docs_src/source/externals/01_epilogue.nblink @@ -0,0 +1,3 @@ +{ + "path": "./../../../../examples/python/01_epilogue.ipynb" +} diff --git a/python/docs_src/source/externals/02_pytorch_extension_grouped_gemm.nblink b/python/docs_src/source/externals/02_pytorch_extension_grouped_gemm.nblink new file mode 100644 index 00000000..7da19aff --- /dev/null +++ b/python/docs_src/source/externals/02_pytorch_extension_grouped_gemm.nblink @@ -0,0 +1,3 @@ +{ + "path": "./../../../../examples/python/02_pytorch_extension_grouped_gemm.ipynb" +} diff --git a/python/docs_src/source/index.rst b/python/docs_src/source/index.rst new file mode 100644 index 00000000..73cc742d --- /dev/null +++ b/python/docs_src/source/index.rst @@ -0,0 +1,55 @@ +.. CUTLASS Python interface documentation master file, created by + sphinx-quickstart on Mon Feb 13 17:57:39 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +.. include:: ../../README.md + :start-line: 1 + :parser: markdown + +.. toctree:: + :hidden: + + Home + +.. toctree:: + :hidden: + :caption: Getting Started: + + install.md + Getting Started + contribute.md + +.. toctree:: + :hidden: + :caption: Python Documentation: + + modules.rst + +.. toctree:: + :hidden: + :caption: Examples and Tutorials: + + examples.rst + +.. toctree:: + :hidden: + :caption: Advanced: + +.. toctree:: + :hidden: + :caption: FAQ: + +.. toctree:: + :hidden: + :caption: Reference: + + Github + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/python/docs_src/source/install.md b/python/docs_src/source/install.md new file mode 100644 index 00000000..8f901df9 --- /dev/null +++ b/python/docs_src/source/install.md @@ -0,0 +1,37 @@ +# Installation + +## Installing from source + +Installing from source requires the latest CUDA Toolkit that matches the major.minor of CUDA Python installed. + +Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables: +* `CUTLASS_PATH`: the path to the cloned CUTLASS repository +* `CUDA_INSTALL_PATH`: the path to the installation of CUDA + +If these environment variables are not set, the installation process will infer them to be the following: +* `CUTLASS_PATH`: one directory level above the current directory (i.e., `$(pwd)/..`) +* `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`) + +**NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`. + +### Installing a developer-mode package +The CUTLASS Python interface can currently be installed via: +```bash +python setup.py develop --user +``` +This will allow changes to the Python interface source to be reflected when using the Python interface. + +We plan to add support for installing via `python setup.py install` in a future release. + +## Docker +To ensure that you have all of the necessary Python modules for running the examples using the +CUTLASS Python interface, we recommend using one of the Docker images for CUDA [11.8](../../../python/docker/Dockerfile-cuda11.8-pytorch) +and [12.0](../../../python/docker/Dockerfile-cuda12.0-pytorch) are located in the docker directory. + +For example, to build and launch a container that uses CUDA 12.0 via an NGC PyTorch container, run: +```bash +docker build -t cutlass-cuda12.0:latest -f docker/Dockerfile-cuda12.0-pytorch . +docker run --gpus all -it --rm cutlass-cuda12.0:latest +``` + +The CUTLASS Python interface has been tested with CUDA 11.8 and CUDA 12.0 on Python 3.8.10 and 3.9.7. diff --git a/python/docs_src/source/modules.rst b/python/docs_src/source/modules.rst new file mode 100644 index 00000000..467824e9 --- /dev/null +++ b/python/docs_src/source/modules.rst @@ -0,0 +1,7 @@ +CUTLASS Python API +================== + +.. toctree:: + :maxdepth: 5 + + cutlass diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 00000000..4c97819a --- /dev/null +++ b/python/setup.py @@ -0,0 +1,106 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os +from setuptools import setup + + +def _cutlass_path_from_dir() -> str: + cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../') + if not os.path.isdir(cutlass_path): + raise Exception(f'Environment variable "CUTLASS_PATH" is not defined, and default path of {cutlass_path} does not exist.') + return cutlass_path + + +def _cuda_install_path_from_nvcc() -> str: + import subprocess + # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC + result = subprocess.run(['which', 'nvcc'], capture_output=True) + if result.returncode != 0: + raise Exception(f'Unable to find nvcc via `which` utility.') + + cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0] + if not os.path.isdir(cuda_install_path): + raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, and default path of {cuda_install_path} does not exist.') + + return cuda_install_path + + +cutlass_path = ( + os.getenv('CUTLASS_PATH') + if os.getenv('CUTLASS_PATH') is not None + else _cutlass_path_from_dir() +) + +cuda_install_path = ( + os.getenv('CUDA_INSTALL_PATH') + if os.getenv('CUDA_INSTALL_PATH') is not None + else _cuda_install_path_from_nvcc() +) + +ext_modules = [] + +try: + from pybind11.setup_helpers import Pybind11Extension, build_ext + include_dirs = [ + cutlass_path + '/include', + cuda_install_path + '/include', + cutlass_path + '/tools/util/include', + cutlass_path + '/test', + ] + + ext_modules = [ + Pybind11Extension('cutlass_bindings', + ['cutlass/cpp/cutlass_bindings.cpp'], + include_dirs=include_dirs, + extra_compile_args=['-fpermissive', '-w', '-std=c++17', '-DCUTLASS_PYTHON_HOST_CC=1']) + ] +except ImportError: + pass + + +setup( + name='cutlass', + version='3.1.0', + description='CUTLASS Pythonic Interface', + package_dir={'': '.'}, + packages=['cutlass', 'cutlass.emit', 'cutlass.op', 'cutlass.utils', 'cutlass.backend', 'cutlass.backend.utils'], + setup_requires=['pybind11'], + install_requires=[ + 'bfloat16', + 'cuda-python>=11.8.0', + 'pybind11', + 'scikit-build', + 'treelib' + ], + ext_modules=ext_modules, +) diff --git a/tools/library/scripts/pycutlass/test/conv/__init__.py b/test/python/backend/conv/__init__.py similarity index 100% rename from tools/library/scripts/pycutlass/test/conv/__init__.py rename to test/python/backend/conv/__init__.py diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py similarity index 64% rename from tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py rename to test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 2f003b50..831602b6 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -from pycutlass.conv2d_operation import * -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +from cutlass.backend.conv2d_operation import * +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -110,14 +110,14 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -125,22 +125,22 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -151,24 +151,24 @@ def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -178,22 +178,22 @@ def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -204,24 +204,24 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -229,5 +229,5 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc self.assertTrue(test_all_conv2d(operation, problem_sizes)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 69% rename from tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 2813f1c7..30b1d5cb 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage3(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -110,14 +110,14 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -125,22 +125,22 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage3_64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -151,14 +151,14 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -166,22 +166,22 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage4_64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -192,18 +192,18 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py similarity index 73% rename from tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py rename to test/python/backend/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 93d9e3bb..c811cff7 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -31,11 +31,11 @@ ################################################################################################# # test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -import pycutlass -from pycutlass.conv2d_operation import * -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend.conv2d_operation import * +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -44,22 +44,22 @@ class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( instruction_shape=[1, 1, 1], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=1) tile_description = TileDescription( @@ -70,14 +70,14 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -85,22 +85,22 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( instruction_shape=[1, 1, 1], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=1) tile_description = TileDescription( @@ -111,14 +111,14 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -126,5 +126,5 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 74% rename from tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 53fb0ebc..e4b9d07d 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhw epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhw def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -110,18 +110,18 @@ def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Unity, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py similarity index 52% rename from tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index 0e980616..42fa5187 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -31,76 +31,77 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") def conv2d_few_channel_problemsizes(channels): problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 8, 8, channels), - cutlass.Tensor4DCoord(16, 3, 3, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 8, 8, channels), + cutlass_bindings.Tensor4DCoord(16, 3, 3, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 16, 16, channels), - cutlass.Tensor4DCoord(16, 3, 3, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 16, 16, channels), + cutlass_bindings.Tensor4DCoord(16, 3, 3, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 16, 16, channels), - cutlass.Tensor4DCoord(16, 7, 7, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 16, 16, channels), + cutlass_bindings.Tensor4DCoord(16, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(32, 7, 7, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(32, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(64, 7, 7, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(64, 5, 5, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 5, 5, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(64, 5, 5, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 5, 5, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -111,22 +112,22 @@ class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCa def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -137,14 +138,14 @@ def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16n epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.few_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation, conv2d_few_channel_problemsizes(2))) @@ -152,22 +153,22 @@ def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16n def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_1(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=1) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=1) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -178,18 +179,18 @@ def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16n epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.few_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation, conv2d_few_channel_problemsizes(1))) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py similarity index 58% rename from tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index b4d9b45e..76787d14 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -31,58 +31,59 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") def conv2d_fixed_channel_problemsizes(channels): problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 8, 8, channels), - cutlass.Tensor4DCoord(16, 3, 3, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 8, 8, channels), + cutlass_bindings.Tensor4DCoord(16, 3, 3, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(32, 7, 7, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(32, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(64, 7, 7, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(64, 5, 5, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 5, 5, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 224, 224, channels), - cutlass.Tensor4DCoord(64, 5, 5, channels), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 5, 5, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -93,22 +94,22 @@ class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.Test def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -119,14 +120,14 @@ def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f1 epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.fixed_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation, conv2d_fixed_channel_problemsizes(8))) @@ -134,22 +135,22 @@ def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f1 def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -160,14 +161,14 @@ def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f1 epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.fixed_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation, conv2d_fixed_channel_problemsizes(4))) @@ -175,22 +176,22 @@ def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f1 def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -201,19 +202,19 @@ def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f1 epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.fixed_channels, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation, conv2d_fixed_channel_problemsizes(2))) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py similarity index 53% rename from tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py rename to test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index cf772782..35f78a14 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -110,14 +110,14 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -125,22 +125,22 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -151,42 +151,42 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 14), - cutlass.Tensor4DCoord(8, 3, 3, 14), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 14), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 14), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 23, 56, 98), - cutlass.Tensor4DCoord(128, 3, 3, 98), - cutlass.Tensor4DCoord(4, 0, 5, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 23, 56, 98), + cutlass_bindings.Tensor4DCoord(128, 3, 3, 98), + cutlass_bindings.Tensor4DCoord(4, 0, 5, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -196,22 +196,22 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -222,42 +222,42 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 14), - cutlass.Tensor4DCoord(8, 3, 3, 14), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 14), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 14), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 23, 56, 98), - cutlass.Tensor4DCoord(128, 3, 3, 98), - cutlass.Tensor4DCoord(4, 0, 5, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 23, 56, 98), + cutlass_bindings.Tensor4DCoord(128, 3, 3, 98), + cutlass_bindings.Tensor4DCoord(4, 0, 5, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -267,22 +267,22 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -293,42 +293,42 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float16) + math_inst.element_accumulator, cutlass_bindings.float16) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 28), - cutlass.Tensor4DCoord(8, 3, 3, 28), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 28), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 28), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 23, 56, 100), - cutlass.Tensor4DCoord(128, 3, 3, 100), - cutlass.Tensor4DCoord(4, 0, 5, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 23, 56, 100), + cutlass_bindings.Tensor4DCoord(128, 3, 3, 100), + cutlass_bindings.Tensor4DCoord(4, 0, 5, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -337,5 +337,5 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 79% rename from tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 8276bdd9..28ee42c2 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -69,18 +69,18 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/test/python/backend/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py similarity index 73% rename from tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py rename to test/python/backend/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 6949697f..c63c41cc 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/test/python/backend/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -31,11 +31,11 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -import pycutlass -from pycutlass.conv2d_operation import * -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend.conv2d_operation import * +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -44,22 +44,22 @@ class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( instruction_shape=[1, 1, 1], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=1) tile_description = TileDescription( @@ -70,14 +70,14 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle2 + swizzling_functor=cutlass_bindings.IdentitySwizzle2 ) self.assertTrue(test_all_conv2d(operation)) @@ -85,22 +85,22 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( instruction_shape=[1, 1, 1], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=1) tile_description = TileDescription( @@ -111,18 +111,18 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 69% rename from tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 10520e1f..5067b342 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhw epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhw def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=2) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -110,24 +110,24 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.fprop, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ) ] @@ -135,5 +135,5 @@ def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh self.assertTrue(test_all_conv2d(operation, problem_sizes)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 62% rename from tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index efa2d2d1..85a62e24 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittes def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + swizzling_functor=cutlass_bindings.StridedDgradIdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x256_64x3_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -110,14 +110,14 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + swizzling_functor=cutlass_bindings.StridedDgradIdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -125,22 +125,22 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4_128x128_32x3_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -151,24 +151,24 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + swizzling_functor=cutlass_bindings.StridedDgradIdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -178,22 +178,22 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_ def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -204,14 +204,14 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + swizzling_functor=cutlass_bindings.StridedDgradIdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -219,22 +219,22 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -245,33 +245,33 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.dgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + swizzling_functor=cutlass_bindings.StridedDgradIdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 56, 56, 12), - cutlass.Tensor4DCoord(8, 1, 1, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 56, 56, 12), + cutlass_bindings.Tensor4DCoord(8, 1, 1, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 55, 55, 12), - cutlass.Tensor4DCoord(8, 1, 1, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 55, 55, 12), + cutlass_bindings.Tensor4DCoord(8, 1, 1, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -281,5 +281,5 @@ def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py similarity index 75% rename from tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py rename to test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 2e6828c2..df4f79ea 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestC def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -69,15 +69,15 @@ def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tenso epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, - cutlass.float16 + cutlass_bindings.float16 ) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -85,22 +85,22 @@ def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tenso def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float16, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float16, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -111,19 +111,19 @@ def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tens epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, - cutlass.float16 + cutlass_bindings.float16 ) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 64% rename from tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index bb7533b6..cf7547c2 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestC def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tenso epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tenso def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -110,14 +110,14 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -125,22 +125,22 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_64x256_32x4_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=8) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -151,14 +151,14 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -166,22 +166,22 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -192,24 +192,24 @@ def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tenso epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -219,22 +219,22 @@ def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tenso def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -245,24 +245,24 @@ def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 4, 4, 12), - cutlass.Tensor4DCoord(8, 3, 3, 12), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(3, 3), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -270,5 +270,5 @@ def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens self.assertTrue(test_all_conv2d(operation, problem_sizes)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py similarity index 73% rename from tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py rename to test/python/backend/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index e2a60f9a..04f52da8 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -31,11 +31,11 @@ ################################################################################################# # test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu -import pycutlass -from pycutlass.conv2d_operation import * -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend.conv2d_operation import * +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -44,22 +44,22 @@ class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase) def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( instruction_shape=[1, 1, 1], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=1) tile_description = TileDescription( @@ -70,14 +70,14 @@ def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_ epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.analytic, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -85,22 +85,22 @@ def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( instruction_shape=[1, 1, 1], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=1) tile_description = TileDescription( @@ -111,18 +111,18 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py similarity index 69% rename from tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py rename to test/python/backend/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 213618b1..1e3d3e07 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/test/python/backend/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -31,10 +31,10 @@ ################################################################################################# # test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * +from cutlass.backend.utils.device import device_cc import unittest @@ -43,22 +43,22 @@ class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.Te def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=4) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=8) tile_description = TileDescription( @@ -69,14 +69,14 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) self.assertTrue(test_all_conv2d(operation)) @@ -84,22 +84,22 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) A = TensorDescription( element=math_inst.element_a, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=1) B = TensorDescription( element=math_inst.element_b, - layout=cutlass.TensorNHWC, + layout=cutlass_bindings.TensorNHWC, alignment=1) C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, + element=cutlass_bindings.float32, + layout=cutlass_bindings.TensorNHWC, alignment=4) tile_description = TileDescription( @@ -110,24 +110,24 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + conv_kind=cutlass_bindings.conv.Operator.wgrad, iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, arch=80, tile_description=tile_description, A=A, B=B, C=C, stride_support=StrideSupport.Strided, epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1 ) problem_sizes = [ - cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 8, 8, 1), - cutlass.Tensor4DCoord(1, 3, 3, 1), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 8, 8, 1), + cutlass_bindings.Tensor4DCoord(1, 3, 3, 1), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, 1, 1 ), ] @@ -135,5 +135,5 @@ def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nh self.assertTrue(test_all_conv2d(operation, problem_sizes)) if __name__ == '__main__': - pycutlass.get_memory_pool(2**26, 2**26) + cutlass.backend.get_memory_pool(2**26, 2**26) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/run_all_tests.py b/test/python/backend/conv/run_all_tests.py similarity index 94% rename from tools/library/scripts/pycutlass/test/conv/run_all_tests.py rename to test/python/backend/conv/run_all_tests.py index 9fec5d28..abbad54c 100644 --- a/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +++ b/test/python/backend/conv/run_all_tests.py @@ -30,12 +30,12 @@ # ################################################################################################# -import pycutlass +import cutlass.backend import unittest -from pycutlass.memory_manager import * +from cutlass.backend.memory_manager import * if __name__ == '__main__': - pycutlass.get_memory_pool(2**32, 2**32) + cutlass.backend.get_memory_pool(2**32, 2**32) loader = unittest.TestLoader() tests = loader.discover('./', 'conv2d_*.py') testRunner = unittest.runner.TextTestRunner() diff --git a/tools/library/scripts/pycutlass/test/gemm/__init__.py b/test/python/backend/gemm/__init__.py similarity index 100% rename from tools/library/scripts/pycutlass/test/gemm/__init__.py rename to test/python/backend/gemm/__init__.py diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py b/test/python/backend/gemm/gemm_bf16_sm80.py similarity index 74% rename from tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py rename to test/python/backend/gemm/gemm_bf16_sm80.py index de81e4b0..d77a005e 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +++ b/test/python/backend/gemm/gemm_bf16_sm80.py @@ -30,13 +30,13 @@ # ################################################################################################# -import pycutlass -from pycutlass import * -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * import unittest -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") @@ -44,8 +44,8 @@ class GemmBF16TensorOpSm80(unittest.TestCase): def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.bfloat16, element_b=cutlass.bfloat16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.bfloat16, element_b=cutlass_bindings.bfloat16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -56,23 +56,23 @@ def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self): ) A = TensorDescription( - element=cutlass.bfloat16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.bfloat16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.bfloat16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.bfloat16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=4 ) epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -85,8 +85,8 @@ def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self): def test_SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_128x256x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.bfloat16, element_b=cutlass.bfloat16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.bfloat16, element_b=cutlass_bindings.bfloat16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -97,23 +97,23 @@ def test_SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_128x256x64_64x64x64(se ) A = TensorDescription( - element=cutlass.bfloat16, layout=cutlass.RowMajor, + element=cutlass_bindings.bfloat16, layout=cutlass_bindings.RowMajor, alignment=8 ) B = TensorDescription( - element=cutlass.bfloat16, layout=cutlass.RowMajor, + element=cutlass_bindings.bfloat16, layout=cutlass_bindings.RowMajor, alignment=8 ) C = TensorDescription( - element=cutlass.bfloat16, layout=cutlass.RowMajor, + element=cutlass_bindings.bfloat16, layout=cutlass_bindings.RowMajor, alignment=8 ) epilogue_functor = LinearCombination( C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) + math_inst.element_accumulator, cutlass_bindings.float32) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -124,5 +124,5 @@ def test_SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_128x256x64_64x64x64(se self.assertTrue(test_all_gemm(operation, "multistage")) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm90.py b/test/python/backend/gemm/gemm_bf16_sm90.py similarity index 77% rename from tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm90.py rename to test/python/backend/gemm/gemm_bf16_sm90.py index 8d91979e..9970c218 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm90.py +++ b/test/python/backend/gemm/gemm_bf16_sm90.py @@ -31,18 +31,18 @@ ################################################################################################# from functools import partial -import pycutlass -from pycutlass import * -from pycutlass import library -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend import library +from cutlass.backend.test import * import unittest -from pycutlass.test.utils import LayoutCombination, get_name -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc -name_fn = partial(get_name, element_a=cutlass.bfloat16, element_b=cutlass.bfloat16, arch=90) +name_fn = partial(get_name, element_a=cutlass_bindings.bfloat16, element_b=cutlass_bindings.bfloat16, arch=90) def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, cluster_shape, threadblock_shape, stages, opclass, persistent=False): @@ -61,7 +61,7 @@ def add_test(cls, layouts, alignments, element_output, element_accumulator, elem :param stages: number of pipeline stages to use in the kernel :type stages: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass + :type opclass: cutlass_bindings.OpClass :param persistent: whether this is a persistent warp-specialized kernel :type persistent: bool """ @@ -71,10 +71,10 @@ def run(self): Dynamically-generated function that constructs a GEMM operation and verifies it against multiple test cases. """ - element_A = cutlass.bfloat16 - element_B = cutlass.bfloat16 - inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None - warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + element_A = cutlass_bindings.bfloat16 + element_B = cutlass_bindings.bfloat16 + inst_shape = [1, 1, 1] if opclass == cutlass_bindings.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass_bindings.OpClass.Simt else None math_inst = MathInstruction( instruction_shape=inst_shape, element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, @@ -95,7 +95,7 @@ def run(self): epilogue_functor = LinearCombination(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=90, tile_description=tile_description, A=A, B=B, C=C, @@ -123,16 +123,16 @@ class GemmBF16Sm90(unittest.TestCase): pass -add_test_tensorop = partial(add_test, opclass=cutlass.OpClass.TensorOp) -add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) +add_test_tensorop = partial(add_test, opclass=cutlass_bindings.OpClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass_bindings.OpClass.Simt) -add_test_tensorop(GemmBF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], 3) -add_test_tensorop(GemmBF16Sm90, LayoutCombination.NNN, [4, 4, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], 5) -add_test_tensorop(GemmBF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [2, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmBF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.bfloat16, cutlass.float32, cutlass.float32, [2, 1, 1], [128, 128, 32], None, persistent=True) -add_test_simt(GemmBF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.bfloat16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 8], 2) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass_bindings.bfloat16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], 3) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.NNN, [4, 4, 8], cutlass_bindings.bfloat16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], 5) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass_bindings.bfloat16, cutlass_bindings.float32, cutlass_bindings.float32, [2, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmBF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass_bindings.bfloat16, cutlass_bindings.float32, cutlass_bindings.float32, [2, 1, 1], [128, 128, 32], None, persistent=True) +add_test_simt(GemmBF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass_bindings.bfloat16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 8], 2) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py b/test/python/backend/gemm/gemm_f16_sm80.py similarity index 68% rename from tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py rename to test/python/backend/gemm/gemm_f16_sm80.py index b4f245e3..fed73dd1 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +++ b/test/python/backend/gemm/gemm_f16_sm80.py @@ -30,13 +30,13 @@ # ################################################################################################# -import pycutlass -from pycutlass import * -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * import unittest -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") @@ -44,8 +44,8 @@ class GemmF16Sm80(unittest.TestCase): def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -56,25 +56,25 @@ def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32( ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.BatchedIdentitySwizzle + swizzling_functor = cutlass_bindings.BatchedIdentitySwizzle operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -88,8 +88,8 @@ def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32( def test_SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32_128x128x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -100,25 +100,25 @@ def test_SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32_128x128x64_64x64x64(self) ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -131,8 +131,8 @@ def test_SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32_128x128x64_64x64x64(self) def test_SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32_128x256x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -143,25 +143,25 @@ def test_SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32_128x256x64_64x64x64(self) ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -174,8 +174,8 @@ def test_SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32_128x256x64_64x64x64(self) def test_SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_256x128x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -186,25 +186,25 @@ def test_SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_256x128x64_64x64x64(self) ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -217,8 +217,8 @@ def test_SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_256x128x64_64x64x64(self) def test_SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k_128x64x64_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -229,25 +229,25 @@ def test_SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k_128x64x64_64x64x ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float16 + element_epilogue = cutlass_bindings.float16 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -260,8 +260,8 @@ def test_SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k_128x64x64_64x64x def test_SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32_64x64x32_32x32x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float16, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -272,25 +272,25 @@ def test_SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32_64x64x32_32x32x3 ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float16 + element_epilogue = cutlass_bindings.float16 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -303,8 +303,8 @@ def test_SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32_64x64x32_32x32x3 def test_SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_256x128x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -315,25 +315,25 @@ def test_SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_256x128x64_64x64x64(self) ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -346,8 +346,8 @@ def test_SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_256x128x64_64x64x64(self) def test_test_SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k_128x64x64_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -358,25 +358,25 @@ def test_test_SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k_128x64x64_6 ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -389,8 +389,8 @@ def test_test_SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k_128x64x64_6 def test_SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_128x256x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -401,25 +401,25 @@ def test_SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_128x256x64_64x64x64(self) ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.RowMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -432,8 +432,8 @@ def test_SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_128x256x64_64x64x64(self) def test_SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_128x256x64_64x64x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -444,25 +444,25 @@ def test_SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_128x256x64_64x64x64(self) ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -475,5 +475,5 @@ def test_SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_128x256x64_64x64x64(self) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm90.py b/test/python/backend/gemm/gemm_f16_sm90.py similarity index 58% rename from tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm90.py rename to test/python/backend/gemm/gemm_f16_sm90.py index 79339cae..357ec7d9 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm90.py +++ b/test/python/backend/gemm/gemm_f16_sm90.py @@ -31,19 +31,19 @@ ################################################################################################# from functools import partial -import pycutlass -from pycutlass import * -from pycutlass import library -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend import library +from cutlass.backend.test import * import unittest -from pycutlass.test.utils import LayoutCombination, get_name -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc -# Partial specialization for naming tests -name_fn = partial(get_name, element_a=cutlass.float16, element_b=cutlass.float16, arch=90) +# Partial specialziation for naming tests +name_fn = partial(get_name, element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, arch=90) def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, @@ -63,7 +63,7 @@ def add_test(cls, layouts, alignments, element_output, element_accumulator, elem :param stages: number of pipeline stages to use in the kernel :type stages: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass + :type opclass: cutlass_bindings.OpClass :param persistent: whether this is a persistent warp-specialized kernel :type persistent: bool """ @@ -74,10 +74,10 @@ def run(self): multiple test cases. """ - element_A = cutlass.float16 - element_B = cutlass.float16 - inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None - warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + element_A = cutlass_bindings.float16 + element_B = cutlass_bindings.float16 + inst_shape = [1, 1, 1] if opclass == cutlass_bindings.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass_bindings.OpClass.Simt else None math_inst = MathInstruction( instruction_shape=inst_shape, element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, @@ -98,7 +98,7 @@ def run(self): epilogue_functor = LinearCombination(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=90, tile_description=tile_description, A=A, B=B, C=C, @@ -126,57 +126,57 @@ class GemmF16Sm90(unittest.TestCase): pass -add_test_tensorop = partial(add_test, opclass=cutlass.OpClass.TensorOp) -add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) +add_test_tensorop = partial(add_test, opclass=cutlass_bindings.OpClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass_bindings.OpClass.Simt) # Tests with 1x1x1 clusters -add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], 3) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NTT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 64, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 64, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 64, 64], 5) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [2, 2, 2], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], 3) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [64, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 64, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [64, 64, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass_bindings.float16, cutlass_bindings.float16, cutlass_bindings.float16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float16, cutlass_bindings.float16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [64, 64, 64], 5) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [2, 2, 2], cutlass_bindings.float16, cutlass_bindings.float16, cutlass_bindings.float16, [1, 1, 1], [128, 128, 32], None) # Tests with different cluster shapes -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 4, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 4, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 1, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [1, 4, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 4, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [4, 1, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [4, 2, 1], [64, 128, 64], None) # Tests for persistent warp-specialized threadblocks -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 1, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 1, 1], [128, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 2, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 2, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [1, 4, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [2, 4, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 1, 1], [64, 128, 64], None, persistent=True) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, [4, 4, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 1, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 1, 1], [128, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [1, 2, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 2, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [1, 4, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [2, 4, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [4, 1, 1], [64, 128, 64], None, persistent=True) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass_bindings.float32, cutlass_bindings.float32, cutlass_bindings.float32, [4, 4, 1], [64, 128, 64], None, persistent=True) # Tests using SIMT -add_test_simt(GemmF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 128, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 128, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.NTN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [128, 64, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.TTN, [1, 1, 1], cutlass.float16, cutlass.float32, cutlass.float32, [1, 1, 1], [64, 64, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass.float16, cutlass.float16, cutlass.float16, [1, 1, 1], [128, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [64, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NTN, [1, 1, 1], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [128, 64, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.TTN, [1, 1, 1], cutlass_bindings.float16, cutlass_bindings.float32, cutlass_bindings.float32, [1, 1, 1], [64, 64, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass_bindings.float16, cutlass_bindings.float16, cutlass_bindings.float16, [1, 1, 1], [128, 128, 8], 2) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py b/test/python/backend/gemm/gemm_f32_sm80.py similarity index 72% rename from tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py rename to test/python/backend/gemm/gemm_f32_sm80.py index 0bdf0084..31c2d2d5 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +++ b/test/python/backend/gemm/gemm_f32_sm80.py @@ -30,14 +30,14 @@ # ################################################################################################# -import pycutlass -from pycutlass import * -from pycutlass.memory_manager import get_allocated_size -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.memory_manager import get_allocated_size +from cutlass.backend.test import * import unittest -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") @@ -45,8 +45,8 @@ class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add_fast_bf16 ) @@ -57,25 +57,25 @@ def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32( ) A = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=4 ) B = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -89,8 +89,8 @@ def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32( def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -101,25 +101,25 @@ def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32_128x128x32_64x64x32(self) ) A = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) B = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -132,8 +132,8 @@ def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32_128x128x32_64x64x32(self) def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_fast_accurate_f32_64x64x32_32x32x32(self): math_inst = MathInstruction( instruction_shape=[16, 8, 8], - element_a=cutlass.float32, element_b=cutlass.float32, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float32, element_b=cutlass_bindings.float32, + element_accumulator=cutlass_bindings.float32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add_fast_f32 ) @@ -144,25 +144,25 @@ def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_fast_accurate_f32_64x64x32_32 ) A = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) B = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -173,6 +173,6 @@ def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_fast_accurate_f32_64x64x32_32 self.assertTrue(test_all_gemm(operation, "universal")) if __name__ == '__main__': - pycutlass.get_memory_pool(2**24, 2**24) - pycutlass.compiler.load_from_cache() + cutlass.backend.get_memory_pool(2**24, 2**24) + cutlass.backend.compiler.load_from_cache() unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py b/test/python/backend/gemm/gemm_f64_sm80.py similarity index 75% rename from tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py rename to test/python/backend/gemm/gemm_f64_sm80.py index 4e1aff70..afccac2f 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +++ b/test/python/backend/gemm/gemm_f64_sm80.py @@ -30,13 +30,13 @@ # ################################################################################################# -import pycutlass -from pycutlass import * -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * import unittest -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") @@ -44,8 +44,8 @@ class GemmF64TensorOpSm80(unittest.TestCase): def test_SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64_32x32x16_16x16x16(self): math_inst = MathInstruction( instruction_shape=[8, 8, 4], - element_a=cutlass.float64, element_b=cutlass.float64, - element_accumulator=cutlass.float64, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float64, element_b=cutlass_bindings.float64, + element_accumulator=cutlass_bindings.float64, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -57,25 +57,25 @@ def test_SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64_32x32x16_16x16x16(self): # alignment 1 restricted for double A = TensorDescription( - element=cutlass.float64, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.ColumnMajor, alignment=1 ) B = TensorDescription( - element=cutlass.float64, layout=cutlass.RowMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.RowMajor, alignment=1 ) C = TensorDescription( - element=cutlass.float64, layout=cutlass.RowMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.RowMajor, alignment=1 ) - element_epilogue = cutlass.float64 + element_epilogue = cutlass_bindings.float64 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -88,8 +88,8 @@ def test_SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64_32x32x16_16x16x16(self): def test_SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64_64x64x16_32x32x16(self): math_inst = MathInstruction( instruction_shape=[8, 8, 4], - element_a=cutlass.float64, element_b=cutlass.float64, - element_accumulator=cutlass.float64, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.float64, element_b=cutlass_bindings.float64, + element_accumulator=cutlass_bindings.float64, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -101,25 +101,25 @@ def test_SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64_64x64x16_32x32x16(self): # alignment 1 restricted for double A = TensorDescription( - element=cutlass.float64, layout=cutlass.RowMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.RowMajor, alignment=1 ) B = TensorDescription( - element=cutlass.float64, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.ColumnMajor, alignment=1 ) C = TensorDescription( - element=cutlass.float64, layout=cutlass.RowMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.RowMajor, alignment=1 ) - element_epilogue = cutlass.float64 + element_epilogue = cutlass_bindings.float64 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -130,5 +130,5 @@ def test_SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64_64x64x16_32x32x16(self): self.assertTrue(test_all_gemm(operation, "universal")) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm90.py b/test/python/backend/gemm/gemm_f64_sm90.py similarity index 82% rename from tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm90.py rename to test/python/backend/gemm/gemm_f64_sm90.py index d4d6fdc1..3d40c70f 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm90.py +++ b/test/python/backend/gemm/gemm_f64_sm90.py @@ -31,18 +31,18 @@ ################################################################################################# from functools import partial -import pycutlass -from pycutlass import * -from pycutlass import library -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend import library +from cutlass.backend.test import * import unittest -from pycutlass.test.utils import LayoutCombination, get_name -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc -name_fn = partial(get_name, element_a=cutlass.float64, element_b=cutlass.float64, arch=90) +name_fn = partial(get_name, element_a=cutlass_bindings.float64, element_b=cutlass_bindings.float64, arch=90) def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, cluster_shape, threadblock_shape, stages, opclass): @@ -61,7 +61,7 @@ def add_test(cls, layouts, alignments, element_output, element_accumulator, elem :param stages: number of pipeline stages to use in the kernel :type stages: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass + :type opclass: cutlass_bindings.OpClass """ def run(self): @@ -69,10 +69,10 @@ def run(self): Dynamically-generated function that constructs a GEMM operation and verifies it against multiple test cases. """ - element_A = cutlass.float64 - element_B = cutlass.float64 - inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None - warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + element_A = cutlass_bindings.float64 + element_B = cutlass_bindings.float64 + inst_shape = [1, 1, 1] if opclass == cutlass_bindings.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass_bindings.OpClass.Simt else None math_inst = MathInstruction( instruction_shape=inst_shape, element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, @@ -92,7 +92,7 @@ def run(self): epilogue_functor = LinearCombination(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=90, tile_description=tile_description, A=A, B=B, C=C, @@ -115,10 +115,10 @@ class GemmF64Sm90(unittest.TestCase): pass -add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) -add_test_simt(GemmF64Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.float64, cutlass.float64, cutlass.float64, [1, 1, 1], [64, 64, 32], 2) +add_test_simt = partial(add_test, opclass=cutlass_bindings.OpClass.Simt) +add_test_simt(GemmF64Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass_bindings.float64, cutlass_bindings.float64, cutlass_bindings.float64, [1, 1, 1], [64, 64, 32], 2) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py b/test/python/backend/gemm/gemm_grouped_sm80.py similarity index 71% rename from tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py rename to test/python/backend/gemm/gemm_grouped_sm80.py index a1ee9ed3..03800fbb 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +++ b/test/python/backend/gemm/gemm_grouped_sm80.py @@ -30,22 +30,22 @@ # ################################################################################################# -import pycutlass -from pycutlass import * -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.test import * import unittest -from pycutlass.test.gemm_grouped_testbed import TestbedGrouped -from pycutlass.utils.device import device_cc +from cutlass.backend.test.gemm_grouped_testbed import TestbedGrouped +from cutlass.backend.utils.device import device_cc @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmGroupedSm80(unittest.TestCase): def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( - instruction_shape=[16, 8, 16], element_a=cutlass.float16, - element_b=cutlass.float16, element_accumulator=cutlass.float32, - opcode_class=cutlass.OpClass.TensorOp, + instruction_shape=[16, 8, 16], element_a=cutlass_bindings.float16, + element_b=cutlass_bindings.float16, element_accumulator=cutlass_bindings.float32, + opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -56,25 +56,25 @@ def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x3 ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.BatchedIdentitySwizzle + swizzling_functor = cutlass_bindings.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( @@ -90,9 +90,9 @@ def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x3 def test_SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64_64x64x16_32x32x16(self): math_inst = MathInstruction( - instruction_shape=[8, 8, 4], element_a=cutlass.float64, - element_b=cutlass.float64, element_accumulator=cutlass.float64, - opcode_class=cutlass.OpClass.TensorOp, + instruction_shape=[8, 8, 4], element_a=cutlass_bindings.float64, + element_b=cutlass_bindings.float64, element_accumulator=cutlass_bindings.float64, + opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -103,25 +103,25 @@ def test_SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64_64x64x16_32x32x16( ) A = TensorDescription( - element=cutlass.float64, layout=cutlass.RowMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.RowMajor, alignment=1 ) B = TensorDescription( - element=cutlass.float64, layout=cutlass.RowMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.RowMajor, alignment=1 ) C = TensorDescription( - element=cutlass.float64, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float64, layout=cutlass_bindings.ColumnMajor, alignment=1 ) - element_epilogue = cutlass.float64 + element_epilogue = cutlass_bindings.float64 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.BatchedIdentitySwizzle + swizzling_functor = cutlass_bindings.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( @@ -137,9 +137,9 @@ def test_SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64_64x64x16_32x32x16( def test_SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32_128x64x8_64x32x1(self): math_inst = MathInstruction( - instruction_shape=[1, 1, 1], element_a=cutlass.float32, - element_b=cutlass.float32, element_accumulator=cutlass.float32, - opcode_class=cutlass.OpClass.Simt, + instruction_shape=[1, 1, 1], element_a=cutlass_bindings.float32, + element_b=cutlass_bindings.float32, element_accumulator=cutlass_bindings.float32, + opcode_class=cutlass_bindings.OpClass.Simt, math_operation=MathOperation.multiply_add ) @@ -150,25 +150,25 @@ def test_SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32_128x64x8_64x32x1(self): ) A = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=1 ) B = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=1 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.RowMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.RowMajor, alignment=1 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.BatchedIdentitySwizzle + swizzling_functor = cutlass_bindings.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( @@ -184,9 +184,9 @@ def test_SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32_128x64x8_64x32x1(self): def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x32_cache(self): math_inst = MathInstruction( - instruction_shape=[16, 8, 16], element_a=cutlass.float16, - element_b=cutlass.float16, element_accumulator=cutlass.float32, - opcode_class=cutlass.OpClass.TensorOp, + instruction_shape=[16, 8, 16], element_a=cutlass_bindings.float16, + element_b=cutlass_bindings.float16, element_accumulator=cutlass_bindings.float32, + opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -197,25 +197,25 @@ def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x3 ) A = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) B = TensorDescription( - element=cutlass.float16, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float16, layout=cutlass_bindings.ColumnMajor, alignment=8 ) C = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.float32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) - element_epilogue = cutlass.float32 + element_epilogue = cutlass_bindings.float32 epilogue_functor = LinearCombination( C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.BatchedIdentitySwizzle + swizzling_functor = cutlass_bindings.BatchedIdentitySwizzle for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: operation = GemmOperationGrouped( @@ -231,5 +231,5 @@ def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x3 if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py b/test/python/backend/gemm/gemm_s8_sm80.py similarity index 71% rename from tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py rename to test/python/backend/gemm/gemm_s8_sm80.py index 552b3bec..1352f8e1 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +++ b/test/python/backend/gemm/gemm_s8_sm80.py @@ -30,14 +30,14 @@ # ################################################################################################# -import pycutlass -from pycutlass import * -from pycutlass.epilogue import LinearCombinationClamp -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend.epilogue import LinearCombinationClamp +from cutlass.backend.test import * import unittest -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") @@ -45,8 +45,8 @@ class GemmS8TensorOpF32Sm80(unittest.TestCase): def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self): math_inst = MathInstruction( instruction_shape=[16, 8, 32], - element_a=cutlass.int8, element_b=cutlass.int8, - element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.int8, element_b=cutlass_bindings.int8, + element_accumulator=cutlass_bindings.int32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add_saturate ) @@ -57,15 +57,15 @@ def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self): ) A = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajorInterleaved32, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajorInterleaved32, alignment=16 ) B = TensorDescription( - element=cutlass.int8, layout=cutlass.RowMajorInterleaved32, + element=cutlass_bindings.int8, layout=cutlass_bindings.RowMajorInterleaved32, alignment=16 ) C = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajorInterleaved32, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajorInterleaved32, alignment=8 ) @@ -73,7 +73,7 @@ def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self): C.element, C.alignment ) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -86,8 +86,8 @@ def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self): def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_256x128x128_64x64x128(self): math_inst = MathInstruction( instruction_shape=[16, 8, 32], - element_a=cutlass.int8, element_b=cutlass.int8, - element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.int8, element_b=cutlass_bindings.int8, + element_accumulator=cutlass_bindings.int32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -98,15 +98,15 @@ def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_256x128x128_64x64x128(self): ) A = TensorDescription( - element=cutlass.int8, layout=cutlass.RowMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.RowMajor, alignment=16 ) B = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajor, alignment=16 ) C = TensorDescription( - element=cutlass.int8, layout=cutlass.RowMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.RowMajor, alignment=16 ) @@ -114,7 +114,7 @@ def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_256x128x128_64x64x128(self): C.element, C.alignment ) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -127,8 +127,8 @@ def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_256x128x128_64x64x128(self): def test_SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_128x128x128_64x64x128(self): math_inst = MathInstruction( instruction_shape=[16, 8, 32], - element_a=cutlass.int8, element_b=cutlass.int8, - element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.int8, element_b=cutlass_bindings.int8, + element_accumulator=cutlass_bindings.int32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -139,15 +139,15 @@ def test_SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_128x128x128_64x64x128(self): ) A = TensorDescription( - element=cutlass.int8, layout=cutlass.RowMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.RowMajor, alignment=16 ) B = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajor, alignment=16 ) C = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajor, alignment=16 ) @@ -155,7 +155,7 @@ def test_SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_128x128x128_64x64x128(self): C.element, C.alignment ) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -168,8 +168,8 @@ def test_SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_128x128x128_64x64x128(self): def test_SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32_128x128x128_64x64x128(self): math_inst = MathInstruction( instruction_shape=[16, 8, 32], - element_a=cutlass.int8, element_b=cutlass.int8, - element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.int8, element_b=cutlass_bindings.int8, + element_accumulator=cutlass_bindings.int32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -180,26 +180,26 @@ def test_SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32_128x128x128_64x64x128(self) ) A = TensorDescription( - element=cutlass.int8, layout=cutlass.RowMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.RowMajor, alignment=16 ) B = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajor, alignment=16 ) C = TensorDescription( - element=cutlass.int32, layout=cutlass.ColumnMajor, + element=cutlass_bindings.int32, layout=cutlass_bindings.ColumnMajor, alignment=4 ) - element_epilogue = cutlass.int32 + element_epilogue = cutlass_bindings.int32 epilogue_functor = LinearCombinationClamp( C.element, C.alignment, math_inst.element_accumulator, element_epilogue ) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -212,8 +212,8 @@ def test_SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32_128x128x128_64x64x128(self) def test_SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32_128x128x128_64x64x128(self): math_inst = MathInstruction( instruction_shape=[16, 8, 32], - element_a=cutlass.int8, element_b=cutlass.int8, - element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + element_a=cutlass_bindings.int8, element_b=cutlass_bindings.int8, + element_accumulator=cutlass_bindings.int32, opcode_class=cutlass_bindings.OpClass.TensorOp, math_operation=MathOperation.multiply_add ) @@ -224,26 +224,26 @@ def test_SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32_128x128x128_64x64x128(self) ) A = TensorDescription( - element=cutlass.int8, layout=cutlass.RowMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.RowMajor, alignment=16 ) B = TensorDescription( - element=cutlass.int8, layout=cutlass.ColumnMajor, + element=cutlass_bindings.int8, layout=cutlass_bindings.ColumnMajor, alignment=16 ) C = TensorDescription( - element=cutlass.int32, layout=cutlass.RowMajor, + element=cutlass_bindings.int32, layout=cutlass_bindings.RowMajor, alignment=4 ) - element_epilogue = cutlass.int32 + element_epilogue = cutlass_bindings.int32 epilogue_functor = LinearCombinationClamp( C.element, C.alignment, math_inst.element_accumulator, element_epilogue ) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=80, tile_description=tile_description, @@ -257,5 +257,5 @@ def test_SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32_128x128x128_64x64x128(self) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm90.py b/test/python/backend/gemm/gemm_s8_sm90.py similarity index 73% rename from tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm90.py rename to test/python/backend/gemm/gemm_s8_sm90.py index 682ab7d5..2491f3a8 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm90.py +++ b/test/python/backend/gemm/gemm_s8_sm90.py @@ -31,18 +31,18 @@ ################################################################################################# from functools import partial -import pycutlass -from pycutlass import * -from pycutlass import library -from pycutlass.test import * +import cutlass.backend +from cutlass.backend import * +from cutlass.backend import library +from cutlass.backend.test import * import unittest -from pycutlass.test.utils import LayoutCombination, get_name -from pycutlass.test.gemm_testbed import test_all_gemm -from pycutlass.utils.device import device_cc +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.gemm_testbed import test_all_gemm +from cutlass.backend.utils.device import device_cc -name_fn = partial(get_name, element_a=cutlass.float16, element_b=cutlass.float16, arch=90) +name_fn = partial(get_name, element_a=cutlass_bindings.float16, element_b=cutlass_bindings.float16, arch=90) def add_test(cls, layouts, alignments, element_output, element_accumulator, element_epilogue, cluster_shape, threadblock_shape, stages, opclass, persistent=False): @@ -61,7 +61,7 @@ def add_test(cls, layouts, alignments, element_output, element_accumulator, elem :param stages: number of pipeline stages to use in the kernel :type stages: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass + :type opclass: cutlass_bindings.OpClass :param persistent: whether this is a persistent warp-specialized kernel :type persistent: bool """ @@ -71,10 +71,10 @@ def run(self): Dynamically-generated function that constructs a GEMM operation and verifies it against multiple test cases. """ - element_A = cutlass.int8 - element_B = cutlass.int8 - inst_shape = [1, 1, 1] if opclass == cutlass.OpClass.Simt else None - warp_count = [2, 2, 1] if opclass == cutlass.OpClass.Simt else None + element_A = cutlass_bindings.int8 + element_B = cutlass_bindings.int8 + inst_shape = [1, 1, 1] if opclass == cutlass_bindings.OpClass.Simt else None + warp_count = [2, 2, 1] if opclass == cutlass_bindings.OpClass.Simt else None math_inst = MathInstruction( instruction_shape=inst_shape, element_a=element_A, element_b=element_B, element_accumulator=element_accumulator, @@ -93,13 +93,13 @@ def run(self): B = TensorDescription(element=element_B, layout=layouts[1], alignment=alignments[1]) C = TensorDescription(element=element_output, layout=layouts[2], alignment=alignments[2]) - if opclass == cutlass.OpClass.Simt: + if opclass == cutlass_bindings.OpClass.Simt: epilogue_functor_cls = LinearCombinationClamp else: epilogue_functor_cls = LinearCombination epilogue_functor = epilogue_functor_cls(C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - swizzling_functor = cutlass.IdentitySwizzle1 + swizzling_functor = cutlass_bindings.IdentitySwizzle1 operation = GemmOperationUniversal( arch=90, tile_description=tile_description, A=A, B=B, C=C, @@ -127,28 +127,28 @@ class GemmS8Sm90(unittest.TestCase): pass -add_test_tensorop = partial(add_test, opclass=cutlass.OpClass.TensorOp) -add_test_simt = partial(add_test, opclass=cutlass.OpClass.Simt) +add_test_tensorop = partial(add_test, opclass=cutlass_bindings.OpClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass_bindings.OpClass.Simt) # Tests with 1x1x1 clusters -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNN, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], 3) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 8], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [64, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 64, 32], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [4, 4, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNN, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [128, 128, 128], 3) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 8], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [64, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [128, 64, 32], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [4, 4, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [128, 128, 128], None) # Tests with different cluster shapes -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [2, 2, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [1, 4, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [4, 4, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [2, 2, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 4, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [4, 4, 1], [128, 128, 128], None) # Tests with persistent warp-specialized threadblocks -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.int8, cutlass.int32, cutlass.int32, [2, 1, 1], [128, 128, 128], None, persistent=True) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [2, 1, 1], [128, 128, 128], None, persistent=True) # Tests for SIMT -add_test_simt(GemmS8Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.int8, cutlass.int32, cutlass.int32, [1, 1, 1], [64, 32, 8], 2) +add_test_simt(GemmS8Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass_bindings.int8, cutlass_bindings.int32, cutlass_bindings.int32, [1, 1, 1], [64, 32, 8], 2) if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py b/test/python/backend/gemm/run_all_tests.py similarity index 96% rename from tools/library/scripts/pycutlass/test/gemm/run_all_tests.py rename to test/python/backend/gemm/run_all_tests.py index 38f040b1..12b8c609 100644 --- a/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +++ b/test/python/backend/gemm/run_all_tests.py @@ -30,11 +30,11 @@ # ################################################################################################# -import pycutlass +import cutlass.backend import unittest if __name__ == '__main__': - pycutlass.get_memory_pool(2**30, 2**30) + cutlass.backend.get_memory_pool(2**30, 2**30) loader = unittest.TestLoader() tests = loader.discover('./', 'gemm_*.py') testRunner = unittest.runner.TextTestRunner() diff --git a/test/python/emit/pytorch.py b/test/python/emit/pytorch.py new file mode 100644 index 00000000..3ac1c9b0 --- /dev/null +++ b/test/python/emit/pytorch.py @@ -0,0 +1,161 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests emitting a CUTLASS kernel to a PyTorch CUDA extension +""" + +import random +import tempfile +import unittest + +import cutlass + +if cutlass.utils.datatypes.torch_available: + import torch + + +def _initialize(dtype, M: int, N: int, K: int): + """ + Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K + + :param dtype: data type of tensors + :param M: M dimension of GEMM problem + :type M: int + :param N: N dimension of GEMM problem + :type N: int + :param K: N dimension of GEMM problem + :type K: int + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + sizes = [(M, K), (K, N), (M, N), (M, N)] + return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes] + + +def _generate_problems(dtype, num): + """ + Utility function to generate `num` GEMMs of random sizes + + :param dtype: data type of tensors + :param num: number of GEMMs to generate + :type num: int + + :return: lists of A, B, C, and D tensors + :rtype: list + """ + valid_sizes = [128, 256, 512, 1024] + As, Bs, Cs, Ds = [], [], [], [] + for _ in range(num): + M, N, K = [random.choice(valid_sizes) for _ in range(3)] + A, B, C, D = _initialize(dtype, M, N, K) + As.append(A) + Bs.append(B) + Cs.append(C) + Ds.append(D) + return As, Bs, Cs, Ds + + +@unittest.skipIf(not cutlass.utils.datatypes.torch_available, 'PyTorch must be available to run PyTorch extension tests') +class PyTorchExtensionTest(unittest.TestCase): + + def test_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass.op.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor) + plan.activation = cutlass.epilogue.relu + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + A, B, C, _ = _initialize(dtype, 1024, 256, 512) + + D_ref = torch.nn.functional.relu(A @ B) + D = mod.run(A, B) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0, 0.0) + assert torch.allclose(D, D_ref) + + alpha = 2.0 + beta = -1.0 + D_ref = torch.nn.functional.relu((A @ B) * alpha + (beta * C)) + D = mod.run(A, B, C, alpha, beta) + assert torch.allclose(D, D_ref) + + def test_grouped_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + As, Bs, Cs, _ = _generate_problems(dtype, 50) + + def check_all(X, Y): + for x, y in zip(X, Y): + assert torch.allclose(x, y) + + Ds_ref = [a @ b for a, b in zip(As, Bs)] + Ds = mod.run(As, Bs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0, 0.0) + check_all(Ds, Ds_ref) + + alpha = 2.0 + beta = -1.0 + Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] + Ds = mod.run(As, Bs, Cs, alpha, beta) + check_all(Ds, Ds_ref) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_f16_sm80.py b/test/python/gemm/gemm_f16_sm80.py new file mode 100644 index 00000000..0c32fa52 --- /dev/null +++ b/test/python/gemm/gemm_f16_sm80.py @@ -0,0 +1,167 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM80 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 80 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.f16) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + threadblock_shape, warp_count, stages, opclass, swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param swizzle: threadblock swizzling functor + """ + cluster_shape = [1, 1, 1] + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.f16 + element_B = cutlass.DataType.f16 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.warp_count = warp_count + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal')) + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmF16Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmF16Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +# Tests using TensorOp +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) + +add_test_tensorop(GemmF16Sm80, LayoutCombination.NNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.NNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.NTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.NTT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TTT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 128, 32], [1, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 64, 32], [2, 1, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 64], [1, 1, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 64], [1, 1, 1], 5) +add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [2, 2, 2], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 32], [2, 2, 1], 3) + +# Tests using SIMT +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(GemmF16Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 8], [2, 2, 1], 2) +add_test_simt(GemmF16Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [64, 128, 8], [1, 2, 1], 2) +add_test_simt(GemmF16Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [128, 64, 8], [2, 1, 1], 2) +add_test_simt(GemmF16Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 8], [1, 1, 1], 2) +add_test_simt(GemmF16Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 8], [2, 2, 1], 2) + +# Stream K tests +add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(GemmF16Sm80StreamK, LayoutCombination.NNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_streamk(GemmF16Sm80StreamK, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 64], [1, 1, 1], 5) + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_f16_sm90.py b/test/python/gemm/gemm_f16_sm90.py new file mode 100644 index 00000000..8b5ce3f5 --- /dev/null +++ b/test/python/gemm/gemm_f16_sm90.py @@ -0,0 +1,173 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM90 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 90 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.f16) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + cluster_shape, threadblock_shape, stages, opclass, + kernel_schedule=cutlass.KernelScheduleType.ScheduleAuto, + swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param cluster_shape: dimensions of threadblock cluster + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param kernel_schedule: kernel schedule type + :type kernel_schedule: cutlass.KernelScheduleType + :param swizzle: threadblock swizzling functor + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.f16 + element_B = cutlass.DataType.f16 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + td.kernel_schedule = kernel_schedule + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal')) + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, + opclass=binding_opclass(opclass), kernel_schedule=kernel_schedule) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +class GemmF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], 3) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 32], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [64, 64, 64], 5) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [2, 2, 2], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 32], None) + +# Tests with different cluster shapes +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f16, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 2, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [1, 4, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 4, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [4, 1, 1], [64, 128, 64], None) +add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [4, 2, 1], [64, 128, 64], None) + +# Tests for different schedule modes +add_test_schedule = partial(add_test, GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, opclass=cutlass.OpcodeClass.TensorOp) +add_test_schedule([1, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong) +add_test_schedule([1, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) +add_test_schedule([2, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong) +add_test_schedule([2, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) +add_test_schedule([2, 1, 1], [256, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) +add_test_schedule([2, 1, 1], [128, 128, 64], 5, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong) +add_test_schedule([2, 1, 1], [128, 128, 64], 5, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) + +# Tests using SIMT +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) +add_test_simt(GemmF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [64, 128, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 64, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [64, 64, 8], 2) +add_test_simt(GemmF16Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 8], 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_f32_sm80.py b/test/python/gemm/gemm_f32_sm80.py new file mode 100644 index 00000000..beb19f50 --- /dev/null +++ b/test/python/gemm/gemm_f32_sm80.py @@ -0,0 +1,155 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F32 operands on SM80 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 80 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.f32) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + threadblock_shape, warp_count, stages, opclass, swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param swizzle: threadblock swizzling functor + """ + + cluster_shape = [1, 1, 1] + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.f32 + element_B = cutlass.DataType.f32 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.warp_count = warp_count + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + self.assertTrue(test_all_gemm(op, 'universal')) + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) + setattr(cls, name, run) + + return run + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmF32Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmF32Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +# Tests using TensorOp +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) + +add_test_tensorop(GemmF32Sm80, LayoutCombination.NNN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF32Sm80, LayoutCombination.NNT, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop(GemmF32Sm80, LayoutCombination.NTN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [64, 128, 32], [1, 2, 1], 3) +add_test_tensorop(GemmF32Sm80, LayoutCombination.NTN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [64, 64, 32], [1, 1, 1], 4) +# Tests using SIMT +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(GemmF32Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 8], [2, 2, 1], 2) +add_test_simt(GemmF32Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [64, 128, 8], [1, 2, 1], 2) +add_test_simt(GemmF32Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [128, 64, 8], [2, 1, 1], 2) +add_test_simt(GemmF32Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [64, 64, 8], [1, 1, 1], 2) +add_test_simt(GemmF32Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 8], [2, 2, 1], 2) + +# Stream K tests +add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(GemmF32Sm80StreamK, LayoutCombination.TTN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_f64_sm80.py b/test/python/gemm/gemm_f64_sm80.py new file mode 100644 index 00000000..10c43ddf --- /dev/null +++ b/test/python/gemm/gemm_f64_sm80.py @@ -0,0 +1,156 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM80 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 80 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.f64) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + threadblock_shape, warp_count, stages, opclass, swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param swizzle: threadblock swizzling functor + """ + + cluster_shape = [1, 1, 1] + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.f64 + element_B = cutlass.DataType.f64 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.warp_count = warp_count + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + self.assertTrue(test_all_gemm(op, 'universal')) + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmF64Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmF64Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +# Tests using TensorOp +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) + +add_test_tensorop(GemmF64Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 16], [4, 2, 1], 3) +add_test_tensorop(GemmF64Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [64, 64, 16], [2, 2, 1], 4) +add_test_tensorop(GemmF64Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [32, 32, 16], [2, 1, 1], 5) + +# Tests using SIMT +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(GemmF64Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 8], [2, 2, 1], 2) +add_test_simt(GemmF64Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [64, 128, 8], [1, 2, 1], 2) +add_test_simt(GemmF64Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 64, 8], [2, 1, 1], 2) +add_test_simt(GemmF64Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [64, 64, 8], [1, 1, 1], 2) +add_test_simt(GemmF64Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 8], [2, 2, 1], 2) + +# Stream K tests +add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(GemmF64Sm80StreamK, LayoutCombination.NTT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 16], [4, 2, 1], 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_f64_sm90.py b/test/python/gemm/gemm_f64_sm90.py new file mode 100644 index 00000000..4a51df99 --- /dev/null +++ b/test/python/gemm/gemm_f64_sm90.py @@ -0,0 +1,142 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM90 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 90 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.f64) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + cluster_shape, threadblock_shape, stages, opclass, persistent=False, swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param cluster_shape: dimensions of threadblock cluster + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param persistent: whether this is a persistent warp-specialized kernel + :type persistent: bool + :param swizzle: threadblock swizzling functor + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.f64 + element_B = cutlass.DataType.f64 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + td.persistent = persistent + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal')) + + if persistent: + suffix = "_persistent" + else: + suffix = "" + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, + opclass=binding_opclass(opclass), suffix=suffix) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +class GemmF64Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) + +add_test_tensorop(GemmF64Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [128, 128, 32], 3) +add_test_tensorop(GemmF64Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [128, 128, 32], 3) +add_test_simt(GemmF64Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [128, 128, 8], 2) +add_test_simt(GemmF64Sm90, LayoutCombination.TTT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [64, 128, 8], 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_s8_sm80.py b/test/python/gemm/gemm_s8_sm80.py new file mode 100644 index 00000000..128f5e58 --- /dev/null +++ b/test/python/gemm/gemm_s8_sm80.py @@ -0,0 +1,156 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM80 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 80 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.s8) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + threadblock_shape, warp_count, stages, opclass, swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param swizzle: threadblock swizzling functor + """ + + cluster_shape = [1, 1, 1] + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.s8 + element_B = cutlass.DataType.s8 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.warp_count = warp_count + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + self.assertTrue(test_all_gemm(op, 'universal')) + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmS8Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmS8Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +# Tests using TensorOp +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) + +add_test_tensorop(GemmS8Sm80, LayoutCombination.TNN, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [256, 128, 64], [4, 2, 1], 3) +add_test_tensorop(GemmS8Sm80, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [128, 256, 64], [2, 4, 1], 3) +add_test_tensorop(GemmS8Sm80, LayoutCombination.TNN, [16, 16, 4], cutlass.DataType.s32, cutlass.DataType.s32, [64, 64, 64], [1, 1, 1], 4) + +# Tests using SIMT +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(GemmS8Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [128, 128, 8], [2, 2, 1], 2) +add_test_simt(GemmS8Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [64, 128, 8], [1, 2, 1], 2) +add_test_simt(GemmS8Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [128, 64, 8], [2, 1, 1], 2) +add_test_simt(GemmS8Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.s32, cutlass.DataType.s32, [64, 64, 8], [1, 1, 1], 2) +add_test_simt(GemmS8Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.s32, cutlass.DataType.s32, [128, 128, 8], [2, 2, 1], 2) + +# Stream K tests +add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(GemmS8Sm80StreamK, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [128, 256, 64], [2, 4, 1], 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/gemm_s8_sm90.py b/test/python/gemm/gemm_s8_sm90.py new file mode 100644 index 00000000..376c80b5 --- /dev/null +++ b/test/python/gemm/gemm_s8_sm90.py @@ -0,0 +1,155 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial + +import cutlass +from cutlass.utils.datatypes import binding_opclass, binding_type +from cutlass.backend.test.gemm_testbed import test_all_gemm +import unittest + +from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.utils.device import device_cc + +cc = 90 + +# Partial specialziation for naming tests +bound_type = binding_type(cutlass.DataType.s8) +name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) + + +def add_test(cls, layouts, alignments, element_output, element_accumulator, + cluster_shape, threadblock_shape, stages, opclass, persistent=False, swizzle=None): + """ + Create a test-running function with the given specification and set it as a method of `cls`. + + :param cls: class to which the generated method will be added + :type cls: type + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param cluster_shape: dimensions of threadblock cluster + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param persistent: whether this is a persistent warp-specialized kernel + :type persistent: bool + :param swizzle: threadblock swizzling functor + """ + + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = cutlass.DataType.s8 + element_B = cutlass.DataType.s8 + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + td.persistent = persistent + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal')) + + if persistent: + suffix = "_persistent" + else: + suffix = "" + + element_epilogue = element_accumulator + name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), + binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, + opclass=binding_opclass(opclass), suffix=suffix) + setattr(cls, name, run) + + return run + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +class GemmS8Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNN, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], 3) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 8], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [64, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 64, 32], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [4, 4, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], None) + +# Tests with different cluster shapes +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [2, 2, 1], [128, 128, 128], None) +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 4, 1], [128, 128, 128], None) + +# Tests with persistent warp-specialized threadblocks +add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [2, 1, 1], [128, 128, 128], None, persistent=True) + +# Tests for SIMT +add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) +add_test_simt(GemmS8Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [64, 32, 8], 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/gemm/run_all_tests.py b/test/python/gemm/run_all_tests.py new file mode 100644 index 00000000..57b23a22 --- /dev/null +++ b/test/python/gemm/run_all_tests.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + tests = loader.discover('./', 'gemm_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/test/python/interface/gemm_interface.py b/test/python/interface/gemm_interface.py new file mode 100644 index 00000000..7696a5b0 --- /dev/null +++ b/test/python/interface/gemm_interface.py @@ -0,0 +1,354 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level GEMM interface +""" + +from math import ceil +import unittest + +import cutlass +import cutlass_bindings +import cutlass.utils.datatypes as datatypes +from cutlass.backend.utils.device import device_cc + + +class GemmEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Gemm interface + """ + def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, + layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + self.plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) + self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default GEMM + :type other_plan: cutlass.op.Gemm + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) + + # Compare whether the operations are equal by comparing the C++ code that would be emitted for them + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types + and layouts for constructing the Gemm interface + """ + if not datatypes.numpy_available: + return + + # Test when specifying all parameters + plan_other = cutlass.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass.op.Gemm(element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_B=self.layout_B, layout_C=self.layout_C, + element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if self.element_A == self.element_B and self.layout_A == self.layout_B: + plan_other = cutlass.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if self.element_C == self.element_accumulator: + plan_other = cutlass.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, + layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_other = cutlass.op.Gemm(element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend + """ + if not datatypes.numpy_available: + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + layout_to_order = { + cutlass.LayoutType.RowMajor: 'C', + cutlass.LayoutType.ColumnMajor: 'F' + } + size = (2, 2) + A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) + B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) + C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) + D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) + + # Test when specifying all parameters via tensors + plan_np = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if type_A == type_B and self.layout_A == self.layout_B: + plan_np = cutlass.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass.op.Gemm(A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_np = cutlass.op.Gemm(element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + + +class GemmEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Gemm interface + """ + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, + element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, + layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, + element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, + layout_A=cutlass.LayoutType.ColumnMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.ColumnMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): + gemm_eq = GemmEquivalence( + element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, + element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, + layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") + def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): + gemm_eq = GemmEquivalence( + element_A=cutlass.DataType.f64, element_B=cutlass.DataType.f64, element_C=cutlass.DataType.f64, + element_D=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, + layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor, layout_C=cutlass.LayoutType.RowMajor, + alignment_A=1, alignment_B=1, alignment_C=1) + gemm_eq.test_all() + + +class ExpectException: + """ + Utility class to assert that an exception was raised when expected + + Example: + + .. highlight:: python + .. code-block:: python + + with ExceptionExpected(True, 'Division by zero'): + x = 1.0 / 0.0 + + :param exception_expected: whether an exception is expected to be raised + :type exception_expected: bool + :param message: message to print if an exception is raised when not expected or vice versa + :type message: str + """ + def __init__(self, exception_expected: bool, message: str = ''): + self.exception_expected = exception_expected + self.message = message + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + exception_raised = exc_type is not None + assert self.exception_expected == exception_raised, self.message + + # Suppress the exception + return True + + +class GemmErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass.op.Gemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + + with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) + + def test_tensorop_availability(self): + """ + Tests case in which only SIMT operations are available but TensorOp is requested + """ + cc = device_cc() + + # F64 Tensor Core operations are only avaiable on devices with CC >= 80 + supports_tensorop_f64 = cc >= 80 + plan = cutlass.op.Gemm(cc=cc, element=cutlass.DataType.f64, layout=cutlass.LayoutType.RowMajor) + + error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' + with ExpectException(not supports_tensorop_f64, error_msg): + plan.opclass = cutlass.OpcodeClass.TensorOp + + expected_opclass = cutlass.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass.OpcodeClass.Simt + assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") + def test_opclass_switch(self): + """ + Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) + """ + plan = cutlass.op.Gemm( element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + assert plan.opclass == cutlass.OpcodeClass.TensorOp + + # Ensure that all tile descriptions have opclass of TensorOp + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp + + plan.opclass = cutlass.OpcodeClass.Simt + + # Ensure that all tile descriptions have opclass of Simt + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_bindings.OpClass.Simt + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + cc = device_cc() + plan = cutlass.op.Gemm(cc=cc, element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + td = plan.tile_descriptions()[0] + stages = td.stages + + # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage + # count should be used + with ExpectException(cc < 90, f'Requested zero stages'): + td.stages = 0 + plan.construct(td) + + with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): + td.stages = 3 + plan.construct(td) + + with ExpectException(True, f'Requested too many stages'): + td.stages = 100 + plan.construct(td) + + # Reset stage count + td.stages = stages + + cluster_shape = td.cluster_shape + with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): + td.cluster_shape = [2, 1, 1] + plan.construct(td) + + # Reset cluster shape + td.cluster_shape = cluster_shape + + kernel_schedule = td.kernel_schedule + with ExpectException(cc < 90, f'Requested a persistent kernel on SM{cc}'): + td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong + plan.construct(td) + + # Ensure that all returned tile descriptions are unique + ops = {} + for i, td in enumerate(plan.tile_descriptions()): + op = plan.construct(td) + code_str = op.rt_module.emit() + if code_str in ops: + conflicting_td = ops[code_str] + assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' + + +if __name__ == '__main__': + unittest.main() diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index c4e0634c..48a55d33 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -121,6 +121,8 @@ set(SUBDIRS reduction util pipeline + substrate + cluster_launch ) if(TARGET nvidia::nvrtc AND TARGET nvidia::cuda_driver) diff --git a/test/unit/cluster_launch/CMakeLists.txt b/test/unit/cluster_launch/CMakeLists.txt new file mode 100644 index 00000000..5f410c8f --- /dev/null +++ b/test/unit/cluster_launch/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cluster_launch + cluster_launch.cu +) diff --git a/test/unit/cluster_launch/cluster_launch.cu b/test/unit/cluster_launch/cluster_launch.cu new file mode 100644 index 00000000..9f755ea8 --- /dev/null +++ b/test/unit/cluster_launch/cluster_launch.cu @@ -0,0 +1,370 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit test for the launch_on_cluster function +*/ + +#include "../common/cutlass_unit_test.h" +#include "cutlass/cluster_launch.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include +#include +#include + +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + +namespace { // (anonymous) + +// Using a struct instead of a lambda makes it possible +// to name the deleter type without std::function +// (which type-erases). +struct scalar_deleter { + void operator() (float* p) { + if (p != nullptr) { + cudaFree(p); + } + } +}; + +using scalar_device_pointer = std::unique_ptr; + +// Each test needs to initialize this anew, +// from a scalar instance that is in scope during the test. +__device__ float* scalar_ptr_gpu; + +// A single scalar value on device. +// The constructor allocates space on device for one value, +// copies the value to device, and sets the global pointer +// `scalar_ptr_gpu` (see above) to point to it. +// sync_to_host() copies that value back to host. +// +// This class exists only for the tests in this file. +// In order to know whether a kernel that launch_on_cluster +// claimed to launch actually got launched, each kernel +// performs a side effect: it modifies the scalar value +// through the scalar_ptr_gpu value. +// It performs a side effect through a global, +// rather than through an argument, +// so that we can test kernel launch +// with kernels that take zero parameters. +class scalar { +private: + static constexpr std::size_t num_bytes = sizeof(float); + +public: + scalar(float value) : value_host_(value) + { + float* ptr_gpu_raw = nullptr; + auto err = cudaMalloc(&ptr_gpu_raw, num_bytes); + assert(err == cudaSuccess); + + scalar_device_pointer ptr_gpu{ptr_gpu_raw, scalar_deleter{}}; + err = cudaMemcpy(ptr_gpu.get(), &value_host_, + num_bytes, cudaMemcpyHostToDevice); + assert(err == cudaSuccess); + ptr_gpu_ = std::move(ptr_gpu); + upload_device_pointer(); + } + + float sync_to_host() + { + auto err = cudaMemcpy(&value_host_, ptr_gpu_.get(), + num_bytes, cudaMemcpyDeviceToHost); + assert(err == cudaSuccess); + return value_host_; + } + +private: + void upload_device_pointer() + { + float* ptr_raw = ptr_gpu_.get(); + auto err = cudaMemcpyToSymbol(scalar_ptr_gpu, &ptr_raw, sizeof(float*)); + assert(err == cudaSuccess); + } + + float value_host_ = 0.0; + scalar_device_pointer ptr_gpu_; +}; + +template +CUTE_DEVICE void check_cluster_shape() { + [[maybe_unused]] const dim3 cluster_shape = cute::cluster_shape(); + assert(cluster_shape.x == cluster_x); + assert(cluster_shape.y == cluster_y); + assert(cluster_shape.z == cluster_z); +} + +template +__global__ void kernel_0() +{ + check_cluster_shape(); + + // Write to global memory, so that we know + // whether the kernel actually ran. + const dim3 block_id = cute::block_id_in_cluster(); + if (threadIdx.x == 0 && block_id.x == 0 && block_id.y == 0 && block_id.z == 0) { + *scalar_ptr_gpu = 0.1f; + } +} + +template +__global__ void kernel_1(int p0) +{ + check_cluster_shape(); + assert(p0 == expected_p0); + + // Write to global memory, so that we know + // whether the kernel actually ran. + const dim3 block_id = cute::block_id_in_cluster(); + if (threadIdx.x == 0 && block_id.x == 0 && block_id.y == 0 && block_id.z == 0) { + *scalar_ptr_gpu = 1.2f; + } +} + +template +__global__ void kernel_2(int p0, void* p1, int p2) +{ + check_cluster_shape(); + assert(p0 == expected_p0); + assert(p1 == nullptr); + assert(p2 == expected_p2); + + // Write to global memory, so that we know + // whether the kernel actually ran. + const dim3 block_id = cute::block_id_in_cluster(); + if (threadIdx.x == 0 && block_id.x == 0 && block_id.y == 0 && block_id.z == 0) { + *scalar_ptr_gpu = 2.3f; + } +} + +struct OverloadedOperatorAmpersand { + struct tag_t {}; + + // Test that kernel launch uses the actual address, + // instead of any overloaded operator& that might exist. + CUTE_HOST_DEVICE tag_t operator& () const { + return {}; + } + + int x = 0; + int y = 0; + int z = 0; + int w = 0; +}; + +static_assert(sizeof(OverloadedOperatorAmpersand) == 4 * sizeof(int)); + +template +__global__ void kernel_3(int p0, OverloadedOperatorAmpersand p1, std::uint64_t p2) +{ + check_cluster_shape(); + assert(p0 == expected_p0); + assert(p1.x == expected_p1_x); + assert(p1.y == expected_p1_y); + assert(p1.z == expected_p1_z); + assert(p1.w == expected_p1_w); + assert(p2 == expected_p2); + + // Write to global memory, so that we know + // whether the kernel actually ran. + const dim3 block_id = cute::block_id_in_cluster(); + if (threadIdx.x == 0 && block_id.x == 0 && block_id.y == 0 && block_id.z == 0) { + *scalar_ptr_gpu = 3.4f; + } +} + +} // namespace (anonymous) + +TEST(SM90_ClusterLaunch, Kernel_0) +{ + scalar global_value(-1.0f); + + const dim3 grid_dims{2, 1, 1}; + const dim3 block_dims{1, 1, 1}; + const dim3 cluster_dims{grid_dims.x * block_dims.x, 1, 1}; + const int smem_size_in_bytes = 0; + cutlass::ClusterLaunchParams params{ + grid_dims, block_dims, cluster_dims, smem_size_in_bytes}; + + void const* kernel_ptr = reinterpret_cast(&kernel_0<2, 1, 1>); + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, + kernel_ptr); + ASSERT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t result = cudaDeviceSynchronize(); + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST("Kernel launch succeeded\n"); + } + else { + CUTLASS_TRACE_HOST("Kernel launch FAILED\n"); + cudaError_t error = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << "Error at kernel sync: " + << cudaGetErrorString(error) << "\n"; + } + + ASSERT_EQ(global_value.sync_to_host(), 0.1f); +} + +TEST(SM90_ClusterLaunch, Kernel_1) +{ + scalar global_value(-1.0f); + + const dim3 grid_dims{2, 1, 1}; + const dim3 block_dims{1, 1, 1}; + const dim3 cluster_dims{grid_dims.x * block_dims.x, 1, 1}; + const int smem_size_in_bytes = 0; + cutlass::ClusterLaunchParams params{ + grid_dims, block_dims, cluster_dims, smem_size_in_bytes}; + + constexpr int expected_p0 = 42; + void const* kernel_ptr = reinterpret_cast(&kernel_1<2, 1, 1, expected_p0>); + const int p0 = expected_p0; + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, + kernel_ptr, p0); + ASSERT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t result = cudaDeviceSynchronize(); + if (result == cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("Kernel launch succeeded\n"); +#endif + } + else { + CUTLASS_TRACE_HOST("Kernel launch FAILED\n"); + cudaError_t error = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << "Error at kernel sync: " + << cudaGetErrorString(error) << "\n"; + } + + ASSERT_EQ(global_value.sync_to_host(), 1.2f); +} + +TEST(SM90_ClusterLaunch, Kernel_2) +{ + scalar global_value(-1.0f); + + const dim3 grid_dims{2, 1, 1}; + const dim3 block_dims{1, 1, 1}; + const dim3 cluster_dims{grid_dims.x * block_dims.x, 1, 1}; + const int smem_size_in_bytes = 0; + cutlass::ClusterLaunchParams params{ + grid_dims, block_dims, cluster_dims, smem_size_in_bytes}; + + constexpr int expected_p0 = 42; + constexpr int expected_p2 = 43; + + int p0 = expected_p0; + int* p1 = nullptr; + int p2 = expected_p2; + + void const* kernel_ptr = reinterpret_cast( + &kernel_2<2, 1, 1, expected_p0, expected_p2>); + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, + kernel_ptr, p0, p1, p2); + ASSERT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t result = cudaDeviceSynchronize(); + if (result == cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("Kernel launch succeeded\n"); +#endif + } + else { + CUTLASS_TRACE_HOST("Kernel launch FAILED\n"); + cudaError_t error = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << "Error at kernel sync: " + << cudaGetErrorString(error) << "\n"; + } + + ASSERT_EQ(global_value.sync_to_host(), 2.3f); +} + +TEST(SM90_ClusterLaunch, Kernel_3) +{ + scalar global_value(-1.0f); + + const dim3 grid_dims{2, 1, 1}; + const dim3 block_dims{1, 1, 1}; + const dim3 cluster_dims{grid_dims.x * block_dims.x, 1, 1}; + const int smem_size_in_bytes = 0; + cutlass::ClusterLaunchParams params{ + grid_dims, block_dims, cluster_dims, smem_size_in_bytes}; + + constexpr int expected_p0 = 42; + constexpr int expected_p1_x = 1; + constexpr int expected_p1_y = 2; + constexpr int expected_p1_z = 3; + constexpr int expected_p1_w = 4; + constexpr std::uint64_t expected_p2 = 1'000'000'000'000uLL; + + int p0 = expected_p0; + OverloadedOperatorAmpersand p1{expected_p1_x, + expected_p1_y, expected_p1_z, expected_p1_w}; + // Verify that operator& is overloaded for this type. + static_assert(! std::is_same_v); + std::uint64_t p2 = expected_p2; + + void const* kernel_ptr = reinterpret_cast( + &kernel_3<2, 1, 1, expected_p0, expected_p1_x, + expected_p1_y, expected_p1_z, expected_p1_w, + expected_p2>); + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, + kernel_ptr, p0, p1, p2); + ASSERT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t result = cudaDeviceSynchronize(); + if (result == cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("Kernel launch succeeded\n"); +#endif + } + else { + CUTLASS_TRACE_HOST("Kernel launch FAILED\n"); + cudaError_t error = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << "Error at kernel sync: " + << cudaGetErrorString(error) << "\n"; + } + + ASSERT_EQ(global_value.sync_to_host(), 3.4f); +} + +#endif // CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index f8013751..11671dee 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -243,3 +243,4 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 75) endif() endif() + diff --git a/test/unit/conv/device/conv2d_problems.h b/test/unit/conv/device/conv2d_problems.h index 5d1fbdcf..29ad122d 100644 --- a/test/unit/conv/device/conv2d_problems.h +++ b/test/unit/conv/device/conv2d_problems.h @@ -35,8 +35,6 @@ #include -#include "../../common/cutlass_unit_test.h" - #include "cutlass/cutlass.h" #include "cutlass/layout/matrix.h" #include "cutlass/conv/convolution.h" diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 221d8c09..47c3ca18 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -573,7 +573,7 @@ bool TestSpecificConv2d( ///////////////////////////////////////////////////////////////////////////////////////////////////////// // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes // (conv_blacklist_sizes) ///////////////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/test/unit/conv/device/conv2d_testbed_interleaved.h b/test/unit/conv/device/conv2d_testbed_interleaved.h index 201d4fe7..cc00d82b 100644 --- a/test/unit/conv/device/conv2d_testbed_interleaved.h +++ b/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -410,6 +410,7 @@ class InterleavedTestbedConv2d { LayoutC, ElementCompute, ElementAccumulator, + ElementC, cutlass::NumericConverterClamp >( kConvolutionalOperator, @@ -517,7 +518,7 @@ class InterleavedTestbedConv2d { ///////////////////////////////////////////////////////////////////////////////////////////////////////// // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes // (conv_blacklist_sizes) ///////////////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index 1b771607..7bbe6745 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -502,7 +502,7 @@ class TestbedConv2dWithBroadcast { ///////////////////////////////////////////////////////////////////////////////////////////////////////// // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes // (conv_blacklist_sizes) ///////////////////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index 60b12d6f..64447c52 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -522,7 +522,7 @@ class TestbedConv3d { ///////////////////////////////////////////////////////////////////////////////////////////////////////// // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes // (conv_blacklist_sizes) ///////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu index acf073fa..b4d4c1a4 100644 --- a/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu @@ -241,6 +241,106 @@ TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhw //////////////////////////////////////////////////////////////////////////////// +// Analytic 2 stage SingleGroup kernel +TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + SingleGroupPerCTA_128x128_64x2_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +// Analytic 2 stage MutipleGroup kernel +TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + MutipleGroupPerCTA_64x64_64x2_32x32x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kMultipleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, SingleGroupPerCTA_128x128_64x3_64x64x64) { @@ -340,14 +440,14 @@ TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nh //////////////////////////////////////////////////////////////////////////////// -// Optimized 2 stage singleGroup kernel +// Optimized 2 stage SingleGroup kernel TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, SingleGroupPerCTA_64x64_64x2_32x32x64) { /// Conv operation element types for the Gemm equivalent (ImplicitGemm) using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; - using ElementC = float; + using ElementC = cutlass::half_t; using ElementAccumulator = float; using ElementCompute = float; using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; diff --git a/test/unit/cute/CMakeLists.txt b/test/unit/cute/CMakeLists.txt index 43a7bd00..16e5df3b 100644 --- a/test/unit/cute/CMakeLists.txt +++ b/test/unit/cute/CMakeLists.txt @@ -30,6 +30,7 @@ add_subdirectory(core) add_subdirectory(ampere) add_subdirectory(hopper) add_subdirectory(layout) +add_subdirectory(msvc_compilation) add_custom_target( cutlass_test_unit_cute @@ -38,6 +39,7 @@ add_custom_target( cutlass_test_unit_cute_core cutlass_test_unit_cute_ampere cutlass_test_unit_cute_hopper + cutlass_test_unit_cute_msvc_compilation ) add_custom_target( @@ -47,4 +49,5 @@ add_custom_target( test_unit_cute_core test_unit_cute_ampere test_unit_cute_hopper + test_unit_cute_msvc_compilation ) diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index e8e3555a..0a2006dc 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -29,8 +29,10 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_core + array_subbyte.cpp bitfield.cpp coalesce.cpp + compact_xmajor.cpp compare.cpp complement.cpp composition.cpp diff --git a/test/unit/cute/core/array_subbyte.cpp b/test/unit/cute/core/array_subbyte.cpp new file mode 100644 index 00000000..0667e8e3 --- /dev/null +++ b/test/unit/cute/core/array_subbyte.cpp @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include + +#include + +TEST(CuTe_core, ArraySubbyte) +{ + using namespace cute; + + { + array_subbyte a; + + //std::cout << sizeof_bits::value << std::endl; + EXPECT_EQ(sizeof_bits::value, 14*8); + + fill(a, uint8_t(13)); + for (int i = 0; i < int(a.size()); ++i) { + //std::cout << i << ": " << int(a[i]) << " -> "; + EXPECT_EQ(a[i], uint8_t(13)); + a[i] = uint8_t(i); + //std::cout << int(a[i]) << std::endl; + EXPECT_EQ(a[i], uint8_t(i)); + } + + //std::cout << std::endl; + } + + { + array_subbyte a; + + //std::cout << sizeof_bits::value << std::endl; + EXPECT_EQ(sizeof_bits::value, 14/2*8); + + fill(a, int4_t(-5)); + for (int i = 0; i < int(a.size()); ++i) { + //std::cout << i << ": " << int4_t(a[i]) << " -> "; + EXPECT_EQ(int4_t(a[i]), int4_t(-5)); + a[i] = int4_t(i); + //std::cout << int4_t(a[i]) << std::endl; + EXPECT_EQ(int4_t(a[i]), int4_t(i)); + } + + //std::cout << std::endl; + } + + { + array_subbyte a; + + //std::cout << sizeof_bits::value << std::endl; + EXPECT_EQ(sizeof_bits::value, 4*8); + + fill(a, uint2_t(-5)); + for (int i = 0; i < int(a.size()); ++i) { + //std::cout << i << ": " << uint2_t(a[i]) << " -> "; + EXPECT_EQ(uint2_t(a[i]), uint2_t(-5)); + a[i] = uint2_t(i); + //std::cout << uint2_t(a[i]) << std::endl; + EXPECT_EQ(uint2_t(a[i]), uint2_t(i)); + } + + //std::cout << std::endl; + } + + { + array_subbyte a; + + //std::cout << sizeof_bits::value << std::endl; + EXPECT_EQ(sizeof_bits::value, 2*8); + + fill(a, bool(1)); + for (int i = 0; i < int(a.size()); ++i) { + //std::cout << i << ": " << bool(a[i]) << " -> "; + EXPECT_EQ(a[i], bool(1)); + a[i] = bool(i % 2); + //std::cout << bool(a[i]) << std::endl; + EXPECT_EQ(a[i], bool(i % 2)); + } + //std::cout << std::endl; + } +} diff --git a/test/unit/cute/core/compact_xmajor.cpp b/test/unit/cute/core/compact_xmajor.cpp new file mode 100644 index 00000000..21d5898a --- /dev/null +++ b/test/unit/cute/core/compact_xmajor.cpp @@ -0,0 +1,231 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +TEST(CuTe_core, CompactColMajor_Static) +{ + using namespace cute; + + CUTE_STATIC_ASSERT_V((compact_col_major(Int<1>{}) == Int<0>{})); + CUTE_STATIC_ASSERT_V((compact_col_major(Int<1>{}, Int<3>{}) == Int<0>{})); + CUTE_STATIC_ASSERT_V((compact_col_major(Int<8>{}) == Int<1>{})); + CUTE_STATIC_ASSERT_V((compact_col_major(Int<8>{}, Int<3>{}) == Int<3>{})); + + CUTE_STATIC_ASSERT_V((compact_col_major(1) == Int<1>{})); + CUTE_STATIC_ASSERT_V((compact_col_major(8) == Int<1>{})); + + { + auto test = make_tuple(Int<4>{}, Int<8>{}); + auto result = make_tuple(Int<1>{}, Int<4>{}); + CUTE_STATIC_ASSERT_V((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(Int<4>{}, Int<8>{}, Int< 2>{}); + auto result = make_tuple(Int<1>{}, Int<4>{}, Int<32>{}); + CUTE_STATIC_ASSERT_V((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(Int<4>{}, Int<8>{}, Int<1>{}, Int< 2>{}); + auto result = make_tuple(Int<1>{}, Int<4>{}, Int<0>{}, Int<32>{}); + CUTE_STATIC_ASSERT_V((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(make_tuple(Int<4>{}, Int<8>{}), Int<1>{}, Int< 2>{}); + auto result = make_tuple(make_tuple(Int<1>{}, Int<4>{}), Int<0>{}, Int<32>{}); + CUTE_STATIC_ASSERT_V((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(Int<4>{}, make_tuple(Int<8>{}, Int<1>{}, Int< 2>{})); + auto result = make_tuple(Int<1>{}, make_tuple(Int<4>{}, Int<0>{}, Int<32>{})); + CUTE_STATIC_ASSERT_V((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(Int<4>{}, make_tuple(Int<8>{}, Int<1>{}, make_tuple(Int< 2>{}, Int< 3>{}))); + auto result = make_tuple(Int<1>{}, make_tuple(Int<4>{}, Int<0>{}, make_tuple(Int<32>{}, Int<64>{}))); + CUTE_STATIC_ASSERT_V((compact_col_major(test) == result)); + } +} + +TEST(CuTe_core, CompactColMajor_Dynamic) +{ + using namespace cute; + + ASSERT_TRUE((compact_col_major(1) == 1)); + ASSERT_TRUE((compact_col_major(1, 3) == 3)); + ASSERT_TRUE((compact_col_major(8) == 1)); + ASSERT_TRUE((compact_col_major(8, 3) == 3)); + + ASSERT_TRUE((compact_col_major(1) == 1)); + ASSERT_TRUE((compact_col_major(8) == 1)); + + { + auto test = make_tuple(4, 8); + auto result = make_tuple(1, 4); + ASSERT_TRUE((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(4, 8, 2); + auto result = make_tuple(1, 4, 32); + ASSERT_TRUE((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(4, 8, 1, 2); + auto result = make_tuple(1, 4, 32, 32); + ASSERT_TRUE((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(make_tuple(4, 8), 1, 2); + auto result = make_tuple(make_tuple(1, 4), 32, 32); + ASSERT_TRUE((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(4, make_tuple(8, 1, 2)); + auto result = make_tuple(1, make_tuple(4, 32, 32)); + ASSERT_TRUE((compact_col_major(test) == result)); + } + + { + auto test = make_tuple(4, make_tuple(8, 1, make_tuple( 2, 3))); + auto result = make_tuple(1, make_tuple(4, 32, make_tuple(32, 64))); + ASSERT_TRUE((compact_col_major(test) == result)); + } +} + +TEST(CuTe_core, CompactRowMajor_Static) +{ + using namespace cute; + + CUTE_STATIC_ASSERT_V((compact_row_major(Int<1>{}) == Int<0>{})); + CUTE_STATIC_ASSERT_V((compact_row_major(Int<1>{}, Int<3>{}) == Int<0>{})); + CUTE_STATIC_ASSERT_V((compact_row_major(Int<8>{}) == Int<1>{})); + CUTE_STATIC_ASSERT_V((compact_row_major(Int<8>{}, Int<3>{}) == Int<3>{})); + + CUTE_STATIC_ASSERT_V((compact_row_major(1) == Int<1>{})); + CUTE_STATIC_ASSERT_V((compact_row_major(8) == Int<1>{})); + + { + auto test = make_tuple(Int<4>{}, Int<8>{}); + auto result = make_tuple(Int<8>{}, Int<1>{}); + CUTE_STATIC_ASSERT_V((compact_row_major(test) == result)); + } + + { + auto test = make_tuple(Int< 4>{}, Int<8>{}, Int<2>{}); + auto result = make_tuple(Int<16>{}, Int<2>{}, Int<1>{}); + CUTE_STATIC_ASSERT_V((compact_row_major(test) == result)); + } + + { + auto test = make_tuple(Int< 4>{}, Int<8>{}, Int<1>{}, Int<2>{}); + auto result = make_tuple(Int<16>{}, Int<2>{}, Int<0>{}, Int<1>{}); + CUTE_STATIC_ASSERT_V((compact_row_major(test) == result)); + } + + { + auto test = make_tuple(make_tuple(Int< 4>{}, Int<8>{}), Int<1>{}, Int<2>{}); + auto result = make_tuple(make_tuple(Int<16>{}, Int<2>{}), Int<0>{}, Int<1>{}); + CUTE_STATIC_ASSERT_V((compact_row_major(test) == result)); + } + + { + auto test = make_tuple(Int< 4>{}, make_tuple(Int<8>{}, Int<1>{}, Int<2>{})); + auto result = make_tuple(Int<16>{}, make_tuple(Int<2>{}, Int<0>{}, Int<1>{})); + CUTE_STATIC_ASSERT_V((compact_row_major(test) == result)); + } + + { + auto test = make_tuple(Int< 4>{}, make_tuple(Int<8>{}, Int<1>{}, make_tuple(Int<2>{}, Int<3>{}))); + auto result = make_tuple(Int<48>{}, make_tuple(Int<6>{}, Int<0>{}, make_tuple(Int<3>{}, Int<1>{}))); + CUTE_STATIC_ASSERT_V((compact_row_major(test) == result)); + } +} + +TEST(CuTe_core, CompactRowMajor_Dynamic) +{ + using namespace cute; + + ASSERT_TRUE((compact_row_major(1) == 1)); + ASSERT_TRUE((compact_row_major(1, 3) == 3)); + ASSERT_TRUE((compact_row_major(8) == 1)); + ASSERT_TRUE((compact_row_major(8, 3) == 3)); + + ASSERT_TRUE((compact_row_major(1) == 1)); + ASSERT_TRUE((compact_row_major(8) == 1)); + + { + auto test = make_tuple(4, 8); + auto result = make_tuple(8, 1); + ASSERT_TRUE((compact_row_major(test) == result)); + } + + { + auto test = make_tuple( 4, 8, 2); + auto result = make_tuple(16, 2, 1); + ASSERT_TRUE((compact_row_major(test) == result)); + } + + { + auto test = make_tuple( 4, 8, 1, 2); + auto result = make_tuple(16, 2, 2, 1); + ASSERT_TRUE((compact_row_major(test) == result)); + } + + { + auto test = make_tuple(make_tuple( 4, 8), 1, 2); + auto result = make_tuple(make_tuple(16, 2), 2, 1); + ASSERT_TRUE((compact_row_major(test) == result)); + } + + { + auto test = make_tuple( 4, make_tuple(8, 1, 2)); + auto result = make_tuple(16, make_tuple(2, 2, 1)); + ASSERT_TRUE((compact_row_major(test) == result)); + } + + { + auto test = make_tuple( 4, make_tuple(8, 1, make_tuple(2, 3))); + auto result = make_tuple(48, make_tuple(6, 6, make_tuple(3, 1))); + ASSERT_TRUE((compact_row_major(test) == result)); + } +} diff --git a/test/unit/cute/hopper/CMakeLists.txt b/test/unit/cute/hopper/CMakeLists.txt index ce301101..bffc16db 100644 --- a/test/unit/cute/hopper/CMakeLists.txt +++ b/test/unit/cute/hopper/CMakeLists.txt @@ -32,6 +32,8 @@ add_custom_target( cutlass_test_unit_cute_hopper_stsm cutlass_test_unit_cute_hopper_tma_load cutlass_test_unit_cute_hopper_tma_store + cutlass_test_unit_cute_hopper_bulk_load + cutlass_test_unit_cute_hopper_bulk_store ) add_custom_target( @@ -40,6 +42,8 @@ add_custom_target( test_unit_cute_hopper_stsm test_unit_cute_hopper_tma_load test_unit_cute_hopper_tma_store + test_unit_cute_hopper_bulk_load + test_unit_cute_hopper_bulk_store ) cutlass_test_unit_add_executable( @@ -56,3 +60,14 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_hopper_tma_store tma_store.cu ) + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_bulk_load + bulk_load.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_bulk_store + bulk_store.cu +) + diff --git a/test/unit/cute/hopper/bulk_load.cu b/test/unit/cute/hopper/bulk_load.cu new file mode 100644 index 00000000..7f93c29f --- /dev/null +++ b/test/unit/cute/hopper/bulk_load.cu @@ -0,0 +1,196 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic tests for BULK_COPY usage with various layouts. +*/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage { + cute::array_aligned> smem; + cute::uint64_t bulk_copy_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +template +__global__ void +bulk_copy_test_device_cute(T const* g_in, + T * g_out, + GmemLayout gmem_layout, + SmemLayout smem_layout) +{ + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); + // Construct the GMEM tensor + Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* bulk_copy_mbar = shared_storage.bulk_copy_mbar; + + // + // Perform the BULK_COPY load + // + + auto atom = Copy_Atom{}; + +#if 0 + if (thread0()) { + print("sA: "); print(sA.data()); print(" o "); print(sA.layout()); print("\n"); + print("gA: "); print(gA.data()); print(" o "); print(gA.layout()); print("\n"); + } +#endif + + // Set the bytes transferred in this transaction (may involve multiple issues) + constexpr int transaction_bytes = size(sA) * sizeof(T); + + if (threadIdx.x == 0) { + /// Initialize shared memory barrier + bulk_copy_mbar[0] = 0; + initialize_barrier(bulk_copy_mbar[0], 1 /*numThreads*/); + set_barrier_transaction_bytes(bulk_copy_mbar[0], transaction_bytes); + + copy(atom.with(bulk_copy_mbar[0]), gA, sA); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value + constexpr int kPhaseBit = 0; + wait_barrier(bulk_copy_mbar[0], kPhaseBit); + +#if 0 + if (thread0()) { + print(sA); + } +#endif + + // + // Write out trivially + // + + Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + // Output smem -> gmem + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + gA_out(i) = sA(i); + } +} + +template +void run_and_validate(GLayout gmem_layout, + SLayout smem_layout) +{ + thrust::host_vector h_in(cosize(gmem_layout)); + for (int32_t i = 0; i < h_in.size(); ++i) { + h_in[i] = T(i); + } + + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(d_in.size(), T(-1)); + + int32_t smem_size = static_cast(sizeof(SharedStorage)); + bulk_copy_test_device_cute<<<1, 128, smem_size>>>(thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + gmem_layout, + smem_layout); + // Transfering results back to host + thrust::host_vector h_out = d_out; + + // Validate the results + for (int i = 0; i < cute::size(gmem_layout); ++i) { + int k = gmem_layout(i); + EXPECT_EQ(int(h_in[k]), int(h_out[k])); + } +} + +// } // namespace + +TEST(SM90_CuTe_BLKCP, ColMajor) +{ + + auto smem_layout = make_layout(Shape<_32,_32>{}, GenColMajor{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); +} + +TEST(SM90_CuTe_BLKCP, RowMajor) +{ + + auto smem_layout = make_layout(Shape<_32,_32>{}, GenRowMajor{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); +} + +TEST(SM90_CuTe_BLKCP, NonCompact) +{ + + { + auto smem_layout = make_layout(Shape<_32,_32>{}, Stride<_1,Int<48>>{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); + } + { + auto smem_layout = make_layout(Shape<_32,_32>{}, Stride<_1,Int<48>>{}); + auto gmem_layout = make_layout(Shape, Shape<_4,_8>>{}, Stride,Stride<_16,_128>>{}); + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); + } + { + auto smem_layout = make_layout(Shape<_32,_32>{}, Stride<_64,_1>{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); + } +} +#endif // #if CUDA_12_0_SM90_FEATURES_SUPPORTED diff --git a/test/unit/cute/hopper/bulk_store.cu b/test/unit/cute/hopper/bulk_store.cu new file mode 100644 index 00000000..13324b63 --- /dev/null +++ b/test/unit/cute/hopper/bulk_store.cu @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic tests for BULK_COPY usage with various layouts. +*/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage { + cute::array_aligned> smem; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED +template +__global__ void +bulk_copy_test_device_cute(T const* g_in, + T * g_out, + GmemLayout gmem_layout, + SmemLayout smem_layout) +{ + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); + // Construct the GMEM tensor + Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + + // + // Read in trivially + // + + // Input gmem -> smem + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + sA(i) = gA(i); + } + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // + // Perform the BULK_COPY store + // + +#if 0 + if (thread0()) { + print("sA: "); print(sA.data()); print(" o "); print(sA.layout()); print("\n"); + print("gA: "); print(gA.data()); print(" o "); print(gA.layout()); print("\n"); + } +#endif + + Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + auto atom = Copy_Atom, uint8_t>{}; + + copy(atom, sA, gA_out); + // Bulk Copy store requires the same sync as TMA store. + tma_store_arrive(); + tma_store_wait<0>(); +} + +template +void run_and_validate(GLayout gmem_layout, + SLayout smem_layout) +{ + thrust::host_vector h_in(cosize(gmem_layout)); + for (int32_t i = 0; i < h_in.size(); ++i) { + h_in[i] = T(i); + } + + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(d_in.size(), T(-1)); + + int32_t smem_size = static_cast(sizeof(SharedStorage)); + bulk_copy_test_device_cute<<<1, 128, smem_size>>>(thrust::raw_pointer_cast(d_in.data()), + thrust::raw_pointer_cast(d_out.data()), + gmem_layout, + smem_layout); + // Transfering results back to host + thrust::host_vector h_out = d_out; + + // Validate the results + for (int i = 0; i < cute::size(gmem_layout); ++i) { + int k = gmem_layout(i); + EXPECT_EQ(int(h_in[k]), int(h_out[k])); + } +} + +// } // namespace + +TEST(SM90_CuTe_BLKCP, ColMajor) +{ + + auto smem_layout = make_layout(Shape<_32,_32>{}, GenColMajor{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); +} + +TEST(SM90_CuTe_BLKCP, RowMajor) +{ + + auto smem_layout = make_layout(Shape<_32,_32>{}, GenRowMajor{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); +} + +TEST(SM90_CuTe_BLKCP, NonCompact) +{ + + { + auto smem_layout = make_layout(Shape<_32,_32>{}, Stride<_1,Int<48>>{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); + } + { + auto smem_layout = make_layout(Shape<_32,_32>{}, Stride<_1,Int<48>>{}); + auto gmem_layout = make_layout(Shape, Shape<_4,_8>>{}, Stride,Stride<_16,_128>>{}); + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); + } + { + auto smem_layout = make_layout(Shape<_32,_32>{}, Stride<_64,_1>{}); + auto gmem_layout = smem_layout; + run_and_validate< int8_t>(gmem_layout, smem_layout); + run_and_validate< half_t>(gmem_layout, smem_layout); + run_and_validate(gmem_layout, smem_layout); + } +} +#endif // #if CUDA_12_0_SM90_FEATURES_SUPPORTED diff --git a/test/unit/cute/hopper/stsm.cu b/test/unit/cute/hopper/stsm.cu index ffc8aa74..c5d45def 100644 --- a/test/unit/cute/hopper/stsm.cu +++ b/test/unit/cute/hopper/stsm.cu @@ -264,7 +264,7 @@ TEST(SM90_CuTe_Hopper, Stsm) //printf("%d %d\n", int(h_in[i]), int(h_out[i])); EXPECT_EQ(h_out[i], h_in[i]); } - CUTLASS_TRACE_HOST("CuTe 32x8 interleaved STS.U16 SUCCESS\n"); + CUTLASS_TRACE_HOST("CuTe 32x8 interleaved STSM.U16 SUCCESS\n"); } { @@ -352,7 +352,7 @@ TEST(SM90_CuTe_Hopper, Stsm) //printf("%d %d\n", int(h_in[i]), int(h_out[i])); EXPECT_EQ(h_out[i], h_in[i]); } - CUTLASS_TRACE_HOST("CuTe 32x32 STS.U16 SUCCESS\n"); + CUTLASS_TRACE_HOST("CuTe 32x32 STSM.U16 SUCCESS\n"); } { diff --git a/test/unit/cute/hopper/tma_load.cu b/test/unit/cute/hopper/tma_load.cu index 24f17fca..ddb95c3c 100644 --- a/test/unit/cute/hopper/tma_load.cu +++ b/test/unit/cute/hopper/tma_load.cu @@ -47,78 +47,51 @@ struct SharedStorage cute::uint64_t tma_load_mbar[1]; }; -// __grid_constant__ was introduced in CUDA 11.7. -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -# define CUTE_GRID_CONSTANT_SUPPORTED -#endif - -// __grid_constant__ can be enabled only on SM70+ -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -# define CUTE_GRID_CONSTANT_ENABLED -#endif - -#if ! defined(CUTE_GRID_CONSTANT) -# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) -# define CUTE_GRID_CONSTANT __grid_constant__ -# else -# define CUTE_GRID_CONSTANT -# endif -#endif - #if CUDA_12_0_SM90_FEATURES_SUPPORTED -template +template __global__ void tma_test_device_cute(T const* g_in, T* g_out, - CUTE_GRID_CONSTANT TiledCopy const tma, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, GmemLayout gmem_layout, SmemLayout smem_layout) { - assert(product_each(shape(gmem_layout)) == product_each(smem_layout.shape())); + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); // Use Shared Storage structure to allocate and distribute aligned SMEM addresses extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) // Shared memory barriers use 64bits in SMEM for synchronization uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; - // Construct SMEM tensor - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); - -#if 0 - - // - // Read in trivially - // - - Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); - - // Input gmem -> smem - for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { - sA(i) = gA_in(i); - } - __syncthreads(); - -#else // TMA requires special handling of strides to deal with coord codomain mapping // Represent the full tensors -- get these from TMA - Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + constexpr int R = rank_v; + Tensor gA = local_tile(mA, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = local_tile(mB, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) // // Prepare the TMA_LOAD // - auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - Tensor tAgA = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N) - Tensor tAsA = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) + Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) #if 0 if (thread0()) { - print(" gA: "); print(gA.data()); print(" o "); print(gA.layout()); print("\n"); - print("tAgA: "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); - print(" sA: "); print(sA.data()); print(" o "); print(sA.layout()); print("\n"); - print("tAsA: "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n"); + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA.data()); print(" o "); print( mA.layout()); print("\n"); + print(" gA : "); print( gA.data()); print(" o "); print( gA.layout()); print("\n"); + print("tAgA_x: "); print(tAgA_x.data()); print(" o "); print(tAgA_x.layout()); print("\n"); + print(" sA : "); print( sA.data()); print(" o "); print( sA.layout()); print("\n"); + print("tAsA_x: "); print(tAsA_x.data()); print(" o "); print(tAsA_x.layout()); print("\n"); } #endif @@ -126,14 +99,24 @@ tma_test_device_cute(T const* g_in, T* g_out, // Perform the TMA_LOAD // - // Group the TMA_M and TMA_N modes - Tensor tAgA_2 = group_modes<1,rank(tAgA)>(tAgA); // (TMA,Rest) - Tensor tAsA_TR = group_modes<1,rank(tAsA)>(tAsA); // (TMA,Rest) - static_assert(size<1>(tAsA_TR) == 1); - Tensor tAsA_2 = tAsA_TR(_,0); + // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST) + Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST) + static_assert(size<1>(tAsA) == 1); + + // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output + Tensor tBgB = group_modes<0,R>(group_modes(gB)); // (CTA_TILE, REST) + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print("tAsA : "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n"); + print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); + } +#endif // Loop over the TMA stages, using smem as our buffer - for (int stage = 0; stage < size<1>(tAgA_2); ++stage) + for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) constexpr int kTmaTransactionBytes = size(sA) * sizeof(T); @@ -145,7 +128,7 @@ tma_test_device_cute(T const* g_in, T* g_out, cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); - copy(tma.with(tma_load_mbar[0]), tAgA_2(_,stage), tAsA_2); + copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); } __syncthreads(); @@ -153,343 +136,282 @@ tma_test_device_cute(T const* g_in, T* g_out, constexpr int kPhaseBit = 0; cute::wait_barrier(tma_load_mbar[0], kPhaseBit); - #endif - // - // Write out trivially + // Write out trivially smem -> gmem // - Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); - // Do the same slicing and grouping as sA - Tensor tAgA_out = cta_tma.partition_D(gA_out); // (TMA,TMA_M,TMA_N) - Tensor tAgA_2_out = group_modes<1,rank(tAgA_out)>(tAgA_out); // (TMA,Rest) - - // Output smem -> gmem - for (int i = threadIdx.x; i < size(tAsA_2); i += blockDim.x) { - tAgA_2_out(i,stage) = tAsA_2(i); + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + tBgB(i,stage) = sA(i); } __syncthreads(); } } -TEST(SM90_CuTe_Hopper, Tma_load_32x32_Col) +template +void +test_tma_load(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) { - using T = half_t; - Layout smem_layout = Layout, Stride<_1,_32>>{}; - Layout gmem_layout = smem_layout; - - thrust::host_vector h_in(size(gmem_layout)); + thrust::host_vector h_in(cosize(gmem_layout)); for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } thrust::device_vector d_in = h_in; thrust::device_vector d_out(h_in.size(), T(-1)); - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + //print("TMA Instr size: "); print(decltype(tma)::NumValSrc); print("\n"); int smem_size = int(sizeof(SharedStorage)); tma_test_device_cute<<<1, 128, smem_size>>>( thrust::raw_pointer_cast(d_in.data()), thrust::raw_pointer_cast(d_out.data()), - tma, + tma, cta_tile, gmem_layout, smem_layout); thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + Tensor hA_in = make_tensor(h_in.data(), gmem_layout); + Tensor hA_out = make_tensor(h_out.data(), gmem_layout); + for (int i = 0; i < size(gmem_layout); ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 ColMajor SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_load_32x32_Row) +template +void +test_tma_load(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) { - using T = half_t; - Layout smem_layout = Layout, Stride<_32,_1>>{}; - Layout gmem_layout = smem_layout; - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); - - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); - } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD 32x32 RowMajor SUCCESS\n"); + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } -TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN) +TEST(SM90_CuTe_Hopper, Tma_Load_32x32_Col) { - using T = half_t; - auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + Layout smem_layout = Layout, Stride<_1,_32>>{}; + { + Layout gmem_layout = smem_layout; + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + } - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + { + Layout gmem_layout = make_layout(make_shape(32,32), GenColMajor{}); + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + { + Layout gmem_layout = make_layout(make_shape(32,32), make_stride(Int<1>{}, 1024)); + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_K) +TEST(SM90_CuTe_Hopper, Tma_Load_32x32_Row) { - using T = half_t; - auto smem_layout = GMMA::Layout_K_SW128_Atom{}; - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + Layout smem_layout = Layout, Stride<_32,_1>>{}; + { + Layout gmem_layout = smem_layout; + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + } - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + { + Layout gmem_layout = make_layout(make_shape(32,32), GenRowMajor{}); + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + { + Layout gmem_layout = make_layout(make_shape(32,32), make_stride(1024, Int<1>{})); + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_K_SW128_Atom SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi) +template typename SWIZZLE_ATOM> +void +test_tma_load_swizzle_atom_mn() { - using T = half_t; - auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); - - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); - } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); + auto smem_layout = SWIZZLE_ATOM{}; + Layout gmem_layout = make_layout(shape(smem_layout), GenColMajor{}); + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } -TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi2) +template typename SWIZZLE_ATOM> +void +test_tma_load_swizzle_atom_k() { - using T = half_t; - // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size - auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); - - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); - } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); + auto smem_layout = SWIZZLE_ATOM{}; + Layout gmem_layout = make_layout(shape(smem_layout), GenRowMajor{}); + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } -TEST(SM90_CuTe_Hopper, Tma_load_GMMA_SW128_MN_Multi_Dyn) +TEST(SM90_CuTe_Hopper, Tma_Load_Swizzle_Atoms) { - using T = half_t; - auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); - Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); - - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); - } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn< float, GMMA::Layout_MN_SW128_Atom>(); + test_tma_load_swizzle_atom_mn(); + + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn< float, GMMA::Layout_MN_SW64_Atom>(); + test_tma_load_swizzle_atom_mn(); + + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn< float, GMMA::Layout_MN_SW32_Atom>(); + test_tma_load_swizzle_atom_mn(); + + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn(); + test_tma_load_swizzle_atom_mn< float, GMMA::Layout_MN_INTER_Atom>(); + test_tma_load_swizzle_atom_mn(); + + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k< float, GMMA::Layout_K_SW128_Atom>(); + test_tma_load_swizzle_atom_k(); + + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k< float, GMMA::Layout_K_SW64_Atom>(); + test_tma_load_swizzle_atom_k(); + + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k< float, GMMA::Layout_K_SW32_Atom>(); + test_tma_load_swizzle_atom_k(); + + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k(); + test_tma_load_swizzle_atom_k< float, GMMA::Layout_K_INTER_Atom>(); + test_tma_load_swizzle_atom_k(); } -TEST(SM90_CuTe_Hopper, Tma_load_32x32_Multimode) +template typename SWIZZLE_ATOM> +void +test_tma_load_swizzle_tile_mn() { - using T = half_t; - auto smem_layout = Layout, Stride<_32,_1>>{}; - Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); - - //auto smem_layout = Layout>{}; - //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + auto smem_layout = tile_to_shape(SWIZZLE_ATOM{}, Shape<_128,_128>{}); + Layout gmem_layout = make_layout(make_shape(int(size<0>(smem_layout)), int(size<1>(smem_layout))), GenColMajor{}); + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); - } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +template typename SWIZZLE_ATOM> +void +test_tma_load_swizzle_tile_k() +{ + auto smem_layout = tile_to_shape(SWIZZLE_ATOM{}, Shape<_128,_128>{}); + Layout gmem_layout = make_layout(make_shape(int(size<0>(smem_layout)), int(size<1>(smem_layout))), GenRowMajor{}); + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } -TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking) +TEST(SM90_CuTe_Hopper, Tma_Load_Swizzle_Tiles) { - using T = half_t; - auto gmem_layout = make_shape(make_shape(336,40),make_shape(32,656)); // GMEM - auto cta_tile = make_shape(make_shape(_16{},_8{}),make_shape(_32{},_2{})); // GMEM Tiling: - // Take 16-elem from m0, 8-elem from m1, - // Take 32-elem from k0, 2-elem from k1 - auto smem_layout = make_layout(cta_tile); // Col-Major SMEM - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); + // Other T-types use too much smem + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_mn(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); + test_tma_load_swizzle_tile_k(); +} - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); +TEST(SM90_CuTe_Hopper, Tma_Load_Metamode) +{ + { + auto smem_layout = Layout, Stride<_1,_32>>{}; + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); + test_tma_load(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenColMajor{}); + test_tma_load(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenColMajor{}); + test_tma_load(gmem_layout, smem_layout); + } + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + { + auto smem_layout = Layout, Stride<_32,_1>>{}; + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); + test_tma_load(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenRowMajor{}); + test_tma_load(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenRowMajor{}); + test_tma_load(gmem_layout, smem_layout); + } } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_load_Tensor_blocking_2) +TEST(SM90_CuTe_Hopper, Tma_Load_Tensor) { - using T = half_t; - auto gmem_layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM - auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: - // Take 128-elem from m: m0 must divide 128, - // m-last may be predicated - // Take 32-elem from k0, 2-elem from k1 - auto smem_layout = make_layout(cta_tile); // Col-Major SMEM - - thrust::host_vector h_in(size(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + // Tensor by-mode + { + Layout gmem_layout = make_layout(make_shape(make_shape(80,40),make_shape(32,12))); + auto cta_tile = Shape,Shape<_32,_2>>{}; // GMEM Tiling: + // Take 16-elem from m0, 8-elem from m1, + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_128,_64>{}); + test_tma_load(gmem_layout, smem_layout, cta_tile); + } - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + // Tensor Metamode -- Tiler selects flat elements from a multimode + { + Layout gmem_layout = make_layout(make_shape(make_shape(32,40),make_shape(make_shape(8,8),12))); + auto cta_tile = Shape<_128, Shape<_32,_2>>{}; // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_128,_64>{}); + test_tma_load(gmem_layout, smem_layout, cta_tile); + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + // Tensor Multimode -- TMA with more than 5 modes in GMEM (packs residual modes into last TMA mode) + { + Layout gmem_layout = make_layout(make_shape(make_shape(32,3,2,2),make_shape(32,4,2))); + auto cta_tile = Shape, Shape<_32,_2>>{}; // GMEM Tiling: + // Take 32-elem from m0 + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_32,_64>{}); + test_tma_load(gmem_layout, smem_layout, cta_tile); } - CUTLASS_TRACE_HOST("CuTe TMA_LOAD Tensor blocking 2 SUCCESS\n"); + } + #endif diff --git a/test/unit/cute/hopper/tma_store.cu b/test/unit/cute/hopper/tma_store.cu index 448b7f91..4d96070a 100644 --- a/test/unit/cute/hopper/tma_store.cu +++ b/test/unit/cute/hopper/tma_store.cu @@ -46,339 +46,363 @@ struct SharedStorage cute::array_aligned> smem; }; -// __grid_constant__ was introduced in CUDA 11.7. -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) -# define CUTE_GRID_CONSTANT_SUPPORTED -#endif - -// __grid_constant__ can be enabled only on SM70+ -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) -# define CUTE_GRID_CONSTANT_ENABLED -#endif - -#if ! defined(CUTE_GRID_CONSTANT) -# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) -# define CUTE_GRID_CONSTANT __grid_constant__ -# else -# define CUTE_GRID_CONSTANT -# endif -#endif - #if CUDA_12_0_SM90_FEATURES_SUPPORTED -template +template __global__ void tma_test_device_cute(T const* g_in, T* g_out, - CUTE_GRID_CONSTANT TiledCopy const tma, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, GmemLayout gmem_layout, SmemLayout smem_layout) { + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); // Construct SMEM tensor - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); - - // - // Read in trivially - // - - Tensor gA_in = make_tensor(make_gmem_ptr(g_in), gmem_layout); - - // Input gmem -> smem - for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { - sA(i) = gA_in(i); - } - - __syncthreads(); - -#if 0 - - // - // Write out trivially - // - - Tensor gA_out = make_tensor(make_gmem_ptr(g_out), gmem_layout); - - // Output smem -> gmem - for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { - gA_out(i) = sA(i); - } - -#else + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) // TMA requires special handling of strides to deal with coord codomain mapping // Represent the full tensors -- get these from TMA - Tensor gA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); + + constexpr int R = rank_v; + Tensor gA = local_tile(mA, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = local_tile(mB, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) // // Prepare the TMA_STORE // - auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - - Tensor tAsA = cta_tma.partition_S(sA); - Tensor tAgA = cta_tma.partition_D(gA); + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - // - // Perform the TMA_STORE - // + Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) + Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) - if (threadIdx.x == 0) { - copy(tma, tAsA, tAgA); +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mB : "); print( mB.data()); print(" o "); print( mB.layout()); print("\n"); + print(" gB : "); print( gB.data()); print(" o "); print( gB.layout()); print("\n"); + print("tBgB_x: "); print(tBgB_x.data()); print(" o "); print(tBgB_x.layout()); print("\n"); + print(" sB : "); print( sB.data()); print(" o "); print( sB.layout()); print("\n"); + print("tBsB_x: "); print(tBsB_x.data()); print(" o "); print(tBsB_x.layout()); print("\n"); } - #endif -} - -TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Col) -{ - using T = half_t; - Layout smem_layout = Layout, Stride<_1,_32>>{}; - Layout gmem_layout = smem_layout; - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); + // + // Perform the TMA_STORE + // - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + // INPUT: Group the CTA_TILE_X modes and REST_X modes for input + Tensor tAgA = group_modes<0,R>(group_modes(gA)); // (CTA_TILE, REST) - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + // OUTPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tBgB = group_modes<1,rank(tBgB_x)>(tBgB_x); // (TMA,REST) + Tensor tBsB = group_modes<1,rank(tBsB_x)>(tBsB_x); // (TMA,REST) + static_assert(size<1>(tBsB) == 1); - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print("tBsB : "); print(tBsB.data()); print(" o "); print(tBsB.layout()); print("\n"); + print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); } - CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 ColMajor SUCCESS\n"); -} +#endif -TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Row) -{ - using T = half_t; - Layout smem_layout = Layout, Stride<_32,_1>>{}; - Layout gmem_layout = smem_layout; + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tBgB); ++stage) + { + // + // Read in trivially gmem -> smem + // - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); + for (int i = threadIdx.x; i < size(sB); i += blockDim.x) { + sB(i) = tAgA(i,stage); + } - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + __syncthreads(); - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + // + // Perform the TMA_STORE + // - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + if (threadIdx.x == 0) { + copy(tma, tBsB(_,0), tBgB(_,stage)); + } + + tma_store_wait<0>(); + __syncthreads(); } - CUTLASS_TRACE_HOST("CuTe TMA_STORE 32x32 RowMajor SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN) +template +void +test_tma_store(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) { - using T = half_t; - auto smem_layout = GMMA::Layout_MN_SW128_Atom{}; - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); - - thrust::host_vector h_in(size(smem_layout)); + thrust::host_vector h_in(cosize(gmem_layout)); for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } thrust::device_vector d_in = h_in; thrust::device_vector d_out(h_in.size(), T(-1)); Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout, cta_tile, Int<1>{}); + //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + //print("TMA Instr size: "); print(decltype(tma)::NumValSrc); print("\n"); int smem_size = int(sizeof(SharedStorage)); tma_test_device_cute<<<1, 128, smem_size>>>( thrust::raw_pointer_cast(d_in.data()), thrust::raw_pointer_cast(d_out.data()), - tma, + tma, cta_tile, gmem_layout, smem_layout); thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + Tensor hA_in = make_tensor(h_in.data(), gmem_layout); + Tensor hA_out = make_tensor(h_out.data(), gmem_layout); + for (int i = 0; i < size(gmem_layout); ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); } - CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_K) +template +void +test_tma_store(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) { - using T = half_t; - auto smem_layout = GMMA::Layout_K_SW128_Atom{}; - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenRowMajor{}); - - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); + return test_tma_store(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); +TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Col) +{ + Layout smem_layout = Layout, Stride<_1,_32>>{}; + { + Layout gmem_layout = smem_layout; + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + } - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + { + Layout gmem_layout = make_layout(make_shape(32,32), GenColMajor{}); + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + { + Layout gmem_layout = make_layout(make_shape(32,32), make_stride(Int<1>{}, 1024)); + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); } - CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_K_SW128_Atom SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi) +TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Row) { - using T = half_t; - auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}); - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); - - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + Layout smem_layout = Layout, Stride<_32,_1>>{}; + { + Layout gmem_layout = smem_layout; + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + } - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + { + Layout gmem_layout = make_layout(make_shape(32,32), GenRowMajor{}); + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + { + Layout gmem_layout = make_layout(make_shape(32,32), make_stride(1024, Int<1>{})); + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); } - CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi2) +template typename SWIZZLE_ATOM> +void +test_tma_store_swizzle_atom_mn() { - using T = half_t; - // Tile the GMMA::Layout atom in the K-mode first, then the M-mode to get a bigger box size - auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); - Layout gmem_layout = make_layout(make_shape(size<0>(smem_layout), size<1>(smem_layout)), GenColMajor{}); - - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); + auto smem_layout = SWIZZLE_ATOM{}; + Layout gmem_layout = make_layout(shape(smem_layout), GenColMajor{}); + return test_tma_store(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); +template typename SWIZZLE_ATOM> +void +test_tma_store_swizzle_atom_k() +{ + auto smem_layout = SWIZZLE_ATOM{}; + Layout gmem_layout = make_layout(shape(smem_layout), GenRowMajor{}); + return test_tma_store(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); +TEST(SM90_CuTe_Hopper, Tma_Store_Swizzle_Atoms) +{ + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn< float, GMMA::Layout_MN_SW128_Atom>(); + test_tma_store_swizzle_atom_mn(); + + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn< float, GMMA::Layout_MN_SW64_Atom>(); + test_tma_store_swizzle_atom_mn(); + + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn< float, GMMA::Layout_MN_SW32_Atom>(); + test_tma_store_swizzle_atom_mn(); + + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn(); + test_tma_store_swizzle_atom_mn< float, GMMA::Layout_MN_INTER_Atom>(); + test_tma_store_swizzle_atom_mn(); + + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k< float, GMMA::Layout_K_SW128_Atom>(); + test_tma_store_swizzle_atom_k(); + + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k< float, GMMA::Layout_K_SW64_Atom>(); + test_tma_store_swizzle_atom_k(); + + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k< float, GMMA::Layout_K_SW32_Atom>(); + test_tma_store_swizzle_atom_k(); + + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k(); + test_tma_store_swizzle_atom_k< float, GMMA::Layout_K_INTER_Atom>(); + test_tma_store_swizzle_atom_k(); +} - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); - } - CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); +template typename SWIZZLE_ATOM> +void +test_tma_store_swizzle_tile_mn() +{ + auto smem_layout = tile_to_shape(SWIZZLE_ATOM{}, Shape<_128,_128>{}); + Layout gmem_layout = make_layout(make_shape(int(size<0>(smem_layout)), int(size<1>(smem_layout))), GenColMajor{}); + return test_tma_store(gmem_layout, smem_layout, product_each(shape(smem_layout))); } -TEST(SM90_CuTe_Hopper, Tma_Store_GMMA_SW128_MN_Multi_Dyn) +template typename SWIZZLE_ATOM> +void +test_tma_store_swizzle_tile_k() { - using T = half_t; - auto smem_layout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, Shape,Int<128>>{}, Step<_2,_1>{}); - Layout gmem_layout = make_layout(make_shape(128, 128), GenColMajor{}); + auto smem_layout = tile_to_shape(SWIZZLE_ATOM{}, Shape<_128,_128>{}); + Layout gmem_layout = make_layout(make_shape(int(size<0>(smem_layout)), int(size<1>(smem_layout))), GenRowMajor{}); + return test_tma_store(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); +TEST(SM90_CuTe_Hopper, Tma_Store_Swizzle_Tiles) +{ + // Other T-types use too much smem + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_mn(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); + test_tma_store_swizzle_tile_k(); +} - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); +TEST(SM90_CuTe_Hopper, Tma_Store_Metamode) +{ + { + auto smem_layout = Layout, Stride<_1,_32>>{}; + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); + test_tma_store(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenColMajor{}); + test_tma_store(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenColMajor{}); + test_tma_store(gmem_layout, smem_layout); + } + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + { + auto smem_layout = Layout, Stride<_32,_1>>{}; + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); + test_tma_store(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenRowMajor{}); + test_tma_store(gmem_layout, smem_layout); + } + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenRowMajor{}); + test_tma_store(gmem_layout, smem_layout); + } } - CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); } -TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Multimode) +TEST(SM90_CuTe_Hopper, Tma_Store_Tensor) { - using T = half_t; - auto smem_layout = Layout, Stride<_32,_1>>{}; - Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); - - //auto smem_layout = Layout>{}; - //Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); - - thrust::host_vector h_in(size(smem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); - auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + // Tensor by-mode + { + Layout gmem_layout = make_layout(make_shape(make_shape(80,40),make_shape(32,12))); + auto cta_tile = Shape,Shape<_32,_2>>{}; // GMEM Tiling: + // Take 16-elem from m0, 8-elem from m1, + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_128,_64>{}); + test_tma_store(gmem_layout, smem_layout, cta_tile); + } - int smem_size = int(sizeof(SharedStorage)); - tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), - tma, - gmem_layout, - smem_layout); + // Tensor Metamode -- Tiler selects flat elements from a multimode + { + Layout gmem_layout = make_layout(make_shape(make_shape(32,40),make_shape(make_shape(8,8),12))); + auto cta_tile = Shape<_128, Shape<_32,_2>>{}; // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_128,_64>{}); + test_tma_store(gmem_layout, smem_layout, cta_tile); + } - thrust::host_vector h_out = d_out; - for (int i = 0; i < size(smem_layout); ++i) { - //printf("%d %d\n", int(h_in[i]), int(h_out[i])); - EXPECT_EQ(h_out[i], h_in[i]); + // Tensor Multimode -- TMA with more than 5 modes in GMEM (packs residual modes into last TMA mode) + { + Layout gmem_layout = make_layout(make_shape(make_shape(32,3,2,2),make_shape(32,4,2))); + auto cta_tile = Shape, Shape<_32,_2>>{}; // GMEM Tiling: + // Take 32-elem from m0 + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_32,_64>{}); + test_tma_store(gmem_layout, smem_layout, cta_tile); } - CUTLASS_TRACE_HOST("CuTe TMA_STORE GMMA::Layout_MN_SW128_Atom Multi SUCCESS\n"); + } + #endif diff --git a/test/unit/cute/msvc_compilation/CMakeLists.txt b/test/unit/cute/msvc_compilation/CMakeLists.txt new file mode 100644 index 00000000..308e296f --- /dev/null +++ b/test/unit/cute/msvc_compilation/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_msvc_compilation + + tuple.cpp +) diff --git a/test/unit/cute/msvc_compilation/tuple.cpp b/test/unit/cute/msvc_compilation/tuple.cpp new file mode 100644 index 00000000..44c268a2 --- /dev/null +++ b/test/unit/cute/msvc_compilation/tuple.cpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include +#include + +#include +#include + +template +class ConvertibleTo { +public: + ConvertibleTo(T val) : val_(val) {} + + operator T () const { return val_; } + +private: + T val_ = 0; +}; + +template +using IC = std::integral_constant; + +TEST(CuTe_core_msvc_compilation, TupleAssignment) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("cute::tuple creation and assignment"); + CUTLASS_TRACE_HOST("-------------------------------"); + + using forty_two_type = IC; + using forty_three_type = IC; + + using ebo_s_type = cute::detail::EBO<0, forty_two_type>; + [[maybe_unused]] ebo_s_type ebo_s; + static_assert(std::is_same_v); + + using ebo_d_type = cute::detail::EBO<1, size_t>; + [[maybe_unused]] ebo_d_type ebo_d(43u); + assert(ebo_d.t_ == 43u); + static_assert(std::is_same_v>, size_t > ); + assert(cute::detail::getv(ebo_d) == 43u); + + [[maybe_unused]] cute::detail::TupleBase, int, forty_two_type, size_t> tb0{ + 41, forty_two_type{}, size_t(43u) }; + [[maybe_unused]] cute::detail::TupleBase, int, forty_two_type, size_t> tb1; + + int val41 = ConvertibleTo{41}; + assert(val41 == 41); + size_t val43 = ConvertibleTo{size_t(43u)}; + assert(val43 == size_t{43u}); + [[maybe_unused]] cute::detail::TupleBase, int, forty_two_type, size_t> tb2{ + ConvertibleTo{41}, forty_two_type{}, ConvertibleTo{size_t(43u)}}; + + [[maybe_unused]] cute::detail::TupleBase, int> tb3{ 41 }; + [[maybe_unused]] cute::detail::TupleBase, int> tb3a{ 42 }; + tb3 = tb3a; + + using tuple_0d_type = cute::tuple<>; + using tuple_1d_d_type = cute::tuple; + using tuple_1d_s_type = cute::tuple; + using tuple_2d_dd_type = cute::tuple; + using tuple_2d_ss_type = cute::tuple; + + [[maybe_unused]] tuple_0d_type t0; + + // Symptom: "illegal member initialization: 'TupleBase' is not a base or member" + [[maybe_unused]] tuple_1d_d_type t1{ 42 }; + + [[maybe_unused]] tuple_1d_s_type t2; + + [[maybe_unused]] tuple_1d_d_type t1a{ 43 }; + t1 = t1a; + + [[maybe_unused]] tuple_2d_dd_type t3{ 42, size_t(43u) }; + [[maybe_unused]] tuple_2d_ss_type t4; + t3 = t4; + + [[maybe_unused]] tuple_2d_dd_type t3a{ 44, size_t(45u) }; + // Symptom: "illegal member initialization: + // 'TupleBase' is not a base or member" + t3 = t3a; +} + +TEST(CuTe_core_msvc_compilation, TupleGetSingleInteger) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("cute::get on cute::tuple for single integer I"); + CUTLASS_TRACE_HOST("-------------------------------"); + + cute::tuple, IC> t0{ 41, size_t(42u), IC{} }; + + [[maybe_unused]] auto t0_0 = cute::get<0>(t0); + static_assert(std::is_same_v); + assert(t0_0 == 41); + + [[maybe_unused]] auto t0_1 = cute::get<1>(t0); + static_assert(std::is_same_v>); + + [[maybe_unused]] auto t0_2 = cute::get<2>(t0); + static_assert(std::is_same_v>); +} + +TEST(CuTe_core_msvc_compilation, TupleGetRecursive) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("cute::get on cute::tuple"); + CUTLASS_TRACE_HOST("-------------------------------"); + + using inner_tuple_type = cute::tuple, IC>; + using outer_tuple_type = cute::tuple, inner_tuple_type, size_t>; + + inner_tuple_type t0_inner{ 41, size_t(42u), IC{} }; + outer_tuple_type t0_outer{ IC{}, t0_inner, size_t(44u) }; + + [[maybe_unused]] auto t0_outer_0 = cute::get<0>(t0_outer); + static_assert(std::is_same_v>); + + [[maybe_unused]] auto t0_outer_1 = cute::get<1>(t0_outer); + static_assert(std::is_same_v); + + [[maybe_unused]] auto t0_outer_2 = cute::get<2>(t0_outer); + static_assert(std::is_same_v); + assert(t0_outer_2 == size_t(44u)); + + // Leftmost index is innermost in the nexted get sequence. + [[maybe_unused]] auto t0_outer_10 = cute::get<1, 0>(t0_outer); + static_assert(std::is_same_v); + assert(t0_outer_10 == 41); +} diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 2803a896..717dbd5b 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -267,6 +267,19 @@ cutlass_test_unit_add_executable( sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu ) +# Fused epilogue tests +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_sm90 + + BATCH_SOURCES ON + BATCH_SIZE 4 + sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu + sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu + sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 @@ -276,7 +289,17 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu - sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90 + + BATCH_SOURCES ON + BATCH_SIZE 4 + + sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu ) cutlass_test_unit_add_executable( @@ -337,6 +360,7 @@ cutlass_test_unit_add_executable( gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu + gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu @@ -416,7 +440,6 @@ cutlass_test_unit_add_executable( gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu ) - cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_grouped diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp index 76422b15..adfb9eda 100644 --- a/test/unit/gemm/device/default_gemm_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -40,6 +40,7 @@ #include "cutlass/layout/layout.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -200,7 +201,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -331,7 +333,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -397,7 +400,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -504,7 +508,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; @@ -579,7 +584,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -642,7 +648,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -703,7 +710,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -764,7 +772,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -827,7 +836,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -886,7 +896,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -947,7 +958,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -1008,7 +1020,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -1071,7 +1084,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; /* using EpilogueOutputOp = epilogue::collective::Epilogue< @@ -1148,7 +1162,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -1211,7 +1226,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -1274,7 +1290,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, - epilogue::thread::LinearCombination>; + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// @@ -1330,10 +1347,16 @@ struct DefaultGemmConfigurationToCutlass3Types< >; // Epilogue - using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< - TagToStrideC_t, - TagToStrideC_t, - epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + double, double, + double, cutlass::layout::ColumnMajor, 1, + double, cutlass::layout::ColumnMajor, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + }; /////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_grouped_sm80.cu b/test/unit/gemm/device/gemm_grouped_sm80.cu index 3fa3519a..38834bda 100644 --- a/test/unit/gemm/device/gemm_grouped_sm80.cu +++ b/test/unit/gemm/device/gemm_grouped_sm80.cu @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file \brief Tests for device-wide GEMM interface - + */ #include @@ -80,7 +80,7 @@ struct GemmGroupedProblemVisitor { // // Data members // - + SharedStorage &shared_storage; Params const ¶ms; cutlass::MatrixCoord threadblock_shape; @@ -95,7 +95,7 @@ struct GemmGroupedProblemVisitor { // CUTLASS_DEVICE GemmGroupedProblemVisitor( - SharedStorage &shared_storage_, + SharedStorage &shared_storage_, Params const ¶ms_, cutlass::MatrixCoord threadblock_shape_, int32_t block_idx @@ -187,7 +187,7 @@ struct GemmGroupedProblemVisitor { CUTLASS_DEVICE void advance(int32_t grid_size) { - tile_idx += grid_size; + tile_idx += grid_size; } }; @@ -199,9 +199,9 @@ __global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { __shared__ GemmGroupedProblemVisitor::SharedStorage shared_storage; GemmGroupedProblemVisitor problem_visitor( - shared_storage, - params, - {ThreadblockShapeM, ThreadblockShapeN}, + shared_storage, + params, + {ThreadblockShapeM, ThreadblockShapeN}, blockIdx.x); while (problem_visitor.next_tile()) { @@ -220,12 +220,12 @@ __global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { if (threadIdx.x == 0) { #if 0 - printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n", - blockIdx.x, - problem_visitor.tile_index(), - problem_visitor.problem_index(), - threadblock_idx, - threadblock_tile_m_idx, + printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n", + blockIdx.x, + static_cast(problem_visitor.tile_index()), + problem_visitor.problem_index(), + threadblock_idx, + threadblock_tile_m_idx, threadblock_tile_n_idx); #endif } @@ -272,10 +272,10 @@ TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { tile_counts.at(i) = tile_count; if (false) { - std::cout << "Problem " << i << " size(" - << problem_sizes.at(i).m() << "-by-" << problem_sizes.at(i).n() - << ") - tiles: " << problem_tile_count << ", grid(" << grid_shape.m() << ", " << grid_shape.n() - << "), tiles[" << tile_start << ", " << tile_count << ")" << std::endl; + std::cout << "Problem " << i << " size(" + << problem_sizes.at(i).m() << "-by-" << problem_sizes.at(i).n() + << ") - tiles: " << problem_tile_count << ", grid(" << grid_shape.m() << ", " << grid_shape.n() + << "), tiles[" << tile_start << ", " << tile_count << ")" << std::endl; } } @@ -309,25 +309,25 @@ TEST(SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) using ElementAccumulator = float; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - cutlass::half_t, - cutlass::layout::ColumnMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 8, cutlass::half_t, - cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 8, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -340,7 +340,7 @@ TEST(SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) bool passed = testbed.run(24); EXPECT_TRUE(passed); - + } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -392,25 +392,25 @@ TEST(SM80_Device_GemmGrouped_f16t_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { using ElementAccumulator = float; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - cutlass::half_t, - cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8, cutlass::half_t, - cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 8, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 64, 32>, - cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 4>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -475,17 +475,17 @@ TEST(SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64, 64x64x16_32x32x16) { using ElementAccumulator = double; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementInput, - cutlass::layout::RowMajor, + ElementInput, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1, ElementInput, - cutlass::layout::RowMajor, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 16>, cutlass::gemm::GemmShape<32, 32, 16>, @@ -493,7 +493,7 @@ TEST(SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64, 64x64x16_32x32x16) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 1, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 4>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -517,17 +517,17 @@ TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x128x8_64x32x1) { using ElementAccumulator = float; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementInput, - cutlass::layout::RowMajor, + ElementInput, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1, ElementInput, - cutlass::layout::RowMajor, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassSimt, + ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 8>, cutlass::gemm::GemmShape<64, 32, 8>, @@ -535,7 +535,7 @@ TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x128x8_64x32x1) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 1, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -685,17 +685,17 @@ TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) using ElementAccumulator = cutlass::complex; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementInput, - cutlass::layout::ColumnMajor, + ElementInput, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 1, ElementInput, - cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, 1, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 16>, cutlass::gemm::GemmShape<32, 32, 16>, @@ -703,7 +703,7 @@ TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) cutlass::epilogue::thread::LinearCombination< ElementOutput, 1, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; @@ -729,17 +729,17 @@ TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) using ElementAccumulator = cutlass::complex; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementInput, - cutlass::layout::ColumnMajor, + ElementInput, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kConjugate, 1, ElementInput, - cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kConjugate, 1, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 16>, cutlass::gemm::GemmShape<32, 32, 16>, @@ -747,7 +747,7 @@ TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) cutlass::epilogue::thread::LinearCombination< ElementOutput, 1, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; @@ -817,17 +817,17 @@ TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) using ElementAccumulator = cutlass::complex; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementInput, - cutlass::layout::RowMajor, + ElementInput, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1, ElementInput, - cutlass::layout::RowMajor, + cutlass::layout::RowMajor, cutlass::ComplexTransform::kConjugate, 1, ElementOutput, cutlass::layout::ColumnMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<32, 32, 16>, cutlass::gemm::GemmShape<16, 16, 16>, @@ -835,7 +835,7 @@ TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) cutlass::epilogue::thread::LinearCombination< ElementOutput, 1, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu index 09a502bd..27c37159 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu @@ -116,6 +116,38 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } +TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32_align8, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 8, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { using ElementOutput = cutlass::int4b_t; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu index 7b002d55..bce66df7 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu @@ -249,6 +249,26 @@ CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x12 EXPECT_TRUE(testbed.run_all()); } ) +CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32_align8, 256x128x128_64x64x128, { + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 8, ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128, { using ElementOutput = cutlass::int4b_t; using ElementAccumulator = int32_t; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu index 525677a6..1fb0d7b0 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu @@ -116,6 +116,38 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } +TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32_align8, 256x128x128_64x64x128) { + + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, + cutlass::layout::RowMajor, + cutlass::int4b_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<8, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, + 8, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { using ElementOutput = cutlass::int4b_t; diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu index ccaaabfd..6d62511b 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu @@ -249,6 +249,26 @@ CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x12 EXPECT_TRUE(testbed.run_all()); } ) +CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32_align8, 256x128x128_64x64x128, { + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 8, ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + CUTLASS_TEST_L0(SM80_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128, { using ElementOutput = cutlass::int4b_t; using ElementAccumulator = int32_t; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu new file mode 100644 index 00000000..526758d1 --- /dev/null +++ b/test/unit/gemm/device/gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu @@ -0,0 +1,77 @@ +/************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "multistage_testbed.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8t_s8n_f16t_tensor_op_s32, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} + +//////////////////////////////////////////////////////////////////////////////// +#endif // #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu index cb7401f8..2880777b 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm75.cu @@ -89,6 +89,24 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_align8, 256x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 8>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { using ElementOutput = int8_t; using ElementAccumulator = int32_t; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu index 18daa2bb..b0a1276b 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu @@ -249,6 +249,26 @@ CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 256x128x64_64x64x64, EXPECT_TRUE(testbed.run_all()); } ) +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_align8, 256x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::ColumnMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 8>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32, 128x128x64_64x64x64, { using ElementOutput = int8_t; using ElementAccumulator = int32_t; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu index a13a6eb0..9e290763 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu @@ -88,6 +88,24 @@ CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_align8, 256x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<8, 8, 16>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 8>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} ) + CUTLASS_TEST_L0(SM75_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { using ElementOutput = int8_t; using ElementAccumulator = int32_t; diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu index 9af1d01c..36ea57c1 100644 --- a/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu @@ -249,6 +249,26 @@ CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 256x128x64_64x64x64, EXPECT_TRUE(testbed.run_all()); } ) +CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_align8, 256x128x64_64x64x64, { + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::FastLinearCombinationClamp< + ElementOutput, 8>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + test::gemm::device::MultistageTestbed testbed; + + EXPECT_TRUE(testbed.run_all()); +} ) + CUTLASS_TEST_L0(SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32, 128x128x64_64x64x64, { using ElementOutput = int8_t; using ElementAccumulator = int32_t; diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 24a9e242..456d7360 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -67,7 +67,10 @@ namespace device { namespace detail{ -template +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> struct TestbedImpl { // Kernel data types using ElementA = typename Gemm::GemmKernel::ElementA; @@ -82,6 +85,8 @@ struct TestbedImpl { using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ThreadEpilogueOp = typename Gemm::GemmKernel::CollectiveEpilogue::ThreadEpilogueOp; + using ActivationFunctor = ActivationFunctor_; static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -110,7 +115,6 @@ struct TestbedImpl { using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); - using LayoutTagPackedVector = cutlass::layout::PackedVectorLayout; /// Initialization StrideA stride_a; @@ -136,7 +140,6 @@ struct TestbedImpl { // Used to force multi-wave tests for persistent kernel schedules constexpr static int MaxSmCount = 16; - // // Methods // @@ -214,6 +217,10 @@ struct TestbedImpl { view.data(), view.capacity()); } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { EXPECT_TRUE(false) << "Not implemented"; return false; @@ -260,7 +267,7 @@ struct TestbedImpl { // in the upper left corner of each operand. tensor_A.host_view().at({0, 0}) = ElementA(1); tensor_B.host_view().at({0, 0}) = ElementB(1); - tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = ElementC(1); + tensor_C.host_view().at({0, 0}) = ElementC(1); cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); @@ -274,8 +281,8 @@ struct TestbedImpl { bool compare_reference( cute::Shape problem_shape_MNKL, ElementScalar alpha, - ElementScalar beta - ) { + ElementScalar beta) + { auto [M, N, K, L] = problem_shape_MNKL; tensor_D.sync_host(); @@ -322,8 +329,8 @@ struct TestbedImpl { bool verify( ProblemShapeType problem_size, ElementScalar alpha, - ElementScalar beta - ) { + ElementScalar beta) + { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto M = cute::size<0>(problem_shape_MNKL); auto N = cute::size<1>(problem_shape_MNKL); @@ -338,6 +345,10 @@ struct TestbedImpl { cute::make_layout(cute::make_shape(M, N, L), stride_c)); auto D = cute::make_tensor(reference_D.host_data(), cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto Bias = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, 1))); + auto T = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; cutlass::reference::host::GettEpilogueParams< @@ -345,18 +356,19 @@ struct TestbedImpl { ElementAccumulator, ElementCompute, decltype(C), - decltype(D) + decltype(D), + decltype(Bias), + decltype(T), + ActivationFunctor > epilogue_params{ alpha, beta, - C, D + C, D, Bias, T }; cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - return compare_reference( - problem_shape_MNKL, alpha, beta - ); + return compare_reference(problem_shape_MNKL, alpha, beta); } /// Determine if the CUDA device is sufficient to run the kernel @@ -429,12 +441,12 @@ struct TestbedImpl { /// Executes one test bool run( - ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - bool profiling = false, - int iterations = 20 - ) { + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20) + { // Fail test if insufficient CUDA device if (!sufficient()) { std::cout << "Test failed due to insufficient CUDA device." << std::endl; @@ -459,17 +471,21 @@ struct TestbedImpl { hw_info.sm_count = this->sm_count; } - // DefaultEpilogue - arguments = typename Gemm::Arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - tensor_A.device_data(), - stride_a, - tensor_B.device_data(), - stride_b, - {tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d, {alpha, beta}}, - hw_info - }; + // DefaultEpilogue + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + tensor_A.device_data(), stride_a, + tensor_B.device_data(), stride_b + }, + { + {alpha, beta}, + tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d + }, + hw_info + }; + Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -505,9 +521,7 @@ struct TestbedImpl { // // Verify // - bool passed = this->verify( - problem_size, alpha, beta - ); + bool passed = this->verify(problem_size, alpha, beta); if (!passed) { std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) << "\n"; @@ -525,33 +539,143 @@ struct TestbedImpl { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Testbed { +template < + typename Gemm, + template class ActivationFunctor +> +struct Testbed3x { - using TestBedImplementation = typename detail::TestbedImpl; + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; - using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; - using LayoutTagA = typename TestBedImplementation::LayoutTagA; - using LayoutTagB = typename TestBedImplementation::LayoutTagB; - using LayoutTagC = typename TestBedImplementation::LayoutTagC; - using LayoutTagD = typename TestBedImplementation::LayoutTagD; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ElementScalar = typename Epilogue::ElementScalar; + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; // Detail Implementation - TestBedImplementation impl_; + TestBedImpl impl_; // // Methods // - Testbed( + Testbed3x( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(init_A_, init_B_, init_C_, seed_) {} + + Testbed3x( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + init_A_, + init_B_, + init_C_, + seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20) + { + return impl_.run( + problem_size, alpha, beta, profiling, iterations + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Testbed for GEMMs with epilogues including a bias operation and an elementwise function +template +struct Testbed3xBiasElementwise { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementA = typename Kernel::ElementA; + using StrideA = typename Kernel::StrideA; + using ElementB = typename Kernel::ElementB; + using StrideB = typename Kernel::StrideB; + using ElementC = typename Kernel::ElementC; + using StrideC = typename Kernel::StrideC; + using ElementD = typename Kernel::ElementD; + using StrideD = typename Kernel::StrideD; + + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementBias = typename Epilogue::ElementBias; + using ElementT = typename Epilogue::ElementT; + using ElementScalar = typename Epilogue::ElementScalar; + using ActivationFunctor = typename Epilogue::ActivationFunctor; + using BinaryOp = typename Epilogue::BinaryOp; + + static constexpr bool IsBiasEnabled = Epilogue::iskThreadEpilogueOpWithBias; + static constexpr bool StoreT = Epilogue::StoreT; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + cutlass::HostTensor bias; + cutlass::HostTensor< ElementT, LayoutTagD> tensor_T; + cutlass::HostTensor< ElementT, LayoutTagD> reference_T; + + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + // Factors used for calculating relative equality. These default + // values are borrowed from those used by default in the CUTLASS + // profiler for performing relative equality checks. + float epsilon = 0.05f; + float nonzero_floor = 1.0f / 256.0f; + + // + // Methods + // + Testbed3xBiasElementwise( + bool check_relative_equality_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImplementation::kDefaultSeed) - : impl_(init_A_, init_B_, init_C_, seed_) {} + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(check_relative_equality_) { } - Testbed( + Testbed3xBiasElementwise( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(false) { } + + Testbed3xBiasElementwise( typename LayoutTagA::Stride stride_factor_A_, typename LayoutTagB::Stride stride_factor_B_, typename LayoutTagC::Stride stride_factor_C_, @@ -559,33 +683,292 @@ struct Testbed { cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImplementation::kDefaultSeed) - : impl_(stride_factor_A_, - stride_factor_B_, - stride_factor_C_, - stride_factor_D_, - init_A_, - init_B_, - init_C_, - seed_) {} - - /// Executes one test - bool run( - typename TestBedImplementation::ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - bool profiling = false, - int iterations = 20 - ) { - return impl_.run( - problem_size, alpha, beta, profiling, iterations - ); - } + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + init_A_, + init_B_, + init_C_, + seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B/C/D/T tensor + // + impl_.initialize(problem_size); + + if constexpr (StoreT) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto c_coord = cutlass::make_Coord(M * L, N); + tensor_T.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_D)); + reference_T.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_D), false); + tensor_T.sync_device(); + } + } + + void initialize_bias(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(M)); + + EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023)); + bias.sync_device(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + if (check_relative_equality) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, Element(epsilon), Element(nonzero_floor)); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + auto [M, N, K, L] = problem_shape_MNKL; + auto coord_0 = cutlass::make_Coord(0); + + impl_.tensor_D.sync_host(); + tensor_T.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_C.host_view()), 0); + + if (impl_.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_D.host_view()), 0); + } + + if (impl_.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.reference_D.host_view()), 0); + } + + if constexpr (StoreT) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_T.host_view()), 0); + } + + bool passed_D = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); + EXPECT_TRUE(passed_D); + + bool passed_T = StoreT ? equality_check(reference_T.host_view(), tensor_T.host_view()) : true; + EXPECT_TRUE(passed_T); + + bool passed = passed_D && passed_T; + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n"; + + if constexpr (IsBiasEnabled) { + file << "Bias = \n" << bias.host_view()<< "\n\n"; + } + + file + << "A =\n" << impl_.tensor_A.host_view() + << "\nB =\n" << impl_.tensor_B.host_view() + << "\nC =\n" << impl_.tensor_C.host_view(); + if constexpr (StoreT) { + file + << "\n\nReference_T =\n" << reference_T.host_view() + << "\n\nComputed_T =\n" << tensor_T.host_view(); + } + file + << "\n\nReference_D =\n" << impl_.reference_D.host_view() + << "\n\nComputed_D =\n" << impl_.tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result against a reference implementation + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta) + { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + + auto A = cute::make_tensor(impl_.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); + auto B = cute::make_tensor(impl_.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); + auto C = cute::make_tensor(impl_.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + auto D = cute::make_tensor(impl_.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + auto Bias = cute::make_tensor(static_cast(IsBiasEnabled ? bias.host_data() : nullptr), + cute::make_layout(cute::make_shape(M, 1))); + auto T = cute::make_tensor(static_cast(StoreT ? reference_T.host_data() : nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(T), + ActivationFunctor, + BinaryOp> + epilogue_params{ + alpha, + beta, + C, + D, + Bias, + T + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return compare_reference(problem_shape_MNKL, alpha, beta); + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20) + { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// bias + if constexpr (IsBiasEnabled){ + initialize_bias(problem_size); + } + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.tensor_A.device_data(), impl_.stride_a, + impl_.tensor_B.device_data(), impl_.stride_b + }, + { // Epilogue arguments + { + alpha, + beta + }, + impl_.tensor_C.device_data(), + impl_.stride_c, + impl_.tensor_D.device_data(), + impl_.stride_d, + bias.device_data(), + tensor_T.device_data() + }, // Epilogue arguments end + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha, beta); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) + << "\n"; + } + + return passed; + } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> bool TestAll() { using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -595,7 +978,73 @@ bool TestAll() { std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; if constexpr (std::is_same_v) { + cutlass::gemm::KernelTmaWarpSpecializedPingpong>) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3x testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0) + ); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0) + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllBiasElementwise(bool check_relative_equality=false) { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (std::is_same_v) { problem_size_m.push_back(768); problem_size_n.push_back(768); } @@ -605,7 +1054,7 @@ bool TestAll() { std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; - Testbed testbed; + Testbed3xBiasElementwise testbed(check_relative_equality); bool passed = true; for (int m : problem_size_m) { @@ -651,7 +1100,7 @@ bool TestAll() { ///////////////////////////////////////////////////////////////////////////////////////////////// template -bool TestGemmPerf(int iterations = 20) { +bool TestGemmPerf3x(int iterations = 20) { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; using ElementScalar = ElementAccumulator; @@ -661,7 +1110,7 @@ bool TestGemmPerf(int iterations = 20) { std::vector problem_size_n = { 4608 }; std::vector problem_size_k = { 8192 }; - Testbed testbed; + Testbed3x testbed; for (int m : problem_size_m) { for (int n : problem_size_n) { diff --git a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp new file mode 100644 index 00000000..3e5424a1 --- /dev/null +++ b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -0,0 +1,488 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with elementwise tensor-tensor broadcast epilogue +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed_utils.h" +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed3xTensorBroadcast { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementA = typename Kernel::ElementA; + using StrideA = typename Kernel::StrideA; + using ElementB = typename Kernel::ElementB; + using StrideB = typename Kernel::StrideB; + using ElementC = typename Kernel::ElementC; + using StrideC = typename Kernel::StrideC; + using ElementD = typename Kernel::ElementD; + using StrideD = typename Kernel::StrideD; + + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ElementScalar = typename Epilogue::ElementScalar; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementBias = typename Epilogue::ElementBias; + using ActivationFunctor = typename Epilogue::ActivationFunctor; + + static constexpr bool IsBinaryOp0Enabled = Epilogue::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C1; + // tensor_C0 is taken from TestbedImpl's tensor_C + + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3xTensorBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(init_A_, init_B_, init_C_, seed_) { } + + Testbed3xTensorBroadcast( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + init_A_, + init_B_, + init_C_, + seed_) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B/C/D tensor + // + impl_.initialize(problem_size); + } + + void initialize_bias(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(M)); + + EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023)); + bias.sync_device(); + } + + void initialize_c1(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto c_coord = cutlass::make_Coord(M * L, N); + + tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_C)); + EXPECT_TRUE(impl_.initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2024)); + tensor_C1.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta, + bool use_bias) + { + auto [M, N, K, L] = problem_shape_MNKL; + auto coord_0 = cutlass::make_Coord(0); + + impl_.tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_B.host_view()), 0); + + if (impl_.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_D.host_view()), 0); + } + + if (impl_.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.reference_D.host_view()), 0); + } + + bool passed = cutlass::reference::host::TensorEquals(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_broadcast" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias << "\n\n"; + + if (use_bias){ + file << "Bias = \n" << bias.host_view()<< "\n\n"; + } + + file + << "A =\n" << impl_.tensor_A.host_view() + << "\nB =\n" << impl_.tensor_B.host_view() + << "\nC0 =\n" << impl_.tensor_C.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << impl_.reference_D.host_view() + << "\n\nComputed =\n" <(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + + auto A = cute::make_tensor(impl_.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); + auto B = cute::make_tensor(impl_.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); + auto D = cute::make_tensor(impl_.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), + cute::make_layout(cute::make_shape(M, 1))); + auto C0 = cute::make_tensor(impl_.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + auto C1 = cute::make_tensor(tensor_C1.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + + // Create host workspace for output of testbed. This computes a portion of the epilogue: + // ref_compute_out = Activation(alpha * (A @ B) + bias) + cutlass::HostTensor ref_compute_out; + auto c_coord = cutlass::make_Coord(M * L, N); + ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_C), false); + auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + // Use a dummy null tensor for operand C because the epilogue overrides C. + auto dummy_C = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); + ElementCompute dummy_beta(0); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(dummy_C), + decltype(RefComputeOut), + decltype(Bias), + decltype(dummy_C), + ActivationFunctor> epilogue_params{ + alpha, + dummy_beta, + dummy_C, + RefComputeOut, + Bias, + dummy_C + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + cutlass::NumericConverter source_converter; + cutlass::NumericConverter destination_converter; + cutlass::multiplies mul; + + // Compute broadcast operations atop the reference + #pragma omp parallel for collapse(3) + for (int64_t l = 0; l < cute::size<2>(A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(A.layout()); ++m) { + for (int64_t n = 0; n < cute::size<0>(B.layout()); ++n) { + ElementCompute intermediate = RefComputeOut(m, n, l); + // Apply BinaryOp0, if needed + if constexpr (IsBinaryOp0Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp0 bin0; + ElementCompute converted_source = source_converter(C0(m, n, l)); + intermediate = bin0(intermediate, mul(beta, converted_source)); + } + + // Apply BinaryOp1, if needed + if constexpr (IsBinaryOp1Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp1 bin1; + ElementCompute converted_source = source_converter(C1(m, n, l)); + intermediate = bin1(intermediate, mul(beta, converted_source)); + } + + // Apply UnaryOp, if needed + if constexpr (IsUnaryOpEnabled) { + typename Epilogue::ThreadEpilogueOp::UnaryOp unary; + intermediate = unary(intermediate); + } + + D(m, n, l) = destination_converter(intermediate); + } + } + } + + return compare_reference(problem_shape_MNKL, alpha, beta, use_bias); + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20, + bool use_bias = true) + { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + /// Initializes data structures + /// A/B/C0/D Tensor + initialize(problem_size); + initialize_bias(problem_size); + + if constexpr (IsBinaryOp1Enabled) { + initialize_c1(problem_size); + } + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { impl_.tensor_A.device_data(), impl_.stride_a, + impl_.tensor_B.device_data(), impl_.stride_b + }, + { // Epilogue arguments + { alpha, beta }, // ThreadOp arguments + impl_.stride_c, + impl_.tensor_D.device_data(), + impl_.stride_d, + use_bias ? bias.device_data() : nullptr, + impl_.tensor_C.device_data(), + tensor_C1.device_data() + }, // Epilogue arguments end + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha, beta, use_bias); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) + << ", beta: " << float(beta) + << ", use_bias: " << use_bias + << "\n"; + } + + return passed; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllTensorBroadcast(bool use_bias=true) { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (std::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xTensorBroadcast testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (bool use_bias : {true, false}) { + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20, // iterations + use_bias + ); + + if (!passed) { + return false; + } + } + } + } + } + + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20 // iterations + ); + if (!passed) { + return false; + } + } + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu index 9fbbd862..b0279f0f 100644 --- a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu @@ -43,6 +43,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -72,15 +73,20 @@ TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -104,15 +110,20 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -136,15 +147,20 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -168,15 +184,20 @@ TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu index d3983e4d..a151126b 100644 --- a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu @@ -42,6 +42,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -71,15 +72,20 @@ TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -103,15 +109,20 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -135,15 +146,20 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -167,15 +183,20 @@ TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu index 0ee526be..1e0c395b 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu @@ -43,6 +43,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -74,15 +75,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::KernelMultistage >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -104,15 +110,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -135,15 +146,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -169,15 +185,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::KernelMultistage >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -201,15 +222,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -233,15 +259,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -267,15 +298,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::KernelMultistage >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -299,15 +335,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -331,15 +372,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -365,15 +411,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::KernelMultistage >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -397,15 +448,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -429,15 +485,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu index 4fea99ab..cc049da4 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op.cu @@ -42,8 +42,9 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/epilogue.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -72,15 +73,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -102,15 +108,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -132,15 +143,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -164,15 +180,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -194,15 +215,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -224,15 +250,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -256,14 +287,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -285,15 +322,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -315,15 +357,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f32, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -347,15 +394,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -377,15 +429,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -407,15 +464,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f32, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -441,15 +503,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -471,15 +538,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -501,15 +573,20 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -533,15 +610,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -563,15 +645,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -593,15 +680,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -625,15 +717,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -655,15 +752,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -685,15 +787,20 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_tensor_op_gmma_f16, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -717,15 +824,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x128x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -747,15 +859,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 128x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -777,295 +894,20 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_tensor_op_gmma_f16, 64x64x64) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 64x128x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - cutlass::half_t, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_Epilogue, 128x64x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - cutlass::half_t, - Shape<_128,_64,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 64x128x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - cutlass::half_t, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_Epilogue, 128x64x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - cutlass::half_t, - Shape<_128,_64,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 64x128x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_1,_64>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_64,_16>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_Epilogue, 128x64x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - Shape<_128,_64,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout,_64>,Stride,_64>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_128,_8>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 64x128x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout>,Stride<_64,Stride<_1,_4096>>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_8,_128>>, - Copy_Atom>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - EpilogueOp - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_Epilogue, 128x64x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - Shape<_128,_64,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + Shape<_64,_64,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::Epilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination, - ComposedLayout, smem_ptr_flag_bits::value>, Layout,Stride<_64,_1>>>, - Copy_Atom, - TiledCopy,Layout,Stride<_8,_1>>,Shape<_16,_64>>, - Copy_Atom>; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu index 16466329..f8eef136 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_unspecialized.cu @@ -42,6 +42,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -73,10 +74,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -105,10 +111,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -137,10 +148,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -169,10 +185,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -204,10 +225,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -236,10 +262,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -268,10 +299,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -300,10 +336,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -336,10 +377,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -368,10 +414,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -400,10 +451,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -432,10 +488,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -468,10 +529,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -500,10 +566,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -532,10 +603,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -564,10 +640,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_unspecialized, 64x128x64 cutlass::gemm::KernelTma >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu index 378315d6..942d1862 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu @@ -42,6 +42,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -73,10 +74,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -105,10 +111,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -137,10 +148,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -169,10 +185,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -204,10 +225,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -236,10 +262,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -268,10 +299,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -300,10 +336,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -336,10 +377,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -368,10 +414,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -400,10 +451,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -432,10 +488,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -468,10 +529,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -500,10 +566,15 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -532,10 +603,15 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, @@ -564,10 +640,15 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x cutlass::gemm::KernelTmaWarpSpecialized >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu new file mode 100644 index 00000000..6eff33a7 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu @@ -0,0 +1,850 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_1x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_1x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x2x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 4x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 1x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu new file mode 100644 index 00000000..b261d5a5 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu @@ -0,0 +1,366 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with bias and elementwise epilogues. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "testing_elementwise.hpp" +#include "gemm_testbed_3x.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_ReLU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeElementwise< + cutlass::epilogue::thread::ReLu>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAll(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_GELU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::GELU, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool check_relative_equality = true; + bool passed = test::gemm::device::TestAllBiasElementwise(check_relative_equality); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_ReLU_NoStoreT) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = false; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_Bias_Negate) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + test::gemm::device::detail::Negate, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasMul_ReLU) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu similarity index 71% rename from test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu rename to test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu index c7d814b8..f549e105 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_persistent.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu @@ -43,7 +43,8 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -65,10 +66,15 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_1,_1,_1>; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -77,7 +83,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -100,12 +106,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_1,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -114,7 +125,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -137,12 +148,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_1,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -151,7 +167,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -174,12 +190,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -188,7 +209,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -212,12 +233,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_4,_1,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -226,7 +252,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -249,12 +275,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_1,_4,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -263,7 +294,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_1x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -286,12 +317,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_4,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -300,7 +336,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_2x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -323,12 +359,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_4,_4,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -337,7 +378,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 64x128x64_4x ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -360,12 +401,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_1,_1,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -374,7 +420,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -397,12 +443,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_1,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -411,7 +462,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -434,12 +485,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_1,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -448,7 +504,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -471,12 +527,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -485,7 +546,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -509,12 +570,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_4,_1,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -523,7 +589,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -546,12 +612,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_1,_4,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -560,7 +631,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_1 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -583,12 +654,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_4,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -597,7 +673,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -620,12 +696,17 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_4,_4,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -634,7 +715,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_persistent, 128x128x64_4 ElementAccumulator, TileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -660,19 +741,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 64x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout,Stride<_1,_64>>; using TileShapeS2R = Shape<_64,_16>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -680,8 +762,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 64x ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -705,19 +787,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 128 using TileShape_MNK = Shape<_128,_64,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout,_64>,Stride,_64>>; using TileShapeS2R = Shape<_128,_8>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -725,8 +808,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f16_persistent_Epilogue, 128 ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -752,19 +835,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 64x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; using TileShapeS2R = Shape<_8,_128>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -772,8 +856,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 64x ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -797,19 +881,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 128 using TileShape_MNK = Shape<_128,_64,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout,Stride<_64,_1>>; using TileShapeS2R = Shape<_16,_64>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -817,8 +902,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f16_persistent_Epilogue, 128 ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -844,19 +929,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 64x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout,Stride<_1,_64>>; using TileShapeS2R = Shape<_64,_16>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -864,8 +950,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 64x ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -889,19 +975,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 128 using TileShape_MNK = Shape<_128,_64,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout,_64>,Stride,_64>>; using TileShapeS2R = Shape<_128,_8>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -909,8 +996,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_Epilogue, 128 ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -936,19 +1023,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 64x using TileShape_MNK = Shape<_64,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout>,Stride<_64,Stride<_1,_4096>>>; using TileShapeS2R = Shape<_8,_128>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -956,8 +1044,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 64x ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -981,19 +1069,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128 using TileShape_MNK = Shape<_128,_64,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPersistent; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; using PreSwizzleLayout = Layout,Stride<_64,_1>>; using TileShapeS2R = Shape<_16,_64>; - using CollectiveEpilogue = cutlass::epilogue::collective::Epilogue< + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, cutlass::epilogue::thread::LinearCombination, ComposedLayout, smem_ptr_flag_bits>, PreSwizzleLayout>, Copy_Atom, TiledCopy,Layout,Stride<_8,_1>>,TileShapeS2R>, - Copy_Atom>; + Copy_Atom>>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -1001,8 +1090,94 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_Epilogue, 128 ElementB, LayoutB, 8, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelTmaWarpSpecializedPersistent + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using ElementC = ElementA; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent, 128x128x64_2x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using ElementC = ElementA; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu new file mode 100644 index 00000000..75eb94d9 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu @@ -0,0 +1,365 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide persistent GEMM interface with bias and elementwise epilogues. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "testing_elementwise.hpp" +#include "gemm_testbed_3x.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_ReLU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedElementwise< + cutlass::epilogue::thread::ReLu>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + bool passed = test::gemm::device::TestAll(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_GELU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::GELU, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool check_relative_equality = true; + bool passed = test::gemm::device::TestAllBiasElementwise(check_relative_equality); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_ReLU_NoStoreT) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = false; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_Bias_Negate) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + test::gemm::device::detail::Negate, cutlass::half_t, cutlass::plus, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasMul_ReLU) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + static constexpr bool StoreT = true; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< + cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu new file mode 100644 index 00000000..5aff82b9 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu @@ -0,0 +1,298 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with an elementwise tensor-tensor broadcast epilogue +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp" +#include "cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_tensor_broadcast.hpp" +#include "testing_elementwise.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32_tensor_broadcast, 64x128x64_ActIdentity_Bin0Plus_Bin1NoOp_UnaryIdentity) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + ElementOutput, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(!EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(!EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32_tensor_broadcast, 64x128x64_ActReLu_Bin0Plus_Bin1Plus_UnaryNegate) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + ElementOutput, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::plus, + cutlass::plus, + test::gemm::device::detail::Negate + >, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16t_tensor_op_gmma_f32_tensor_broadcast, 64x128x64_ActReLu_Bin0Mul_Bin1Plus_UnaryNegate) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + ElementOutput, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::multiplies, + cutlass::plus, + test::gemm::device::detail::Negate + >, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32_tensor_broadcast, 128x128x64_ActReLu_Bin0NoOp_Bin1Plus_UnaryNegate) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + ElementOutput, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::epilogue::thread::detail::NoOp, + cutlass::plus, + test::gemm::device::detail::Negate + >, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(!EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized_tensor_broadcast, 64x128x64_2x2x1_ActReLu_Bin0Mul_Bin1Plus_UnaryNegate) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::multiplies, + cutlass::plus, + test::gemm::device::detail::Negate + >, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu index b4edaf61..680c4dc2 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu @@ -36,9 +36,9 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/collective/default_transposed_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" #include "../../common/cutlass_unit_test.h" @@ -66,10 +66,15 @@ TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32, 64x128x32_1x2x1) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu new file mode 100644 index 00000000..735d14fb --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32_tensor_broadcast.cu @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with an elementwise tensor-tensor broadcast epilogue +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp" +#include "cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_tensor_broadcast.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32_tensor_broadcast, 64x128x32_1x2x1_ActReLU_Bin0Mul_Bin1Plus_UnaryHardSwish) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = float; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + float, LayoutA, 4, + float, LayoutB, 4, + float, + Shape<_64,_128,_128>, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::multiplies, + cutlass::plus, + cutlass::epilogue::thread::HardSwish + >, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu index 5d30e961..174fb1b4 100644 --- a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu @@ -44,6 +44,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -72,15 +73,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 8, + int8_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -102,15 +108,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { cutlass::gemm::KernelMultistage >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 8, + int8_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -132,15 +143,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 4, + int8_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu index f0762a9d..63f6f470 100644 --- a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32.cu @@ -43,6 +43,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -71,15 +72,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -103,15 +109,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 64x128x128_1x2x1) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -133,15 +144,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -163,15 +179,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_1x2x1) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -193,15 +214,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x1x1) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -223,15 +249,20 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32, 128x128x128_2x2x1) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu new file mode 100644 index 00000000..a1f352d6 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with an elementwise tensor-tensor broadcast epilogue +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp" +#include "cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_tensor_broadcast.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_tensor_op_gmma_s32_tensor_broadcast, 128x128x128_2x2x1_ActReLU_Bin0Mul_Bin1Plus_UnaryHardSwish) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementOutput = int32_t; + using ElementAccumulator = ElementOutput; + using ElementCompute = ElementOutput; + using ElementBias = ElementOutput; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueTensorBroadcast< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombinationTensorBroadcast< + ElementOutput, ElementAccumulator, ElementCompute, ElementBias, + cutlass::epilogue::thread::ReLu, + cutlass::multiplies, + cutlass::plus, + cutlass::epilogue::thread::HardSwish + >, + cutlass::gemm::EpilogueDefault>>; + + EXPECT_TRUE(EpilogueOp::IsBinaryOp0Enabled); + EXPECT_TRUE(EpilogueOp::IsBinaryOp1Enabled); + EXPECT_TRUE(EpilogueOp::IsUnaryOpEnabled); + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllTensorBroadcast()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu index e95772f3..bb25de29 100644 --- a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu @@ -43,6 +43,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -71,15 +72,20 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) { cutlass::gemm::KernelMultistage >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -101,15 +107,20 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 2, + float, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -131,15 +142,20 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32, 128x64x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 1, + float, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu index ce570a2f..bc31d24a 100644 --- a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -43,6 +43,7 @@ #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" @@ -69,15 +70,20 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -101,15 +107,20 @@ TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -133,15 +144,20 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 1, + float, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -165,15 +181,20 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; - using EpilogueOp = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveOp, - EpilogueOp + CollectiveEpilogue >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu new file mode 100644 index 00000000..0bfeb0a8 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu @@ -0,0 +1,566 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::gemm::EpilogueTransposed + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::gemm::EpilogueTransposed + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +//////////// CollectiveBuilder with KernelScheduleAuto ////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1_auto_schedule) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1_auto_schedule) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1_auto_schedule) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::gemm::EpilogueTransposed + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_rs_ws_f32, 64x128x32_4x2x1_auto_schedule) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_32>; + using ClusterShape_MNK = Shape<_4,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 4, + ElementB, LayoutB, 4, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/testing_elementwise.hpp b/test/unit/gemm/device/testing_elementwise.hpp new file mode 100644 index 00000000..a2d5b3ea --- /dev/null +++ b/test/unit/gemm/device/testing_elementwise.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Elementwise activation functors used only for testing purposes. +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" + +namespace test { +namespace gemm { +namespace device { +namespace detail{ + +/// Simple activation function that negates the input. +template +struct Negate { + static constexpr T neg_one = T(-1); + + CUTLASS_HOST_DEVICE + T operator()(const T& data) { + return data * neg_one; + } +}; + +} // namespace detail +} // namespace device +} // namespace gemm +} // namespace test diff --git a/test/unit/gemm/warp/wmma_sm72.cu b/test/unit/gemm/warp/wmma_sm72.cu index 8b562206..eab1536f 100644 --- a/test/unit/gemm/warp/wmma_sm72.cu +++ b/test/unit/gemm/warp/wmma_sm72.cu @@ -56,7 +56,7 @@ /////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////// Integer wmma.mma //////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// TODO: FIXME SM75 should SM72, but the compilation breaks as SM72 shows up and runs on VOLTA +// TODO: SM75 should be SM72, but the compilation breaks as SM72 shows up and runs on VOLTA TEST(SM75_warp_wmma_row_col_s8, 16x16x16_16x16x16_16x16x16) { // Threadblock and warp with just one native WMMA operation (most basic unit test) using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; diff --git a/test/unit/pipeline/pipeline_async.cu b/test/unit/pipeline/pipeline_async.cu index d2adad6a..61a46c7d 100644 --- a/test/unit/pipeline/pipeline_async.cu +++ b/test/unit/pipeline/pipeline_async.cu @@ -51,7 +51,7 @@ #include "cutlass/util/GPU_Clock.hpp" #include "testbed.h" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/arch/barrier.h" #include "cute/arch/cluster_sm90.hpp" @@ -98,21 +98,21 @@ void pipeline_async_basic_device(uint32_t const num_iterations) cute::cluster_wait(); __syncthreads(); + if (lane_predicate) { // Producer Warps if (warp_idx==0 || warp_idx==1) { + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); int prologue_iterations = min(NumStages, num_iterations); for ( int i = 0; i < prologue_iterations; ++i) { // Can also specify stage to commit directly - pipeline.producer_commit(i); + pipeline.producer_commit(smem_pipe_write); + ++smem_pipe_write; } int mainloop_iterations = num_iterations - prologue_iterations; - // Only the mainloop needs a PipelineState because this is where we start "waiting" (acquiring) - PipelineState smem_pipe_write; - for ( ; mainloop_iterations > 0; --mainloop_iterations) { pipeline.producer_acquire(smem_pipe_write); pipeline.producer_commit(smem_pipe_write); @@ -123,7 +123,7 @@ void pipeline_async_basic_device(uint32_t const num_iterations) PipelineState smem_pipe_read; for (int iter=0 ; iter < num_iterations; ++iter) { pipeline.consumer_wait(smem_pipe_read); - pipeline.consumer_release(smem_pipe_read.index()); + pipeline.consumer_release(smem_pipe_read); ++smem_pipe_read; } } diff --git a/test/unit/pipeline/pipeline_tma_async.cu b/test/unit/pipeline/pipeline_tma_async.cu index 90e0ca3a..0bd40b1f 100644 --- a/test/unit/pipeline/pipeline_tma_async.cu +++ b/test/unit/pipeline/pipeline_tma_async.cu @@ -41,7 +41,7 @@ #include #include -#include +#include #include #include @@ -52,7 +52,7 @@ #include "cutlass/util/GPU_Clock.hpp" #include "testbed.h" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/arch/barrier.h" #include "cute/arch/cluster_sm90.hpp" @@ -68,12 +68,11 @@ struct SharedStorage // Goal of this kernel is to complete deadlock-free template -__global__ static +__global__ static void pipeline_device(uint32_t const NumIterations) { extern __shared__ char shared_memory[]; - using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; @@ -86,8 +85,8 @@ void pipeline_device(uint32_t const NumIterations) dim3 block_id_in_cluster = cute::block_id_in_cluster(); auto cluster_shape = ClusterShape{}; - - // #Producers = #RowsInCluster + #ColsInCluster - 1 + + // #Producers = #RowsInCluster + #ColsInCluster - 1 uint32_t const NumProducers = cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1; uint32_t const TmaTransactionBytes = sizeof(uint32_t) * NumProducers; uint32_t const per_cta_bytes = sizeof(uint32_t); @@ -104,7 +103,7 @@ void pipeline_device(uint32_t const NumIterations) __syncthreads(); // Ensure All CTAs in Cluster have completed init before issuing commits - cute::cluster_arrive_relaxed(); + cute::cluster_arrive_relaxed(); cute::cluster_wait(); // Total number of gemm_k_iterations @@ -126,7 +125,7 @@ void pipeline_device(uint32_t const NumIterations) for(int i = 0; i < k_pipe_tma_prologue; ++i) { pipeline.producer_acquire(smem_pipe_write); // cp.async.bulk.tensor would typically happen here - pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + pipeline.producer_commit(smem_pipe_write, per_cta_bytes); ++smem_pipe_write; } tma_k_iterations -= k_pipe_tma_prologue; @@ -156,7 +155,7 @@ void pipeline_device(uint32_t const NumIterations) if (lane_predicate && (warp_idx == 0) && (tma_k_iterations > 0)) { pipeline.producer_acquire(smem_pipe_write); // cp.async.bulk.tensor would typically happen here - pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + pipeline.producer_commit(smem_pipe_write, per_cta_bytes); ++smem_pipe_write; --tma_k_iterations; } @@ -167,7 +166,7 @@ void pipeline_device(uint32_t const NumIterations) } // To make sure remote SMEM doesn't get destoryed - cute::cluster_arrive(); + cute::cluster_arrive(); cute::cluster_wait(); } ///////////////////////////////////////////////////// @@ -224,11 +223,6 @@ struct PipelineTest { } for (int iter = 0; iter < iterations; ++iter) { - - // Define the tiled MMA layout (static, 4warps) - using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmma; - using MainloopPipeline = typename cutlass::PipelineTmaAsync; - int smem_size = int(sizeof(SharedStorage)); result = cudaFuncSetAttribute( @@ -237,15 +231,15 @@ struct PipelineTest { smem_size); // Launch a single Cluster, with 128 thread per CTA - dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); - dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), 1); + dim3 dimGrid(size<0>(cluster_shape), size<1>(cluster_shape), 1); dim3 dimBlock(kBlockSize,1,1); const void* kernel = (const void*)pipeline_device; int iters = kNumIters; void* kernel_params[] = {reinterpret_cast(&iters)}; cutlass::ClusterLauncher::launch(dimGrid, dimCluster, dimBlock, smem_size, stream, kernel, kernel_params); - + } // profiling loop ends result = cudaEventRecord(events[1]); diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu index f0d6a79c..16a70a46 100644 --- a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu @@ -50,7 +50,7 @@ #include "cutlass/util/GPU_Clock.hpp" #include "testbed.h" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/arch/barrier.h" #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/barrier.h" @@ -138,7 +138,7 @@ void pipeline_device(KernelParams const kernel_params) for(int i = 0; i < tma_k_prologue; ++i) { pipeline.producer_acquire(smem_pipe_write); // Simulating cp.async.bulk.tensor behavior - pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + pipeline.producer_commit(smem_pipe_write, per_cta_bytes); ++smem_pipe_write; } int tma_k_iter = kernel_params.num_iterations - tma_k_prologue; @@ -150,7 +150,7 @@ void pipeline_device(KernelParams const kernel_params) pipeline.producer_acquire(smem_pipe_write); // Simulating cp.async.bulk.tensor behavior - pipeline.producer_commit(smem_pipe_write.index(), per_cta_bytes); + pipeline.producer_commit(smem_pipe_write, per_cta_bytes); // Advance write stage ++smem_pipe_write; diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu index 4b6a3b1d..8fa64561 100644 --- a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu @@ -50,7 +50,7 @@ #include "cutlass/util/GPU_Clock.hpp" #include "testbed.h" -#include "cutlass/pipeline.hpp" +#include "cutlass/pipeline/pipeline.hpp" #include "cutlass/arch/barrier.h" #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/barrier.h" @@ -90,7 +90,7 @@ struct CollectiveSimulation { for(int i = 0; i < tma_k_prologue; ++i) { pipeline.producer_acquire(tile_start_state_pipe); // Simulating cp.async.bulk.tensor behavior - pipeline.producer_commit(tile_start_state_pipe.index(), per_cta_bytes); + pipeline.producer_commit(tile_start_state_pipe, per_cta_bytes); ++tile_start_state_pipe; } int tma_k_iter = num_iterations - tma_k_prologue; @@ -103,7 +103,7 @@ struct CollectiveSimulation { pipeline.producer_acquire(wr_pipe); // Simulating cp.async.bulk.tensor behavior - pipeline.producer_commit(wr_pipe.index(), per_cta_bytes); + pipeline.producer_commit(wr_pipe, per_cta_bytes); // Advance write stage ++wr_pipe; @@ -198,9 +198,6 @@ __global__ static void pipeline_device(KernelParams params) { extern __shared__ char shared_memory[]; - using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; @@ -345,9 +342,6 @@ struct PipelineTest { } for (int iter = 0; iter < iterations; ++iter) { - - using MainloopPipeline = typename cutlass::PipelineTmaAsync; - constexpr int StagesPerMathWarpGroup = 2; constexpr int MathWarpGroupCountPersistent = 2; int smem_size = int(sizeof(SharedStorage +#include "cutlass/detail/dependent_false.hpp" + +namespace { // (anonymous) + +template +void test_dependent_bool_value() +{ + static_assert(cutlass::detail::dependent_bool_value == true); + static_assert(cutlass::detail::dependent_bool_value == false); +} + +template +void test_dependent_false() +{ + static_assert(cutlass::detail::dependent_false == false); +} + +template +void test_all() +{ + test_dependent_bool_value(); + test_dependent_false(); +} + +// Types to use in Args +struct Type0 {}; +struct Type1 {}; +struct Type2 {}; + +} // end namespace (anonymous) + +TEST(LibcudacxxNext, DependentBoolValue) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("dependent_bool_value"); + CUTLASS_TRACE_HOST("-------------------------------"); + + test_dependent_bool_value(); + test_dependent_bool_value(); + test_dependent_bool_value(); + test_dependent_bool_value(); +} + +TEST(LibcudacxxNext, DependentFalse) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("dependent_false"); + CUTLASS_TRACE_HOST("-------------------------------"); + + test_dependent_false(); + test_dependent_false(); + test_dependent_false(); + test_dependent_false(); +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 1471b5c3..606cf7f6 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -25,7 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +cmake_policy(SET CMP0112 NEW) add_subdirectory(util) if (CUTLASS_ENABLE_LIBRARY) diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 07447495..8745f39e 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -25,7 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +cmake_policy(SET CMP0112 NEW) include(GNUInstallDirs) find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) @@ -94,6 +94,9 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU # set cutlass generator compiler version to filter kernels in the generator not supported by a specific toolkit. set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) +# --log-level is set to DEBUG to enable printing information about which kernels were excluded +# from generation in /tools/library/scripts/manifest.py. To avoid having this information appear +# in ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log, set this parameter to INFO execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/scripts COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/scripts/generator.py @@ -112,6 +115,8 @@ execute_process( ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log ) +message(STATUS "Completed generation of library instances. See ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log for more information.") + if(NOT cutlass_lib_INSTANCE_GENERATION_RESULT EQUAL 0) message(FATAL_ERROR "Error generating library instances. See ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log") endif() diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h index 0d6790e7..a48c173e 100644 --- a/tools/library/include/cutlass/library/arch_mappings.h +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -102,6 +102,12 @@ template struct ArchMap { static int const kMax = 1024; }; +// Arch conditional WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index 8a0dfcba..93070f31 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -178,7 +178,7 @@ class Handle { int K, /// GEMM K dimension NumericTypeID element_compute, /// Data type of internal accumulation - + NumericTypeID element_scalar, /// Data type of alpha/beta scalars void const *alpha, /// Pointer to alpha scalar @@ -186,29 +186,29 @@ class Handle { NumericTypeID element_A, /// Data type of A matrix elements LayoutTypeID layout_A, /// Layout of A matrix ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices - void const * ptr_A, /// Pointer to A matrix in Global Memory - int64_t lda, /// Leading dimension of A matrix + int64_t lda, /// Leading dimension of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices - void const * ptr_B, /// Pointer to B matrix in Global Memory - int64_t ldb, /// Leading dimension of B matrix + int64_t ldb, /// Leading dimension of B matrix void const * beta, /// Pointer to beta scalar - NumericTypeID element_C, /// Data type of C and D matrices - + NumericTypeID element_C, /// Data type of C matrix + LayoutTypeID layout_C, /// Layout of D matrix void const * ptr_C, /// Pointer to C matrix - int64_t ldc, /// Leading dimension of C matrix + int64_t ldc, /// Leading dimension of C matrix + NumericTypeID element_D, /// Data type of D matrix + LayoutTypeID layout_D, /// Layout of D matrix void * ptr_D, /// Pointer to D matrix - int64_t ldd, /// Leading dimension of D matrix - + int64_t ldd, /// Leading dimension of D matrix + int batch_count = 1, /// Batch count or number of split-K slices - + int64_t batch_stride_A = 0, /// Batch stride of A operand int64_t batch_stride_B = 0, /// Batch stride of B operand int64_t batch_stride_C = 0, /// Batch stride of C operand diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index dbd70c44..387765e5 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -114,6 +114,8 @@ enum class NumericTypeID { kS16, kS32, kS64, + kFE4M3, + kFE5M2, kF16, kBF16, kTF32, @@ -474,9 +476,12 @@ struct GemmDescription : public OperationDescription { /// Describes the B operand TensorDescription B; - /// Describes the source and destination matrices + /// Describes the source matrix TensorDescription C; + /// Describes the destination matrix + TensorDescription D; + /// Describes the sparse meta matrices TensorDescription E; @@ -501,6 +506,7 @@ struct GemmDescription : public OperationDescription { TensorDescription const &A = TensorDescription(), TensorDescription const &B = TensorDescription(), TensorDescription const &C = TensorDescription(), + TensorDescription const &D = TensorDescription(), NumericTypeID element_epilogue = NumericTypeID::kInvalid, SplitKMode split_k_mode = SplitKMode::kNone, ComplexTransform transform_A = ComplexTransform::kNone, @@ -510,6 +516,7 @@ struct GemmDescription : public OperationDescription { A(A), B(B), C(C), + D(D), element_epilogue(element_epilogue), split_k_mode(split_k_mode), transform_A(transform_A), @@ -527,13 +534,14 @@ struct SparseGemmDescription : public GemmDescription { TensorDescription const &A = TensorDescription(), TensorDescription const &B = TensorDescription(), TensorDescription const &C = TensorDescription(), + TensorDescription const &D = TensorDescription(), TensorDescription const &E = TensorDescription(), NumericTypeID element_epilogue = NumericTypeID::kInvalid, SplitKMode split_k_mode = SplitKMode::kNone, ComplexTransform transform_A = ComplexTransform::kNone, ComplexTransform transform_B = ComplexTransform::kNone ): - GemmDescription(gemm_kind, A, B, C, element_epilogue, split_k_mode, transform_A, transform_B) + GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) {this->E = E;} }; @@ -1019,6 +1027,9 @@ struct GemmUniversalArguments { int64_t batch_stride_B; int64_t batch_stride_C; int64_t batch_stride_D; + + // Needed for some 3.x kernels + int sm_count; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h index 037703fa..06ea28b0 100644 --- a/tools/library/include/cutlass/library/operation_table.h +++ b/tools/library/include/cutlass/library/operation_table.h @@ -66,6 +66,9 @@ struct GemmFunctionalKey { LayoutTypeID layout_B; ComplexTransform transform_B; NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; // // Methods @@ -83,7 +86,10 @@ struct GemmFunctionalKey { NumericTypeID element_B = NumericTypeID::kF16, LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, ComplexTransform transform_B = ComplexTransform::kNone, - NumericTypeID element_C = NumericTypeID::kF16 + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor ): provider(provider), gemm_kind(gemm_kind), @@ -95,7 +101,10 @@ struct GemmFunctionalKey { element_B(element_B), layout_B(layout_B), transform_B(transform_B), - element_C(element_C) + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D) { } inline @@ -111,7 +120,10 @@ struct GemmFunctionalKey { (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && (transform_B == rhs.transform_B) && - (element_C == rhs.element_C); + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D); } inline @@ -137,6 +149,9 @@ std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey << " layout_B: " << to_string(k.layout_B) << "\n" << " transform_B: " << to_string(k.transform_B) << "\n" << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" << "}"; return out; @@ -157,18 +172,21 @@ struct GemmFunctionalKeyHasher { size_t operator()(GemmFunctionalKey const &key) const { IntHash hash; - return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.gemm_kind)), 2) ^ + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ rotl(hash(int(key.element_compute)), 3) ^ - rotl(hash(int(key.element_scalar)), 4) ^ - rotl(hash(int(key.element_A)), 5) ^ - rotl(hash(int(key.layout_A)), 6) ^ - rotl(hash(int(key.transform_A)), 7) ^ - rotl(hash(int(key.element_B)), 8) ^ - rotl(hash(int(key.layout_B)), 9) ^ - rotl(hash(int(key.transform_B)), 10) ^ - rotl(hash(int(key.element_C)), 11); + rotl(hash(int(key.element_scalar)), 4) ^ + rotl(hash(int(key.element_A)), 5) ^ + rotl(hash(int(key.layout_A)), 6) ^ + rotl(hash(int(key.transform_A)), 7) ^ + rotl(hash(int(key.element_B)), 8) ^ + rotl(hash(int(key.layout_B)), 9) ^ + rotl(hash(int(key.transform_B)), 10) ^ + rotl(hash(int(key.element_C)), 11) ^ + rotl(hash(int(key.layout_C)), 12) ^ + rotl(hash(int(key.element_D)), 13) ^ + rotl(hash(int(key.layout_D)), 14); } }; diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index e4c86a71..cb1075bb 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -23,7 +23,8 @@ class GemmOperation: # def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ - epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, + kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto): self.prefix = "3x" if gemm_kind == GemmKind.Universal3x else "" self.operation_kind = OperationKind.Gemm @@ -33,6 +34,15 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, self.A = A self.B = B self.C = C + self.D = D + if self.D == None: + self.D = self.C + + if gemm_kind != GemmKind.Universal3x: + assert(kernel_schedule == KernelScheduleType.ScheduleAuto) + assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto) + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule self.element_epilogue = element_epilogue self.epilogue_functor = epilogue_functor self.swizzling_functor = swizzling_functor @@ -122,11 +132,12 @@ def extended_name(self): def extended_name_3x(self): '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' - extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}".format( + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a = DataTypeNames[self.A.element], element_b = DataTypeNames[self.B.element], element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element], core_name = self.core_name()) return extended_name @@ -152,12 +163,20 @@ def layout_name_3x(self): ShortLayoutTypeNames[self.B.layout], ShortLayoutTypeNames[self.C.layout]) + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + return EpilogueScheduleSuffixes[self.epilogue_schedule] + # Generates the full kernel function name def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] if self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}" + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" return kernel_name_template.format( p = self.prefix, ar = self.arch, @@ -171,7 +190,9 @@ def procedural_name(self): ck = self.tile_description.cluster_shape[2], l = self.tile_description.stages, s = self.layout_name_3x(), - al = str(max(self.A.alignment, self.B.alignment))) + al = str(max(self.A.alignment, self.B.alignment)), + k = self.kernel_schedule_name_3x(), + e = self.epilogue_schedule_name_3x()) else: threadblock = self.tile_description.procedural_name() return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( @@ -604,8 +625,7 @@ def __init__(self, operation_suffix = ''): "cutlass/numeric_types.h", "cutlass/gemm/kernel/gemm_universal.hpp", "cutlass/gemm/collective/collective_builder.hpp", - "cutlass/epilogue/collective/default_epilogue.hpp", - "cutlass/epilogue/thread/linear_combination.h", + "cutlass/epilogue/collective/collective_builder.hpp", ] self.builtin_epilogue_functor_template = """ ${epilogue_functor}< @@ -617,6 +637,18 @@ def __init__(self, operation_suffix = ''): """ self.gemm_template = """ +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule} + >::CollectiveOp; + using ${operation_name}_mainloop = typename cutlass::gemm::collective::CollectiveBuilder< ${arch}, ${opcode_class}, @@ -625,18 +657,11 @@ def __init__(self, operation_suffix = ''): ${element_accumulator}, cute::Shape, cute::Shape, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::collective::StageCountAutoCarveout< + sizeof(typename ${operation_name}_epilogue::SharedStorage)>, + ${kernel_schedule} >::CollectiveOp; -using ${operation_name}_epilogue = - cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t<${layout_c}>, - cutlass::gemm::TagToStrideC_t<${layout_c}>, - cutlass::epilogue::thread::LinearCombination< - ${element_c}, ${epilogue_vector_length}, ${element_accumulator}, ${element_epilogue}> - >; - // Gemm operator ${operation_name} using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< cute::Shape, @@ -670,8 +695,8 @@ def emit(self, operation): stage_count_string = "cutlass::gemm::collective::StageCountAuto" warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - instance_layout_A, instance_layout_B, instance_layout_C = \ - (operation.A.layout, operation.B.layout, operation.C.layout) + instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ + (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout) # 3.0 profiler integration only supports trivial epilogues for now epilogue_vector_length = 1 @@ -697,6 +722,8 @@ def emit(self, operation): 'layout_b': LayoutTag[instance_layout_B], 'element_c': DataTypeTag[operation.C.element], 'layout_c': LayoutTag[instance_layout_C], + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': LayoutTag[instance_layout_D], 'element_accumulator': DataTypeTag[operation.accumulator_type()], 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], 'arch': "cutlass::arch::Sm%d" % operation.arch, @@ -712,10 +739,14 @@ def emit(self, operation): 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]), + 'epilogue_schedule' : str(EpilogueScheduleTag[operation.epilogue_schedule]), 'epilogue_functor': epilogue_functor, 'stages': stage_count_string, 'align_a': str(operation.A.alignment), 'align_b': str(operation.B.alignment), + 'align_c': str(operation.C.alignment), + 'align_d': str(operation.C.alignment), 'transform_a': ComplexTransformTag[operation.A.complex_transform], 'transform_b': ComplexTransformTag[operation.B.complex_transform], 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 77a5138b..174f33df 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -92,6 +92,7 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ # Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts def CreateGemmUniversal3xOperator( manifest, layouts, tile_descriptions, data_type, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1): @@ -99,7 +100,12 @@ def CreateGemmUniversal3xOperator( if complex_transforms is None: complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] - element_a, element_b, element_c, element_epilogue = data_type + element_a = data_type["a_type"] + element_b = data_type["b_type"] + element_c = data_type["c_type"] + element_d = data_type["d_type"] + element_acc = data_type["acc_type"] + element_epilogue = data_type.get("epi_type", element_acc) operations = [] @@ -110,18 +116,22 @@ def CreateGemmUniversal3xOperator( for layout in layouts: for tile_description in tile_descriptions: for complex_transform in complex_transforms: - A = TensorDescription( - element_a, layout[0][0], layout[0][1], complex_transform[0]) - B = TensorDescription( - element_b, layout[1][0], layout[1][1], complex_transform[1]) - C = TensorDescription(element_c, layout[2][0], layout[2][1]) + for kernel_schedule, epilogue_schedule in schedules: + A = TensorDescription( + element_a, layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + element_b, layout[1][0], layout[1][1], complex_transform[1]) + + C = TensorDescription(element_c, layout[2][0], layout[2][1]) + D = TensorDescription(element_d, layout[2][0], layout[2][1]) - operation = GemmOperation( - GemmKind.Universal3x, tile_description.minimum_compute_capability, - tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + operation = GemmOperation( + GemmKind.Universal3x, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule) - manifest.append(operation) - operations.append(operation) + manifest.append(operation) + operations.append(operation) return operations @@ -136,7 +146,7 @@ def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ element_a, element_b, element_c, element_epilogue = data_type gemm_kinds = [GemmKind.Sparse] - + operations = [] # by default, only generate the largest tile and largest alignment @@ -148,9 +158,9 @@ def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ for tile_description in tile_descriptions: for alignment in alignment_constraints: for complex_transform in complex_transforms: - + alignment_c = min(8, alignment) - + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) C = TensorDescription(element_c, layout[2], alignment_c) @@ -173,12 +183,12 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t element_a, element_b, element_c, element_epilogue = data_type gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] - + # by default, only generate the largest tile and largest alignment if manifest.kernel_filter == '': tile_descriptions = [tile_descriptions[0],] alignment_constraints = [alignment_constraints[0],] - + for gemm_kind in gemm_kinds: for layout in layouts: for tile_description in tile_descriptions: @@ -238,7 +248,7 @@ def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_t swizzling_functor = SwizzlingFunctor.Identity8): element_a, element_c, element_epilogue = data_type - + operations = [] # by default, only generate the largest tile and largest alignment @@ -250,7 +260,7 @@ def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_t for fill_mode in fill_modes: for tile_description in tile_descriptions: for alignment in alignment_constraints: - + # SERK supported layouts (RowMajor, ColumnMajor) with no conjugation complex_transform = ComplexTransform.none @@ -259,7 +269,7 @@ def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_t complex_transform = ComplexTransform.conj alignment_c = 1 # Alignment only applies to A in SYRK - + A = TensorDescription(element_a, layout[0], alignment, complex_transform) C = SymmetricTensorDescription(element_c, layout[1], fill_mode, alignment_c) @@ -269,7 +279,7 @@ def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_t manifest.append(new_operation) operations.append(new_operation) - + # Rank-2K update new_operation = Rank2KOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) @@ -288,7 +298,7 @@ def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, ti complex_transforms = [(ComplexTransform.none),] element_a, element_b, element_c, element_epilogue = data_type - + operations = [] # by default, only generate the largest tile and largest alignment @@ -303,9 +313,9 @@ def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, ti for tile_description in tile_descriptions: for alignment in alignment_constraints: for complex_transform in complex_transforms: - + alignment_c = min(8, alignment) - + A = TriangularTensorDescription(element_a, layout[0], side_mode, fill_mode, diag_type, alignment, complex_transform) B = TensorDescription(element_b, layout[1], alignment) @@ -325,7 +335,7 @@ def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descripti swizzling_functor = SwizzlingFunctor.Identity8): element_a, element_b, element_c, element_epilogue = data_type - + operations = [] # by default, only generate the largest tile and largest alignment @@ -338,13 +348,13 @@ def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descripti for fill_mode in fill_modes: for tile_description in tile_descriptions: for alignment in alignment_constraints: - + # SYMM supported layouts (RowMajor, ColumnMajor) with no conjugation complex_transform = ComplexTransform.none alignment_a = 1 # No vectorized access for the triangular matrix alignment_c = min(8, alignment) - + A = SymmetricTensorDescription(element_a, layout[0], fill_mode, alignment_a, complex_transform, side_mode) # tensor A and B have same data type and layout B = TensorDescription(element_b, layout[0], alignment) @@ -356,7 +366,7 @@ def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descripti manifest.append(new_operation) operations.append(new_operation) - + # SYMM/HEMM update new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) @@ -382,11 +392,11 @@ def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descripti def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - + element_a, element_b, element_c, element_epilogue = data_type - + # one exceptional case - + # iterator algorithm (analytic and optimized) iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] @@ -406,14 +416,14 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme A = TensorDescription(element_a, layout[0], alignment) B = TensorDescription(element_b, layout[1], alignment) C = TensorDescription(element_c, layout[2], alignment_c) - + swizzling_functor_ = swizzling_functor - + # # Conv2d Fprop # if ConvKind.Fprop in conv_kinds: - + # Strided support for Analytic and Optimized Fprop for iterator_algorithm in iterator_algorithms: new_operations = [ @@ -437,51 +447,51 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme for new_operation in new_operations: manifest.append(new_operation) operations.append(new_operation) - + # # Conv2d Dgrad # if ConvKind.Dgrad in conv_kinds: - + # Unity stride for Analytic and Optimized Dgrad for iterator_algorithm in iterator_algorithms: new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) - + manifest.append(new_operation) operations.append(new_operation) - + # Strided support for Analytic Dgrad # strided dgrad uses a special threadblock swizzle - # note that SwizzlingFunctor.StridedDgradHorizontal might be + # note that SwizzlingFunctor.StridedDgradHorizontal might be # better for problem sizes with large activation channel count swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 - + if IteratorAlgorithm.Analytic in iterator_algorithms: new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) - + manifest.append(new_operation) operations.append(new_operation) - + # Strided support for Optimized Dgrad if IteratorAlgorithm.Optimized in iterator_algorithms: new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) - + manifest.append(new_operation) operations.append(new_operation) - + # # Conv2d Wgrad # if ConvKind.Wgrad in conv_kinds: - + # Strided support for Analytic and Optimized Wgrad for iterator_algorithm in iterator_algorithms: new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - + manifest.append(new_operation) operations.append(new_operation) @@ -582,12 +592,12 @@ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_ty # Convolution for 3D operations def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): - + element_a, element_b, element_c, element_epilogue = data_type - + # one exceptional case alignment_c = min(8, alignment) - + # iterator algorithm (analytic and optimized) iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] @@ -603,7 +613,7 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme A = TensorDescription(element_a, layout, alignment) B = TensorDescription(element_b, layout, alignment) C = TensorDescription(element_c, layout, alignment_c) - + # # Conv3d Fprop # @@ -618,7 +628,7 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme # Conv3d Wgrad # if ConvKind.Wgrad in conv_kinds: - + # Strided support for Analytic and Optimized Wgrad for iterator_algorithm in iterator_algorithms: new_operation = Conv3dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ @@ -628,11 +638,11 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme # All tile sizes for Conv3dDgrad for tile in tile_descriptions: - + A = TensorDescription(element_a, layout, alignment) B = TensorDescription(element_b, layout, alignment) C = TensorDescription(element_c, layout, alignment_c) - + # # Conv3d Dgrad # @@ -640,15 +650,15 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme # Unity stride for Optimized Dgrad new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) - + manifest.append(new_operation) operations.append(new_operation) - - # Strided support for Analytic Dgrad - # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs + + # Strided support for Analytic Dgrad + # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) - + manifest.append(new_operation) operations.append(new_operation) @@ -658,9 +668,9 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): - + element_a, element_b, element_c, element_epilogue = data_type - + # iterator algorithm (FixedStrideDilation, Optimized) iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized] @@ -679,11 +689,11 @@ def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type A = TensorDescription(element_a, layout[0], alignment) B = TensorDescription(element_b, layout[1], alignment) C = TensorDescription(element_c, layout[2], alignment_c) - + swizzling_functor_ = swizzling_functor if ConvKind.Fprop in conv_kinds: - + # Strided support for Optimized and FixedStridedDilation Depthwise Conv for iterator_algorithm in iterator_algorithms: stride_support = StrideSupport.Strided @@ -694,16 +704,16 @@ def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type if iterator_algorithm == IteratorAlgorithm.Optimized: if tile.stride != [-1, -1] or tile.dilation != [-1,-1]: - continue - new_operation = Conv2dOperation(ConvKind.Fprop, - iterator_algorithm, - tile.minimum_compute_capability, + continue + new_operation = Conv2dOperation(ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, tile, - A, B, C, - element_epilogue, - stride_support, - epilogue_functor, - swizzling_functor_, + A, B, C, + element_epilogue, + stride_support, + epilogue_functor, + swizzling_functor_, group_mode=GroupMode.Depthwise) manifest.append(new_operation) @@ -757,7 +767,7 @@ def GenerateSM50_Simt(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -805,7 +815,7 @@ def GenerateSM50_Simt_complex(manifest, cuda_version): DataType.cf32, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -861,7 +871,7 @@ def GenerateSM60_Simt(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) # @@ -874,7 +884,7 @@ def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): OpcodeClass.Simt, \ MathOperation.multiply_add), ] - + min_cc = 60 max_cc = 1024 @@ -904,18 +914,18 @@ def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): for math_inst in math_instructions: for stride, dilation in product(strides, dilations): tile_descriptions.extend([ - # filter3x3 ThreadBlock_output, filter, stage, warp + # filter3x3 ThreadBlock_output, filter, stage, warp Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), - + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), - # filter5x5 ThreadBlock_output, filter, stage, warp + # filter5x5 ThreadBlock_output, filter, stage, warp Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), @@ -990,7 +1000,7 @@ def GenerateSM61_Simt(manifest, cuda_version): math_inst.element_a, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -1054,7 +1064,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -1073,7 +1083,7 @@ def GenerateSM70_TensorOp_884(manifest, cuda_version): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints) - + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) # @@ -1125,7 +1135,7 @@ def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, complex_transforms) @@ -1141,7 +1151,7 @@ def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version): CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, complex_transforms) - + # def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version): @@ -1185,7 +1195,7 @@ def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -1308,7 +1318,7 @@ def GenerateSM75_TensorOp_1688(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -1387,7 +1397,7 @@ def GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, complex_transforms) @@ -1439,24 +1449,25 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 256, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), ] data_type = [ @@ -1465,7 +1476,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): math_inst.element_accumulator, DataType.s32, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) @@ -1487,7 +1498,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) @@ -1551,7 +1562,7 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version): math_inst.element_a, DataType.f32, ] - + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) @@ -1609,10 +1620,10 @@ def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version): math_inst.element_accumulator, DataType.s32, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) @@ -1631,7 +1642,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version): operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) @@ -1696,12 +1707,12 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version): data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) for op in operations: - op.C.alignment = 16 + op.C.alignment = 16 # # @@ -1722,7 +1733,7 @@ def GenerateSM75_TensorOp_88128(manifest, cuda_version): MathOperation.xor_popc), ] - min_cc = 75 + min_cc = 75 max_cc = 1024 alignment_constraints = [128,] @@ -1785,7 +1796,7 @@ def GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version): math_inst.element_accumulator, DataType.f32, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -1923,7 +1934,7 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -2011,7 +2022,7 @@ def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -2086,7 +2097,7 @@ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, complex_transforms) @@ -2152,7 +2163,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] @@ -2162,12 +2173,12 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - + operations = [] operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) @@ -2176,7 +2187,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + for op in operations: if op.tile_description.threadblock_shape[1] >= 128: op.C.alignment = 16 @@ -2279,17 +2290,17 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version): TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), ] - + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + for op in operations: op.C.alignment = 8 # @@ -2341,25 +2352,25 @@ def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) - + operations = [] - + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + for op in operations: if op.tile_description.threadblock_shape[1] >= 128: op.C.alignment = 16 @@ -2459,21 +2470,21 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), ] - + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] - + operations = [] - + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) - + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) - + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) - + for op in operations: - op.C.alignment = 16 + op.C.alignment = 16 # # @@ -2495,7 +2506,7 @@ def GenerateSM80_TensorOp_168256(manifest, cuda_version): ] min_cc = 80 - max_cc = { + max_cc = { MathOperation.xor_popc: 1024 } @@ -2562,7 +2573,7 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version): TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2647,7 +2658,7 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2883,7 +2894,7 @@ def GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version): FillMode.Lower, FillMode.Upper, ] - math_instructions = [ + math_instructions = [ MathInstruction( \ [16, 8, 8], \ DataType.tf32, DataType.tf32, DataType.f32, \ @@ -2942,7 +2953,7 @@ def GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version): FillMode.Lower, FillMode.Upper, ] - math_instructions = [ + math_instructions = [ MathInstruction( \ [16, 8, 8], \ DataType.tf32, DataType.tf32, DataType.f32, \ @@ -3005,7 +3016,7 @@ def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version): DiagType.NonUnit, DiagType.Unit, ] - math_instructions = [ + math_instructions = [ MathInstruction( \ [16, 8, 8], \ DataType.tf32, DataType.tf32, DataType.f32, \ @@ -3021,7 +3032,7 @@ def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version): min_cc = 80 max_cc = 1024 - alignment_constraints = [1, 2, 4] + alignment_constraints = [1, 2, 4] for math_inst in math_instructions: tile_descriptions = [ @@ -3072,7 +3083,7 @@ def GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version): DiagType.NonUnit, DiagType.Unit, ] - math_instructions = [ + math_instructions = [ MathInstruction( \ [16, 8, 8], \ DataType.tf32, DataType.tf32, DataType.f32, \ @@ -3130,7 +3141,7 @@ def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version): FillMode.Lower, FillMode.Upper, ] - math_instructions = [ + math_instructions = [ MathInstruction( \ [16, 8, 8], \ DataType.tf32, DataType.tf32, DataType.f32, \ @@ -3148,7 +3159,7 @@ def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version): alignment_constraints = [ 1, 2, 4 - ] + ] for math_inst in math_instructions: tile_descriptions = [ @@ -3194,7 +3205,7 @@ def GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version): FillMode.Lower, FillMode.Upper, ] - math_instructions = [ + math_instructions = [ MathInstruction( \ [16, 8, 8], \ DataType.tf32, DataType.tf32, DataType.f32, \ @@ -3396,7 +3407,7 @@ def GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version): ] fill_modes = [ - FillMode.Lower, FillMode.Upper, + FillMode.Lower, FillMode.Upper, ] math_inst = \ @@ -3696,7 +3707,7 @@ def GenerateSM80_TensorOp_884_symm(manifest, cuda_version): ] fill_modes = [ - FillMode.Lower, FillMode.Upper, + FillMode.Lower, FillMode.Upper, ] math_inst = \ @@ -3878,7 +3889,7 @@ def GenerateSM80_Simt_f32(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) @@ -3925,7 +3936,7 @@ def GenerateSM80_Simt_f64(manifest, cuda_version): math_inst.element_accumulator, math_inst.element_accumulator, ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) # @@ -3953,7 +3964,7 @@ def GenerateSM80_Simt_complex(manifest, cuda_version): DataType.cf32, DataType.cf32 ] - + layouts = [ (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), @@ -3980,7 +3991,7 @@ def GenerateSM80_Simt_complex(manifest, cuda_version): TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - + CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms) conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) @@ -3998,7 +4009,7 @@ def GenerateSM80(manifest, cuda_version): GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version) GenerateSM80_TensorOp_1688_complex(manifest, cuda_version) - # 3xTF32 + # 3xTF32 GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version) GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version) @@ -4068,41 +4079,93 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + #TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - Not compatible with TmaWarpSpecializedCooperative + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + #TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),- Not compatible with TmaWarpSpecializedCooperative + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + #TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - Not compatible with TmaWarpSpecializedCooperative ] - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + # Set alignment c based on Destination format. + for layout in layouts: + if data_type["c_type"] in [DataType.s32, DataType.f32]: + layout[2][1] = 4 + elif data_type["c_type"] in [DataType.f16, DataType.bf16]: + layout[2][1] = 8 + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + kernel_schedules = [ + KernelScheduleType.ScheduleAuto, + KernelScheduleType.TmaWarpSpecializedCooperative, + KernelScheduleType.TmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecialized + ] + else: + kernel_schedules = [ + KernelScheduleType.ScheduleAuto, + KernelScheduleType.TmaWarpSpecialized + # TmaWarpSpecializedCooperative and TmaWarpSpecializedPingpong require CUDA version >= 12.1 for optimal performance. + ] + + schedules = [[s, EpilogueScheduleType.ScheduleAuto] for s in kernel_schedules] - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules) + + # persistent kernels with TMA epilogues + if data_type["c_type"] in [DataType.f16, DataType.bf16] and CudaToolkitVersionSatisfies(cuda_version, 12, 1): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) # for mixed precision kernels, also generate kernels that write output matrix in the A/B format # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] + data_type_mixed = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + # Set alignment c based on Destination format. + for layout in layouts: + if data_type_mixed["c_type"] in [DataType.s32, DataType.f32]: + layout[2][1] = 4 + elif data_type_mixed["c_type"] in [DataType.f16, DataType.bf16]: + layout[2][1] = 8 CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed) - + # persistent kernels with TMA epilogues + if data_type_mixed["c_type"] in [DataType.f16, DataType.bf16] and CudaToolkitVersionSatisfies(cuda_version, 12, 1): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) # def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): @@ -4111,10 +4174,10 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): # layouts for ABC and their alignments layouts_tf32 = [ - [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], - [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], ] math_inst = MathInstruction( @@ -4127,46 +4190,62 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): max_cc = 90 tile_descriptions = [ - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] - data_type_tf32 = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGemmUniversal3xOperator(manifest, layouts_tf32, tile_descriptions, data_type_tf32) - - # F32 kernel, TN only supported for now - layouts_f32 = [layouts_tf32[2]] - - data_type_f32 = [ - DataType.f32, - DataType.f32, - math_inst.element_accumulator, - DataType.f32, - ] - - CreateGemmUniversal3xOperator(manifest, layouts_f32, tile_descriptions, data_type_f32) + data_type_tf32 = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + # TMA kernels with TN or NN layout + layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]] + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_tf32) + + # TMA kernels with NT layout, only support 64x128x32 tile for now. + layouts_tf32_nt = [layouts_tf32[3]] + tile_64x128x32_descriptions = [tile_descriptions[0], tile_descriptions[1], tile_descriptions[2]] + CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_tf32) + + # TMA kernels with TT layout use EpilogueTransposed, because swapping NN kernel and transposed its epilogue will get the kernel + layouts_tf32_tt = [layouts_tf32[1]] + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_tf32, + [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.EpilogueTransposed]]) + + # F32 kernel share same settings with tf32 I/O kernels excluding data type + data_type_f32 = { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_type_f32) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_64x128x32_descriptions, data_type_f32) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_type_f32, + [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.EpilogueTransposed]]) +# def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - + # layouts for ABC and their alignments layouts = [ [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], @@ -4190,28 +4269,41 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([128, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), - TileDescription([ 64, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.s8, + "d_type" : DataType.s8, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } ] - CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type) # def GenerateSM90_TensorOp_1684(manifest, cuda_version): @@ -4376,7 +4468,7 @@ def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): ] fill_modes = [ - FillMode.Lower, FillMode.Upper, + FillMode.Lower, FillMode.Upper, ] math_inst = \ @@ -4676,7 +4768,7 @@ def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): ] fill_modes = [ - FillMode.Lower, FillMode.Upper, + FillMode.Lower, FillMode.Upper, ] math_inst = \ @@ -4835,14 +4927,32 @@ def GenerateSM90(manifest, cuda_version): ################################################################################################### -if __name__ == "__main__": +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level into the numeric identifier used + in setting the log level + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {log_level}') + return numeric_level + + +# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface +# to leverage the functionality in this file without running this script via a shell prompt. +def define_parser(): parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") - parser.add_argument("--architectures", default='53;60;61;70;75;80', help="Target compute architectures") + parser.add_argument("--architectures", default='53;60;61;70;75;80;90', help="Target compute architectures") parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.') parser.add_argument("--ignore-kernels", default='', help='Comma delimited list of kernels to exclude from build.') parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') @@ -4852,26 +4962,13 @@ def GenerateSM90(manifest, cuda_version): help='Specify the output log file containing all enabled kernels in this build') parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") - - def numeric_log_level(log_level: str) -> int: - """ - Converts the string identifier of the log level into the numeric identifier used - in setting the log level - - :param x: string representation of log level (e.g., 'INFO', 'DEBUG') - :type x: str - - :return: numeric representation of log level - :rtype: int - """ - numeric_level = getattr(logging, log_level.upper(), None) - if not isinstance(numeric_level, int): - raise ValueError(f'Invalid log level: {log_level}') - return numeric_level - parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, help='Logging level to be used by the generator script') + return parser + +if __name__ == "__main__": + parser = define_parser() args = parser.parse_args() # Set the logging level based on the user-provided `--log-level` command-line option @@ -4886,7 +4983,6 @@ def numeric_log_level(log_level: str) -> int: GenerateSM75(manifest, args.cuda_version) GenerateSM80(manifest, args.cuda_version) GenerateSM90(manifest, args.cuda_version) - if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index 6919479e..b12de786 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -361,6 +361,58 @@ class LayoutType(enum.Enum): (LayoutType.RowMajor, ComplexTransform.conj): 'h' } +################################################################################################### +class KernelScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + Multistage = enum_auto() + Tma = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() +# +KernelScheduleTag = { + KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', + KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', + KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', + KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', +} + +# +KernelScheduleSuffixes = { + KernelScheduleType.ScheduleAuto: '', + KernelScheduleType.Multistage: '_cpasync', + KernelScheduleType.Tma: '_unspecialized', + KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', +} + +class EpilogueScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + EpilogueTransposed = enum_auto() + NoSmemWarpSpecialized = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() +# +EpilogueScheduleTag = { + EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', + EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', + EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', + EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', + EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', +} + +# +EpilogueScheduleSuffixes = { + EpilogueScheduleType.ScheduleAuto: '', + EpilogueScheduleType.EpilogueTransposed: '', + EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', +} + ################################################################################################### # diff --git a/tools/library/scripts/pycutlass/README.md b/tools/library/scripts/pycutlass/README.md deleted file mode 100644 index 8d0dbaef..00000000 --- a/tools/library/scripts/pycutlass/README.md +++ /dev/null @@ -1,143 +0,0 @@ -# PyCUTLASS: CUTLASS Python Interface - -PyCUTLASS is a python interface of CUTLASS C++ template library. PyCUTLASS takes user-defined operation descriptions, emits C++ code, and compiles it with `nvcc` or `nvrtc`. It also provides wrappers for user-provide arguments from [numpy](https://numpy.org/), [torch](https://pytorch.org/), and [cupy](https://github.com/cupy/cupy) and encode them to kernel's parameters. - -```python -import pycutlass -from pycutlass import * -import torch - -pycutlass.get_memory_pool(2**8, 2**32) - -math_inst = MathInstruction( - [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, - cutlass.OpClass.Simt, MathOperation.multiply_add -) - -tile_description = TileDescription( - [128, 128, 8], 4, [2, 4, 1], - math_inst -) - -A = TensorDescription( - cutlass.float32, cutlass.RowMajor, 1 -) - -B = TensorDescription( - cutlass.float32, cutlass.RowMajor, 1 -) - -C = TensorDescription( - cutlass.float32, cutlass.RowMajor, 1 -) - -epilogue_functor = LinearCombination(cutlass.float32, 1, cutlass.float32, cutlass.float32) - -operation = GemmOperationUniversal( - arch=80, tile_description=tile_description, - A=A, B=B, C=C, - epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 -) - -pycutlass.compiler.add_module([operation,]) - -problem_size = cutlass.gemm.GemmCoord(512, 256, 128) - -tensor_A = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.k()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) -tensor_B = torch.ceil(torch.empty(size=(problem_size.k(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) -tensor_C = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) -tensor_D = torch.empty_like(tensor_C) - - -alpha = 1.0 -beta = 0.0 - -arguments = GemmArguments( - operation=operation, problem_size=problem_size, - A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=operation.epilogue_type(alpha, beta), - gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 -) - -operation.run(arguments) - -arguments.sync() - -tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C - -assert torch.equal(tensor_D, tensor_D_ref) -``` -PyCUTLASS also provides infrastructures for profiling, compiled artifact management, and pool memory manager - -## Supported Features -PyCUTLASS currently supports following operations: -* GEMM with mode {Serial, Parallel Split K, Batched GEMM, Array GEMM}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor, Row/ColumnMajorInterleaved<32> for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, swizzling functions {IdentitySwizzle<1,2,4,8>, HorizontalSwizzle, BatchedIdentitySwizzle}, and epilogue {LinearCombination, LinearCombinationClamp} -* GEMM grouped with op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {RowMajor, ColumnMajor}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, scheduling mode {Host, Device}, and epilogue {LinearCombination, LinearCombinationClamp}. -* Conv2d with {Fprop, Dgrad, Wgrad}, op class {SIMT, TensorCore}, data type {int8, f16, bf16, f32, f64}, layout {Tensor NHWC, TensorNC32HW32 and TensorC32RSK for int8}, math operation {MultiplyAdd, MultiplyAddFastF16, MultiplyAddFastBF16, MultiplyAddFastF32}, split-k mode {Parallel, Serial}, and epilogue {LinearCombination, LinearCombinationClamp} - -The tiling size of above operations can also be customized. - -## Installation - -### Using Docker -We recommend using one of our provided Docker images for using PyCUTLASS. - -**To run CUTLASS 3 GEMM kernels targeting the NVIDIA Hopper architecture via PyCUTLASS,** you can use an included [Dockerfile](docker/Dockerfile-cuda12.0) based on the NGC CUDA 12.0 container: -```shell -docker build -t pycutlass-cuda12.0:latest -f docker/Dockerfile-cuda12.0 . -docker run --gpus all -it --rm pycutlass-cuda12.0:latest -``` -Note that this Docker container does not include CuPy or PyTorch, and, thus, will not be able to run PyCUTLASS examples that -leverage these packages. - -**To run CUTLASS 2.x kernels targeting pre-SM90 architectures via PyCUTLASS,** you can use an included [Dockerfile](docker/Dockerfile-cuda11.8-pytorch) based on an NGC PyTorch container: -```shell -docker build -t pycutlass-cuda11.8-pytorch:latest -f docker/Dockerfile-cuda11.8-pytorch . -docker run --gpus all -it --rm pycutlass-cuda11.8-pytorch:latest -``` - -### Environment variables -PyCUTLASS requires two environment variables: -* `CUTLASS_PATH`: the root directory of CUTLASS. You can set this from the location at which you cloned CUTLASS via: `export CUTLASS_PATH=$(pwd)`. -* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed. If running in bash with `nvcc` installed under a CUDA toolkit, you can set this to the location of your `nvcc` installation via: `export CUDA_INSTALL_PATH=$(which nvcc | awk -F'/bin/nvcc' '{print $1}')` - -After setting these two environment variables, PyCUTLASS can be installed with -```shell -cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh -``` - -## Examples -Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutlass_py) - -## Test -The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with -```shell -# Each of these tests are only supported on devices with compute capability of SM80. For other devices, -# see the basic examples in $CUTLASS_PATH/examples/40_cutlass_py -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && bash run_all_example.sh -``` - -## build documentation -Run -```shell -bash build_doc.sh -``` - - -## Troubleshooting - -### Issue 1: permission denied -Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission. - -### Issue 2: rmm: module not found -PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps: -```shell -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm - -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python -python setup.py build_ext --inplace -python setup.py install -``` -To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues. diff --git a/tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 b/tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 deleted file mode 100644 index f81d79d0..00000000 --- a/tools/library/scripts/pycutlass/docker/Dockerfile-cuda12.0 +++ /dev/null @@ -1,46 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -FROM nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu20.04 - -RUN apt-get update -RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata -RUN apt-get install -y git cmake vim python3 python3-pip -RUN ln -s /usr/bin/python3 /usr/bin/python -RUN chmod ugo+rwx /home -RUN pip install numpy==1.23 -RUN pip install cudf-cu11 dask-cudf-cu11 --extra-index-url=https://pypi.ngc.nvidia.com -RUN pip install cuml-cu11 --extra-index-url=https://pypi.ngc.nvidia.com -RUN pip install cugraph-cu11 --extra-index-url=https://pypi.ngc.nvidia.com -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH -ENV LIBRARY_PATH=/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu/:$LIBRARY_PATH -ENV CUDA_INSTALL_PATH=/usr/local/cuda diff --git a/tools/library/scripts/pycutlass/docs/source/conf.py b/tools/library/scripts/pycutlass/docs/source/conf.py deleted file mode 100644 index 25c601b0..00000000 --- a/tools/library/scripts/pycutlass/docs/source/conf.py +++ /dev/null @@ -1,96 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - - -# -- Project information ----------------------------------------------------- - -project = 'PyCutlass' -copyright = '2022, Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall' -author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall' - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.duration', - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'enum_tools.autoenum', - 'sphinx.ext.autosummary', - 'm2r2' -] - -source_suffix = [".rst", ".md"] - -autosummary_generate = True -autosummary_imported_members = True - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'bizstyle' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] diff --git a/tools/library/scripts/pycutlass/docs/source/conv2d_op.rst b/tools/library/scripts/pycutlass/docs/source/conv2d_op.rst deleted file mode 100644 index 7ce0510d..00000000 --- a/tools/library/scripts/pycutlass/docs/source/conv2d_op.rst +++ /dev/null @@ -1,13 +0,0 @@ -CONV2D Operation -================ - -.. autoclass:: pycutlass.Conv2dOperation - :special-members: - :members: run - :exclude-members: __weakref__, configuration_name, core_name, extended_name, procedural_name - -.. autoclass:: pycutlass.Conv2dArguments - :special-members: - :members: - :exclude-members: initialize - :show-inheritance: diff --git a/tools/library/scripts/pycutlass/docs/source/cutlass.rst b/tools/library/scripts/pycutlass/docs/source/cutlass.rst deleted file mode 100644 index 43c13e5e..00000000 --- a/tools/library/scripts/pycutlass/docs/source/cutlass.rst +++ /dev/null @@ -1,100 +0,0 @@ -cutlass -======= - -.. rubric:: Operator Classification - -.. autoclass:: cutlass.OpClass - :members: - -.. rubric:: GEMM Layout - -.. autoclass:: cutlass.RowMajor - :members: - -.. autoclass:: cutlass.ColumnMajor - :members: - -.. autoclass:: cutlass.RowMajorInterleaved32 - :members: - -.. autoclass:: cutlass.ColumnMajorInterleaved32 - :members: - -.. rubric:: Conv Layout - -.. autoclass:: cutlass.TensorNHWC - :members: - -.. autoclass:: cutlass.TensorNC32HW32 - :members: - -.. autoclass:: cutlass.TensorC32RSK32 - :members: - -.. rubric:: Threadblock Swizzle - -.. autoclass:: cutlass.dim3 - :special-members: - :members: - -.. autoclass:: cutlass.IdentitySwizzle1 - :special-members: - :members: - -.. autoclass:: cutlass.IdentitySwizzle2 - :special-members: - :members: - -.. autoclass:: cutlass.IdentitySwizzle4 - :special-members: - :members: - -.. autoclass:: cutlass.IdentitySwizzle8 - :special-members: - :members: - -.. autoclass:: cutlass.HorizontalSwizzle - :special-members: - :members: - -.. autoclass:: cutlass.BatchedIdentitySwizzle - :special-members: - :members: - -.. autoclass:: cutlass.StridedDgradIdentitySwizzle1 - :special-members: - :members: - -.. autoclass:: cutlass.StridedDgradIdentitySwizzle4 - :special-members: - :members: - -.. autoclass:: cutlass.StridedDgradHorizontalSwizzle - :special-members: - :members: - -.. rubric:: Coordinates - -.. autoclass:: cutlass.Tensor4DCoord - :special-members: - :members: - -.. autoclass:: cutlass.Tensor3DCoord - :special-members: - :members: - -.. autoclass:: cutlass.MatrixCoord - :special-members: - :members: - - -.. rubric:: Convolution - -.. autoclass:: cutlass.conv.Operator - :members: - -.. autoclass:: cutlass.conv.IteratorAlgorithm - :members: - -.. autoclass:: cutlass.conv.StrideSupport - :members: diff --git a/tools/library/scripts/pycutlass/docs/source/gemm_op.rst b/tools/library/scripts/pycutlass/docs/source/gemm_op.rst deleted file mode 100644 index e4bcd8b4..00000000 --- a/tools/library/scripts/pycutlass/docs/source/gemm_op.rst +++ /dev/null @@ -1,18 +0,0 @@ -GEMM Operation -============== - -.. autoclass:: pycutlass.GemmOperationUniversal - :special-members: - :members: - -.. autoclass:: pycutlass.GemmOperationGrouped - :special-members: - :members: - -.. autoclass:: pycutlass.GemmArguments - :special-members: - :members: - -.. autoclass:: pycutlass.GemmGroupedArguments - :special-members: - :members: diff --git a/tools/library/scripts/pycutlass/docs/source/index.rst b/tools/library/scripts/pycutlass/docs/source/index.rst deleted file mode 100644 index b8a16e16..00000000 --- a/tools/library/scripts/pycutlass/docs/source/index.rst +++ /dev/null @@ -1,31 +0,0 @@ -.. PyCutlass documentation master file, created by - sphinx-quickstart on Sun Jun 19 12:05:42 2022. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -CUTLASS Python Project Documentation -===================================== -.. mdinclude:: ../../README.md - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - - -.. Indices and tables -.. ================== - -.. * :ref:`genindex` -.. * :ref:`modindex` -.. * :ref:`search` - - -Indices -================== -.. toctree:: - user_guide - visitor_tree - gemm_op - conv2d_op - cutlass diff --git a/tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md b/tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md deleted file mode 100644 index 7cda6873..00000000 --- a/tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md +++ /dev/null @@ -1,225 +0,0 @@ -# Epilogue Visitor Tree -The Epilogue Visitor Tree is an experimental feature that directly generates epilogues from user-provide python functions. - -## Usage - -The Epilogue Visitor tree support many different operations. - -### Unary functions -Epilogue Visitor Tree supports unary functions like activation functions. For example, -```python -class UnaryEpilogue_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - T = leaky_relu.numpy(accum, 0.2) - Z = alpha * T + beta * c - return Z -epilogue_functor = UnaryEpilogue_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) -``` - -### Broadcast Operation -Epilogue Visitor Tree supports broadcasting row and column vectors to the whole output matrix. To use broadcast, you just need to specify whether the source vector is a `row` vector or a `column` vector. Here is an example. -```python -class ColumnBroadcast_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - vector: 'column', alpha: 'scalar', beta: 'scalar'): - # - T = accum + vector - scale_T = leaky_relu.numpy(alpha * T, 0.2) - Z = scale_T + beta * c - return Z, T -epilogue_functor = ColumnBroadcast_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) -``` - -### Reduction Operation - -Epilogue Visitor Tree also supports row and column-wise reduction in each threadblock tile. The syntax for reduction is -```python -{reduction_output} = reduction_op({input_tensor}, {row|column}, {Add}, {threadblock_shape.n|threadblock_shape.m}) -``` -The `{row|column}` indicates whether the `row` vectors are reduced or the `column` vectors are reduction. The `{Add}` specifies the reduction operation. The `{threadblock_shape.n|threadblock_shape.m}` are the reduction lengths. - -**Constraint** -* The `{input_tensor}` can only be the name of source or intermediate result. `reduction_op(A + B, ...)` will not work, please use `C = A + B`, `reduction_op(C, ...)` instead. -* The `{reduction_output}` cannot be used in the epilogue. It will be directly written to global memory after the reduction is done. -```python -class RowReduction_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - D = alpha * accum + tanh.numpy(beta * c) - reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) - return D, reduction -epilogue_functor = RowReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) -epilogue_functor.initialize() -``` - -## Get output_op - -As shown in the user guide, an `output_op` is required by the argument wrapper. We will take the `RowReduction_` as an example to show how to get `output_op`. -```python -class RowReduction_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - D = alpha * accum + tanh.numpy(beta * c) - reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) - return D, reduction -epilogue_functor = RowReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) -epilogue_functor.initialize() - -cta_n = args.threadblock_shape[1] -num_cta_n = (problem_size.n() + cta_n - 1) // cta_n -reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, element_c)) -# get output op -output_op = operation.epilogue_type( - D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] -) -``` -Like other epilogue functors such as `LinearCombination`, the output op for EpilogueVisitorTree is also created with `operation.epilogue_type(*)`. However, there are two differences: -* The arguments need to be passed as keyword-arguments. The keywords are the argument names in `def __call__`. -* An additional `problem_size=[problem_size.m(), problem_size.n()]` is required. - - -## Add new Unary Operation (e.g. Activation Function) -To add additional unary operation into epilogue visitor tree, a new unary op -should be created for `VisitorOpUnary`. We will take `tanh` as an example. - -### Step 1: define TanhVisitor - -The visitor defines the parameters and computation required by the unary option. -The unary operations are registered in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/unary_ops.h). But you can define it in any header file and include the header file in [pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h](tools/library/scripts/pycutlass/src/cpp/include/epilogue/epilogue_visitor_op/visitor_op_unary.h). - - -* Two template arguments are required: - * `T`: data type used to compute the unary operation - * `N`: compute fragment length -* We also need to provide the `Arguments` and `Params` structures. The `Arguments` will be assembled by [ctypes](https://docs.python.org/3/library/ctypes.html), the `Params` will be generated from `Arguments` automatically. If the unary function takes no argument, an integer like `int tmp` can be provide to ensure the correctness of ctypes. -* The constructor can only take the `params` as the single argument. -* The operation is defined in `Array operator()(Array const &frag) const `. On common way to do that is first define a scalar computation, and them use it for the fragment computation with an unrolled for-loop. -* A guard function is required. If it returns `true`, it will disable all the children nodes of the unary node and return zeros to parent node. This is very helpful for multiplication with scalar while the scalar is `0`. For general cases, you can just return `true`. -```c++ -// T: data type used to compute the unary operation -// N: compute fragment length -template -struct TanhVisitor { - /// Argument - struct Arguments { - // a placeholder argument to ensure correctness of ctypes - int tmp; - - CUTLASS_HOST_DEVICE - Arguments(): tmp(0) { }; - - CUTLASS_HOST_DEVICE - Arguments(int tmp): tmp(tmp) { }; - }; - - /// Param - struct Params { - CUTLASS_HOST_DEVICE - Params(){ }; - Params(Arguments const &args) { } - }; - - /// Constructor - CUTLASS_HOST_DEVICE - TanhVisitor(Params const ¶ms) { } - - // scalar operator - CUTLASS_HOST_DEVICE - T tanh_op(T const &scalar) const { - return fast_tanh(scalar); - } - - /// vector operator - CUTLASS_HOST_DEVICE - Array operator()(Array const &frag) const { - Array y; - - CUTLASS_PRAGMA_UNROLL - for (int i=0; i < N; ++i) { - y[i] = tanh_op(frag[i]); - } - - return y; - } - - // Guard - CUTLASS_HOST_DEVICE - bool guard() { - return true; - } -}; -``` - -### Step 2: register Tanh function -After defining the function in C++, we need to register it in python. The class below gives an example. -* The init function takes the data type `element_compute`, which will be the `T` in the C++ template. -In the init function, we also generate the `_Arguments` class as a `ctypes.Structure`. It includes all the data members in the `TanhVisitor::Arguments`. -* The `_Arguments` need to be registered as `self.argument_type` of `tanh` class. -* A `emit` function is required to emit the namespace and typename of `TanhVisitor`. -* A staticmethod as numpy reference is required to implement the python code to parse. - -The built-in functions are defined in [pycutlass/src/pycutlass/epilogue.py](tools/library/scripts/pycutlass/src/pycutlass/epilogue.py). You can defined yours in any file as long as it can be found by [/pycutlass/src/pycutlass/parser.py](tools/library/scripts/pycutlass/src/pycutlass/parser.py). -```python -class tanh(ActivationFunctor): - def __init__(self, element_compute) -> None: - super().__init__() - class _Arguments(ctypes.Structure): - _fields_ = [ - ("tmp", ctypes.c_int) - ] - def __init__(self, *args) -> None: - self.tmp = 0 - self.argument_type = _Arguments - - def emit(self): - return "cutlass::TanhVisitor" - - @staticmethod - def numpy(x: np.ndarray): - return np.tanh(x) -``` - -### Step 3: Run the function -Now the new unary op is ready to use. An epilogue visitor tree can be built with -```python -class RowReduction_(EpilogueVisitTree): - def __call__( - self, accum: NDArray['tensor', 'float32'], c: NDArray['tensor', 'float32'], - alpha: 'float32', beta: 'float32'): - # - D = alpha * accum + tanh.numpy(beta * c) - reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) - return D, reduction -epilogue_functor = RowReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) -epilogue_functor.initialize() -``` - -## Limitations and Future work - -Although the Epilogue Visitor Tree brings great flexibility to epilogue construction, as the epilogue is formulated as a single tree, there are several limitations. -* [Future Work] Serial and Parallel Split-K GEMM are not supported yet. - * To support serial split-k, additional tree transformation pass is required to inject a `binaryOpNode(Add)` + `TensorInputNode` before each `TensorOutputNode` to fetch the partial sum back. The `semaphore` also needs to be passed into epilogue. - * To support parallel split-k, an Reduction with visitor kernel is required. -* [Future Work] Convolution and GEMM Grouped are not supported yet. - * To support Conv2d and GEMM Grouped, corresponding *_with_visitor kernels are required. - -* [Limitation] If the same node is used by two operations (except that one of them is reduction), the node and all its offsprings will be executed twice. -* [Limitation] The result of reduction can only be used as the return value. diff --git a/tools/library/scripts/pycutlass/docs/source/md/basic_idea.md b/tools/library/scripts/pycutlass/docs/source/md/basic_idea.md deleted file mode 100644 index a417afd9..00000000 --- a/tools/library/scripts/pycutlass/docs/source/md/basic_idea.md +++ /dev/null @@ -1,283 +0,0 @@ -# Basics of PyCUTLASS - -PyCUTLASS handles the following things when launch the CUTLASS kernels -* Memory management -* Operation Description -* Code emission and compilation -* Arguments preprocessing -* Kernel launching -* Result Synchronization - -## Memory management - -PyCUTLASS uses [RMM](https://github.com/rapidsai/rmm) to manage device memory. At the beginning of the program, call -```python -pycutlass.get_memory_pool({init_pool_size_in_bytes}, {max_pool_size_in_bytes}) -``` -We also provide functions to query the allocated size. -```python -bytes = get_allocated_size() -``` - - -## Operation Description -PyCUTLASS provides operation description for GEMM, GEMM Grouped and Conv2d operations. These operation descriptions are assembled from four foundamental concepts -* Math Instruction: math instruction executed in GPU cores -* Tile Description: tiling sizes and pipeline stages -* Operand Description: data type, layout, memory alignment -* Epilogue Functor: epilogue function - -### Math Instruction - -The math instruction is defined as follows: -```python -math_inst = MathInstruction( - {instruction_shape}, {element_a}, {element_b}, - {element_acc}, {opclass}, {math_operation} -) -``` -The `{instruction_shape}` and `{opclass}` defines the instruction size and type. The table below lists valid combinations. `{element_a}`, `{element_b}` define the source operand data type for each instructions, and `{element_acc}` defines the accumulator type. The `{math_operation}` defines the math operation applied. - -|Opclass | element_a/element_b | element_acc | instruction_shape | math_operation | -| -- | -- | -- | -- | -- | -| cutlass.OpClass.TensorOp | cutlass.float64 | cutlass.float64 | [8, 8, 4] | MathOperation.multiply_add| -| | cutlass.float32 cutlass.tfloat32, cutlass.float16 cutlass.bfloat16 | cutlass.float32 | [16, 8, 8] | MathOperation.multiply_add MathOperation.multiply_add_fast_f32 MathOperation.multiply_add_fast_f16 MathOperation.multiply_add_fast_bf16 | -| | cutlass.float16 | cutlass.float16/cutlass.float32|[16, 8, 16]| MathOperation.multiply_add | -| | cutlass.bfloat_16 | cutlass.float32 | [16, 8, 16]|MathOperation.multiply_add | -| | cutlass.int8 | cutlass.int32 | [16, 8, 32] | MathOperation.multiply_add_saturate| -|cutlass.OpClass.Simt| cutlass.float64 | cutlass.float64 | [1, 1, 1] | MathOperation.multiply_add | -| | cutlass.float32 | cutlass.float32 | [1, 1, 1] | MathOperation.multiply_add | - -The `cutlass.OpClass.TensorOp` indicates that the tensor core is used, while `cutlass.OpClass.Simt` uses the SIMT Core. - -The `multiply_add_fast_f32` emulates fast accurate SGEMM kernel which is accelerated -using Ampere Tensor Cores. More details can be found in [examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm](examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm). - -### Tile Description -The tile description describes the threadblock and warp tiling sizes, as well as the pipeline stages. -```python -tile_description = TileDescription( - {threadblock_shape}, {stages}, {warp_count}, - math_inst -) -``` -The `{threadblock_shape}` is a list of 3 integers `[Tile_M, Tile_N, Tile_K]` that defines the threadblock tiling size. `{stages}` defines the number of software pipeline stages ([detail](https://developer.nvidia.com/blog/controlling-data-movement-to-boost-performance-on-ampere-architecture/)). `{warp_count}` defines the number of warps along `M`, `N`, and `K` dimension. I.e., with `{threadblock_shape}=[Tile_M, Tile_N, Tile_K]` and `{warp_count}=[W_M, W_N, W_K]`, the warp tile size would be `[Tile_M / W_M, Tile_N / W_N, Tile_K / W_K]`. - -### Operand Description -The Operand Description defines the data type, layout, and memory alignment of input tensor A, B, and C. The output D shares the same attributes with C. The description is as follows: -```python -A = TensorDescription( - {element_a}, {layout_a}, {alignment_a} -) - -B = TensorDescription( - {element_b}, {layout_b}, {alignment_b} -) - -C = TensorDescription( - {element_c}, {layout_c}, {alignment_c} -) -``` -The table below lists the supported layout and data types for each operation -| Operation | data type | layout | -| -- | -- | -- | -| GEMM, GEMM Grouped | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.RowMajor, cutlass.ColumnMajor | -| | cutlass.int8 | cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32| -| Conv2d Fprop, Dgrad, Wgrad | cutlass.float64, cutlass.float32, cutlass.float16, cutlass.bfloat16 | cutlass.TensorNHWC | -| Conv2d Fprop | cutlass.int8 | cutlass.TensorNHWC, cutlass.TensorNC32HW32, cutlass.TensorC32RSK32| - -### Epilogue Functor -The epilogue functor defines the epilogue executed after mainloop. -We expose the following epilogue functors. -| Epilogue Functor | Remark | -| -- | -- | -| LinearCombination | $D=\alpha \times Accm + \beta \times C$ | -| LinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, Output is clamped to the maximum value of the data type output | -| FastLinearCombinationClamp | $D=\alpha \times Accm + \beta \times C$, only used for problem size $K\le 256$ for cutlass.int8, with accumulator data type `cutlass.int32` and epilogue compute data type `cutlass.float32` | -| LinearCombinationGeneric | $D = activation(\alpha \times Accm + \beta \times C)$, available activations include `relu`, `leaky_relu`, `tanh`, `sigmoid`, `silu`, `hardswish`, and `gelu` | - -The epilogue functors can be created as follows -```python -# LinearCombination -epilogue_functor = LinearCombination( - element_C, alignment_c, element_acc, element_epilogue_compute -) - -# LinearCombinationClamp -epilogue_functor = LinearCombinationClamp( - element_C, alignment_c, element_acc, element_epilogue_compute -) - -# FastLinearCombinationClamp -epilogue_functor = FastLinearCombinationClamp( - element_C, alignment_c -) - -# LinearCombinationGeneric -epilogue_functor = LinearCombinationGeneric( - relu(element_epilogue_compute), element_C, alignment_c, - element_acc, element_epilogue_compute -) -``` - -We also provides an experimental feature "Epilogue Visitor Tree" for GEMM operation. The details can be found in [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md). - - -### GEMM Operation - -The GEMM Operation description can be created with -```python -operation = GemmOperationUniversal( - {compute_capability}, tile_description, - A, B, C, epilogue_functor, - {swizzling_functor}, {visitor} -) -``` -* `{compute_capability}` is an integer indicates the compute capability of the GPU. For A100, it is 80. -* `{swizzling_functor}` describes how threadblocks are scheduled on GPU. This is used to improve the L2 Locality ([detail](https://developer.nvidia.com/blog/optimizing-compute-shaders-for-l2-locality-using-thread-group-id-swizzling/)). Currently we support `cutlass.{IdentitySwizzle1|IdentitySwizzle2|IdentitySwizzle4|IdentitySwizzle8|BatchedIdentitySwizzle}`. The last one is used for batched or array GEMM. -* `{visitor}`: a bool variable indicates whether the epilogue visitor tree is used. - -### GEMM Grouped Operation -The GEMM Grouped Operation description can be created with -```python -operation = GemmOperationGrouped( - compute_capability, tile_description, - A, B, C, epilogue_functor, - swizzling_functor, {precompute_mode} -) -``` -* `{precompute_mode}`: It could be `SchedulerMode.Host` or `SchedulerMode.Device`. See [examples/24_gemm_grouped](examples/24_gemm_grouped) for more details. - - -### Conv2d Operation -The Conv2d Operation description can be created with -```python -operation = Conv2dOperation( - {conv_kind}, {iterator_algorithm}, - compute_capability, tile_description, - A, B, C, {stride_support}, - epilogue_functor, swizzling_functor -) -``` -* `{conv_kind}` defines which convolution is executed. Available options include `fprop`, `dgrad`, and `wgrad`. -* `{iterator_algorithm}` specifies the iterator algorithm used by the implicit GEMM in convolution. The options are as follows: - * `analytic`: functionally correct in all cases but lower performance - * `optimized`: optimized for R <= 32, S <= 32 and unity-stride dgrad - * `fixed_channels`: analytic algorithm optimized for fixed channel count (C == AccessSize) - * `few_channels`: Analytic algorithm optimized for few channels (C divisible by AccessSize) -* `{stride_support}`: distinguishes among partial specializations that accelerate certain problems where convolution -stride is unit. - * `strided`: arbitrary convolution stride - * `unity`: unit convolution stride - -*** -## Code Emission and Compilation -After implementing the operation description, the related host and device code can be compiled with -```python -import pycutlass - -pycutlass.compiler.add_module([operation,]) -``` -Several operations can be compiled together. The `nvcc` at `$CUDA_INSTALL_PATH/bin` is used by default as the compiler backend. But you can also switch to [CUDA Python](https://nvidia.github.io/cuda-python/overview.html)'s `nvrtc` with -```python -pycutlass.compiler.nvrtc() -``` -We also have an internal compiled artifact manager that caches the compiled kernel in both memory and disk. The `compiled_cache.db` at your workspace is the database that contains the binary files. You can delete the file if you want to recompile the kernels. -*** -## Argument Processing -We provide argument wrapper to convert python tensors to the kernel parameters. Currently it supports [torch.Tensor](https://pytorch.org/), [numpy.ndarray](https://numpy.org/), and [cupy.ndarray](https://cupy.dev/). -### GEMM Arguments -The Gemm arguments can be created with -```python -arguments = GemmArguments( - operation=operation, problem_size={problem_size}, - A={tensor_A}, B={tensor_B}, C={tensor_C}, D={tensor_D}, - output_op={output_op}, - gemm_mode={gemm_mode}, - split_k_slices={split_k_slices}, batch={batch} -) -``` -* `problem_size` is a `cutlass.gemm.GemmCoord(M, N, K)` object that defines $M\times N\times K$ matrix multiplication. -* `tensor_X`: user-provide tensors. -* `output_op`: the params for the epilogue functor. -* `gemm_mode`, `split_k_slices`, and `batch`: - -|gemm_mode| split_k_slices | batch | remark| -|--|--|--|--| -|cutlass.gemm.Mode.Gemm | number of split-K slices | - | the ordinary GEMM or GEMM with serial split-K| -|cutlass.gemm.Mode.GemmSplitKParallel | number of split-K slices | - | GEMM Split-K Parallel| -|cutlass.gemm.Mode.Batched | - | batch size | Batched GEMM | -|cutlass.gemm.Mode.Array | - | batch size | Array GEMM | - -### GEMM Grouped Arguments -The GEMM grouped arguments can be created with -```python -arguments = GemmGroupedArguments( - operation, {problem_sizes_coord}, {tensor_As}, {tensor_Bs}, {tensor_Cs}, {tensor_Ds}, - output_op=output_op) -) -``` -* `problem_size_coord` is a list of `cutlass.gemm.GemmCoord(M, N, K)` for each problem size. -* `tensor_Xs` is a list of user-provide tensors. -* `output_op`: the params of the epilogue functor - -### Conv2d Arguments -The Conv2d arguments can be created with -```python -arguments = Conv2dArguments( - operation, {problem_size}, {tensor_A}, - {tensor_B}, {tensor_C}, {tensor_D}, - {output_op}, - {split_k_mode}, - {split_k_slices} -) -``` -* `problem_size`: it can be constructed with - ```python - problem_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(N, H, W, C), - cutlass.Tensor4DCoord(K, R, S, C), - cutlass.Tensor4DCoord(pad[0], pad[1], pad[2], pad[3]), - cutlass.MatrixCoord(stride[0], stride[1]), - cutlass.MatrixCoord(dilation[0], dilation[1]), - cutlass.conv.Mode.cross_correlation, - split_k_slices, 1 - ) - ``` -* `tensor_X` are user-provide tensors -* `output_op`: the params of the epilogue functor -* `split_k_mode`: currently we support `cutlass.conv.SplitKMode.Serial` and `cutlass.conv.SplitKMode.Parallel`. -* `split_k_slice`: number of split-k slices - -For ordinary conv2d, just use `cutlass.conv.SplitKMode.Serial` with `split_k_slice=1`. - -### Getting output_op -The way to create output_op is listed below -```python -output_op = operation.epilogue_type(*([alpha, beta] + args.activation_args)), -``` -It is a list of arguments start with the scaling factor `alpha` and `beta`. -The `output_op` of EpilogueVisitorTree is slightly different. Please check [EpilogueVisitorTree](tools/library/scripts/pycutlass/docs/source/md/EpilogueVisitorTree.md) for details. - - -## Kernel Launching - -With the arguments and operations, the kernel can be launched simply with -```python -operation.run(arguments) -``` - -## Sync results - -We also provide function to synchronize the kernel execution. If you use `numpy`, it will also copy the result back to host. To do that, run -```python -arguments.sync() -``` -If you use EpilogueVisitorTree, please call -```python -output_op.sync() -``` - -## Reduction Kernel behind Parallel Split-K - -If you use parallel-split-K in GEMM or Conv2d, an additional reduction kernel is required. Please check [examples/40_cutlass_py](examples/40_cutlass_py) for detail. diff --git a/tools/library/scripts/pycutlass/docs/source/user_guide.rst b/tools/library/scripts/pycutlass/docs/source/user_guide.rst deleted file mode 100644 index 3db70dbb..00000000 --- a/tools/library/scripts/pycutlass/docs/source/user_guide.rst +++ /dev/null @@ -1,4 +0,0 @@ -User Guide -===================================== - -.. mdinclude:: ./md/basic_idea.md diff --git a/tools/library/scripts/pycutlass/docs/source/visitor_tree.rst b/tools/library/scripts/pycutlass/docs/source/visitor_tree.rst deleted file mode 100644 index c48cdba3..00000000 --- a/tools/library/scripts/pycutlass/docs/source/visitor_tree.rst +++ /dev/null @@ -1,4 +0,0 @@ -User Guide -===================================== - -.. mdinclude:: ./md/EpilogueVisitorTree.md diff --git a/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py b/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py deleted file mode 100644 index 23fd156a..00000000 --- a/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +++ /dev/null @@ -1,106 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -from pycutlass import * -import pycutlass -from pycutlass.epilogue import LinearCombination -from pycutlass.test.conv2d_testbed import Conv2dLauncher - - -if __name__ == "__main__": - pycutlass.get_memory_pool(2**33, 2**33) - pycutlass.compiler.nvcc() - - math_inst = MathInstruction( - instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, - math_operation=MathOperation.multiply_add - ) - - A = TensorDescription( - element=math_inst.element_a, - layout=cutlass.TensorNHWC, - alignment=8) - B = TensorDescription( - element=math_inst.element_b, - layout=cutlass.TensorNHWC, - alignment=8) - C = TensorDescription( - element=cutlass.float32, - layout=cutlass.TensorNHWC, - alignment=8) - - tile_description = TileDescription( - threadblock_shape=[128, 128, 64], stages=4, - warp_count=[2, 2, 1], - math_instruction=math_inst - ) - - epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32) - - operation = Conv2dOperation( - conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, - arch=80, tile_description=tile_description, A=A, B=B, C=C, - element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, - epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 - ) - - profiler = Conv2dLauncher(operation, verification=False, profiling=True) - - python_runtime = profiler.run( - problem_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(32, 224, 224, 128), - cutlass.Tensor4DCoord(128, 3, 3, 128), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, - 1, 1 - ), split_k_mode=cutlass.conv.SplitKMode.Serial - ) - - - cpp_runtime = profiler.run_cutlass_profiler( - problem_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(32, 224, 224, 128), - cutlass.Tensor4DCoord(128, 3, 3, 128), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, - 1, 1 - ), split_k_mode=cutlass.conv.SplitKMode.Serial - ) - - print(cpp_runtime / python_runtime) diff --git a/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py deleted file mode 100644 index e4a82885..00000000 --- a/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +++ /dev/null @@ -1,91 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.test.gemm_testbed import GemmUniversalLauncher - -if __name__ == '__main__': - pycutlass.get_memory_pool(2**32, 2**32) - pycutlass.compiler.nvcc() - - math_inst = MathInstruction( - instruction_shape=[16, 8, 16], - element_a=cutlass.float16, element_b=cutlass.float16, - element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, - math_operation=MathOperation.multiply_add - ) - - tile_description = TileDescription( - threadblock_shape=[256, 128, 32], - stages=3, warp_count=[4, 2, 1], - math_instruction=math_inst - ) - - A = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, - alignment=4 - ) - B = TensorDescription( - element=cutlass.float16, layout=cutlass.RowMajor, - alignment=4 - ) - C = TensorDescription( - element=cutlass.float32, layout=cutlass.ColumnMajor, - alignment=4 - ) - - element_epilogue = cutlass.float32 - - epilogue_functor = LinearCombination(cutlass.float32, 4, cutlass.float32, cutlass.float32) - - swizzling_functor = cutlass.IdentitySwizzle1 - - operation = GemmOperationUniversal( - arch=80, tile_description=tile_description, - A=A, B=B, C=C, element_epilogue=element_epilogue, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor - ) - - profiler = GemmUniversalLauncher(operation, verification=False, profiling=True) - python_runtime = profiler.run( - mode=cutlass.gemm.Mode.Gemm, - problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096) - ) - - cpp_runtime = profiler.run_cutlass_profiler( - mode=cutlass.gemm.Mode.Gemm, - problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096), - ) - - print(cpp_runtime / python_runtime) diff --git a/tools/library/scripts/pycutlass/pyproject.toml b/tools/library/scripts/pycutlass/pyproject.toml deleted file mode 100644 index e192f102..00000000 --- a/tools/library/scripts/pycutlass/pyproject.toml +++ /dev/null @@ -1,9 +0,0 @@ -[build-system] - -requires = [ - "setuptools", - "scikit-build>0.13.1", - "pybind11", - "numpy<1.23", - "cmake>=3.20.1,!=3.23.0" -] diff --git a/tools/library/scripts/pycutlass/setup.py b/tools/library/scripts/pycutlass/setup.py deleted file mode 100644 index bf950ae8..00000000 --- a/tools/library/scripts/pycutlass/setup.py +++ /dev/null @@ -1,116 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import distutils.cmd -from setuptools import setup -import setuptools.command.build_py -import os - -# build rmm dependency -class BuildRMM(distutils.cmd.Command): - user_options = [] - def initialize_options(self): - pass - def finalize_options(self): - pass - def run(self): - try: - import rmm - except ImportError: - print("installing rmm") - os.system("git clone -b branch-22.10 --recurse-submodules https://github.com/rapidsai/rmm.git") - os.chdir("./rmm") - os.system("./build.sh librmm rmm") - os.chdir("./python") - os.system("python setup.py build_ext --inplace") - os.system("python setup.py install") - -cutlass_path = os.getenv('CUTLASS_PATH') -assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." -cuda_install_path = os.getenv('CUDA_INSTALL_PATH') -assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." - -ext_modules = [] - -try: - from pybind11.setup_helpers import Pybind11Extension, build_ext - include_dirs = [ - cutlass_path + "/include", - cuda_install_path + "/include", - cutlass_path + "/tools/util/include", - cutlass_path + "/test", - cutlass_path + "/tools/library/scripts/pycutlass/googletest/googletest/include" - ] - - ext_modules = [ - Pybind11Extension("cutlass", - ["src/cpp/cutlass.cpp"], - include_dirs=include_dirs, - extra_compile_args=["-fpermissive", "-w", "-std=c++17"]), - Pybind11Extension("cute", - ["src/cpp/cute.cpp"], - include_dirs=include_dirs, - extra_compile_args=["-fpermissive", "-w", "-std=c++17"]) - ] -except ImportError: - pass - -setup( - name="PyCutlass", - version="0.0.1", - author="Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall", - author_email="zhaodongc@nvidia.com", - description="Python interface for CUTLASS", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - package_dir={"": "src"}, - packages=['pycutlass', 'pycutlass.utils', 'pycutlass.test'], - setup_requires=["pybind11", "numpy<1.23"], - install_requires=[ - "numpy<1.23", - 'pybind11', - 'cuda-python>=11.8.0', - 'typeguard', - 'bfloat16', - 'typing', - 'scikit-build', - 'treelib' - ], - cmdclass={ - 'rmm': BuildRMM - }, - ext_modules=ext_modules, - python_requires=">=3.6", -) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py deleted file mode 100644 index 18f3e84d..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -import re - - -def SubstituteTemplate(template, values): - text = template - changed = True - while changed: - changed = False - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - newtext = re.sub(regex, value, text) - if newtext != text: - changed = True - text = newtext - return text - -from pycutlass.type_hint import * -from pycutlass.tensor_ref import * -from pycutlass.operation import * -from pycutlass.epilogue import * -from pycutlass.parser import * -from pycutlass.compiler import ArtifactManager -from pycutlass.memory_manager import * -from pycutlass.arguments import * -from pycutlass.library import * -from pycutlass.c_types import * -from pycutlass.gemm_operation import * -from pycutlass.conv2d_operation import * -from pycutlass.compiler import * -from pycutlass.utils import * -from pycutlass.frontend import * -from pycutlass.reduction_operation import * -from pycutlass.compiler import * -from pycutlass.utils.device import device_cc - -# module-wide variables - -import sys -this = sys.modules[__name__] - -# artifact manager -this.compiler = ArtifactManager() - -try: - if not hasattr(this, 'DEVICE_CC') or this.DEVICE_CC is None: - this.DEVICE_CC = device_cc() -except: - this.DEVICE_CC = None - -def get_memory_pool(init_pool_size=0, max_pool_size=2**34): - this.memory_pool = PoolMemoryManager( - init_pool_size=init_pool_size, - max_pool_size=max_pool_size - ) - return this.memory_pool diff --git a/tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py b/tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py deleted file mode 100644 index 13f52435..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/builder/collective_op_builder.py +++ /dev/null @@ -1,395 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utilities for stamping out collective mainloops for SM90 kernels -""" - -import cute -import cutlass -from pycutlass import SubstituteTemplate -import pycutlass.library as library - - -tma_alignment_bytes = 16 -cp_async_min_alignment_bytes = 4 - - -class RowColMajorToGMMAMajor: - @staticmethod - def A(layout, element): - """ - Converts operand A's layout from row/column major format into CuTe's GMMA major format - - :param layout: layout of the A operand - :type layout: cutlass.RowMajor or cutlass.ColumnMajor - :param element: data type of the A operand - - :return: C++ CuTe GMMA major format - :rtype: cute.GMMAMajor - """ - type_requires_k_major = (element == cutlass.tfloat32) or (element == cutlass.int8) - if layout == cutlass.ColumnMajor and not type_requires_k_major: - return cute.GMMAMajor.MN - else: - return cute.GMMAMajor.K - - @staticmethod - def B(layout, element): - """ - Converts operand B's layout from row/column major format into CuTe's GMMA major format - - :param layout: layout of the B operand - :type layout: cutlass.RowMajor or cutlass.ColumnMajor - :param element: data type of the B operand - - :return: C++ CuTe GMMA major format - :rtype: cute.GMMAMajor - """ - type_requires_k_major = (element == cutlass.tfloat32) or (element == cutlass.int8) - if layout == cutlass.RowMajor and not type_requires_k_major: - return cute.GMMAMajor.MN - else: - return cute.GMMAMajor.K - - -def cluster_shape_to_tma(dim): - """ - Returns the TMA copy type for a given cluster dimension - - :param dim: a given dimension of a cluster - :type dim: layout - - :return: C++ TMA copy time - :rtype: str - """ - return 'cute::SM90_TMA_LOAD' if dim == 1 else 'cute::SM90_TMA_LOAD_MULTICAST' - - -def make_cpasync_gmem_tiled_copy(thread_count, element, alignment, gmma_layout, dim_mn, dim_k): - """ - Returns a `make_tiled_copy` call for a given configuration - - :param thread_count: number of threads in the threadblock - :type thread_count: int - :param element: datatype of the operand in question - :param alignment: byte alignment of the operand in question - :type alignment: int - :param gmma_layout: GMMA layout of the operand in question - :type gmma_layout: cute.GMMAMajor - :param dim_mn: extent of the M/N dimension of the tile - :type dim_mn: int - :param dim_k: extent of the reduction dimension of the tile - :type dim_k: int - - :return: C++ call to `make_tiled_copy` - :rtype: str - """ - - emission_str = """decltype(cute::make_tiled_copy( - cute::Copy_Atom(sizeof(${element})) * ${alignment}>>, ${element}>{}, - cute::Layout, - cute::Stride<_${stride_x}, _${stride_y}>>{}, - cute::Layout>{}))""" - if gmma_layout == cute.GMMAMajor.K: - threads_major = dim_k // alignment - threads_minor = thread_count // threads_major - values = { - 'shape0_x': str(threads_minor), - 'shape0_y': str(threads_major), - 'stride_x': str(threads_major), - 'stride_y': '1', - 'shape1_x': '1', - 'shape1_y': str(alignment) - } - elif gmma_layout == cute.GMMAMajor.MN: - threads_major = dim_mn // alignment - threads_minor = thread_count // threads_major - values = { - 'shape0_x': str(threads_major), - 'shape0_y': str(threads_minor), - 'stride_x': '1', - 'stride_y': str(threads_major), - 'shape1_x': str(alignment), - 'shape1_y': '1' - } - else: - raise Exception('Unexpected GMMA layout {}'.format(gmma_layout)) - - # Add common values - values['element'] = library.DataTypeTag[element] - values['alignment'] = str(alignment) - return SubstituteTemplate(emission_str, values) - - -def max_stages(op, arch): - """ - Returns the maximum number pipeline stages that can be used for an operation. - - :param op: operation for which the maximum stages should be computed. If stages are - set via the `op.tile_description.stages` parameter, this setting is ignored - in the present calculation - :type op: pycutlass.GemmOperation - :param arch: compute capability of the device on which the operation will be run - :type arch: int - - :return: maximum number of pipeline stages that can be used for an operation - :rtype: int - """ - smem_per_stage = library.CalculateSmemUsagePerStage(op) - smem_capacity = library.SharedMemPerCC[arch] - return int(smem_capacity // smem_per_stage) - - -class LayoutToStride: - _variable_first = 'cute::Stride, int64_t>' - _variable_last = 'cute::Stride, int64_t, int64_t>' - - @staticmethod - def A(layout): - """ - Returns the CuTe shape type corresponding to the layout of operand A - - :param layout: layout of the B operand - :type layout: cutlass.RowMajor or cutlass.ColumnMajor - - :return: C++ declaration of CuTe stride - :rtype: str - """ - if layout == cutlass.RowMajor: - return LayoutToStride._variable_first - elif layout == cutlass.ColumnMajor: - return LayoutToStride._variable_last - else: - raise Exception('Unsupported layout {}'.format(layout)) - - @staticmethod - def B(layout): - """ - Returns the CuTe shape type corresponding to the layout of operand B - - :param layout: layout of the B operand - :type layout: cutlass.RowMajor or cutlass.ColumnMajor - - :return: C++ declaration of CuTe stride - :rtype: str - """ - if layout == cutlass.RowMajor: - return LayoutToStride._variable_last - elif layout == cutlass.ColumnMajor: - return LayoutToStride._variable_first - else: - raise Exception('Unsupported layout {}'.format(layout)) - - -EMISSION_STR = """ -using TileShape_MNK = cute::Shape<_${threadblock_shape_m}, _${threadblock_shape_n}, _${threadblock_shape_k}>; -using ClusterShape_MNK = cute::Shape<_${cluster_shape_m}, _${cluster_shape_n}, _${cluster_shape_k}>; -using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - ${internal_element_A}, ${internal_element_B}, ${element_accumulator}, TileShape_MNK, ${gmma_layout_A}, ${gmma_layout_B}>())); - -using SmemLayoutAtomA = decltype(cute::GMMA::smem_selector<${gmma_layout_A}, ${internal_element_A}, _${threadblock_shape_m}, _${threadblock_shape_k}>()); -using SmemLayoutAtomB = decltype(cute::GMMA::smem_selector<${gmma_layout_B}, ${internal_element_B}, _${threadblock_shape_n}, _${threadblock_shape_k}>()); - -using CollectiveOp = typename cutlass::gemm::collective::CollectiveMma< - ${mainloop_type}<${stage_count}, ClusterShape_MNK${kernel_schedule}>, - TileShape_MNK, - ${element_A}, - ${stride_A}, - ${element_B}, - ${stride_B}, - TiledMma, - ${gmem_tiled_copy_A}, - SmemLayoutAtomA, - void, // GMMA_SS does not need an SmemCopyAtom - ${transform_A}, - ${gmem_tiled_copy_B}, - SmemLayoutAtomB, - void, // GMMA_SS does not need an SmemCopyAtom - ${transform_B} ->; -""" - - -def internal_element(element): - """ - Returns the data type internally used for `element`. - - :param element: data type - - :return: data type used internally - """ - return cutlass.tfloat32 if element == cutlass.float32 else element - - -def common_values(op, stage_count, transform_A, transform_B): - """ - Returns a dictionary containing common values to be substituted in the emission of the - collective operation declaration. Values specific to a particular collective operation - should be added to these. - - :param op: GEMM operation for which to build a collective operation - :type op: pycutlass.GemmOperation - :param stage_count: number of pipeline stages to use in the operation - :type stage_count: int - :param transform_A: transformation to perform on the A operand - :type transform_A: str - :param transform_B: transformation to perform on the B operand - :type transform_B: str - - :return: dictionary containing values to substitute in emission string - :rtype: dict - """ - internal_element_a = internal_element(op.A.element) - internal_element_b = internal_element(op.B.element) - - return { - 'threadblock_shape_m': str(op.tile_description.threadblock_shape[0]), - 'threadblock_shape_n': str(op.tile_description.threadblock_shape[1]), - 'threadblock_shape_k': str(op.tile_description.threadblock_shape[2]), - 'cluster_shape_m': str(op.tile_description.cluster_shape[0]), - 'cluster_shape_n': str(op.tile_description.cluster_shape[1]), - 'cluster_shape_k': str(op.tile_description.cluster_shape[2]), - 'element_A': library.DataTypeTag[op.A.element], - 'element_B': library.DataTypeTag[op.B.element], - 'internal_element_A': library.DataTypeTag[internal_element_a], - 'internal_element_B': library.DataTypeTag[internal_element_b], - 'element_accumulator': library.DataTypeTag[op.accumulator_type()], - 'gmma_layout_A': library.CuTeLayoutTag[RowColMajorToGMMAMajor.A(op.A.layout, internal_element_a)], - 'gmma_layout_B': library.CuTeLayoutTag[RowColMajorToGMMAMajor.B(op.B.layout, internal_element_b)], - 'stride_A': LayoutToStride.A(op.A.layout), - 'stride_B': LayoutToStride.B(op.B.layout), - 'stage_count': str(stage_count), - 'transform_A': transform_A, - 'transform_B': transform_B - } - - -def build_gmma_tma(op): - """ - Builds a collective operation declaration targeting TMA GMMA kernels - - :param op: GEMM operation for which to build a collective operation - :type op: pycutlass.GemmOperation - - :return: string containing the C++ declaration of collective operation - :rtype: str - """ - A_tma_aligned = (library.DataTypeSizeBytes[op.A.element] * op.A.alignment) % tma_alignment_bytes == 0 - B_tma_aligned = (library.DataTypeSizeBytes[op.B.element] * op.B.alignment) % tma_alignment_bytes == 0 - if not A_tma_aligned or not B_tma_aligned: - raise Exception('Each of the A or B operands must be aligned to {} bytes to use TMA'.format(tma_alignment_bytes)) - - max_stage_count = max_stages(op, arch=90) - if op.tile_description.stages is None: - op.tile_description.stages = max_stage_count - elif op.tile_description.stages > max_stage_count: - raise Exception('Combination of threadblock shape, data types, and number of stages exceeds shared memory capacity.') - - kernel_schedule = 'cutlass::gemm::KernelTmaWarpSpecialized' - if op.tile_description.persistent: - kernel_schedule = 'cutlass::gemm::KernelTmaWarpSpecializedPersistent' - - transform_A = 'cute::identity' - transform_B = 'cute::identity' - values = common_values(op, op.tile_description.stages, transform_A, transform_B) - specific_values = { - 'mainloop_type': 'cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized', - 'kernel_schedule': ', ' + kernel_schedule, - 'gmem_tiled_copy_A': cluster_shape_to_tma(op.tile_description.cluster_shape[1]), - 'gmem_tiled_copy_B': cluster_shape_to_tma(op.tile_description.cluster_shape[0]) - } - values.update(specific_values) - - return SubstituteTemplate(EMISSION_STR, values) - - -def build_gmma_cpasync(op): - """ - Builds a collective operation declaration targeting cp.async GMMA kernels - - :param op: GEMM operation for which to build a collective operation - :type op: pycutlass.GemmOperation - - :return: string containing the C++ declaration of collective operation - :rtype: str - """ - A_cp_async_aligned = (library.DataTypeSizeBytes[op.A.element] * op.A.alignment) % cp_async_min_alignment_bytes == 0 - B_cp_async_aligned = (library.DataTypeSizeBytes[op.B.element] * op.B.alignment) % cp_async_min_alignment_bytes == 0 - if not A_cp_async_aligned or not B_cp_async_aligned: - raise Exception('Each of the A or B operands must be aligned to {} bytes to use cp.async'.format(cp_async_min_alignment_bytes)) - - max_stage_count = max_stages(op, arch=90) - if op.tile_description.stages is None: - op.tile_description.stages = max_stage_count - elif op.tile_description.stages > max_stage_count: - raise Exception('Combination of threadblock shape, data types, and number of stages exceeds shared memory capacity.') - - transform_A = 'cute::identity' - transform_B = 'cute::identity' - - thread_count = 128 - cpasync_copy_A = make_cpasync_gmem_tiled_copy(thread_count, op.A.element, op.A.alignment, RowColMajorToGMMAMajor.A(op.A.layout, op.A.element), - op.tile_description.threadblock_shape[0], op.tile_description.threadblock_shape[2]) - cpasync_copy_B = make_cpasync_gmem_tiled_copy(thread_count, op.B.element, op.B.alignment, RowColMajorToGMMAMajor.B(op.B.layout, op.B.element), - op.tile_description.threadblock_shape[1], op.tile_description.threadblock_shape[2]) - - values = common_values(op, op.tile_description.stages, transform_A, transform_B) - specific_values = { - 'mainloop_type': 'cutlass::gemm::MainloopSm90CpAsyncGmma', - 'kernel_schedule': '', - 'gmem_tiled_copy_A': cpasync_copy_A, - 'gmem_tiled_copy_B': cpasync_copy_B - } - values.update(specific_values) - - return SubstituteTemplate(EMISSION_STR, values) - - -def build(operation): - """ - Builds a collective operation declaration targeting cp.async or TMA for GMMA kernels - - :param operation: GEMM operation for which to build a collective operation - :type operation: pycutlass.GemmOperation - - :return: string containing the C++ declaration of collective operation - :rtype: str - """ - A_tma_aligned = (library.DataTypeSizeBytes[operation.A.element] * operation.A.alignment) % tma_alignment_bytes == 0 - B_tma_aligned = (library.DataTypeSizeBytes[operation.B.element] * operation.B.alignment) % tma_alignment_bytes == 0 - tma_correct_size = (library.DataTypeSizeBytes[operation.A.element] == 2 and library.DataTypeSizeBytes[operation.B.element] == 2) - tma_correct_layout = (operation.A.layout == cutlass.RowMajor or operation.B.layout == cutlass.ColumnMajor) - if A_tma_aligned and B_tma_aligned and (tma_correct_size or tma_correct_layout): - return build_gmma_tma(operation) - else: - return build_gmma_cpasync(operation) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/library.py b/tools/library/scripts/pycutlass/src/pycutlass/library.py deleted file mode 100644 index 08280340..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/library.py +++ /dev/null @@ -1,870 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import re - -################################################################################################### - -import enum -import cutlass -import cute - -# The following block implements enum.auto() for Python 3.5 variants that don't include it such -# as the default 3.5.2 on Ubuntu 16.04. -# -# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility - -try: - from enum import auto as enum_auto -except ImportError: - __cutlass_library_auto_enum = 0 - - def enum_auto() -> int: - global __cutlass_library_auto_enum - i = __cutlass_library_auto_enum - __cutlass_library_auto_enum += 1 - return i - -################################################################################################### - -# - - -class GeneratorTarget(enum.Enum): - Library = enum_auto() - -# -GeneratorTargetNames = { - GeneratorTarget.Library: 'library', -} -# - -################################################################################################### - -# -ShortDataTypeNames = { - cutlass.int32: 'i', - cutlass.float16: 'h', - cutlass.float32: 's', - cutlass.float64: 'd', - cutlass.dtype.cf32: 'c', - cutlass.dtype.cf64: 'z', -} - -# -DataTypeNames = { - cutlass.dtype.b1: "b1", - cutlass.dtype.u4: "u4", - cutlass.dtype.u8: "u8", - cutlass.dtype.u16: "u16", - cutlass.dtype.u32: "u32", - cutlass.dtype.u64: "u64", - cutlass.dtype.s4: "s4", - cutlass.int8: "s8", - cutlass.dtype.s16: "s16", - cutlass.int32: "s32", - cutlass.dtype.s64: "s64", - cutlass.float16: "f16", - cutlass.bfloat16: "bf16", - cutlass.float32: "f32", - cutlass.tfloat32: "tf32", - cutlass.float64: "f64", - cutlass.dtype.cf16: "cf16", - cutlass.dtype.cbf16: "cbf16", - cutlass.dtype.cf32: "cf32", - cutlass.dtype.ctf32: "ctf32", - cutlass.dtype.cf64: "cf64", - cutlass.dtype.cu4: "cu4", - cutlass.dtype.cu8: "cu8", - cutlass.dtype.cu16: "cu16", - cutlass.dtype.cu32: "cu32", - cutlass.dtype.cu64: "cu64", - cutlass.dtype.cs4: "cs4", - cutlass.dtype.cs8: "cs8", - cutlass.dtype.cs16: "cs16", - cutlass.dtype.cs32: "cs32", - cutlass.dtype.cs64: "cs64", -} - -DataTypeTag = { - cutlass.dtype.b1: "cutlass::uint1b_t", - cutlass.dtype.u4: "cutlass::uint4b_t", - cutlass.dtype.u8: "uint8_t", - cutlass.dtype.u16: "uint16_t", - cutlass.dtype.u32: "uint32_t", - cutlass.dtype.u64: "uint64_t", - cutlass.dtype.s4: "cutlass::int4b_t", - cutlass.int8: "int8_t", - cutlass.dtype.s16: "int16_t", - cutlass.int32: "int32_t", - cutlass.dtype.s64: "int64_t", - cutlass.float16: "cutlass::half_t", - cutlass.bfloat16: "cutlass::bfloat16_t", - cutlass.float32: "float", - cutlass.tfloat32: "cutlass::tfloat32_t", - cutlass.float64: "double", - cutlass.dtype.cf16: "cutlass::complex", - cutlass.dtype.cbf16: "cutlass::complex", - cutlass.dtype.cf32: "cutlass::complex", - cutlass.dtype.ctf32: "cutlass::complex", - cutlass.dtype.cf64: "cutlass::complex", - cutlass.dtype.cu4: "cutlass::complex", - cutlass.dtype.cu8: "cutlass::complex", - cutlass.dtype.cu16: "cutlass::complex", - cutlass.dtype.cu32: "cutlass::complex", - cutlass.dtype.cu64: "cutlass::complex", - cutlass.dtype.cs4: "cutlass::complex", - cutlass.dtype.cs8: "cutlass::complex", - cutlass.dtype.cs16: "cutlass::complex", - cutlass.dtype.cs32: "cutlass::complex", - cutlass.dtype.cs64: "cutlass::complex", -} - -DataTypeSize = { - cutlass.dtype.b1: 1, - cutlass.dtype.u4: 4, - cutlass.dtype.u8: 8, - cutlass.dtype.u16: 16, - cutlass.dtype.u32: 32, - cutlass.dtype.u64: 64, - cutlass.dtype.s4: 4, - cutlass.int8: 8, - cutlass.dtype.s16: 16, - cutlass.int32: 32, - cutlass.dtype.s64: 64, - cutlass.float16: 16, - cutlass.bfloat16: 16, - cutlass.float32: 32, - cutlass.tfloat32: 32, - cutlass.float64: 64, - cutlass.dtype.cf16: 32, - cutlass.dtype.cbf16: 32, - cutlass.dtype.cf32: 64, - cutlass.dtype.ctf32: 32, - cutlass.dtype.cf64: 128, - cutlass.dtype.cu4: 8, - cutlass.dtype.cu8: 16, - cutlass.dtype.cu16: 32, - cutlass.dtype.cu32: 64, - cutlass.dtype.cu64: 128, - cutlass.dtype.cs4: 8, - cutlass.dtype.cs8: 16, - cutlass.dtype.cs16: 32, - cutlass.dtype.cs32: 64, - cutlass.dtype.cs64: 128, -} - - -class DataTypeSizeBytes: - """ - Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the - data type key is less than a full byte or a non-integer number of bytes. - """ - @staticmethod - def __class_getitem__(datatype): - """ - Returns the number of bytes in size the data type is. Raises an exception if the data type - is either less than a full byte or a non-integer number of bytes in size. - - :param datatype: data type to query - - :return: number of bytes the data type occupies - :rtype: int - """ - bits = DataTypeSize[datatype] - if bits < 8: - raise Exception('Data type {} is less than one byte in size.'.format(datatype)) - elif bits % 8 != 0: - raise Exception('Data type {} is not an integer number of bytes.'.format(datatype)) - return bits // 8 - -################################################################################################### -# - - -class BlasMode(enum.Enum): - symmetric = enum_auto() - hermitian = enum_auto() - - -# -BlasModeTag = { - BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', - BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', -} - -# -ComplexTransformTag = { - cutlass.complex_transform.none: 'cutlass::ComplexTransform::kNone', - cutlass.complex_transform.conj: 'cutlass::ComplexTransform::kConjugate', -} - -# -RealComplexBijection = [ - (cutlass.float16, cutlass.dtype.cf16), - (cutlass.float32, cutlass.dtype.cf32), - (cutlass.float64, cutlass.dtype.cf64), -] - -# - - -def is_complex(data_type): - for r, c in RealComplexBijection: - if data_type == c: - return True - return False - -# - - -def get_complex_from_real(real_type): - for r, c in RealComplexBijection: - if real_type == r: - return c - return cutlass.dtype.invalid - -# - - -def get_real_from_complex(complex_type): - for r, c in RealComplexBijection: - if complex_type == c: - return r - return cutlass.dtype.invalid - -# - - -class ComplexMultiplyOp(enum.Enum): - multiply_add = enum_auto() - gaussian = enum_auto() - -################################################################################################### - -# - - -class MathOperation(enum.Enum): - multiply_add = enum_auto() - multiply_add_saturate = enum_auto() - xor_popc = enum_auto() - multiply_add_fast_bf16 = enum_auto() - multiply_add_fast_f16 = enum_auto() - multiply_add_fast_f32 = enum_auto() - multiply_add_complex_fast_f32 = enum_auto() - multiply_add_complex = enum_auto() - multiply_add_complex_gaussian = enum_auto() - - -# -MathOperationNames = { - MathOperation.multiply_add: 'multiply_add', - MathOperation.multiply_add_saturate: 'multiply_add_saturate', - MathOperation.xor_popc: 'xor_popc', - MathOperation.multiply_add_fast_bf16: 'multiply_add_fast_bf16', - MathOperation.multiply_add_fast_f16: 'multiply_add_fast_f16', - MathOperation.multiply_add_fast_f32: 'multiply_add_fast_f32', - MathOperation.multiply_add_complex_fast_f32: 'multiply_add_complex_fast_f32', - MathOperation.multiply_add_complex: 'multiply_add_complex', - MathOperation.multiply_add_complex_gaussian: 'multiply_add_complex_gaussian', -} - -# -MathOperationTag = { - MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', - MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', - MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', - MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', - MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', - MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', - MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', - MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', - MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', -} - -################################################################################################### - -# -LayoutTag = { - cutlass.ColumnMajor: 'cutlass::layout::ColumnMajor', - cutlass.RowMajor: 'cutlass::layout::RowMajor', - cutlass.layout.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', - cutlass.layout.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', - cutlass.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', - cutlass.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', - cutlass.layout.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', - cutlass.layout.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', - cutlass.TensorNHWC: 'cutlass::layout::TensorNHWC', - cutlass.layout.TensorNDHWC: 'cutlass::layout::TensorNDHWC', - cutlass.layout.TensorNCHW: 'cutlass::layout::TensorNCHW', - cutlass.layout.TensorNGHWC: 'cutlass::layout::TensorNGHWC', - cutlass.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', - cutlass.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', - cutlass.layout.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', - cutlass.layout.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', -} - -# -TransposedLayout = { - cutlass.ColumnMajor: cutlass.RowMajor, - cutlass.RowMajor: cutlass.ColumnMajor, - cutlass.layout.ColumnMajorInterleaved2: cutlass.layout.RowMajorInterleaved2, - cutlass.layout.RowMajorInterleaved2: cutlass.layout.ColumnMajorInterleaved2, - cutlass.ColumnMajorInterleaved32: cutlass.RowMajorInterleaved32, - cutlass.RowMajorInterleaved32: cutlass.ColumnMajorInterleaved32, - cutlass.layout.ColumnMajorInterleaved64: cutlass.layout.RowMajorInterleaved64, - cutlass.layout.RowMajorInterleaved64: cutlass.layout.ColumnMajorInterleaved64, - cutlass.TensorNHWC: cutlass.TensorNHWC -} - -# -ShortLayoutTypeNames = { - cutlass.ColumnMajor: 'n', - cutlass.layout.ColumnMajorInterleaved2: 'n2', - cutlass.ColumnMajorInterleaved32: 'n32', - cutlass.layout.ColumnMajorInterleaved64: 'n64', - cutlass.RowMajor: 't', - cutlass.layout.RowMajorInterleaved2: 't2', - cutlass.RowMajorInterleaved32: 't32', - cutlass.layout.RowMajorInterleaved64: 't64', - cutlass.TensorNHWC: 'nhwc', - cutlass.layout.TensorNDHWC: 'ndhwc', - cutlass.layout.TensorNCHW: 'nchw', - cutlass.layout.TensorNGHWC: 'nghwc', - cutlass.TensorNC32HW32: 'nc32hw32', - cutlass.layout.TensorNC64HW64: 'nc64hw64', - cutlass.TensorC32RSK32: 'c32rsk32', - cutlass.layout.TensorC64RSK64: 'c64rsk64' -} - -# -ShortComplexLayoutNames = { - (cutlass.ColumnMajor, cutlass.complex_transform.none): 'n', - (cutlass.ColumnMajor, cutlass.complex_transform.conj): 'c', - (cutlass.RowMajor, cutlass.complex_transform.none): 't', - (cutlass.RowMajor, cutlass.complex_transform.conj): 'h' -} - -# -CuTeLayoutTag = { - cute.GMMAMajor.K: 'cute::GMMA::Major::K', - cute.GMMAMajor.MN: 'cute::GMMA::Major::MN' -} - -################################################################################################### - -# - - -class SideMode(enum.Enum): - Left = enum_auto() - Right = enum_auto() - - -# -SideModeTag = { - SideMode.Left: 'cutlass::SideMode::kLeft', - SideMode.Right: 'cutlass::SideMode::kRight' -} - -# -ShortSideModeNames = { - SideMode.Left: 'ls', - SideMode.Right: 'rs' -} - -################################################################################################### - -# - - -class FillMode(enum.Enum): - Lower = enum_auto() - Upper = enum_auto() - - -# -FillModeTag = { - FillMode.Lower: 'cutlass::FillMode::kLower', - FillMode.Upper: 'cutlass::FillMode::kUpper' -} - -# -ShortFillModeNames = { - FillMode.Lower: 'l', - FillMode.Upper: 'u' -} - -################################################################################################### - -# - - -class DiagType(enum.Enum): - NonUnit = enum_auto() - Unit = enum_auto() - - -# -DiagTypeTag = { - DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', - DiagType.Unit: 'cutlass::DiagType::kUnit' -} - -# -ShortDiagTypeNames = { - DiagType.NonUnit: 'nu', - DiagType.Unit: 'un' -} - -################################################################################################### - -OpcodeClassNames = { - cutlass.OpClass.Simt: 'simt', - cutlass.OpClass.TensorOp: 'tensorop', - cutlass.OpClass.WmmaTensorOp: 'wmma_tensorop', - cutlass.OpClass.SparseTensorOp: 'sptensorop' -} - -OpcodeClassTag = { - cutlass.OpClass.Simt: 'cutlass::arch::OpClassSimt', - cutlass.OpClass.TensorOp: 'cutlass::arch::OpClassTensorOp', - cutlass.OpClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', - cutlass.OpClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp' -} - -################################################################################################### - -# - -class OperationKind(enum.Enum): - Gemm = enum_auto() - RankK = enum_auto() - Rank2K = enum_auto() - Trmm = enum_auto() - Symm = enum_auto() - Conv2d = enum_auto() - Conv3d = enum_auto() - - -# -OperationKindNames = { - OperationKind.Gemm: 'gemm', OperationKind.RankK: 'rank_k', OperationKind.Rank2K: 'rank_2k', OperationKind.Trmm: 'trmm', OperationKind.Symm: 'symm', OperationKind.Conv2d: 'conv2d', OperationKind.Conv3d: 'conv3d' -} - -# -ArchitectureNames = { - 50: 'maxwell', - 60: 'pascal', - 61: 'pascal', - 70: 'volta', - 75: 'turing', - 80: 'ampere', - 90: 'hopper' -} - -# -SharedMemPerCC = { - 70: 96 << 10, # 96KB of SMEM - 72: 96 << 10, # 96KB of SMEM - 75: 64 << 10, # 64KB of SMEM - 80: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver - 86: 100 << 10, # 100KB of SMEM - 87: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver - 89: 100 << 10, # 100KB of SMEM - 90: 227 << 10, # 228KB of SMEM - 1KB reserved for the driver -} - -################################################################################################### - -class GemmKind(enum.Enum): - Gemm = enum_auto() - Sparse = enum_auto() - Universal = enum_auto() - PlanarComplex = enum_auto() - PlanarComplexArray = enum_auto() - Grouped = enum_auto() - - -# -GemmKindNames = { - GemmKind.Gemm: "gemm", - GemmKind.Sparse: "spgemm", - GemmKind.Universal: "gemm", - GemmKind.PlanarComplex: "gemm_planar_complex", - GemmKind.PlanarComplexArray: "gemm_planar_complex_array", - GemmKind.Grouped: "gemm_grouped" -} - -# - - -class RankKKind(enum.Enum): - Universal = enum_auto() - - -# -RankKKindNames = { - RankKKind.Universal: "rank_k" -} - -# - - -class TrmmKind(enum.Enum): - Universal = enum_auto() - - -# -TrmmKindNames = { - TrmmKind.Universal: "trmm" -} - -# - - -class SymmKind(enum.Enum): - Universal = enum_auto() - - -# -SymmKindNames = { - SymmKind.Universal: "symm" -} - -# - - -class SwizzlingFunctor(enum.Enum): - Identity1 = enum_auto() - Identity2 = enum_auto() - Identity4 = enum_auto() - Identity8 = enum_auto() - Horizontal = enum_auto() - BatchedIdentity1 = enum_auto() - StridedDgradIdentity1 = enum_auto() - StridedDgradIdentity4 = enum_auto() - StridedDgradHorizontal = enum_auto() - - -# -SwizzlingFunctorTag = { - cutlass.IdentitySwizzle1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', - SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', - SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', - SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', - SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', - SwizzlingFunctor.BatchedIdentity1: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle", - SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', - SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', - SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', -} - -# - - -class SchedulerMode(enum.Enum): - Device = enum_auto(), - Host = enum_auto() - - -# -SchedulerModeTag = { - SchedulerMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', - SchedulerMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' -} - -# -ShortSchedulerModeNames = { - SchedulerMode.Device: 'Device', - SchedulerMode.Host: 'Host' -} - -################################################################################################### - - -# -ConvKindTag = { - cutlass.conv.Operator.fprop: 'cutlass::conv::Operator::kFprop', - cutlass.conv.Operator.dgrad: 'cutlass::conv::Operator::kDgrad', - cutlass.conv.Operator.wgrad: 'cutlass::conv::Operator::kWgrad' -} - -ConvKindNames = { - cutlass.conv.Operator.fprop: 'fprop', - cutlass.conv.Operator.dgrad: 'dgrad', - cutlass.conv.Operator.wgrad: 'wgrad', -} - - -# -IteratorAlgorithmTag = { - cutlass.conv.IteratorAlgorithm.analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', - cutlass.conv.IteratorAlgorithm.optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', - cutlass.conv.IteratorAlgorithm.fixed_channels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', - cutlass.conv.IteratorAlgorithm.few_channels: 'cutlass::conv::IteratorAlgorithm::kFewChannels' -} - -IteratorAlgorithmNames = { - cutlass.conv.IteratorAlgorithm.analytic: 'analytic', - cutlass.conv.IteratorAlgorithm.optimized: 'optimized', - cutlass.conv.IteratorAlgorithm.fixed_channels: 'fixed_channels', - cutlass.conv.IteratorAlgorithm.few_channels: 'few_channels' -} - -# - - -class StrideSupport(enum.Enum): - Strided = enum_auto() - Unity = enum_auto() - - -# -StrideSupportTag = { - StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', - StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', -} - -StrideSupportNames = { - StrideSupport.Strided: '', - StrideSupport.Unity: 'unity_stride', -} - - -class ConvMode(enum.Enum): - CrossCorrelation = enum_auto() - Convolution = enum_auto() - - -# -ConvModeTag = { - ConvMode.CrossCorrelation: 'cutlass::conv::Mode::kCrossCorrelation', - ConvMode.Convolution: 'cutlass::conv::Mode::kConvolution' -} - -################################################################################################### - -# - - -class MathInstruction: - """ - Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel - """ - def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class=cutlass.OpClass.Simt, math_operation=MathOperation.multiply_add): - """ - :param instruction_shape: size of the [M, N, K] dimensions of the instruction - :type instruction_shape: list or tuple - :param element_a: data type of operand A - :param element_b: data type of operand B - :param element_accumulator: data type used in accumulation - :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core) - :type opcode_class: cutlass.OpClass - :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate) - :type math_operation: MathOperation - """ - self.instruction_shape = instruction_shape - self.element_a = element_a - self.element_b = element_b - self.element_accumulator = element_accumulator - self.opcode_class = opcode_class - self.math_operation = math_operation - -# - - -class TileDescription: - """ - Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, - stage count, and math instruction specification - """ - def __init__(self, threadblock_shape, stages, warp_count, math_instruction, cluster_shape=[1, 1, 1], persistent=False): - """ - :param threadblock_shape: shape of a threadblock tyle - :type threadblock_shape: list or tuple - :param stages: number of pipeline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum - number of stages that can be supported for an operation on a given architecture will be computed at a later time - :type stages: int or None - :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile - :type warp_count: list, tuple, or None - :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands - :type math_instruction: MathInstruction - :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster - :param persistent: whether the kernel uses persistent warp-specialized threadblocks (only available for SM90+) - :type persistent: bool - """ - self.threadblock_shape = threadblock_shape - self.cluster_shape = cluster_shape - self.persistent: bool = persistent - self.stages: int = stages - - self.math_instruction = math_instruction - - # Number of warps along x, y, z directions - self.warp_count = warp_count - - @property - def num_threads(self): - """ - Returns the number of threads in the threadblock - - :return: number of threads in the threadblock - :rtype: int or None (if warp count is None) - """ - if self.warp_count is not None: - threads = 32 - for cnt in self.warp_count: - threads *= cnt - return threads - return None - - def procedural_name(self): - """ - Returns a name identifying the tile description - - :return: name identifying the tile description - :rtype: int - """ - emit_stages = 0 if self.stages is None else self.stages - name = "%dx%dx%d_%dx%d_%dx%d" % ( - self.cluster_shape[0], self.cluster_shape[1], self.cluster_shape[2], - self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], emit_stages) - - if self.persistent: - name += '_persistent' - return name - -# - - -class TensorDescription: - def __init__(self, element, layout, alignment=1, complex_transform=cutlass.complex_transform.none): - self.element = element - self.layout = layout - self.alignment = min(128 // DataTypeSize[self.element], alignment) - self.complex_transform = complex_transform - -# - - -class SymmetricTensorDescription: - def __init__(self, element, layout, fill_mode, alignment=1, complex_transform=cutlass.complex_transform.none, side_mode=SideMode.Left): - self.element = element - self.layout = layout - self.fill_mode = fill_mode - self.alignment = alignment - self.complex_transform = complex_transform - self.side_mode = side_mode - -# - - -class TriangularTensorDescription: - def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment=1, complex_transform=cutlass.complex_transform.none): - self.element = element - self.layout = layout - self.side_mode = side_mode - self.fill_mode = fill_mode - self.diag_type = diag_type - self.alignment = alignment - self.complex_transform = complex_transform - -################################################################################################### - -# -def CalculateSmemUsagePerStage(operation): - """ - Returns the amount of shared memory in bytes consumed in a single stage of a kernel. - - :param op: operation for which the maximum stages should be computed. If stages are - set via the `op.tile_description.stages` parameter, this setting is ignored - in the present calculation - :type op: pycutlass.Operation - - :return: number of bytes of shared memory consumed by a single stage - :rtype: int - """ - m, n, k = operation.tile_description.threadblock_shape - - if operation.operation_kind == OperationKind.Gemm: - stage_barrier_bytes = 32 - return (DataTypeSize[operation.A.element] * m * k // 8) + \ - (DataTypeSize[operation.B.element] * k * n // 8) + stage_barrier_bytes - else: - raise Exception('Unsupported operation kind {}.'.format(operation.operation_kind)) - - -# -def CalculateSmemUsage(operation): - """ - Returns the amount of shared memory in bytes consumed by a kernel. - - :param op: operation for which the maximum stages should be computed. If stages are - set via the `op.tile_description.stages` parameter, this setting is ignored - in the present calculation - :type op: pycutlass.Operation - - :return: int - """ - return operation.tile_description.stages * CalculateSmemUsagePerStage(operation) - - -class ApiVersion(enum.Enum): - """ - Differentiate between CUTLASS 2.x and 3.x API versions - """ - v2x = enum_auto() - v3x = enum_auto() - - -def api_version(arch, opclass, datatype): - """ - Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x - or 3.x for code emission. - - :param arch: compute capability of device on which to run - :type arch: int - :param opclass: class of the operation being performed - :type opclass: cutlass.OpClass - :param datatype: data type to be used in operation (assumes that ElementA and ElementB are the same) - - :return: API version to be used in code emission - :rtype: ApiVersion - """ - if arch >= 90 and opclass == cutlass.OpClass.TensorOp and (datatype != cutlass.float64): - return ApiVersion.v3x - else: - return ApiVersion.v2x - -################################################################################################### diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py deleted file mode 100644 index dacdc43a..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from pycutlass.test.profiler import * -from pycutlass.test.conv2d_testbed import * -from pycutlass.test.gemm_testbed import * -from pycutlass.test.gemm_grouped_testbed import * diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py deleted file mode 100644 index 43f2cee5..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +++ /dev/null @@ -1,632 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import pycutlass -from pycutlass import * -from pycutlass.test import * -from time import sleep -from bfloat16 import bfloat16 -import subprocess -from typeguard import typechecked -import re - - - -def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand): - ptr = tensor.__array_interface__['data'][0] - if operand == "a": - tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(conv_kind, problem_size) - elif operand == "b": - tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(conv_kind, problem_size) - elif operand in ["c", "d"]: - tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(conv_kind, problem_size) - else: - raise ValueError("unknown operand: " + operand) - - layout = tensor_layout.packed(tensor_coord) - - if tensor.dtype == np.float64: - return cutlass.TensorRefF64NHWC(ptr, layout) - elif tensor.dtype == np.float32: - return cutlass.TensorRefF32NHWC(ptr, layout) - elif tensor.dtype == np.float16: - return cutlass.TensorRefF16NHWC(ptr, layout) - if tensor.dtype == bfloat16: - return cutlass.TensorRefBF16NHWC(ptr, layout) - elif tensor.dtype == np.int32: - return cutlass.TensorRefS32NHWC(ptr, layout) - elif tensor.dtype == np.int8: - if tensor_layout == cutlass.TensorNC32HW32: - return cutlass.TensorRefS8NC32HW32(ptr, layout) - elif tensor_layout == cutlass.TensorC32RSK32: - return cutlass.TensorRefS8C32RSK32(ptr, layout) - else: - return cutlass.TensorRefS8NHWC(ptr, layout) - else: - raise ValueError("unsupported data type") - -def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand): - tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand) - - if operand == "a": - tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(conv_kind, problem_size) - elif operand == "b": - tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(conv_kind, problem_size) - elif operand in ["c", "d"]: - tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(conv_kind, problem_size) - else: - raise ValueError("unknown operand: " + operand) - - if tensor.dtype == np.float64: - return cutlass.TensorViewF64NHWC(tensor_ref, tensor_coord) - elif tensor.dtype == np.float32: - return cutlass.TensorViewF32NHWC(tensor_ref, tensor_coord) - elif tensor.dtype == np.float16: - return cutlass.TensorViewF16NHWC(tensor_ref, tensor_coord) - elif tensor.dtype == bfloat16: - return cutlass.TensorViewBF16NHWC(tensor_ref, tensor_coord) - elif tensor.dtype == np.int32: - return cutlass.TensorViewS32NHWC(tensor_ref, tensor_coord) - elif tensor.dtype == np.int8: - if tensor_layout == cutlass.TensorNC32HW32: - return cutlass.TensorViewS8NC32HW32(tensor_ref, tensor_coord) - elif tensor_layout == cutlass.TensorC32RSK32: - return cutlass.TensorViewS8C32RSK32(tensor_ref, tensor_coord) - else: - return cutlass.TensorViewS8NHWC(tensor_ref, tensor_coord) - - else: - raise ValueError("unsupported data type") - - - -# @typechecked -class Conv2dLauncher: - """ - Launcher that runs the operation on given problem size - """ - def __init__(self, operation: 'Conv2dOperation', seed: int=2080, interleaved=False, - verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None: - - self.enable_cached_results = True - self.interleaved = interleaved - - # create the reduction kernel - self.reduction_operation = ReductionOperation( - shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment), - C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_compute=operation.epilogue_functor.element_epilogue, epilogue_functor=operation.epilogue_functor, - count=operation.C.alignment - ) - - #: verify the output result - self.verification = verification - #: profile the kernel's runtime - self.profiling = profiling - - self.timer = GpuTimer() - - self.warmup_iterations = warmup_iterations - self.iterations = iterations - - if "sleep" in kwargs.keys(): - self.sleep_time = kwargs["sleep"] - else: - self.sleep_time = 0 - - # - # Compile the operator - # - - pycutlass.compiler.add_module([operation, self.reduction_operation]) - - self.operation = operation - - self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element) - self.layout_A = operation.A.layout - self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element) - self.layout_B = operation.B.layout - self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element) - self.layout_C = operation.C.layout - self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element) - self.layout_D = operation.C.layout - - accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator] - element_size = DataTypeSize[operation.A.element] - - if element_size <= 8: - self.scope = 1 - elif element_size == 16: - if accumulator_size <= 16: - self.scope = 2 - else: - self.scope = 4 - else: - self.scope = 7 - - # Seed - self.seed = seed - - self.conv_kind = operation.conv_kind - - - # - # Get the host reference function - # - - self.element_compute = operation.epilogue_functor.element_epilogue - - self.host_conv2d = cutlass.test.conv.host.conv2d - - self.timer = GpuTimer() - - @staticmethod - def numpy_type(type): - if type == cutlass.float64: - return np.float64 - elif type == cutlass.float32: - return np.float32 - elif type == cutlass.float16: - return np.float16 - elif type == cutlass.bfloat16: - return bfloat16 - elif type == cutlass.int32: - return np.int32 - elif type == cutlass.int8: - return np.int8 - else: - raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) - - def print_problem_size(self, p, split_k_mode=1): - print("nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d" - % (p.N, p.H, p.W, p.C, p.K, p.R, p.S, p.C, p.pad_h, - p.pad_w, p.stride_h, p.stride_w, p.dilation_h, p.dilation_w, p.split_k_slices, split_k_mode)) - - def uniform_init(self, size, dtype): - if dtype in [np.float32, np.float16, bfloat16, np.float64]: - return np.ceil( - np.random.uniform( - low=-self.scope - 0.5, high=self.scope - 0.5, - size=size).astype(dtype) - ) - else: - return np.random.uniform( - low=-self.scope - 1, high=self.scope + 1, - size=size).astype(dtype) - - def eq_gemm_size(self, problem_size): - n = problem_size.N - p = problem_size.P - q = problem_size.Q - k = problem_size.K - r = problem_size.R - s = problem_size.S - c = problem_size.C - h = problem_size.H - w = problem_size.W - if self.conv_kind == cutlass.conv.Operator.fprop: - return cutlass.gemm.GemmCoord(n * p * q, k, r * s * c) - elif self.conv_kind == cutlass.conv.Operator.dgrad: - return cutlass.gemm.GemmCoord(n * h * w, c, k * r * s) - else: - return cutlass.gemm.GemmCoord(k, r * s * c, n * p * q) - - def bytes(self, problem_size, alpha, beta): - mnk = self.eq_gemm_size(problem_size) - - bytes_ = \ - (DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k() + \ - (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k() + \ - (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n() - - if beta != 0: - bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n() - - return bytes_ - - def flops(self, problem_size): - mnk = self.eq_gemm_size(problem_size) - - flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2 - flops_epilogue_ = mnk.m() * mnk.n() * 2 - - # Adjust mainloop flop for dgrad stride - if self.conv_kind == cutlass.conv.Operator.dgrad: - flops_mainloop_ = flops_mainloop_ // (problem_size.stride_h * problem_size.stride_w) - - flops_total_ = flops_mainloop_ + flops_epilogue_ - - return flops_total_ - - - - def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): - if self.element_compute == cutlass.float16: - alpha = cutlass.float16(alpha) - beta = cutlass.float16(beta) - elif self.element_compute == cutlass.int32: - alpha = int(alpha) - beta = int(beta) - else: - alpha = alpha - beta = beta - - # if cached result is loaded - cached_result_loaded = False - - if self.enable_cached_results: - # get problem key - cached_test_key = cutlass.test.conv.host.CreateCachedConv2dTestKey( - self.conv_kind, problem_size, alpha, beta, - getTensorView(tensor_A, self.layout_A, self.conv_kind, problem_size, "a"), - getTensorView(tensor_B, self.layout_B, self.conv_kind, problem_size, "b"), - getTensorView(tensor_C, self.layout_C, self.conv_kind, problem_size, "c"), - ) - - cached_test_result = cutlass.test.conv.host.CachedTestResult() - - conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (self.operation.arch, self.seed) - - cached_results = cutlass.test.conv.host.CachedTestResultListing(conv2d_result_cache_name) - # CachedTestResultListing cached_results(conv2d_result_cache_name); - cached = cached_results.find(cached_test_key) - cached_result_loaded = cached[0] - if cached_result_loaded : - cached_test_result = cached[1] - - if not cached_result_loaded: - # compute the conv2d on host - tensor_D_ref = np.ones_like(tensor_C) - tensor_ref_A = getTensorRef(tensor_A, self.layout_A, self.conv_kind, problem_size, "a") - tensor_ref_B = getTensorRef(tensor_B, self.layout_B, self.conv_kind, problem_size, "b") - tensor_ref_C = getTensorRef(tensor_C, self.layout_C, self.conv_kind, problem_size, "c") - tensor_ref_D_ref = getTensorRef(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d") - - self.host_conv2d( - self.conv_kind, problem_size, - tensor_ref_A, tensor_ref_B, tensor_ref_C, tensor_ref_D_ref, - alpha, beta - ) - - tensor_view_D_ref = getTensorView(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d") - - if self.enable_cached_results: - cached_test_result.D = cutlass.test.conv.host.TensorHash(tensor_view_D_ref) - cached_results = cutlass.test.conv.host.CachedTestResultListing(conv2d_result_cache_name) - cached_results.append(cached_test_key, cached_test_result) - cached_results.write(conv2d_result_cache_name) - else: - return tensor_D_ref - - return cached_test_result.D - - def equal(self, tensor_D, tensor_D_ref, problem_size): - if self.enable_cached_results: - tensor_view_D = getTensorView(tensor_D, self.layout_D, self.conv_kind, problem_size, "d") - tensor_D_hash = cutlass.test.conv.host.TensorHash(tensor_view_D) - - return tensor_D_hash == tensor_D_ref - else: - tensor_view_D = getTensorView(tensor_D, self.layout_D, self.conv_kind, problem_size, "d") - tensor_view_D_ref = getTensorView(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d") - return cutlass.test.conv.host.equals(tensor_view_D, tensor_view_D_ref) - - def run_cutlass_profiler(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial, alpha=1.0, beta=0.0): - - if split_k_mode == cutlass.conv.SplitKMode.Serial: - split_k_mode_ = "serial" - else: - split_k_mode_ = "parallel" - - cutlass_path = os.getenv('CUTLASS_PATH') - assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." - - values = { - "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler", - "kernel_name": self.operation.procedural_name(), - "verification_providers": "device", - "provider": "cutlass", - 'n': str(problem_size.N), - 'h': str(problem_size.H), - 'w': str(problem_size.W), - 'c': str(problem_size.C), - 'k': str(problem_size.K), - 'r': str(problem_size.R), - 's': str(problem_size.S), - 'p': str(problem_size.P), - 'q': str(problem_size.Q), - 'pad_h': str(problem_size.pad_h), - 'pad_w': str(problem_size.pad_w), - 'stride_h': str(problem_size.stride_h), - 'stride_w': str(problem_size.stride_w), - 'dilation_h': str(problem_size.dilation_h), - 'dilation_w': str(problem_size.dilation_w), - 'split_k_slices': str(problem_size.split_k_slices), - 'split_k_mode': split_k_mode_, - 'alpha': str(alpha), - 'beta': str(beta), - 'warmup': str(self.warmup_iterations), - 'profile': str(self.iterations) - } - - cmd_template = \ - "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \ - " --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}" \ - " --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}" \ - " --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}" \ - " --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}" - - cmd = SubstituteTemplate(cmd_template, values) - result = subprocess.getoutput(cmd) - - m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) - runtime = float(m.group('runtime')) - - m = re.search(r"Bytes:\s+(?P\d+)", result) - bytes = int(m.group('bytes')) - - m = re.search(r"FLOPs:\s+(?P\d+)", result) - flops = int(m.group('flops')) - - # check if the problem size matches - assert bytes == self.bytes(problem_size, alpha, beta) - assert flops == self.flops(problem_size) - - return runtime - - - - def run(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial, - alpha=1.0, beta=0.0): - - assert get_allocated_size() == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() - - # - # Initialize input and output tensors - # - tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(self.conv_kind, problem_size) - tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(self.conv_kind, problem_size) - tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(self.conv_kind, problem_size) - - np.random.seed(self.seed) - - tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A) - tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B) - tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C) - tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D) - - - # - # Launch kernel - # - - arguments = Conv2dArguments( - operation=self.operation, problem_size=problem_size, A=tensor_A, - B=tensor_B, C=tensor_C, D=tensor_D, - output_op = self.operation.epilogue_type(alpha, beta), - split_k_slices=problem_size.split_k_slices, - split_k_mode=split_k_mode - ) - - if split_k_mode == cutlass.conv.SplitKMode.Parallel: - implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(self.operation.conv_kind, arguments.problem_size) - reduction_arguments = ReductionArguments( - self.reduction_operation, - problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], partitions=problem_size.split_k_slices, - workspace=arguments.ptr_D, - destination=tensor_D, - source=tensor_C, - output_op = self.reduction_operation.epilogue_type(alpha, beta) - ) - - self.operation.run(arguments) - if split_k_mode == cutlass.conv.SplitKMode.Parallel: - self.reduction_operation.run(reduction_arguments) - - passed = True - if self.verification: - if split_k_mode == cutlass.conv.SplitKMode.Parallel: - reduction_arguments.sync() - else: - arguments.sync() - - tensor_D_ref = self.host_reference(problem_size, tensor_A, tensor_B, tensor_C, alpha, beta) - - passed = self.equal(tensor_D, tensor_D_ref, problem_size) - - try: - assert passed - except AssertionError: - self.print_problem_size(problem_size, split_k_mode) - - if self.profiling: - sleep(self.sleep_time) - for _ in range(self.warmup_iterations): - self.operation.run(arguments) - if split_k_mode == cutlass.conv.SplitKMode.Parallel: - self.reduction_operation.run(reduction_arguments) - - self.timer.start() - for _ in range(self.warmup_iterations): - self.operation.run(arguments) - if split_k_mode == cutlass.conv.SplitKMode.Parallel: - self.reduction_operation.run(reduction_arguments) - self.timer.stop_and_wait() - runtime = self.timer.duration(self.iterations) - - # free memory - del arguments - if split_k_mode == cutlass.conv.SplitKMode.Parallel: - del reduction_arguments - - assert get_allocated_size() == 0, "%d byte of pool memory is not released after current run" % get_allocated_size() - if self.profiling: - return runtime - return passed - - - -######################################################################################################## -# TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference -# TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes -# Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes -# (conv_blacklist_sizes) -############################################################################################################ - -def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleaved=False): - passed = True - # - # Testbed object - # - - testbed = Conv2dLauncher(operation, interleaved=interleaved) - - # - # Get conv problem sizes to run conv operator - # - - conv_problems = cutlass.test.conv.TestbedConv2dProblemSizes(64) - - # Vector of conv2d problem sizes to avoid duplicate runs - conv_tested_sizes = [] - - # Flatten 2D problem_vectors into a 1D problem sizes - problem_sizes = conv_problems.conv2d_default_sizes - - problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes - - # Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0) - for conv_problem in problem_sizes: - - if conv_problem in conv_tested_sizes: - continue - - # skip channel dimension % 32 != 0 for interleaved case - if interleaved: - if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0: - continue - - # - # Procedurally disable certain cases - # - - # CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Unity: - if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)): - continue - - if not interleaved: - # Fixed channels algorithm requires channel count to match access size - if operation.iterator_algorithm == cutlass.conv.IteratorAlgorithm.fixed_channels: - if conv_problem.C != operation.A.alignment: - continue - - # Few channels algorithm requires channel count to match access size - if operation.iterator_algorithm == cutlass.conv.IteratorAlgorithm.few_channels: - if conv_problem.C % operation.A.alignment: - continue - - # CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} - # Although strided dgrad works for all stride combinations, we are only going - # to run strided dgrad for non-unity strides - - if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Strided: - if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1): - continue - - # - # Test - # - - # push back tested problem size to avoid re-running duplicates - conv_tested_sizes.append(conv_problem) - - passed = testbed.run(conv_problem) - - if not passed: - return False - - if interleaved: - return True - # - # filter the cases for split K - # - - # Small-channels convolution can't run here. - if operation.iterator_algorithm in [cutlass.conv.IteratorAlgorithm.fixed_channels, cutlass.conv.IteratorAlgorithm.few_channels]: - return True - - # CUTLASS DGRAD's *stride* specialization does not support split-k mode - if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Strided: - conv_problem = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 56, 56, 8), - cutlass.Tensor4DCoord(8, 1, 1, 8), - cutlass.Tensor4DCoord(0, 0, 0, 0), - cutlass.MatrixCoord(2, 2), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, - 1, 1 - ) - passed = testbed.run(conv_problem) - - return passed - - # Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for - # a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters - # which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep - # alpha and beta for local testing, but only runs one value for alpha and beta. - - conv2d_split_k_test_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(1, 17, 11, 288), - cutlass.Tensor4DCoord(160, 3, 3, 288), - cutlass.Tensor4DCoord(1, 1, 1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.MatrixCoord(1, 1), - cutlass.conv.Mode.cross_correlation, - 1, 1 - ) - - split_k_modes = [cutlass.conv.SplitKMode.Parallel, cutlass.conv.SplitKMode.Serial] - - split_k_slices = [1, 2, 3, 4, 201] - problem_alpha = [2.0,] - problem_beta = [2.0,] - - for split_k_mode in split_k_modes: - for split_k_slice in split_k_slices: - for alpha in problem_alpha: - for beta in problem_beta: - passed = testbed.run(conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), - split_k_mode, - alpha, beta) - - return passed diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py deleted file mode 100644 index 70e44b18..00000000 --- a/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from pycutlass.utils.reference_model import * diff --git a/tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt b/tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt deleted file mode 100644 index 91cbe531..00000000 --- a/tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt +++ /dev/null @@ -1,274 +0,0 @@ -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1767700736 2104699940 3506659864 557648934 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1539314507 3971227455 1976927351 1642148785 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 276489656 653235219 3147305346 880610205 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 272457724 2178229139 2786201726 4170295839 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 242235041 2149454506 784935854 682531065 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 3478189705 1667216236 1437761176 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 379326961 1780379994 3740415776 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 924848818 3533854396 2683779476 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 359232443 2147867990 1653277018 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 3784314846 2644315999 4224154526 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3787448414 3562991793 535073859 2563373454 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 426169840 2464808416 864648234 461884698 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2564934525 3910792915 3577331017 827498183 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 28479234 867695528 1947311971 83328334 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4192922822 4244595864 2296602326 2349214706 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 274678245 3464152269 1682550229 3446204619 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3993280136 828543035 1319748516 956044554 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 832003025 3799813757 4030292245 457791957 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1444316594 4129865888 93616503 412257611 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 36703874 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 1842147148 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1612565294 109894479 1782187316 3370789453 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 841569299 1010785577 1158956167 3261208135 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1893352157 48149942 3544807462 446577726 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 3585320147 2150950452 1625817025 3964129474 -conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 289918791 2624928614 3423533117 3186342135 -conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 2732296888 1838622641 4203745561 -conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3456572634 893492926 1966259884 -conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 671982235 4014726279 4027869577 1510990157 -conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 798317794 4140605332 3580988556 3425909428 -conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1721270411 2106553169 835800311 3417471222 -conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 860217059 166776702 1109666471 -conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2128738105 855244826 2670006594 3857976152 -conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1931093565 3079461262 3579256638 2926210806 -conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2472246681 2952423142 2045838875 3445165841 -conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2956871200 2133381336 2601441527 2035094220 -conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 365467186 1700915522 2515933441 406719240 -conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 3347784734 156533442 1012781676 688128904 -conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 927718585 3117803557 1370701307 1462167731 -conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 973422497 1926250028 3440543762 -conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 2892862516 3649300762 1521470286 -conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2075083065 3181416651 1733426984 872275640 -conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4005590448 1639170045 388151578 4186957447 -conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 181075276 1433744686 860506550 3475157408 -conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1747719409 877465841 2345541783 -conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 856324887 2307248012 337386755 3363072703 -conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1906605830 722034901 2562804622 2508759317 -conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 805717279 2196645331 3235235362 1518334120 -conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3168796339 72559978 778918419 1260968000 -conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 261954979 2634885882 451986822 3792829599 -conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 3747142491 2426759809 2622222681 371723930 -conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2056905385 3612826298 2531545294 476754549 -conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 2391975923 197605094 3409942185 -conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3071904063 408984565 2378809888 -conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3414629540 3067676760 1540919649 2008865071 -conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4100326666 1085505037 2778215386 230227569 -conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3662895757 2731079464 3570839563 3483629877 -conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 408419601 3415600242 2106927195 -conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2154102133 3606099389 4034802752 3200055633 -conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2609259399 3910244699 1319285699 2229775542 -conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2948772873 2780071616 2703730845 3090625734 -conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 752289976 4278696824 360883914 3802692600 -conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3723912751 653419877 359675571 283806385 -conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 2027599472 1075980921 3101013494 2025203940 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 991402150 1393431534 1148212814 1350914659 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 4283492776 419570292 1210341563 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4178596783 3828059710 2735749436 2671012171 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 924522595 563724475 3750778972 4152580670 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1021044158 1686067905 3765040166 4102272733 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 2674994719 635224486 2759329777 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 4201252830 2920298728 304256151 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 70289262 646435722 4137562540 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 1288095320 2132879813 656196754 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 2202157489 2326567490 2475188414 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2476454437 1857118302 4164386062 239840568 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2767650699 3514840131 590439733 3879821123 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3896287283 3112762669 2515107934 2106635937 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1903067870 1021832870 3003938078 2751931686 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3489785028 2466126497 1374078692 2737628040 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2051350923 263676708 3639860119 1370886256 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 719099834 1474713672 204857540 2768940347 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3441724486 3162593831 421721594 3097845598 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2034354027 1249407570 2567025479 1441082595 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 2369653089 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 1218705038 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 172579142 319546523 718795680 1453661415 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2823351660 1326352711 1110204809 1155441703 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3238446487 2572503545 686287700 1559476701 -conv2d fprop_1x8x8x1_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 991402150 1883874274 1180207512 3934800419 -conv2d fprop_1x16x16x1_8x8_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 24290453 4230587034 4117433929 2540623821 -conv2d fprop_1x16x16x1_12x12_16x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 24290453 3802993432 1563447158 515257167 -conv2d fprop_1x224x224x1_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 2583340103 3928463259 1564251818 -conv2d fprop_1x224x224x1_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 2966178620 3457283045 1726663817 -conv2d fprop_1x224x224x1_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 1794561978 3101289788 3492498648 -conv2d fprop_1x224x224x1_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 1794561978 498358130 4111289929 -conv2d fprop_1x8x8x2_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2693144988 3876248534 3038023830 1910263513 -conv2d fprop_1x16x16x2_8x8_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 3355193355 319259163 535683577 -conv2d fprop_1x16x16x2_12x12_16x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 1548147432 3385829172 2741952709 -conv2d fprop_1x224x224x2_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 2686562907 3948710179 3669872932 -conv2d fprop_1x224x224x2_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 576815792 2317227037 1211532666 -conv2d fprop_1x224x224x2_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 27596985 555460201 895685163 -conv2d fprop_1x224x224x2_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 27596985 1465341652 2228916523 -conv2d fprop_1x8x8x4_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 24290453 137535877 1436667267 1395660627 -conv2d fprop_1x224x224x4_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 2226159049 4051661898 209529384 -conv2d fprop_1x224x224x4_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 3541851870 2271016226 2671623385 -conv2d fprop_1x224x224x4_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 982184919 2007343215 3362992769 -conv2d fprop_1x224x224x4_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 982184919 20610297 1086800078 -conv2d fprop_1x8x8x8_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 3117444553 1497663382 3561001103 -conv2d fprop_1x224x224x8_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 1414143072 827338392 2827855918 -conv2d fprop_1x224x224x8_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 3886996022 26545788 3407771964 -conv2d fprop_1x224x224x8_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 380272816 2374613655 3601677176 -conv2d fprop_1x224x224x8_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 380272816 778374730 2110111988 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1736512560 49406874 846358010 3314905564 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1848484956 1432417472 1903569827 3750799351 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4236427320 3696009469 69852620 201921851 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 109006944 450017448 1793784844 903209915 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 813367872 2397796503 1928191746 3210229460 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1307184141 46021356 1674017987 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1212511562 3331767121 2446286369 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 2013675943 1681111033 1469213228 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 500298386 3218034344 4159283207 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 1123534155 145385311 4273847179 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3862659311 349459322 1503631520 1404971956 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1623686755 961217371 552550209 3980749384 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3554927580 1131648083 4149599295 3119557776 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1767639287 3350675774 128324027 1059816532 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3986143536 17411088 40173029 1694092310 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1157793540 3513299281 48848814 1435528367 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 988962069 4292634763 388976034 2674929544 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4202383208 3529769234 1046186503 3368902675 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 856448884 3057259762 2063087558 1995545427 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 400986166 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 1082696406 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2702905851 1992889713 731289041 608504198 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2742293143 4197915274 606840 3671124731 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 149434841 2288560511 2994968424 2881838300 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 2226824643 327135318 3718671210 2121176659 -conv2d fprop_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3254575292 1119957081 672831271 -conv2d fprop_1x4x4x14_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3115523958 3622905002 4020453928 3853387318 -conv2d fprop_1x23x56x98_10x22_128x3x3_pad_h4w5_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1702870033 1876930844 1190400523 3937287850 -conv2d fprop_1x4x4x28_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 2587856937 2021107274 2789519899 -conv2d fprop_1x23x56x100_10x22_128x3x3_pad_h4w5_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2368669977 1353376771 744357395 786349633 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 991402150 1393431534 2496492611 3901723984 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4208297221 4283492776 3148637036 258220505 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4178596783 3828059710 281106520 1103939403 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 924522595 563724475 1938163814 2197809394 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1021044158 1686067905 350851834 3999808950 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3335547 2674994719 1034822169 1611033520 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3335547 4201252830 1597212204 2181492560 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3335547 70289262 3001492060 1379239000 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1317457392 1288095320 4211138051 2804617605 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1317457392 2202157489 1043108884 2923122465 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2476454437 1857118302 3877008798 1206012078 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2767650699 3514840131 2946529611 3907056932 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3896287283 3112762669 1581171257 3959460786 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1903067870 1021832870 1926804094 1756790353 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3489785028 2466126497 1712378956 434322965 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2051350923 263676708 355203300 821870356 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 719099834 1474713672 2886387159 4086314983 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3441724486 3162593831 1422796372 2049419539 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2034354027 1249407570 1196036582 2684312264 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 941893937 3608468045 2198911423 1060050551 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 941893937 3608468045 2198911423 3361618746 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 172579142 319546523 2332616929 543467298 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2823351660 1326352711 3839068434 65031397 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3238446487 2572503545 3604065639 2111204111 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 2149247508 1775375365 2663631601 1249487679 -conv2d fprop_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 403997062 1679063623 4062928786 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 3464637181 1623218578 436154205 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1479940693 3253144559 3883419107 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 1871463331 2425320272 74566211 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3484040069 664160900 3610888033 22347127 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1924855848 1382111427 2541177413 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 868180534 1764715518 3070473696 2392864704 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3437976747 666906244 3401957738 2050602745 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4195072693 1575210381 781892324 2848949054 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3457330201 2316839359 1539389419 4293781748 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 2693098375 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 1969608051 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1690216859 554790212 2885143346 780489333 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 835105643 3337423971 3866137775 -conv2d dgrad_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2956180805 1092015789 3160693693 1526395881 -conv2d dgrad_1x56x56x12_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 1941683430 2236679600 3168985259 -conv2d dgrad_1x55x55x12_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 1941683430 3784328837 471971363 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 289918791 1266976707 942688231 3457364823 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1027662440 2005082293 2235558527 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3380032042 1370040310 1348846927 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 671982235 1423304149 2107662762 1234913781 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 798317794 1709026638 2421185623 3308071321 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1721270411 2519327328 2541413264 3185574975 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2070174510 1364436192 3531942595 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2128738105 2056902987 3079166829 2329433528 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3227877956 645422556 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3817218800 985231315 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 1398036015 3630062764 2492522537 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2784049299 643733019 3649549642 2637869234 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2332160299 302086821 3303132343 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1931093565 2458714707 2919710256 2311575036 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2472246681 2260022344 500095455 2760458995 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1530672622 3635363851 2402907878 4131497953 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1500864134 2536338700 2459524764 2504484273 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3344871528 2667385029 2714805835 3487838445 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 966721255 1547169349 3198573835 302049294 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 1317923157 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 3186679687 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4028893260 4220759192 2236533218 3731336532 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2956871200 1591352238 1756650151 1262787222 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 365467186 892422645 1334708242 1372556938 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 3347784734 150035460 2897171548 3701081496 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 927718585 4106152802 2634710231 744755886 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 3464637181 2709881923 2407415563 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 3723472741 3733128758 3129111191 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2075083065 2042513140 253288229 404121198 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4005590448 1116254439 525487530 3284739065 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 181075276 1743485155 91136873 2508716910 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 386662952 1127709182 4026285141 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 856324887 3954249564 2591894666 2655687700 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1300426008 1263618595 1313664339 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1300426008 1756414462 2995557277 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 447261065 121940906 1497499264 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3484040069 2966693627 1423016429 341928547 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1759979610 2761559427 68093525 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1906605830 2980501720 1650970502 3258883197 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 805717279 3502822733 3985958544 2568949300 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 868180534 3289288595 385631111 328914986 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3437976747 3391080565 1513955316 1521294163 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4195072693 1669352457 2608107448 4284090805 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3457330201 1126870455 106232038 3054809396 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 1723074453 1186911503 4239438967 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 1723074453 1186911503 2113601884 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1690216859 2413490039 36034283 1112346965 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3168796339 1601750164 14375779 2894970748 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 261954979 1300976652 4259930640 305685205 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 3747142491 1747587481 4137156526 1174257270 -conv2d wgrad_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2956180805 1086820986 1644914756 2013471312 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2056905385 447674669 724481645 1457430910 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 1227883689 3401425854 3897766524 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3749787834 3350064812 1136116240 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3414629540 820341033 770836461 2451581199 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4100326666 2581696511 1088458082 1521190911 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3662895757 2885454895 935600441 2615245898 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 3831334389 3506139121 814982501 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2154102133 737968461 1291834254 2665225480 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 1809195644 1765637461 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 3379808294 483095299 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 4194153035 2863868771 1639389008 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2624318208 157618421 1779474147 814087242 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 2300180628 423968553 3890279569 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2609259399 1848932917 522753581 1926508271 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2948772873 3663040534 4014266327 1288646188 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3271403719 1585195072 1487505772 3253374264 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1419588777 451194147 3578359696 3659768981 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 763924990 2780826684 2883769406 148530958 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2578426561 3849874822 102765469 1305171059 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1516344656 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1586331550 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2462511240 2274021368 1188866747 3178890497 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 752289976 1226457131 4187777346 1400559240 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3723912751 1585959358 3731079159 1498901684 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 2027599472 2758666204 3287095476 4291916486 -conv2d wgrad_1x8x8x1_8x8_1x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1767700736 4278264698 2331753571 2554564568 -conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 927718585 3117803557 1370701307 1462167731 -conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 973422497 1926250028 3440543762 -conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 2892862516 3649300762 1521470286 -conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2075083065 3181416651 1733426984 872275640 -conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4005590448 1639170045 388151578 4186957447 -conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 181075276 1433744686 860506550 3475157408 -conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1747719409 877465841 2345541783 -conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 856324887 2307248012 337386755 3363072703 -conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1906605830 722034901 2562804622 2508759317 -conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 805717279 2196645331 3235235362 1518334120 -conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3168796339 72559978 778918419 1260968000 -conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 261954979 2634885882 451986822 3792829599 -conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 3747142491 2426759809 2622222681 371723930 diff --git a/tools/library/scripts/pycutlass/test/example/run_all_example.sh b/tools/library/scripts/pycutlass/test/example/run_all_example.sh deleted file mode 100755 index 0a51ccf6..00000000 --- a/tools/library/scripts/pycutlass/test/example/run_all_example.sh +++ /dev/null @@ -1,112 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -pushd $CUTLASS_PATH/examples/40_cutlass_py/customizable - -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4 - -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 - -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3 - -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5 - -python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 - -python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device - -python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host - -python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device - -python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device - -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0 - -python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0 - -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 - -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu - -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh - -python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid - -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu - -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish - -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 - -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 - -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3 -popd diff --git a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py deleted file mode 100644 index 8eaf42f6..00000000 --- a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +++ /dev/null @@ -1,154 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Test cases for frontends -""" - -import pycutlass -import unittest -from pycutlass import * -from pycutlass.utils.device import device_cc - - -class Test_Frontend(unittest.TestCase): - def setUp(self) -> None: - # - # define the cutlass operator - # - cc = device_cc() - math_inst = MathInstruction( - [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, - cutlass.OpClass.Simt, MathOperation.multiply_add - ) - - stages = 2 - tile_description = TileDescription( - [128, 128, 8], stages, [2, 4, 1], - math_inst - ) - - A = TensorDescription( - cutlass.float32, cutlass.RowMajor, 1 - ) - - B = TensorDescription( - cutlass.float32, cutlass.RowMajor, 1 - ) - - C = TensorDescription( - cutlass.float32, cutlass.RowMajor, 1 - ) - - epilogue_functor = LinearCombination( - C.element, C.alignment, - math_inst.element_accumulator, cutlass.float32) - - self.operation = GemmOperationUniversal( - arch=cc, tile_description=tile_description, - A=A, B=B, C=C, - epilogue_functor=epilogue_functor, - swizzling_functor=cutlass.IdentitySwizzle1 - ) - - pycutlass.compiler.add_module([self.operation,]) - - - def test_torch_frontend(self): - try: - import torch - except: - self.assertTrue(False, "Unable to import torch") - - problem_size = cutlass.gemm.GemmCoord(512, 256, 128) - - tensor_A = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.k()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) - tensor_B = torch.ceil(torch.empty(size=(problem_size.k(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) - tensor_C = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) - tensor_D = torch.empty_like(tensor_C) - - - alpha = 1.0 - beta = 0.0 - - arguments = GemmArguments( - operation=self.operation, problem_size=problem_size, - A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=self.operation.epilogue_type(alpha, beta), - gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 - ) - - self.operation.run(arguments) - - arguments.sync() - - tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C - - self.assertTrue(torch.equal(tensor_D, tensor_D_ref)) - - def test_cupy_frontend(self): - try: - import cupy as cp - except: - self.assertTrue(False, "Unable to import cupy") - - cp.cuda.set_allocator(rmm.rmm_cupy_allocator) - - problem_size = cutlass.gemm.GemmCoord(512, 256, 128) - - tensor_A = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.m(), problem_size.k()), dtype=cp.float32)) - tensor_B = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.k(), problem_size.n()), dtype=cp.float32)) - tensor_C = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.m(), problem_size.n()), dtype=cp.float32)) - tensor_D = cp.ones_like(tensor_C) - - alpha = 1.0 - beta = 1.0 - - tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C - - arguments = GemmArguments( - operation=self.operation, problem_size=problem_size, - A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=self.operation.epilogue_type(alpha, beta), - gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 - ) - - self.operation.run(arguments) - - arguments.sync() - - self.assertTrue(cp.array_equal(tensor_D, tensor_D_ref)) - - -if __name__ == '__main__': - pycutlass.get_memory_pool(2**32, 2**32) - unittest.main() diff --git a/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt b/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt deleted file mode 100644 index c026860a..00000000 --- a/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt +++ /dev/null @@ -1,363 +0,0 @@ -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1767700736 2104699940 3506659864 557648934 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1539314507 3971227455 1976927351 1642148785 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 276489656 653235219 3147305346 880610205 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 272457724 2178229139 2786201726 4170295839 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 242235041 2149454506 784935854 682531065 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 3478189705 1667216236 1437761176 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 379326961 1780379994 3740415776 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 924848818 3533854396 2683779476 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 359232443 2147867990 1653277018 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 3784314846 2644315999 4224154526 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3787448414 3562991793 535073859 2563373454 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 426169840 2464808416 864648234 461884698 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2564934525 3910792915 3577331017 827498183 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 28479234 867695528 1947311971 83328334 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4192922822 4244595864 2296602326 2349214706 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 274678245 3464152269 1682550229 3446204619 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3993280136 828543035 1319748516 956044554 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 832003025 3799813757 4030292245 457791957 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1444316594 4129865888 93616503 412257611 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 36703874 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 1842147148 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1612565294 109894479 1782187316 3370789453 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 841569299 1010785577 1158956167 3261208135 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1893352157 48149942 3544807462 446577726 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 3585320147 2150950452 1625817025 3964129474 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 1227883689 3016005301 4142905842 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3337296764 4183699161 3654176452 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3852963969 864006170 920352568 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2624318208 2750240096 2120184232 2600672872 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 3224082300 2084034673 3588056946 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3271403719 3033073939 304048758 1882633089 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1419588777 610026473 447427404 2639856195 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 763924990 2818680871 58428273 3332443900 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2578426561 1891702153 103393067 2558647731 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 162127134 3567670201 3173514764 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 162127134 3567670201 363897018 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2462511240 1350938697 1696306119 1005311005 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3884703009 3552725366 1975514757 1210310496 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2056905385 447674669 724481645 1457430910 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 1227883689 3401425854 3897766524 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3749787834 3350064812 1136116240 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3414629540 820341033 770836461 2451581199 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4100326666 2581696511 1088458082 1521190911 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3662895757 2885454895 935600441 2615245898 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 3831334389 3506139121 814982501 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2154102133 737968461 1291834254 2665225480 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 1809195644 1765637461 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 3379808294 483095299 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 4194153035 2863868771 1639389008 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2624318208 157618421 1779474147 814087242 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 2300180628 423968553 3890279569 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2609259399 1848932917 522753581 1926508271 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2948772873 3663040534 4014266327 1288646188 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3271403719 1585195072 1487505772 3253374264 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1419588777 451194147 3578359696 3659768981 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 763924990 2780826684 2883769406 148530958 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2578426561 3849874822 102765469 1305171059 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1516344656 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1586331550 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2462511240 2274021368 1188866747 3178890497 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 752289976 1226457131 4187777346 1400559240 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3723912751 1585959358 3731079159 1498901684 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 2027599472 2758666204 3287095476 4291916486 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3393706648 3519979618 1149261202 799742106 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3409586999 409840186 1724648597 2642018980 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1815685330 1398622058 2431638856 1016967269 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2555706782 3271563943 1020153035 299097281 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4173830187 736684125 472021975 2064613035 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3010335403 2751224679 2250540122 3725638844 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3010335403 1583610315 3287895411 2394340435 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3010335403 2356047354 7055632 915702611 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2748205217 2539405983 1217377670 2011175578 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2748205217 2114448427 249997769 2711364520 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1528321643 1532777511 3597171412 296622236 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1326617037 3415095747 847196866 1481554158 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1122706355 2841974626 2791878604 632900093 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1728385278 2462678309 3066040807 1334515660 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2175275779 1117731224 857614711 2096711962 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4140401170 3710340185 1683575469 317397427 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3552249008 2918315307 2290683130 536859016 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2869959072 2516947012 3328285094 2393284712 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1349264322 1823945068 400087667 2893025864 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3321662203 426084311 4233055093 4078572279 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3321662203 426084311 4233055093 3044377475 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 803041205 2521863610 3206942690 127091020 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4083508736 37801570 240515127 2234797539 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2207374588 535059558 2268619394 1489214085 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 dnhwc_dnhwc_dnhwc_d_d 3614026280 1721563676 2979825951 1104908081 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 2226238626 2053372396 2462697514 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 235646718 1374133172 3696289981 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2536722089 184705847 3148323124 84213385 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2264868815 1724845245 3498302256 4094034457 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1621735632 233390337 1801952602 3532884734 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3048346885 2306163504 642074123 4083120683 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2798030672 683783039 3025345160 1890891136 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1731071506 1844675436 2292509333 4006304179 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 132147677 604503886 143348844 3037223953 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 1678940393 3405733837 1820114523 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 1678940393 3405733837 467254076 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1794301352 2320042028 2134048179 508141072 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 561590023 3382154048 4154621995 517057927 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 593915463 2360210889 2685491481 2265099675 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 2226238626 1155815529 558646991 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2536722089 1876429398 4216128545 1754596046 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 348523586 2609019785 3938405680 2601133907 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1984146316 1475870285 1157657800 1143965395 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2971058593 1478256319 503014742 3930504182 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1621735632 1214508920 1537003531 3830217225 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2031518387 2695641559 933408074 4026827730 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 517276344 1158854831 3123629043 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 517276344 1448394173 1864626308 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2536722089 711164468 2465036841 2993377049 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2264868815 3003481795 333430991 3094857755 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1621735632 1126010692 3313703859 637497110 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1130094757 2605103293 2477101661 1276123281 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4286533436 1302900889 2613245986 2523724148 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3048346885 923365529 1681226722 417509256 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2798030672 3441819646 1293178065 188472807 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1731071506 1117530547 2706270359 502156742 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 132147677 2029225588 3851064913 3164530726 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 2337137106 3312954197 2466682688 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 2337137106 3312954197 2684544683 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1794301352 72938921 2354994612 1463501392 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 252570564 2903451081 3619280116 1448586411 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2037991187 1665743881 241585763 103256264 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 dnhwc_dnhwc_dnhwc_d_d 2653975581 3337638999 1440125233 2448165745 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 991402150 1393431534 1148212814 1350914659 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 4283492776 419570292 1210341563 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4178596783 3828059710 2735749436 2671012171 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 924522595 563724475 3750778972 4152580670 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1021044158 1686067905 3765040166 4102272733 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 2674994719 635224486 2759329777 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 4201252830 2920298728 304256151 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 70289262 646435722 4137562540 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 1288095320 2132879813 656196754 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 2202157489 2326567490 2475188414 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2476454437 1857118302 4164386062 239840568 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2767650699 3514840131 590439733 3879821123 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3896287283 3112762669 2515107934 2106635937 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1903067870 1021832870 3003938078 2751931686 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3489785028 2466126497 1374078692 2737628040 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2051350923 263676708 3639860119 1370886256 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 719099834 1474713672 204857540 2768940347 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3441724486 3162593831 421721594 3097845598 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2034354027 1249407570 2567025479 1441082595 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 2369653089 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 1218705038 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 172579142 319546523 718795680 1453661415 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2823351660 1326352711 1110204809 1155441703 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3238446487 2572503545 686287700 1559476701 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_f_f 2149247508 1775375365 3317647029 2497607448 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 3464637181 1623218578 436154205 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1479940693 3253144559 3883419107 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 1871463331 2425320272 74566211 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3484040069 664160900 3610888033 22347127 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1924855848 1382111427 2541177413 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 868180534 1764715518 3070473696 2392864704 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3437976747 666906244 3401957738 2050602745 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4195072693 1575210381 781892324 2848949054 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3457330201 2316839359 1539389419 4293781748 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 2693098375 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 1969608051 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1690216859 554790212 2885143346 780489333 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 835105643 3337423971 3866137775 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 927718585 4106152802 720400339 3989318043 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 4051957661 126285749 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 3723472741 2044236350 2463899842 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2075083065 2042513140 3691286135 322550345 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4005590448 1116254439 2328237343 1918824440 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 181075276 1743485155 3526891198 1979405632 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 386662952 4057300775 1456746562 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 856324887 3954249564 2340393915 4127188930 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1300426008 2921497047 4145791960 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1300426008 4080981223 3076991942 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 447261065 3823545045 392205236 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 2966693627 3900095420 919511892 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1759979610 4272621682 1029257940 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1906605830 2980501720 978889789 3136018973 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 805717279 3502822733 1810065278 1387739380 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 3289288595 209477462 4142168174 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 3391080565 97275649 4063718293 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1669352457 2182133559 2494741804 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 1126870455 319272291 3811977088 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 1723074453 1660326213 3902884425 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 1723074453 1660326213 423159249 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 2413490039 223529410 3303697952 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3168796339 1601750164 1428743330 403295189 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 261954979 1300976652 2749562370 3058142403 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_f_f 3747142491 1747587481 3143977827 835130482 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1736512560 49406874 846358010 3314905564 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1848484956 1432417472 1903569827 3750799351 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4236427320 3696009469 69852620 201921851 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 109006944 450017448 1793784844 903209915 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 813367872 2397796503 1928191746 3210229460 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1307184141 46021356 1674017987 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1212511562 3331767121 2446286369 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 2013675943 1681111033 1469213228 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 500298386 3218034344 4159283207 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 1123534155 145385311 4273847179 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3862659311 349459322 1503631520 1404971956 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1623686755 961217371 552550209 3980749384 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3554927580 1131648083 4149599295 3119557776 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1767639287 3350675774 128324027 1059816532 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3986143536 17411088 40173029 1694092310 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1157793540 3513299281 48848814 1435528367 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 988962069 4292634763 388976034 2674929544 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4202383208 3529769234 1046186503 3368902675 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 856448884 3057259762 2063087558 1995545427 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 400986166 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 1082696406 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2702905851 1992889713 731289041 608504198 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2742293143 4197915274 606840 3671124731 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 149434841 2288560511 2994968424 2881838300 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 2226824643 327135318 3718671210 2121176659 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1027662440 4172720592 446082987 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1101653138 3727072529 875733988 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3906526127 655926291 939844058 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2784049299 2031878085 1709408312 1277173429 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 22652410 1700696921 2175632852 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1530672622 436588210 470857851 284463232 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1500864134 59350507 969037229 1510558485 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3344871528 856797938 2030818524 4231831552 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 966721255 2885833872 2829967135 3441569557 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 4148824382 2827420298 378131261 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 4148824382 2827420298 2955292920 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4028893260 1474248671 1302526250 4182204885 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1569788048 162506176 819639712 763595635 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 289918791 1266976707 942688231 3457364823 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1027662440 2005082293 2235558527 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3380032042 1370040310 1348846927 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 671982235 1423304149 2107662762 1234913781 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 798317794 1709026638 2421185623 3308071321 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1721270411 2519327328 2541413264 3185574975 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2070174510 1364436192 3531942595 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2128738105 2056902987 3079166829 2329433528 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3227877956 645422556 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3817218800 985231315 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 1398036015 3630062764 2492522537 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2784049299 643733019 3649549642 2637869234 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2332160299 302086821 3303132343 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1931093565 2458714707 2919710256 2311575036 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2472246681 2260022344 500095455 2760458995 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1530672622 3635363851 2402907878 4131497953 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1500864134 2536338700 2459524764 2504484273 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3344871528 2667385029 2714805835 3487838445 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 966721255 1547169349 3198573835 302049294 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 1317923157 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 3186679687 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4028893260 4220759192 2236533218 3731336532 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2956871200 1591352238 1756650151 1262787222 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 365467186 892422645 1334708242 1372556938 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 3347784734 150035460 2897171548 3701081496 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 945660191 3750377696 2496492611 3515056508 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2806300501 2591577756 3148637036 3845512743 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2322444122 3525997046 281106520 3456307300 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 327345109 1137297282 1938163814 2551101563 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 797067973 481331945 350851834 2477733239 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1316460560 2044204046 1034822169 3340281844 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1316460560 4174274001 1597212204 1881272946 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1316460560 1535088984 3001492060 2308505016 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3190527989 3733991924 4211138051 3710311115 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3190527989 3430768821 1043108884 4185640072 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 943531303 1948306075 3877008798 2803592376 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3262141476 4125717435 2946529611 2221512094 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1599291337 3982786366 1581171257 1188352423 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2237070215 3046262465 1926804094 1435916873 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 721666814 2012769306 1712378956 1388990183 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1596349869 3775131163 355203300 1126174452 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1380587417 1208642645 2886387159 3113955983 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1332573203 1417735573 1422796372 3309229181 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2714027800 2106992819 1196036582 2095126659 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1105097447 1992731268 2198911423 3378137735 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1105097447 1992731268 2198911423 3868431311 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2552471160 2218470296 2332616929 923645661 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2231354584 4035702005 3839068434 8981294 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 4019719318 3985307916 3604065639 277096636 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 bf16nhwc_bf16nhwc_fnhwc_f_f 258381429 3482776077 2663631601 593179089 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 188810648 1623218578 2585892217 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 691990354 3253144559 2988350639 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2788041828 1670375523 2425320272 2553108650 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1049321188 1865889553 3610888033 1459693945 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3820648800 3236781482 1382111427 1986396315 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 463742721 2524037630 3070473696 210045128 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 738614177 4071452982 3401957738 2920893800 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2479111539 2662555669 781892324 2338234282 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2089076160 260434096 1539389419 1219120658 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 3344412669 2885305868 1926445693 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 3344412669 2885305868 1478058549 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3945616248 4118489020 2885143346 1545684873 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 295760528 1685244361 3337423971 772814550 -conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 623727338 942771643 2634710231 3063349371 -conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 188810648 2709881923 3532383400 -conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2788041828 3762161398 3733128758 3693097785 -conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 139944998 3812563855 253288229 1359907535 -conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 492562992 3677108443 525487530 445191233 -conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 594197095 3773864559 91136873 4170763393 -conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3820648800 1025574686 1127709182 677727764 -conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1901075489 3296829308 2591894666 2932517926 -conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 4223561525 1263618595 50680160 -conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 4223561525 1756414462 3209752057 -conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2788041828 1023542180 121940906 624551470 -conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1049321188 296097075 1423016429 1058165639 -conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3820648800 4160685370 2761559427 1788182893 -conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1859384988 222880684 1650970502 1632078530 -conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1704522433 2403392926 3985958544 1432584676 -conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 463742721 3455033786 385631111 1683348880 -conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 738614177 3199562330 1513955316 2131256035 -conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2479111539 2702777753 2608107448 4014212857 -conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2089076160 4042009058 106232038 1140762595 -conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 2260768172 1186911503 3194129408 -conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 2260768172 1186911503 1312312812 -conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3945616248 2287161276 36034283 4262860382 -conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2906914535 476297538 14375779 1340176713 -conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 4292101959 3378414564 4259930640 1392755176 -conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 bf16nhwc_bf16nhwc_fnhwc_f_f 3529371817 368260304 4137156526 122558013 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 2948718568 2631391783 3260825675 4278587299 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 1635109696 2835574424 4179385325 2803281440 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 3344954627 1649157278 2032056735 1176638626 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 61750237 3452849177 1697665310 3475459781 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 1394759191 1571308277 898534533 4125341936 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 3402206912 2433594404 1575577431 4106154211 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 98638790 2735493952 346473870 1911666301 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 98638790 2735493952 346473870 2124440208 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 2934485636 3286257323 541566528 1113783492 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 164942943 4259285988 1250700182 508419908 -conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3805460372 2607401558 3465030781 210641751 -conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 4200926784 1001915027 387475271 3360115596 -conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 331078659 469730619 2547196469 1620698703 -conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 431968022 1614654085 903827412 1349891842 -conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3674369485 1055554271 3217013807 1356703347 -conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2609462247 3227824772 365527403 2720889763 -conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2609462247 2150996976 2899308770 2371758816 -conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2609462247 2124373651 2711906981 3194739760 -conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1070162100 2750964634 3090791018 3481982191 -conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1070162100 1563941622 767747438 3163252390 -conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 884815233 3576251756 3216742798 3534462723 -conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3230717758 3192193994 1161445944 371179683 -conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2450454245 2905280248 910194866 839083662 -conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2948718568 2631391783 638794727 4292051282 -conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1635109696 2835574424 1855687620 130932480 -conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3344954627 1649157278 4191418350 958044197 -conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 61750237 3452849177 3260472389 771128506 -conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1394759191 1571308277 4279538191 956191103 -conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3402206912 2433594404 2021112123 2983097553 -conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 98638790 2735493952 3178839372 568554158 -conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 98638790 2735493952 3178839372 18194802 -conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2934485636 3286257323 2559221535 2310182528 -conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 164942943 4259285988 984016853 888753301 -conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2823094147 1681845497 4242738907 3244428635 -conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 s8nhwc_s8nhwc_inhwc_i_i 4060010502 2881035321 3927119619 3311661122 -conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 1030377090 3211227145 -conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1479940693 2379046159 2482639965 -conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 1871463331 2718290800 1797658305 -conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 664160900 3954982568 985899371 -conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1924855848 1728786974 3821277575 -conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 1764715518 3998637379 2782670608 -conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 666906244 2107859856 831363691 -conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1575210381 2486552517 3268706408 -conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 2316839359 1729888024 2308314800 -conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 544154978 -conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 2469024119 464378888 3191247524 -conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 554790212 956712535 1281779197 -conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3184127693 835105643 4011933753 3207244654 diff --git a/tools/library/scripts/pycutlass/test/unit/test_sm80.py b/tools/library/scripts/pycutlass/test/unit/test_sm80.py deleted file mode 100644 index cf7eb53e..00000000 --- a/tools/library/scripts/pycutlass/test/unit/test_sm80.py +++ /dev/null @@ -1,464 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -## Test case generator for SM80 - -import pycutlass -from pycutlass import * -from pycutlass.test import * -from pycutlass.utils.device import device_cc -import unittest -import xmlrunner -import argparse - -# -# Create GEMM operation -# -@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") -def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False, - epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): - """ - Test GEMM Operation based on configuration - """ - - if "data_type" in kwargs.keys(): - data_type = kwargs["data_type"] - else: - if mixed or math_inst.element_a == cutlass.bfloat16: - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator - ] - else: - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator - ] - - tile_description = TileDescription( - tiling[0], tiling[1], tiling[2], - math_inst - ) - - A = TensorDescription( - data_type[0], layout[0], alignment[0] - ) - - B = TensorDescription( - data_type[1], layout[1], alignment[1] - ) - - C = TensorDescription( - data_type[2], layout[2], alignment[2] - ) - - element_epilogue = data_type[3] - if epilogue_functor is None: - epilogue_functor = LinearCombination( - C.element, C.alignment, - math_inst.element_accumulator, element_epilogue) - - if gemm_kind == GemmKind.Universal: - operation = GemmOperationUniversal( - arch=arch, tile_description=tile_description, - A=A, B=B, C=C, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor - ) - if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]: - return test_all_gemm(operation, "interleaved") - else: - return test_all_gemm(operation, "universal") - - elif gemm_kind == GemmKind.Grouped: - operation = GemmOperationGrouped( - arch, tile_description, A, B, C, - epilogue_functor, swizzling_functor, - precompute_mode=kwargs["precompute_mode"] - ) - testbed = TestbedGrouped(operation=operation) - return testbed.run(24) - else: - raise NotImplementedError("the gemm kind is not implemented") - - -def TestConv2dOperator(math_inst, alignment, tiling, arch, - stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided], - epilogue_functor=None, - swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs): - """ - Test Conv2d Operation based on configurations - """ - - mixeds = [False, True, False] - conv_kinds = [cutlass.conv.Operator.fprop, cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad] - - results = [] - - default_swizzling_functor = swizzling_functor - - if "layout" in kwargs.keys(): - layout = kwargs["layout"] - else: - layout = (cutlass.TensorNHWC, cutlass.TensorNHWC, cutlass.TensorNHWC) - - for mixed, conv_kind, stride_support in zip(mixeds, conv_kinds, stride_supports): - - if "data_type" in kwargs.keys(): - data_type = kwargs["data_type"] - else: - if mixed or math_inst.element_a == cutlass.bfloat16: - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator - ] - else: - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator - ] - # skip Int8 Conv Backward - if data_type[0] == cutlass.int8 and conv_kind in [cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad]: - continue - - A = TensorDescription( - element=data_type[0], - layout=layout[0], - alignment=alignment[0]) - B = TensorDescription( - element=data_type[1], - layout=layout[1], - alignment=alignment[1]) - C = TensorDescription( - element=data_type[2], - layout=layout[2], - alignment=alignment[2]) - - tile_description = TileDescription( - threadblock_shape=tiling[0], stages=tiling[1], - warp_count=tiling[2], - math_instruction=math_inst - ) - - if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided: - swizzling_functor = cutlass.StridedDgradIdentitySwizzle1 - else: - swizzling_functor = default_swizzling_functor - - if epilogue_functor is None: - epilogue_functor_ = LinearCombination( - C.element, C.alignment, - math_inst.element_accumulator, data_type[3]) - - operation = Conv2dOperation( - conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, - arch=arch, tile_description=tile_description, A=A, B=B, C=C, - stride_support=stride_support, - epilogue_functor=epilogue_functor_, - swizzling_functor=swizzling_functor - ) - - results.append(test_all_conv2d(operation, interleaved=interleaved)) - - return results - - - -class Test_SM80(unittest.TestCase): - def test_SM80_TensorOp_16816(self): - math_instructions = [ - MathInstruction( - [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add - ), - MathInstruction( - [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float16, - cutlass.OpClass.TensorOp, MathOperation.multiply_add - ), - MathInstruction( - [16, 8, 16], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add - ) - ] - - layouts = [ - (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor), - (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.RowMajor), - (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajor) - ] - - alignments = [ - (8, 8, 8), (4, 8, 8), (8, 4, 8) - ] - - tilings = [ - ([256, 128, 32], 3, [4, 2, 1]), - ([64, 256, 32], 4, [1, 4, 1]), - ([128, 64, 64], 3, [2, 2, 1]) - ] - - for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings): - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False)) - self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host)) - stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] - results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports) - for res in results: - self.assertTrue(res) - - def test_SM80_TensorOp_1688(self): - # tf32 is not supported by most of python environment. Skip the test - self.assertTrue(True) - - def test_SM80_TensorOp_1688_fast_math(self): - math_instructions = [ - MathInstruction( - [16, 8, 8], cutlass.tfloat32, cutlass.tfloat32, cutlass.float32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add - ), - MathInstruction( - [16, 8, 8], cutlass.float16, cutlass.float16, cutlass.float32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f16 - ), - MathInstruction( - [16, 8, 8], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_bf16 - ), - MathInstruction( - [16, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f32 - ) - ] - - layouts = [ - (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor), - (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor), - (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.ColumnMajor), - (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.RowMajor) - ] - alignments = [ - (4, 4, 4), (4, 2, 4), (2, 4, 4), (2, 2, 4) - ] - tilings = [ - ([128, 256, 16], 3, [4, 2, 1]), - ([64, 256, 16], 4, [1, 4, 1]), - ([128, 64, 32], 3, [2, 2, 1]), - ([256, 64, 32], 3, [4, 2, 1]) - ] - data_type = [ - cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32 - ] - for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings): - self.assertTrue( - TestGemmOperator( - GemmKind.Universal, math_inst, layout, - alignment, tiling, 80, False, data_type=data_type)) - self.assertTrue( - TestGemmOperator( - GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, - True, precompute_mode=SchedulerMode.Device, data_type=data_type)) - stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity] - results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) - for res in results: - self.assertTrue(res) - - def test_SM80_TensorOp_884(self): - math_inst = MathInstruction( - [8, 8, 4], cutlass.float64, cutlass.float64, cutlass.float64, - cutlass.OpClass.TensorOp, MathOperation.multiply_add - ) - layout = (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.ColumnMajor) - alignment = (1, 1, 1) - - tiling = ([64, 256, 16], 3, [2, 4, 1]) - data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64] - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) - self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type)) - stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity] - results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) - for res in results: - self.assertTrue(res) - - def test_SM80_TensorOp_16832_TN(self): - math_inst = MathInstruction( - [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate - ) - layout = (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor) - alignment = (16, 16, 4) - alignment_mixed = (16, 16, 16) - tiling = ([128, 256, 64], 3, [2, 4, 1]) - - data_type = [cutlass.int8, cutlass.int8, cutlass.int32, cutlass.int32] - data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32] - - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) - self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment_mixed, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type_mixed)) - stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] - results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) - for res in results: - self.assertTrue(res) - - def test_SM80_Simt_f32(self): - math_inst = MathInstruction( - [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, - cutlass.OpClass.Simt, MathOperation.multiply_add - ) - layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor) - alignment = (1, 1, 1) - - tiling = ([128, 256, 8], 4, [2, 4, 1]) - data_type = [cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32] - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) - self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host, data_type=data_type)) - stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] - results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) - for res in results: - self.assertTrue(res) - - def test_SM80_Simt_f64(self): - math_inst = MathInstruction( - [1, 1, 1], cutlass.float64, cutlass.float64, cutlass.float64, - cutlass.OpClass.Simt, MathOperation.multiply_add - ) - layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor) - alignment = (1, 1, 1) - - tiling = ([64, 128, 8], 5, [2, 2, 1]) - data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64] - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) - self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type)) - stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] - results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) - for res in results: - self.assertTrue(res) - - def test_SM80_TensorOp_16832_Interleaved(self): - math_inst = MathInstruction( - [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32, - cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate - ) - - layout = (cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32) - alignment_mixed = (16, 16, 8) - tiling = ([256, 64, 64], 4, [4, 1, 1]) - data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32] - - epilogue_functor = FastLinearCombinationClamp( - data_type_mixed[2], alignment_mixed[2] - ) - - self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=epilogue_functor)) - stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] - layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32] - results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True) - for res in results: - self.assertTrue(res) - - def SM80_SparseTensorOp_16832(self): - pass - def SM80_PlanarComplexTensorOp_16816(self): - pass - def SM80_SparseTensorOp_16816_fast_math(self): - pass - def SM80_TensorOp_1688_complex(self): - pass - def SM80_TensorOp_1688_fast_fp32_math_complex(self): - pass - def SM80_TensorOp_1688_rank_k(self): - pass - def SM80_TensorOp_1688_rank_k_complex(self): - pass - def SM80_TensorOp_1688_trmm(self): - pass - def SM80_TensorOp_1688_trmm_complex(self): - pass - def SM80_TensorOp_1688_symm(self): - pass - def SM80_TensorOp_1688_symm_complex(self): - pass - def SM80_TensorOp_884_complex(self): - pass - def SM80_TensorOp_884_complex_gaussian(self): - pass - def SM80_TensorOp_884_rank_k(self): - pass - def SM80_TensorOp_884_rank_k_complex(self): - pass - def SM80_TensorOp_884_rank_k_complex_gaussian(self): - pass - def SM80_TensorOp_884_trmm(self): - pass - def SM80_TensorOp_884_trmm_complex(self): - pass - def SM80_TensorOp_884_trmm_complex_gaussian(self): - pass - def SM80_TensorOp_884_symm(self): - pass - def SM80_TensorOp_884_symm_complex(self): - pass - def SM80_TensorOp_884_symm_complex_gaussian(self): - pass - def SM80_SparseTensorOp_16864_TN(self): - pass - def SM80_TensorOp_16864_TN(self): - pass - def SM80_SparseTensorOp_168128_TN(self): - pass - def SM80_TensorOp_16864_Interleaved(self): - pass - def SM80_TensorOp_168256(self): - pass - def SM80_Simt_complex(self): - pass - - -def argumentParser(): - parser = argparse.ArgumentParser(description="Entrypoint for PyCutlass testing on Ampere architecture.") - parser.add_argument("-j", "--junit_path", help="The absolute path to the directory for generating a junit xml report", default="") - return parser.parse_args() - - -if __name__ == '__main__': - pycutlass.get_memory_pool(2**20, 2**34) - pycutlass.compiler.nvcc() - args = argumentParser() - if args.junit_path: - unittest.main(argv=[''], testRunner=xmlrunner.XMLTestRunner(output=args.junit_path)) - else: - unittest.main(argv=['']) diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index ab5704bd..62f07220 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -64,6 +64,8 @@ class GemmOperationBase : public Operation { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; // assuming all tensors use same type for StrideIndex using StrideIndex = typename Operator::LayoutA::Index; using ElementAccumulator = typename Operator::ElementAccumulator; @@ -121,6 +123,7 @@ class GemmOperationBase : public Operation { description_.A = make_TensorDescription(Operator::kAlignmentA); description_.B = make_TensorDescription(Operator::kAlignmentB); description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentC); description_.element_epilogue = NumericTypeMap::kId; description_.split_k_mode = SplitKMode::kNone; @@ -147,6 +150,8 @@ class GemmOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; using OperatorArguments = typename Operator::Arguments; @@ -204,7 +209,7 @@ class GemmOperation : public GemmOperationBase { operator_args.ref_A.reset(static_cast(arguments->A)); operator_args.ref_B.reset(static_cast(arguments->B)); operator_args.ref_C.reset(static_cast(arguments->C)); - operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_D.reset(static_cast(arguments->D)); return Status::kSuccess; } @@ -345,6 +350,8 @@ class GemmSparseOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; using ElementE = typename Operator::ElementE; using LayoutE = typename Operator::LayoutE; using ElementAccumulator = typename Operator::ElementAccumulator; @@ -405,7 +412,7 @@ class GemmSparseOperation : public GemmOperationBase { operator_args.ref_A.reset(static_cast(arguments->A)); operator_args.ref_B.reset(static_cast(arguments->B)); operator_args.ref_C.reset(static_cast(arguments->C)); - operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_D.reset(static_cast(arguments->D)); operator_args.ref_E.reset(static_cast(arguments->E)); return Status::kSuccess; @@ -547,6 +554,8 @@ class GemmUniversalOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; @@ -751,6 +760,8 @@ class GemmPlanarComplexOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; @@ -958,6 +969,8 @@ class GemmPlanarComplexArrayOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; @@ -1159,6 +1172,8 @@ class GemmGroupedOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; @@ -1218,7 +1233,7 @@ class GemmGroupedOperation : public GemmOperationBase { op_args.ptr_A = static_cast(arguments->ptr_A); op_args.ptr_B = static_cast(arguments->ptr_B); op_args.ptr_C = static_cast(arguments->ptr_C); - op_args.ptr_D = static_cast(arguments->ptr_D); + op_args.ptr_D = static_cast(arguments->ptr_D); op_args.lda = arguments->lda; op_args.ldb = arguments->ldb; diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index 895de5be..eec57169 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -34,7 +34,6 @@ #pragma once #include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.hpp" #include "cutlass/library/library.h" #include "library_internal.h" @@ -56,6 +55,8 @@ class GemmOperation3xBase : public Operation { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; // assuming all tensors use same type for StrideIndex using StrideIndex = typename Operator::LayoutA::Index; using ElementAccumulator = typename Operator::ElementAccumulator; @@ -117,6 +118,7 @@ class GemmOperation3xBase : public Operation { description_.A = make_TensorDescription(Operator::kAlignmentA); description_.B = make_TensorDescription(Operator::kAlignmentB); description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); description_.element_epilogue = NumericTypeMap::kId; description_.split_k_mode = SplitKMode::kNone; @@ -144,6 +146,8 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; @@ -167,9 +171,6 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { // Do nothing here and construct kernel arguments in update_arguments_ instead // We also cannot construct TMA descriptors without all the arguments available - if (operator_args.hw_info.sm_count <= 0) { - operator_args.hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); - } operator_args.mode = configuration->mode; return Status::kSuccess; } @@ -181,13 +182,13 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { typename ThreadEpilogueOp::Params params( *static_cast(arguments->alpha), *static_cast(arguments->beta)); - operator_args.epilogue_params.thread_params = params; + operator_args.epilogue.thread = params; } else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { typename ThreadEpilogueOp::Params params( static_cast(arguments->alpha), static_cast(arguments->beta)); - operator_args.epilogue_params.thread_params = params; + operator_args.epilogue.thread = params; } else { return Status::kErrorInvalidProblem; @@ -201,18 +202,21 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { arguments->batch_count); // update arguments - operator_args.ptr_A = static_cast(arguments->A); - operator_args.ptr_B = static_cast(arguments->B); - operator_args.epilogue_params.ptr_C = static_cast(arguments->C); - operator_args.epilogue_params.ptr_D = static_cast(arguments->D); + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); - operator_args.dA = cute::make_int_tuple_from( + operator_args.mainloop.dA = cute::make_int_tuple_from( arguments->lda, arguments->batch_stride_A); - operator_args.dB = cute::make_int_tuple_from( + operator_args.mainloop.dB = cute::make_int_tuple_from( arguments->ldb, arguments->batch_stride_B); - operator_args.epilogue_params.dC = cute::make_int_tuple_from( + operator_args.epilogue.dC = cute::make_int_tuple_from( arguments->ldc, arguments->batch_stride_C); - operator_args.epilogue_params.dD = operator_args.epilogue_params.dC; + operator_args.epilogue.dD = operator_args.epilogue.dC; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; return Status::kSuccess; } @@ -223,6 +227,8 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { Status can_implement( void const *configuration_ptr, void const *arguments_ptr) const override { + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); GemmUniversalArguments const *arguments = static_cast(arguments_ptr); @@ -232,6 +238,13 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { return status; } + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + return Operator::can_implement(args); } diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index 90f61126..bdea2f49 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -379,7 +379,10 @@ Status Handle::gemm( element_B, layout_B, transform_B, - element_C + element_C, // C/D are same type and col major default + LayoutTypeID::kColumnMajor, + element_C, + LayoutTypeID::kColumnMajor ); auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); @@ -498,26 +501,26 @@ Status Handle::gemm_universal( NumericTypeID element_A, /// Data type of A matrix elements LayoutTypeID layout_A, /// Layout of A matrix ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices - void const * ptr_A, /// Pointer to A matrix in Global Memory - int64_t lda, /// Leading dimension of A matrix + int64_t lda, /// Leading dimension of A matrix NumericTypeID element_B, /// Data type of B matrix elements LayoutTypeID layout_B, /// Layout of B matrix ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices - void const * ptr_B, /// Pointer to B matrix in Global Memory - int64_t ldb, /// Leading dimension of B matrix + int64_t ldb, /// Leading dimension of B matrix void const * beta, /// Pointer to beta scalar - NumericTypeID element_C, /// Data type of C and D matrices - + NumericTypeID element_C, /// Data type of C matrix + LayoutTypeID layout_C, /// Layout of D matrix void const * ptr_C, /// Pointer to C matrix - int64_t ldc, /// Leading dimension of C matrix + int64_t ldc, /// Leading dimension of C matrix + NumericTypeID element_D, /// Data type of D matrix + LayoutTypeID layout_D, /// Layout of D matrix void * ptr_D, /// Pointer to D matrix - int64_t ldd, /// Leading dimension of D matrix + int64_t ldd, /// Leading dimension of D matrix int batch_count, /// Batch count or number of split-K slices @@ -542,7 +545,10 @@ Status Handle::gemm_universal( element_B, layout_B, transform_B, - element_C + element_C, + layout_C, + element_D, + layout_D ); auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); @@ -741,7 +747,10 @@ Status Handle::gemm_planar_complex( element_B, layout_B, transform_B, - element_C + element_C, // C/D are same type + LayoutTypeID::kColumnMajor, + element_C, + LayoutTypeID::kColumnMajor ); auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); @@ -935,7 +944,10 @@ Status Handle::gemm_planar_complex_array( element_B, layout_B, transform_B, - element_C + element_C, // C/D are same type + LayoutTypeID::kColumnMajor, + element_C, + LayoutTypeID::kColumnMajor ); auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); @@ -1121,7 +1133,7 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope static_cast(operation->description()); // if the curren gemm operation accumulator and output data type match return operation - if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) { + if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.D.element) { return operation; } @@ -1137,7 +1149,10 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope gemm_desc.B.element, gemm_desc.B.layout, gemm_desc.transform_B, - gemm_desc.tile_description.math_instruction.element_accumulator); + gemm_desc.tile_description.math_instruction.element_accumulator, // C/D are same type + LayoutTypeID::kColumnMajor, + gemm_desc.tile_description.math_instruction.element_accumulator, + LayoutTypeID::kColumnMajor); // gemm operation table auto gemm_operations = Singleton::get().operation_table.gemm_operations; diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index e9739e32..5423edda 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -96,6 +96,14 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kU8; }; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE4M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE5M2; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kU16; }; diff --git a/tools/library/src/manifest.cpp b/tools/library/src/manifest.cpp index 09481438..1f3c456a 100644 --- a/tools/library/src/manifest.cpp +++ b/tools/library/src/manifest.cpp @@ -57,7 +57,7 @@ Status Manifest::initialize() { // initialize procedurally generated cutlass op in manifest object initialize_all(*this); - // initialize manually instanced conv3d reference op in manifest object + // initialize manually instanced reference op in manifest object initialize_reference_operations(*this); // initialize manually instanced reduction reference op in manifest object diff --git a/tools/library/src/operation_table.cu b/tools/library/src/operation_table.cu index d3799c3a..113e48d2 100644 --- a/tools/library/src/operation_table.cu +++ b/tools/library/src/operation_table.cu @@ -66,7 +66,10 @@ void OperationTable::append(Manifest const &manifest) { gemm_desc.B.element, gemm_desc.B.layout, gemm_desc.transform_B, - gemm_desc.C.element + gemm_desc.C.element, + gemm_desc.C.layout, + gemm_desc.D.element, + gemm_desc.D.layout ); Operation const *op = operation.get(); diff --git a/tools/library/src/reduction/init_reduction_operations.cu b/tools/library/src/reduction/init_reduction_operations.cu index bd8d9bb1..2d14166b 100644 --- a/tools/library/src/reduction/init_reduction_operations.cu +++ b/tools/library/src/reduction/init_reduction_operations.cu @@ -58,7 +58,6 @@ void initialize_all_reduction_op(Manifest &manifest) { initialize_reduce_add_linear_combination_f32_f32_f32(manifest); initialize_reduce_add_linear_combination_f64_f64_f64(manifest); initialize_reduce_add_linear_combination_cf32_cf32_cf32(manifest); - } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reduction/reduction_device.cu b/tools/library/src/reduction/reduction_device.cu index dfe8568b..5ede2fdf 100644 --- a/tools/library/src/reduction/reduction_device.cu +++ b/tools/library/src/reduction/reduction_device.cu @@ -43,9 +43,10 @@ namespace library { // naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput] + void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) { - using ElementWorkspace = cutlass::half_t; + using ElementWorkspace = cutlass::half_t; using ElementAccumulator = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementCompute = cutlass::half_t; @@ -58,7 +59,7 @@ void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) { >; using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, + ElementAccumulator, typename EpilogueOutputOp::ElementAccumulator, EpilogueOutputOp::kCount >; @@ -79,7 +80,7 @@ void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) { void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { - using ElementWorkspace = float; + using ElementWorkspace = float; using ElementAccumulator = float; using ElementOutput = cutlass::half_t; using ElementCompute = float; @@ -92,7 +93,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { >; using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, + ElementAccumulator, typename EpilogueOutputOp::ElementAccumulator, EpilogueOutputOp::kCount >; @@ -114,7 +115,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { - using ElementWorkspace = float; + using ElementWorkspace = float; using ElementAccumulator = float; using ElementOutput = float; using ElementCompute = float; @@ -127,7 +128,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { >; using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, + ElementAccumulator, typename EpilogueOutputOp::ElementAccumulator, EpilogueOutputOp::kCount >; @@ -148,7 +149,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest) { - using ElementWorkspace = double; + using ElementWorkspace = double; using ElementAccumulator = double; using ElementOutput = double; using ElementCompute = double; @@ -161,7 +162,7 @@ void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest) { >; using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, + ElementAccumulator, typename EpilogueOutputOp::ElementAccumulator, EpilogueOutputOp::kCount >; @@ -182,7 +183,7 @@ void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest) { void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) { - using ElementWorkspace = cutlass::complex; + using ElementWorkspace = cutlass::complex; using ElementAccumulator = cutlass::complex; using ElementOutput = cutlass::complex; using ElementCompute = cutlass::complex; @@ -195,7 +196,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest) >; using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, + ElementAccumulator, typename EpilogueOutputOp::ElementAccumulator, EpilogueOutputOp::kCount >; diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index 0b108c2b..4eac5deb 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -146,6 +146,7 @@ struct ConvReferenceDispatcher< LayoutC, ElementCompute, ElementAccumulator, + ElementC, ConvertOp, InnerProductOp >( diff --git a/tools/library/src/reference/gemm.cu b/tools/library/src/reference/gemm.cu index 890772e2..b1a6e830 100644 --- a/tools/library/src/reference/gemm.cu +++ b/tools/library/src/reference/gemm.cu @@ -137,6 +137,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int8_t, float, int32_t, + int8_t, NumericConverterClamp >(manifest); @@ -146,6 +147,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int32_t, float, int32_t, + int32_t, NumericConverterClamp >(manifest); @@ -163,6 +165,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int8_t, float, int32_t, + int8_t, NumericConverterClamp >(manifest); @@ -172,6 +175,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int32_t, float, int32_t, + int32_t, NumericConverterClamp >(manifest); @@ -191,6 +195,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int32_t, float, int32_t, + int32_t, NumericConverterClamp >(manifest); @@ -201,6 +206,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int8_t, float, int32_t, + int8_t, NumericConverterClamp >(manifest); @@ -220,6 +226,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int32_t, float, int32_t, + int32_t, NumericConverterClamp >(manifest); @@ -230,6 +237,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { uint8_t, float, int32_t, + uint8_t, NumericConverterClamp >(manifest); @@ -240,6 +248,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int8_t, float, int32_t, + int8_t, NumericConverterClamp >(manifest); @@ -259,6 +268,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int32_t, float, int32_t, + int32_t, NumericConverterClamp >(manifest); @@ -269,6 +279,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int4b_t, float, int32_t, + int4b_t, NumericConverterClamp >(manifest); @@ -288,6 +299,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int32_t, float, int32_t, + int32_t, NumericConverterClamp >(manifest); @@ -298,6 +310,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { uint4b_t, float, int32_t, + uint4b_t, NumericConverterClamp >(manifest); @@ -308,6 +321,7 @@ void initialize_gemm_reference_operations(Manifest &manifest) { int4b_t, float, int32_t, + int4b_t, NumericConverterClamp >(manifest); @@ -330,6 +344,359 @@ void initialize_gemm_reference_operations(Manifest &manifest) { complex, complex >(manifest); + + // + // FP8 GEMMs + // + ////////////////////////////////// + /// ElementC: half_t + ////////////////////////////////// + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float , // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float , // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + /// ElementC: bfloat16_t + ////////////////////////////////// + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + ////////////////////////////////// + /// ElementC: float + ////////////////////////////////// + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h index 5d4d150b..b300529b 100644 --- a/tools/library/src/reference/gemm_reference_operation.h +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -67,7 +67,8 @@ template < typename LayoutC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > class GemmReferenceOperation : public Operation { @@ -84,7 +85,9 @@ class GemmReferenceOperation : public Operation { static cutlass::ComplexTransform const kTransformB = TransformB; using ElementC = ElementC_; using LayoutC = LayoutC_; + using ElementD = ElementD_; using TensorRefC = TensorRef; + using TensorRefD = TensorRef; using ElementCompute = ElementCompute_; using ElementAccumulator = ElementAccumulator_; using ConvertOp = ConvertOp_; @@ -114,6 +117,7 @@ class GemmReferenceOperation : public Operation { description_.B = make_TensorDescription(); description_.transform_B = ComplexTransformMap::kId; description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); // Epilogue compute and accumulator type description description_.element_epilogue = NumericTypeMap::kId; @@ -196,7 +200,7 @@ class GemmReferenceOperation : public Operation { TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; - TensorRefC ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; + TensorRefD ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; if (kProvider == Provider::kReferenceHost) { @@ -209,6 +213,7 @@ class GemmReferenceOperation : public Operation { LayoutC, ElementCompute, ElementAccumulator, + ElementD, ConvertOp, InnerProductOp >( @@ -242,6 +247,7 @@ class GemmReferenceOperation : public Operation { LayoutC, ElementCompute, ElementAccumulator, + ElementD, ConvertOp, InnerProductOp >( @@ -282,7 +288,8 @@ template < typename LayoutC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > void make_gemm(Manifest &manifest) { @@ -294,6 +301,7 @@ void make_gemm(Manifest &manifest) { ElementC_, LayoutC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >); @@ -305,6 +313,7 @@ void make_gemm(Manifest &manifest) { ElementC_, LayoutC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >); @@ -317,37 +326,42 @@ template < typename ElementC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > void make_gemm_canonical_layouts(Manifest &manifest) { + // M Major outputs make_gemm< ElementA_, cutlass::layout::ColumnMajor, TransformA, ElementB_, cutlass::layout::ColumnMajor, TransformB, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); - + make_gemm< ElementA_, cutlass::layout::ColumnMajor, TransformA, ElementB_, cutlass::layout::RowMajor, TransformB, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); - + make_gemm< ElementA_, cutlass::layout::RowMajor, TransformA, ElementB_, cutlass::layout::ColumnMajor, TransformB, ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -358,6 +372,52 @@ void make_gemm_canonical_layouts(Manifest &manifest) { ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + // N Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -372,6 +432,7 @@ template < typename ElementC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > @@ -383,6 +444,7 @@ void make_gemm_interleaved_layouts(Manifest &manifest) { ElementC_, cutlass::layout::ColumnMajor, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -396,7 +458,8 @@ template < typename ElementC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > void make_gemm_real_canonical_layouts(Manifest &manifest) { @@ -406,6 +469,7 @@ void make_gemm_real_canonical_layouts(Manifest &manifest) { ElementC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -418,7 +482,8 @@ template < typename ElementC_, typename ElementCompute_, typename ElementAccumulator_ = ElementCompute_, - typename ConvertOp_ = NumericConverter, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, typename InnerProductOp_ = multiply_add > void make_gemm_complex_canonical_layouts(Manifest &manifest) { @@ -429,6 +494,7 @@ void make_gemm_complex_canonical_layouts(Manifest &manifest) { ElementC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -439,6 +505,7 @@ void make_gemm_complex_canonical_layouts(Manifest &manifest) { ElementC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -449,6 +516,7 @@ void make_gemm_complex_canonical_layouts(Manifest &manifest) { ElementC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); @@ -459,6 +527,7 @@ void make_gemm_complex_canonical_layouts(Manifest &manifest) { ElementC_, ElementCompute_, ElementAccumulator_, + ElementD_, ConvertOp_, InnerProductOp_ >(manifest); diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index 36334576..98a578c0 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -443,6 +443,8 @@ NumericTypeID_enumerants[] = { {"s16", "S16", NumericTypeID::kS16}, {"s32", "S32", NumericTypeID::kS32}, {"s64", "S64", NumericTypeID::kS64}, + {"fe4m3", "FE4M3", NumericTypeID::kFE4M3}, + {"fe5m2", "FE5M2", NumericTypeID::kFE5M2}, {"f16", "F16", NumericTypeID::kF16}, {"bf16", "BF16", NumericTypeID::kBF16}, {"f32", "F32", NumericTypeID::kF32}, @@ -504,6 +506,8 @@ NumericTypeID from_string(std::string const &str) { /// Returns the size of a data type in bits int sizeof_bits(NumericTypeID type) { switch (type) { + case NumericTypeID::kFE4M3: return 8; + case NumericTypeID::kFE5M2: return 8; case NumericTypeID::kF16: return 16; case NumericTypeID::kBF16: return 16; case NumericTypeID::kTF32: return 32; @@ -581,6 +585,8 @@ bool is_integer_type(NumericTypeID type) { /// Returns true if numeric type is signed bool is_signed_type(NumericTypeID type) { switch (type) { + case NumericTypeID::kFE4M3: return true; + case NumericTypeID::kFE5M2: return true; case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; @@ -610,6 +616,8 @@ bool is_unsigned_integer(NumericTypeID type) { /// Returns true if numeric type is floating-point type bool is_float_type(NumericTypeID type) { switch (type) { + case NumericTypeID::kFE4M3: return true; + case NumericTypeID::kFE5M2: return true; case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; @@ -1050,6 +1058,20 @@ bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string c ss >> *reinterpret_cast(bytes.data()); } break; + case NumericTypeID::kFE4M3: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; + case NumericTypeID::kFE5M2: + { + float tmp; + ss >> tmp; + *reinterpret_cast(bytes.data()) = static_cast(tmp); + } + break; case NumericTypeID::kF16: { float tmp; @@ -1187,6 +1209,18 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { ss << *reinterpret_cast(bytes.data()); } break; + case NumericTypeID::kFE4M3: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; + case NumericTypeID::kFE5M2: + { + float tmp = *reinterpret_cast(bytes.data()); + ss << tmp; + } + break; case NumericTypeID::kF16: { float tmp = *reinterpret_cast(bytes.data()); @@ -1329,6 +1363,16 @@ bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t sr *reinterpret_cast(bytes.data()) = static_cast(src); } break; + case NumericTypeID::kFE4M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE5M2: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1429,6 +1473,16 @@ bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t *reinterpret_cast(bytes.data()) = static_cast(src); } break; + case NumericTypeID::kFE4M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE5M2: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1530,6 +1584,16 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr *reinterpret_cast(bytes.data()) = static_cast(src); } break; + case NumericTypeID::kFE4M3: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; + case NumericTypeID::kFE5M2: + { + *reinterpret_cast(bytes.data()) = static_cast(float(src)); + } + break; case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index f9f425cc..1675d42c 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -29,7 +29,7 @@ # # Sources for CUTLASS Profiler Tool # - +cmake_policy(SET CMP0112 NEW) set(CUTLASS_TOOLS_PROFILER_SOURCES src/main.cpp src/cutlass_profiler.cu diff --git a/tools/profiler/src/conv2d_operation_profiler.cu b/tools/profiler/src/conv2d_operation_profiler.cu index dfbce274..cca50d0b 100644 --- a/tools/profiler/src/conv2d_operation_profiler.cu +++ b/tools/profiler/src/conv2d_operation_profiler.cu @@ -611,7 +611,7 @@ Status Conv2dOperationProfiler::initialize_workspace( if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; conv_workspace_.A = device_context.allocate_tensor( options, "A", @@ -619,7 +619,8 @@ Status Conv2dOperationProfiler::initialize_workspace( operation_desc.A.layout, problem_.extent_a(operation_desc.conv_kind), conv_workspace_.configuration.stride_a, - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); conv_workspace_.B = device_context.allocate_tensor( @@ -629,7 +630,8 @@ Status Conv2dOperationProfiler::initialize_workspace( operation_desc.B.layout, problem_.extent_b(operation_desc.conv_kind), conv_workspace_.configuration.stride_b, - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); if(problem_.groups == problem_.c && problem_.groups == problem_.k){ @@ -641,7 +643,8 @@ Status Conv2dOperationProfiler::initialize_workspace( operation_desc.B.layout, problem_.extent_b(operation_desc.conv_kind), conv_workspace_.configuration.stride_b, - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); } @@ -652,7 +655,8 @@ Status Conv2dOperationProfiler::initialize_workspace( operation_desc.C.layout, problem_.extent_c(operation_desc.conv_kind), conv_workspace_.configuration.stride_c, - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); conv_workspace_.Computed = device_context.allocate_tensor( diff --git a/tools/profiler/src/conv3d_operation_profiler.cu b/tools/profiler/src/conv3d_operation_profiler.cu index da9c3653..24a53c39 100644 --- a/tools/profiler/src/conv3d_operation_profiler.cu +++ b/tools/profiler/src/conv3d_operation_profiler.cu @@ -651,7 +651,7 @@ Status Conv3dOperationProfiler::initialize_workspace( if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; conv_workspace_.A = device_context.allocate_tensor( options, "A", @@ -659,7 +659,8 @@ Status Conv3dOperationProfiler::initialize_workspace( operation_desc.A.layout, problem_.extent_a(operation_desc.conv_kind), conv_workspace_.stride_a(operation_desc.conv_kind), - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); conv_workspace_.B = device_context.allocate_tensor( @@ -669,7 +670,8 @@ Status Conv3dOperationProfiler::initialize_workspace( operation_desc.B.layout, problem_.extent_b(operation_desc.conv_kind), conv_workspace_.stride_b(operation_desc.conv_kind), - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); conv_workspace_.C = device_context.allocate_tensor( @@ -679,7 +681,8 @@ Status Conv3dOperationProfiler::initialize_workspace( operation_desc.C.layout, problem_.extent_c(operation_desc.conv_kind), conv_workspace_.stride_c(operation_desc.conv_kind), - conv_workspace_.problem_count + conv_workspace_.problem_count, + seed_shift++ ); conv_workspace_.Computed = device_context.allocate_tensor( diff --git a/tools/profiler/src/cublas_helpers.cu b/tools/profiler/src/cublas_helpers.cu index 2175b359..57000919 100644 --- a/tools/profiler/src/cublas_helpers.cu +++ b/tools/profiler/src/cublas_helpers.cu @@ -103,13 +103,22 @@ bool get_cublas_transpose_operation( /// Maps a CUTLASS numeric type to a cuBLAS data type enumeration bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type) { - switch (element_type) { + switch (element_type) { + case library::NumericTypeID::kFE4M3: + data_type = CUDA_R_8F_E4M3; + return true; + + case library::NumericTypeID::kFE5M2: + data_type = CUDA_R_8F_E5M2; + return true; + case library::NumericTypeID::kF16: data_type = CUDA_R_16F; return true; case library::NumericTypeID::kBF16: - break; + data_type = CUDA_R_16BF; + return true; case library::NumericTypeID::kTF32: break; diff --git a/tools/profiler/src/cudnn_helpers.cpp b/tools/profiler/src/cudnn_helpers.cpp index 844119d1..158799e2 100644 --- a/tools/profiler/src/cudnn_helpers.cpp +++ b/tools/profiler/src/cudnn_helpers.cpp @@ -68,7 +68,7 @@ Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status) { return Disposition::kFailed; } -/// Checks cudnnStatus_t converts to cutlass status and returns if Status::kSuccess o.w. throws exception +/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception Status checkCudnnErr(cudnnStatus_t cudnn_status) { Status cutlass_status = get_cutlass_status(cudnn_status); if(cutlass_status != Status::kSuccess) { diff --git a/tools/profiler/src/cudnn_helpers.h b/tools/profiler/src/cudnn_helpers.h index e1c4f644..d5a9af7b 100644 --- a/tools/profiler/src/cudnn_helpers.h +++ b/tools/profiler/src/cudnn_helpers.h @@ -55,7 +55,7 @@ Status get_cutlass_status(cudnnStatus_t cudnn_status); /// Converts a cuDNN status to cutlass::profiler::Disposition Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); -/// Checks cudnnStatus_t converts to cutlass status and returns if Status::kSuccess o.w. throws exception +/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception Status checkCudnnErr(cudnnStatus_t cudnn_status); /// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 92679ef5..f464ccb7 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -549,6 +549,22 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kFE4M3: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE5M2: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -681,6 +697,22 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { std::vector host_data(bytes()); switch (type_) { + case library::NumericTypeID::kFE4M3: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kFE5M2: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -942,6 +974,18 @@ bool DeviceAllocation::block_compare_equal( size_t capacity) { switch (numeric_type) { + case library::NumericTypeID::kFE4M3: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + + case library::NumericTypeID::kFE5M2: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + case library::NumericTypeID::kF16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -1095,6 +1139,22 @@ bool DeviceAllocation::block_compare_relatively_equal( double nonzero_floor) { switch (numeric_type) { + case library::NumericTypeID::kFE4M3: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kFE5M2: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + case library::NumericTypeID::kF16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), @@ -1430,6 +1490,14 @@ void DeviceAllocation::write_tensor_csv( std::ostream &out) { switch (this->type()) { + case library::NumericTypeID::kFE4M3: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kFE5M2: + write_tensor_csv_static_type(out, *this); + break; + case library::NumericTypeID::kF16: write_tensor_csv_static_type(out, *this); break; @@ -1586,6 +1654,14 @@ static void tensor_fill(DeviceAllocation &allocation, Element val = Element()) { void DeviceAllocation::fill(double val = 0.0) { switch (this->type()) { + case library::NumericTypeID::kFE4M3: + tensor_fill(*this, static_cast(val)); + break; + + case library::NumericTypeID::kFE5M2: + tensor_fill(*this, static_cast(val)); + break; + case library::NumericTypeID::kF16: tensor_fill(*this, static_cast(val)); break; diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index 117f78ba..43a5ebd3 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -76,7 +76,8 @@ DeviceAllocation *DeviceContext::allocate_tensor( library::LayoutTypeID layout_id, std::vector const &extent, std::vector const &stride, - int batch_count) { + int batch_count, + int seed_shift) { DeviceAllocation *allocation = allocate_tensor(name, type, layout_id, extent, stride, batch_count); @@ -88,6 +89,12 @@ DeviceAllocation *DeviceContext::allocate_tensor( if(!options.initialization.fix_data_distribution) { // change data distribution based on bit width switch(type) { + case library::NumericTypeID::kFE4M3: + data_distribution.set_uniform(-1, 1, 0); + break; + case library::NumericTypeID::kFE5M2: + data_distribution.set_uniform(-1, 1, 0); + break; case library::NumericTypeID::kF16: data_distribution.set_uniform(-3, 3, 0); break; @@ -118,12 +125,12 @@ DeviceAllocation *DeviceContext::allocate_tensor( if (options.initialization.provider == library::Provider::kReferenceDevice) { allocation->initialize_random_device( - options.initialization.seed, + options.initialization.seed + seed_shift, data_distribution); } else if (options.initialization.provider == library::Provider::kReferenceHost) { allocation->initialize_random_host( - options.initialization.seed, + options.initialization.seed + seed_shift, data_distribution); } } @@ -140,7 +147,8 @@ DeviceAllocation *DeviceContext::allocate_sparsemeta_tensor( library::NumericTypeID type_a, std::vector const &extent, std::vector const &stride, - int batch_count) { + int batch_count, + int seed_shift) { DeviceAllocation *allocation = allocate_tensor(name, type, layout_id, extent, stride, batch_count); @@ -151,12 +159,12 @@ DeviceAllocation *DeviceContext::allocate_sparsemeta_tensor( if (options.initialization.provider == library::Provider::kReferenceDevice) { allocation->initialize_random_sparsemeta_device( - options.initialization.seed, + options.initialization.seed + seed_shift, MetaSizeInBits); } else if (options.initialization.provider == library::Provider::kReferenceHost) { allocation->initialize_random_sparsemeta_host( - options.initialization.seed, + options.initialization.seed + seed_shift, MetaSizeInBits); } } diff --git a/tools/profiler/src/device_context.h b/tools/profiler/src/device_context.h index 16a72f9c..1f21dc3c 100644 --- a/tools/profiler/src/device_context.h +++ b/tools/profiler/src/device_context.h @@ -93,8 +93,9 @@ class DeviceContext { library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector(), - int batch_count = 1); + std::vector const &stride, + int batch_count, + int seed_shift = 0); /// Allocates memory for sparse meta data DeviceAllocation *allocate_sparsemeta_tensor( @@ -104,8 +105,9 @@ class DeviceContext { library::LayoutTypeID layout_id, library::NumericTypeID type_a, std::vector const &extent, - std::vector const &stride = std::vector(), - int batch_count = 1); + std::vector const &stride, + int batch_count, + int seed_shift = 0); /// Clears named allocations (but does not necessarily free memory) void clear(); diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index a929ee89..595c9084 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -61,7 +61,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): options, library::OperationKind::kGemm, { - {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"}, + {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (universal, gemm, planar_complex, planar_complex_array)"}, {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, @@ -432,7 +432,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_( library::Provider::kCUTLASS, gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator - gemm_desc.C.element, // element output + gemm_desc.D.element, // element output gemm_desc.element_epilogue // element compute ); @@ -485,7 +485,7 @@ Status GemmOperationProfiler::initialize_workspace( } if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; gemm_workspace_.A = device_context.allocate_tensor( options, "A", @@ -493,7 +493,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.A.layout, {int(problem_.m), int(problem_.k)}, {int(problem_.lda)}, - problem_.batch_count * gemm_workspace_.problem_count + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++ ); gemm_workspace_.B = device_context.allocate_tensor( @@ -503,7 +504,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.B.layout, {int(problem_.k), int(problem_.n)}, {int(problem_.ldb)}, - problem_.batch_count * gemm_workspace_.problem_count + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++ ); gemm_workspace_.C = device_context.allocate_tensor( @@ -513,13 +515,14 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.C.layout, {int(problem_.m), int(problem_.n)}, {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count + problem_.batch_count * gemm_workspace_.problem_count, + seed_shift++ ); gemm_workspace_.Computed = device_context.allocate_tensor( "D", - operation_desc.C.element, - operation_desc.C.layout, + operation_desc.D.element, + operation_desc.D.layout, {int(problem_.m), int(problem_.n)}, {int(problem_.ldc)}, problem_.batch_count * gemm_workspace_.problem_count @@ -527,8 +530,8 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_.Reference = device_context.allocate_tensor( "Reference", - operation_desc.C.element, - operation_desc.C.layout, + operation_desc.D.element, + operation_desc.D.layout, {int(problem_.m), int(problem_.n)}, {int(problem_.ldc)}, problem_.batch_count * gemm_workspace_.problem_count @@ -547,6 +550,9 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + gemm_workspace_.arguments.sm_count = options.device.properties.multiProcessorCount; } // @@ -965,9 +971,12 @@ bool GemmOperationProfiler::verify_with_reference_( problem_.beta.data(), gemm_desc.C.element, + gemm_desc.C.layout, ptr_C, int(gemm_workspace_.configuration.ldc), + gemm_desc.D.element, + gemm_desc.D.layout, ptr_D, int(gemm_workspace_.configuration.ldd), @@ -975,8 +984,7 @@ bool GemmOperationProfiler::verify_with_reference_( gemm_workspace_.A->batch_stride(), gemm_workspace_.B->batch_stride(), gemm_workspace_.C->batch_stride(), - gemm_workspace_.Reference->batch_stride() - ); + gemm_workspace_.Reference->batch_stride()); if (status != Status::kSuccess) { results_.back().verification_map[provider] = Disposition::kNotRun; diff --git a/tools/profiler/src/gemm_operation_profiler.h b/tools/profiler/src/gemm_operation_profiler.h index a01c93a0..309aaad1 100644 --- a/tools/profiler/src/gemm_operation_profiler.h +++ b/tools/profiler/src/gemm_operation_profiler.h @@ -66,8 +66,9 @@ class GemmOperationProfiler : public OperationProfiler { /// Problem structure obtained from problem space struct GemmProblem { - + cutlass::library::GemmUniversalMode mode; + int64_t m; int64_t n; int64_t k; diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 3401d15b..ab2b4ed0 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -713,9 +713,10 @@ Options::Options(cutlass::CommandLine const &cmdline): } // Prevent launches on the device for anything other than CUTLASS operation + // Allow verification only on host if (execution_mode == ExecutionMode::kTrace) { initialization.provider = library::Provider::kReferenceHost; - verification.enabled = false; + verification.providers = {library::Provider::kReferenceHost}; profiling.enabled = false; } } diff --git a/tools/profiler/src/rank_2k_operation_profiler.cu b/tools/profiler/src/rank_2k_operation_profiler.cu index 2c2f2361..4ff4e4e6 100644 --- a/tools/profiler/src/rank_2k_operation_profiler.cu +++ b/tools/profiler/src/rank_2k_operation_profiler.cu @@ -391,14 +391,16 @@ Status Rank2KOperationProfiler::initialize_workspace( static_cast(operation->description()); if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; rank_k_workspace_.A = device_context.allocate_tensor( options, "A", operation_desc.A.element, operation_desc.A.layout, {int(problem_.n), int(problem_.k)}, - {int(problem_.lda)} + {int(problem_.lda)}, + 1, // batch_count + seed_shift++ ); rank_k_workspace_.B = device_context.allocate_tensor( @@ -407,7 +409,9 @@ Status Rank2KOperationProfiler::initialize_workspace( operation_desc.B.element, operation_desc.B.layout, {int(problem_.n), int(problem_.k)}, - {int(problem_.ldb)} + {int(problem_.ldb)}, + 1, // batch_count + seed_shift++ ); rank_k_workspace_.C = device_context.allocate_tensor( @@ -417,7 +421,8 @@ Status Rank2KOperationProfiler::initialize_workspace( operation_desc.C.layout, {int(problem_.n), int(problem_.n)}, {int(problem_.ldc)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); rank_k_workspace_.Computed = device_context.allocate_tensor( diff --git a/tools/profiler/src/rank_k_operation_profiler.cu b/tools/profiler/src/rank_k_operation_profiler.cu index 7e452e70..5d3972ee 100644 --- a/tools/profiler/src/rank_k_operation_profiler.cu +++ b/tools/profiler/src/rank_k_operation_profiler.cu @@ -391,14 +391,16 @@ Status RankKOperationProfiler::initialize_workspace( static_cast(operation->description()); if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; rank_k_workspace_.A = device_context.allocate_tensor( options, "A", operation_desc.A.element, operation_desc.A.layout, {int(problem_.n), int(problem_.k)}, - {int(problem_.lda)} + {int(problem_.lda)}, + 1, // batch_count + seed_shift++ ); rank_k_workspace_.C = device_context.allocate_tensor( @@ -408,7 +410,8 @@ Status RankKOperationProfiler::initialize_workspace( operation_desc.C.layout, {int(problem_.n), int(problem_.n)}, {int(problem_.ldc)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); rank_k_workspace_.Computed = device_context.allocate_tensor( diff --git a/tools/profiler/src/sparse_gemm_operation_profiler.cu b/tools/profiler/src/sparse_gemm_operation_profiler.cu index 2caf5f01..6499039a 100644 --- a/tools/profiler/src/sparse_gemm_operation_profiler.cu +++ b/tools/profiler/src/sparse_gemm_operation_profiler.cu @@ -56,7 +56,7 @@ SparseGemmOperationProfiler::SparseGemmOperationProfiler(Options const &options) options, library::OperationKind::kSparseGemm, { - {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (e.g. gemm, planar complex, batched, ...)"}, + {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (e.g. sparse, ...)"}, {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, @@ -348,14 +348,16 @@ Status SparseGemmOperationProfiler::initialize_workspace( static_cast(operation->description()); if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; gemm_workspace_.A = device_context.allocate_tensor( options, "A", operation_desc.A.element, operation_desc.A.layout, {int(problem_.m), int(problem_.k) / int(problem_.sparse)}, - {int(problem_.lda)} + {int(problem_.lda)}, + 1, // batch_count + seed_shift++ ); gemm_workspace_.B = device_context.allocate_tensor( @@ -364,7 +366,9 @@ Status SparseGemmOperationProfiler::initialize_workspace( operation_desc.B.element, operation_desc.B.layout, {int(problem_.k), int(problem_.n)}, - {int(problem_.ldb)} + {int(problem_.ldb)}, + 1, // batch_count + seed_shift++ ); gemm_workspace_.C = device_context.allocate_tensor( @@ -373,7 +377,9 @@ Status SparseGemmOperationProfiler::initialize_workspace( operation_desc.C.element, operation_desc.C.layout, {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)} + {int(problem_.ldc)}, + 1, // batch_count + seed_shift++ ); gemm_workspace_.Computed = device_context.allocate_tensor( @@ -391,7 +397,9 @@ Status SparseGemmOperationProfiler::initialize_workspace( operation_desc.E.layout, operation_desc.A.element, {int(problem_.m), int(problem_.k) / int(problem_.sparse) / int(problem_.elements_per_128b)}, - {int(problem_.lde)} + {int(problem_.lde)}, + 1, // batch_count + seed_shift++ ); gemm_workspace_.Reference = device_context.allocate_tensor( diff --git a/tools/profiler/src/symm_operation_profiler.cu b/tools/profiler/src/symm_operation_profiler.cu index 97cb34a1..2a344182 100644 --- a/tools/profiler/src/symm_operation_profiler.cu +++ b/tools/profiler/src/symm_operation_profiler.cu @@ -415,7 +415,7 @@ Status SymmOperationProfiler::initialize_workspace( static_cast(operation->description()); if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; if (operation_desc.side_mode == SideMode::kLeft) { symm_workspace_.A = device_context.allocate_tensor( options, @@ -424,7 +424,8 @@ Status SymmOperationProfiler::initialize_workspace( operation_desc.A.layout, {int(problem_.m), int(problem_.m)}, {int(problem_.lda)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); } else if (operation_desc.side_mode == SideMode::kRight) { symm_workspace_.A = device_context.allocate_tensor( @@ -434,7 +435,8 @@ Status SymmOperationProfiler::initialize_workspace( operation_desc.A.layout, {int(problem_.n), int(problem_.n)}, {int(problem_.lda)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); } @@ -444,7 +446,9 @@ Status SymmOperationProfiler::initialize_workspace( operation_desc.B.element, operation_desc.B.layout, {int(problem_.m), int(problem_.n)}, - {int(problem_.ldb)} + {int(problem_.ldb)}, + 1, // batch_count + seed_shift++ ); symm_workspace_.C = device_context.allocate_tensor( @@ -454,7 +458,8 @@ Status SymmOperationProfiler::initialize_workspace( operation_desc.C.layout, {int(problem_.m), int(problem_.n)}, {int(problem_.ldc)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); symm_workspace_.Computed = device_context.allocate_tensor( diff --git a/tools/profiler/src/trmm_operation_profiler.cu b/tools/profiler/src/trmm_operation_profiler.cu index 19014d0b..14e6fb2d 100644 --- a/tools/profiler/src/trmm_operation_profiler.cu +++ b/tools/profiler/src/trmm_operation_profiler.cu @@ -372,7 +372,7 @@ Status TrmmOperationProfiler::initialize_workspace( static_cast(operation->description()); if (options.execution_mode != ExecutionMode::kDryRun) { - + int seed_shift = 0; if (operation_desc.side_mode == SideMode::kLeft) { trmm_workspace_.A = device_context.allocate_tensor( options, @@ -381,7 +381,8 @@ Status TrmmOperationProfiler::initialize_workspace( operation_desc.A.layout, {int(problem_.m), int(problem_.m)}, {int(problem_.lda)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); } else if (operation_desc.side_mode == SideMode::kRight) { trmm_workspace_.A = device_context.allocate_tensor( @@ -391,7 +392,8 @@ Status TrmmOperationProfiler::initialize_workspace( operation_desc.A.layout, {int(problem_.n), int(problem_.n)}, {int(problem_.lda)}, - 1 // batch_count = 1, default + 1, // batch_count + seed_shift++ ); } @@ -401,7 +403,9 @@ Status TrmmOperationProfiler::initialize_workspace( operation_desc.B.element, operation_desc.B.layout, {int(problem_.m), int(problem_.n)}, - {int(problem_.ldb)} + {int(problem_.ldb)}, + 1, // batch_count + seed_shift++ ); trmm_workspace_.Computed = device_context.allocate_tensor( diff --git a/tools/util/CMakeLists.txt b/tools/util/CMakeLists.txt index 1999aaf2..bc96016e 100644 --- a/tools/util/CMakeLists.txt +++ b/tools/util/CMakeLists.txt @@ -25,7 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +cmake_policy(SET CMP0112 NEW) add_library(cutlass_tools_util_includes INTERFACE) add_library(nvidia::cutlass::tools::util ALIAS cutlass_tools_util_includes) set_target_properties(cutlass_tools_util_includes PROPERTIES EXPORT_NAME tools::util) diff --git a/tools/util/include/cutlass/util/gett_commandline.hpp b/tools/util/include/cutlass/util/gett_commandline.hpp index e2a992f8..d5be80dd 100644 --- a/tools/util/include/cutlass/util/gett_commandline.hpp +++ b/tools/util/include/cutlass/util/gett_commandline.hpp @@ -209,7 +209,7 @@ struct GettCommandLine { // // Permute the batched modes to promote coalescing - // Sort the batched modes by min(ldAl,ldBl) and tie-broken by the size + // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); @@ -227,7 +227,7 @@ struct GettCommandLine { } // Permute the reduction modes to promote coalescing - // Sort the reduction modes by min(ldAk,ldBk) and tie-broken by the size + // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); @@ -243,7 +243,7 @@ struct GettCommandLine { } // Permute the row modes to promote coalescing - // Sort the row modes by min(ldAm,ldCm) and tie-broken by ldAm + // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); @@ -259,7 +259,7 @@ struct GettCommandLine { } // Permute the col modes to promote coalescing - // Sort the col modes by min(ldBn,ldCn) and tie-broken by ldBn + // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); @@ -362,7 +362,7 @@ struct GettCommandLine { " A command delimited list of symbolic mode and its corresponding extent.\n" " Extents are defaulted to 1 if any are not provided.\n\n" - "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extent=m:4096,n:4096,k:4096\n"; + "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n"; } }; diff --git a/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/tools/util/include/cutlass/util/reference/device/gemm_complex.h index 0f3977be..04d308db 100644 --- a/tools/util/include/cutlass/util/reference/device/gemm_complex.h +++ b/tools/util/include/cutlass/util/reference/device/gemm_complex.h @@ -67,7 +67,8 @@ template < typename LayoutC, typename ScalarType, typename ComputeType, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add, int kMblock = 4, int kNblock = 4 @@ -81,7 +82,7 @@ __global__ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, - TensorRef tensor_d, + TensorRef tensor_d, ComputeType initial_accum, int batch_count = 1, int64_t batch_stride_A = 0, @@ -198,7 +199,8 @@ template < typename LayoutC, typename ScalarType, typename ComputeType, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add > void GemmComplex( @@ -210,7 +212,7 @@ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, - TensorRef tensor_d, + TensorRef tensor_d, ComputeType initial_accum, int batch_count = 1, int64_t batch_stride_A = 0, @@ -243,6 +245,7 @@ void GemmComplex( LayoutC, ScalarType, ComputeType, + ElementD, ConvertOp, InnerProductOp, kMblock, @@ -285,6 +288,7 @@ void GemmComplex( LayoutC, ScalarType, ComputeType, + ElementD, ConvertOp, InnerProductOp, kBigMblock, @@ -322,7 +326,8 @@ template < typename LayoutB, typename ElementC, typename LayoutC, - typename ScalarType + typename ScalarType, + typename ElementD = ElementC > void GemmComplex( gemm::GemmCoord problem_size, @@ -333,7 +338,7 @@ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, - TensorRef tensor_d) { + TensorRef tensor_d) { GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index b4238a0a..1b5e62bc 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -58,6 +58,8 @@ #include "cutlass/tensor_view.h" #include "cutlass/blas3.h" +#include "cutlass/layout/vector.h" + #include "cutlass/util/reference/device/tensor_foreach.h" #include "cutlass/util/distribution.h" @@ -1646,6 +1648,15 @@ void BlockFillSequential( Element v = Element(1), Element s = Element(0)) { + using Layout = layout::PackedVectorLayout; + Layout::TensorCoord size(static_cast(capacity)); // -Wconversion + Layout layout = Layout::packed(size); + TensorView view(ptr, layout, size); + + Array c; + c[0] = v; + + TensorFillLinear(view, c, s); } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/convolution.h b/tools/util/include/cutlass/util/reference/host/convolution.h index 4d8a7fc0..64c1cd9b 100644 --- a/tools/util/include/cutlass/util/reference/host/convolution.h +++ b/tools/util/include/cutlass/util/reference/host/convolution.h @@ -65,7 +65,8 @@ template < typename LayoutC, typename ElementCompute, typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add > void Conv2dFprop( @@ -73,7 +74,7 @@ void Conv2dFprop( TensorRef tensor_x, TensorRef tensor_w, TensorRef tensor_y_in, - TensorRef tensor_y_out, + TensorRef tensor_y_out, ElementCompute alpha, ElementCompute beta) { @@ -142,12 +143,13 @@ template , - typename InnerProductOp = multiply_add > + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> void Depsep_Fprop(cutlass::TensorView tensor_A, cutlass::TensorView tensor_B, cutlass::TensorView tensor_C, - cutlass::TensorView tensor_D, + cutlass::TensorView tensor_D, ElementCompute alpha, ElementCompute beta, cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), @@ -208,7 +210,8 @@ template < typename LayoutC, typename ElementCompute, typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add > void Conv2dDgrad( @@ -216,7 +219,7 @@ void Conv2dDgrad( TensorRef tensor_dy, TensorRef tensor_w, TensorRef tensor_dx_in, - TensorRef tensor_dx_out, + TensorRef tensor_dx_out, ElementCompute alpha, ElementCompute beta) { @@ -309,7 +312,8 @@ template < typename LayoutC, typename ElementCompute, typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add > void Conv2dWgrad( @@ -317,7 +321,7 @@ void Conv2dWgrad( TensorRef tensor_dy, TensorRef tensor_x, TensorRef tensor_dw_in, - TensorRef tensor_dw_out, + TensorRef tensor_dw_out, ElementCompute alpha, ElementCompute beta) { @@ -389,7 +393,8 @@ template < typename LayoutC, typename ElementCompute, typename ElementAccumulator = ElementCompute, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add > void Conv2d( @@ -398,7 +403,7 @@ void Conv2d( TensorRef tensor_A, TensorRef tensor_B, TensorRef tensor_C, - TensorRef tensor_D, + TensorRef tensor_D, ElementCompute alpha, ElementCompute beta) { @@ -409,7 +414,8 @@ void Conv2d( ElementB, LayoutB, ElementC, LayoutC, ElementCompute, - ElementAccumulator, + ElementAccumulator, + ElementD, ConvertOp, InnerProductOp >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; @@ -421,6 +427,7 @@ void Conv2d( ElementC, LayoutC, ElementCompute, ElementAccumulator, + ElementD, ConvertOp, InnerProductOp >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; @@ -431,7 +438,8 @@ void Conv2d( ElementB, LayoutB, ElementC, LayoutC, ElementCompute, - ElementAccumulator, + ElementAccumulator, + ElementD, ConvertOp, InnerProductOp >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; diff --git a/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_complex.h index f16e19c1..a884023f 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm_complex.h +++ b/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -67,7 +67,8 @@ template < typename LayoutC, typename ScalarType, typename ComputeType, - typename ConvertOp = NumericConverter, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, typename InnerProductOp = multiply_add > void GemmComplex( @@ -79,7 +80,7 @@ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, - TensorRef tensor_d, + TensorRef tensor_d, ComputeType initial_accum, int batch_count = 1, int64_t batch_stride_A = 0, @@ -185,7 +186,8 @@ template < typename LayoutB, typename ElementC, typename LayoutC, - typename ScalarType + typename ScalarType, + typename ElementD = ElementC > void GemmComplex( gemm::GemmCoord problem_size, @@ -196,7 +198,7 @@ void GemmComplex( ComplexTransform transform_b, ScalarType beta, TensorRef tensor_c, - TensorRef tensor_d) { + TensorRef tensor_d) { GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); } diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 64a0600b..f87e3d8e 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -38,6 +38,7 @@ #include "cutlass/complex.h" #include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" #include "cute/tensor.hpp" @@ -75,7 +76,11 @@ template< class ElementAccumulator_, class ElementCompute_, class TensorC_, // (M, N, L) - class TensorD_ // (M, N, L) + class TensorD_, // (M, N, L) + class TensorBias_, // (M, 1) + class TensorT_, // (M, N, L) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class BiasBinaryOp_ = cutlass::plus > struct GettEpilogueParams { using ElementScalar = ElementScalar_; @@ -83,15 +88,26 @@ struct GettEpilogueParams { using ElementCompute = ElementCompute_; using TensorC = TensorC_; using TensorD = TensorD_; + using TensorBias = TensorBias_; + using TensorT = TensorT_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + using EngineC = typename TensorC::engine_type; using LayoutC = typename TensorC::layout_type; using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; + using EngineBias = typename TensorBias::engine_type; + using LayoutBias = typename TensorBias::layout_type; + using EngineT = typename TensorT::engine_type; + using LayoutT = typename TensorT::layout_type; ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); TensorC C{}; TensorD D{}; + TensorBias Bias{}; + TensorT T{}; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -204,19 +220,33 @@ void gett_epilogue( using ElementC = typename EpilogueParams::EngineC::value_type; using ElementD = typename EpilogueParams::EngineD::value_type; + using ElementBias = typename EpilogueParams::EngineBias::value_type; + using ElementT = typename EpilogueParams::EngineT::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + // Input related converter NumericConverter accumulator_converter; NumericConverter source_converter; + NumericConverter bias_converter; // Scale related converter NumericConverter scale_converter; // Output related converter NumericConverter destination_converter; + NumericConverter temporary_converter; // Epilogue operations multiply_add epilogue_fma; multiplies mul; + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + // Do conversion ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); ElementCompute converted_beta = scale_converter(epilogue_params.beta); @@ -225,10 +255,24 @@ void gett_epilogue( if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { // Convert every type to ElementCompute first, do compute, convert to output type, write it out ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - ElementScalar output = epilogue_fma(converted_alpha, converted_acc, ElementCompute(0)); - output = epilogue_fma(converted_beta, converted_src, output); + ElementCompute output = mul(converted_alpha, converted_acc); + if (epilogue_params.Bias.data()) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(m + m_b)); + output = bias_op(output, converted_bias); + } + + if (epilogue_params.C.data()) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + output = epilogue_fma(converted_beta, converted_src, output); + } + + if (epilogue_params.T.data()) { + // Store intermediate output + epilogue_params.T(m + m_b, n + n_b, l) = temporary_converter(output); + } + + output = activation(output); epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); } @@ -238,6 +282,14 @@ void gett_epilogue( ///////////////////////////////////////////////////////////////////////////////////////////////// +template +auto make_layout_rank3(const TensorType& tensor) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + return make_layout( + make_shape(get<0>(tensor.shape()), get<1>(tensor.shape()), Int<1>{}), + make_stride(get<0>(tensor.stride()), get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); +} + /// GEMM - General Matrix-Matrix contraction without conjugation options template < class MainloopParams, @@ -254,26 +306,20 @@ void Gemm3x( static_assert(rank(typename MainloopParams::LayoutA{}) == rank(typename EpilogueParams::LayoutC{})); if constexpr (rank(typename MainloopParams::LayoutA{}) == 2) { - // append a batch mode of size 1 if we do not have tensors that are rank 3 - Layout layout_A = make_layout( - make_shape(get<0>(mainloop_params.A.shape()), get<1>(mainloop_params.A.shape()), Int<1>{}), - make_stride(get<0>(mainloop_params.A.stride()), get<1>(mainloop_params.A.stride()), int64_t(cosize(mainloop_params.A.layout())))); - - Layout layout_B = make_layout( - make_shape(get<0>(mainloop_params.B.shape()), get<1>(mainloop_params.B.shape()), Int<1>{}), - make_stride(get<0>(mainloop_params.B.stride()), get<1>(mainloop_params.B.stride()), int64_t(cosize(mainloop_params.B.layout())))); - - Layout layout_C = make_layout( - make_shape(get<0>(epilogue_params.C.shape()), get<1>(epilogue_params.C.shape()), Int<1>{}), - make_stride(get<0>(epilogue_params.C.stride()), get<1>(epilogue_params.C.stride()), int64_t(cosize(epilogue_params.C.layout())))); - - Layout layout_D = make_layout( - make_shape(get<0>(epilogue_params.D.shape()), get<1>(epilogue_params.D.shape()), Int<1>{}), - make_stride(get<0>(epilogue_params.D.stride()), get<1>(epilogue_params.D.stride()), int64_t(cosize(epilogue_params.D.layout())))); + Layout layout_A = make_layout_rank3(mainloop_params.A); + Layout layout_B = make_layout_rank3(mainloop_params.B); + Layout layout_C = make_layout_rank3(epilogue_params.C); + Layout layout_D = make_layout_rank3(epilogue_params.D); + Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + Layout layout_T = make_layout_rank3(epilogue_params.T); + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + auto TensorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto TensorT = make_tensor(epilogue_params.T.data(), layout_T); + // Reconstruct mainloop params GettMainloopParams epilogue_params_converted{epilogue_params.alpha, epilogue_params.beta, TensorC, - TensorD + TensorD, + TensorBias, + TensorT }; Gett(mainloop_params_converted, epilogue_params_converted); diff --git a/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/tools/util/include/cutlass/util/reference/host/tensor_compare.h index f9a362e3..20187aba 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -39,11 +39,11 @@ // Cutlass includes #include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" #include "cutlass/tensor_view.h" #include "cutlass/tensor_view_planar_complex.h" #include "cutlass/util/distribution.h" -//#include "cutlass/util/type_traits.h" #include "tensor_foreach.h" namespace cutlass { @@ -83,10 +83,55 @@ struct TensorEqualsFunc { Element lhs_ = lhs.at(coord); Element rhs_ = rhs.at(coord); - + if (lhs_ != rhs_) { result = false; - } + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorRelativelyEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + Element epsilon; + Element nonzero_floor; + bool result; + + /// Ctor + TensorRelativelyEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_, + Element epsilon_, + Element nonzero_floor_ + ) : + lhs(lhs_), + rhs(rhs_), + epsilon(epsilon_), + nonzero_floor(nonzero_floor_), + result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { + result = false; + } } /// Returns true if equal @@ -104,7 +149,7 @@ template < typename Element, ///< Element type typename Layout> ///< Layout function bool TensorEquals( - TensorView const &lhs, + TensorView const &lhs, TensorView const &rhs) { // Extents must be identical @@ -126,7 +171,7 @@ template < typename Element, ///< Element type typename Layout> ///< Layout function bool TensorEquals( - TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &lhs, TensorViewPlanarComplex const &rhs) { // Extents must be identical @@ -135,7 +180,7 @@ bool TensorEquals( } detail::TensorEqualsFunc real_func( - {lhs.data(), lhs.layout(), lhs.extent()}, + {lhs.data(), lhs.layout(), lhs.extent()}, {rhs.data(), rhs.layout(), rhs.extent()} ); @@ -164,12 +209,85 @@ bool TensorEquals( /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorView const &lhs, + TensorView const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + /// Returns true if two tensor views are NOT equal. template < typename Element, ///< Element type typename Layout> ///< Layout function bool TensorNotEquals( - TensorView const &lhs, + TensorView const &lhs, TensorView const &rhs) { // Extents must be identical @@ -191,7 +309,7 @@ template < typename Element, ///< Element type typename Layout> ///< Layout function bool TensorNotEquals( - TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &lhs, TensorViewPlanarComplex const &rhs) { return !TensorEquals(lhs, rhs); @@ -235,7 +353,7 @@ struct TensorContainsFunc { if (view.at(coord) == value) { if (!contains) { - location = coord; + location = coord; } contains = true; } diff --git a/tools/util/include/cutlass/util/tensor_view_io.h b/tools/util/include/cutlass/util/tensor_view_io.h index 6a352df2..51e47b92 100644 --- a/tools/util/include/cutlass/util/tensor_view_io.h +++ b/tools/util/include/cutlass/util/tensor_view_io.h @@ -96,12 +96,16 @@ inline std::ostream & TensorView_WriteRank( if (rank + 2 == Layout::kRank) { // Write least significant ranks asa matrix with rows delimited by "\n" - out << (idx ? ",\n" : ""); + if (idx) { + out << ",\n"; + } TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); } else { // Higher ranks are separated by newlines - out << (idx ? ",\n\n" : ""); + if (idx) { + out << ",\n\n"; + } TensorView_WriteRank(out, view, coord, rank + 1, width); } } @@ -166,12 +170,16 @@ inline std::ostream & TensorViewPlanarComplex_WriteRank( if (rank + 2 == Layout::kRank) { // Write least significant ranks asa matrix with rows delimited by ";\n" - out << (idx ? ";\n" : ""); + if (idx) { + out << ";\n"; + } TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); } else { // Higher ranks are separated by newlines - out << (idx ? "\n" : ""); + if (idx) { + out << "\n"; + } TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); } }