diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md index c4556d1c..4d13ea3c 100644 --- a/examples/40_cutlass_py/README.md +++ b/examples/40_cutlass_py/README.md @@ -2,7 +2,6 @@ This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples: * _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations * [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel ->>>>>>> Add simplified examples ## Setting up the Python interface Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API. diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py index 50854d38..a4d2dcef 100644 --- a/examples/40_cutlass_py/conv2d.py +++ b/examples/40_cutlass_py/conv2d.py @@ -41,7 +41,7 @@ import cutlass import pycutlass from pycutlass import * -import util +from pycutlass.utils.device import device_cc parser = argparse.ArgumentParser( @@ -62,7 +62,7 @@ sys.exit(0) # Check that the device is of a sufficient compute capability -cc = util.get_device_cc() +cc = device_cc() assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70." alignment = 1 @@ -82,8 +82,17 @@ element_acc = cutlass.float32 element_epilogue = cutlass.float32 +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + instruction_shape = [16, 8, 16] + math_inst = MathInstruction( - [16, 8, 8], # Shape of the Tensor Core instruction + instruction_shape, A.element, B.element, element_acc, cutlass.OpClass.TensorOp, MathOperation.multiply_add diff --git a/examples/40_cutlass_py/customizable/conv2d.py b/examples/40_cutlass_py/customizable/conv2d.py index 365ce3e5..d890cf1b 100644 --- a/examples/40_cutlass_py/customizable/conv2d.py +++ b/examples/40_cutlass_py/customizable/conv2d.py @@ -34,6 +34,7 @@ from pycutlass import * from pycutlass.conv2d_operation import * from pycutlass.utils import reference_model +from pycutlass.utils.device import device_cc import sys import torch.nn.functional as F @@ -146,6 +147,11 @@ except: sys.exit(0) +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) np.random.seed(0) diff --git a/examples/40_cutlass_py/customizable/gemm.py b/examples/40_cutlass_py/customizable/gemm.py index dd6a7a4a..294914e7 100644 --- a/examples/40_cutlass_py/customizable/gemm.py +++ b/examples/40_cutlass_py/customizable/gemm.py @@ -34,6 +34,7 @@ from pycutlass import * import cutlass from bfloat16 import bfloat16 +from pycutlass.utils.device import device_cc import sys import argparse @@ -131,12 +132,16 @@ parser.add_argument('--print_cuda', action="store_true", help="print the underlying CUDA kernel") - try: args = parser.parse_args() except: sys.exit(0) +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) pycutlass.compiler.nvcc() diff --git a/examples/40_cutlass_py/customizable/gemm_grouped.py b/examples/40_cutlass_py/customizable/gemm_grouped.py index 40f2bc8d..f995e882 100644 --- a/examples/40_cutlass_py/customizable/gemm_grouped.py +++ b/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -32,6 +32,7 @@ import numpy as np import pycutlass from pycutlass import * +from pycutlass.utils.device import device_cc import csv import sys @@ -129,6 +130,11 @@ except: sys.exit(0) +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) np.random.seed(0) diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py index 341177cf..88f2b0a7 100644 --- a/examples/40_cutlass_py/gemm.py +++ b/examples/40_cutlass_py/gemm.py @@ -40,7 +40,7 @@ import cutlass import pycutlass from pycutlass import * -import util +from pycutlass.utils.device import device_cc parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'") @@ -55,7 +55,7 @@ sys.exit(0) # Check that the device is of a sufficient compute capability -cc = util.get_device_cc() +cc = device_cc() assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70." alignment = 8 @@ -78,13 +78,23 @@ element_acc = cutlass.float32 element_epilogue = cutlass.float32 +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + instruction_shape = [16, 8, 16] + math_inst = MathInstruction( - [16, 8, 8], # Shape of the Tensor Core instruction + instruction_shape, A.element, B.element, element_acc, cutlass.OpClass.TensorOp, MathOperation.multiply_add ) + tile_description = TileDescription( [128, 128, 32], # Threadblock shape 2, # Number of stages diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py index f62d8009..0ac804e4 100644 --- a/examples/40_cutlass_py/gemm_grouped.py +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -40,7 +40,7 @@ import cutlass import pycutlass from pycutlass import * -import util +from pycutlass.utils.device import device_cc parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python") @@ -52,7 +52,7 @@ sys.exit(0) # Check that the device is of a sufficient compute capability -cc = util.get_device_cc() +cc = device_cc() assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70." np.random.seed(0) @@ -71,8 +71,17 @@ element_acc = cutlass.float32 element_epilogue = cutlass.float32 +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + instruction_shape = [16, 8, 16] + math_inst = MathInstruction( - [16, 8, 8], # Shape of the Tensor Core instruction + instruction_shape, A.element, B.element, element_acc, cutlass.OpClass.TensorOp, MathOperation.multiply_add diff --git a/tools/library/scripts/pycutlass/README.md b/tools/library/scripts/pycutlass/README.md index 1fd905e6..2843298b 100644 --- a/tools/library/scripts/pycutlass/README.md +++ b/tools/library/scripts/pycutlass/README.md @@ -102,8 +102,10 @@ Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutl ## Test The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with ```shell +# Each of these tests are only supported on devices with compute capability of SM80. For other devices, +# see the basic examples in $CUTLASS_PATH/examples/40_cutlass_py cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && run_all_example.sh +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && bash run_all_example.sh ``` ## build documentation diff --git a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py index 5b50c2f7..088c4a85 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py @@ -308,7 +308,7 @@ def emit_compile_(self, operation_list, compilation_options): cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host for opt in options: opt = opt.decode("utf-8") - if opt not in ['-default-device', '-std=c++11', '-arch=sm_80', '-Xcicc', '-Xllc']: + if opt not in ['-default-device', '-std=c++11', '-Xcicc', '-Xllc'] and '-arch=sm_' not in opt: if '--include-path=' in opt: cmd += " " + opt.replace('--include-path=', '-I') else: diff --git a/examples/40_cutlass_py/util.py b/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py similarity index 75% rename from examples/40_cutlass_py/util.py rename to tools/library/scripts/pycutlass/src/pycutlass/utils/device.py index d37bd045..a69929f4 100644 --- a/examples/40_cutlass_py/util.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py @@ -31,14 +31,22 @@ ################################################################################################# """ -Utility functions for interacting with device +Utility functions for interacting with the device """ from cuda import cudart -# Raises an exception if `result` returned an error. Otherwise returns the result. def check_cuda_errors(result: list): + """ + Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise, + returns the result contained in the remaining fields of `result`. + + :param result: the results of the `cudart` method, consisting of an error code and any method results + :type result: list + + :return: non-error-code results from the `results` parameter + """ # `result` is of the format : (cudaError_t, result...) err = result[0] if err.value: @@ -52,9 +60,17 @@ def check_cuda_errors(result: list): return result[1:] -# Returns the integer representation of the device compute capability -def get_device_cc(device: int = 0): +def device_cc(device: int = 0) -> int: + """ + Returns the compute capability of the device with ID `device`. + + :param device: ID of the device to query + :type device: int + + :return: compute capability of the queried device (e.g., 80 for SM80) + :rtype: int + """ deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) major = str(deviceProp.major) minor = str(deviceProp.minor) - return int(major + minor) + return int(major + minor) \ No newline at end of file diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 24d70376..6948d274 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -2,9 +2,11 @@ from pycutlass.conv2d_operation import * from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index e9ce7460..26741ced 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage3(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 0351bc5c..821f99c7 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -3,8 +3,11 @@ from pycutlass.conv2d_operation import * from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 21061729..210c2ba3 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index fb4f2434..54dbea96 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -1,8 +1,11 @@ # test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") def conv2d_few_channel_problemsizes(channels): problem_sizes = [ cutlass.conv.Conv2dProblemSize( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py index cf46d0b5..4be81f99 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -1,8 +1,11 @@ # test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu import pycutlass from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") def conv2d_fixed_channel_problemsizes(channels): problem_sizes = [ cutlass.conv.Conv2dProblemSize( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index e2a4ccc3..49d59c1a 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 101aaa1d..36d115e4 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index 412e199a..578b5fd8 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -3,8 +3,11 @@ from pycutlass.conv2d_operation import * from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index 4585d66c..aa9f1da6 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 4ce627c4..5e4ce635 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py index 533e66c4..64b40dd7 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase): def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py index 2399a1e1..96f9ff36 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase): def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py index c932d808..a42a0980 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -3,8 +3,11 @@ from pycutlass.conv2d_operation import * from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py index a69274fc..b64bd39f 100644 --- a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -2,8 +2,11 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase): def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py index 59b5549a..f547776d 100644 --- a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +++ b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py @@ -33,6 +33,7 @@ import pycutlass import unittest from pycutlass import * +from pycutlass.utils.device import device_cc import torch import cupy as cp @@ -42,13 +43,18 @@ def setUp(self) -> None: # # define the cutlass operator # + cc = device_cc() math_inst = MathInstruction( [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, cutlass.OpClass.Simt, MathOperation.multiply_add ) + # Stages > 2 is supported only for compute capability 80 and beyond + stages = 4 if cc >= 80 else 2 + + tile_description = TileDescription( - [128, 128, 8], 4, [2, 4, 1], + [128, 128, 8], stages, [2, 4, 1], math_inst ) @@ -69,7 +75,7 @@ def setUp(self) -> None: math_inst.element_accumulator, cutlass.float32) self.operation = GemmOperationUniversal( - arch=80, tile_description=tile_description, + arch=cc, tile_description=tile_description, A=A, B=B, C=C, epilogue_functor=epilogue_functor, swizzling_functor=cutlass.IdentitySwizzle1 diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py index b4505c65..b03e2431 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py @@ -4,7 +4,10 @@ import unittest from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmBF16TensorOpSm80(unittest.TestCase): def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py index 5bef482d..6ffb04a5 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py @@ -4,8 +4,10 @@ import unittest from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmF16Sm80(unittest.TestCase): def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py index 960bdd39..ad48d0dd 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py @@ -5,8 +5,10 @@ import unittest from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py index 1e1778a9..11d26683 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py @@ -4,7 +4,10 @@ import unittest from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmF64TensorOpSm80(unittest.TestCase): def test_SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64_32x32x16_16x16x16(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py index 451a91ac..c7acd742 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py @@ -4,8 +4,10 @@ import unittest from pycutlass.test.gemm_grouped_testbed import TestbedGrouped +from pycutlass.utils.device import device_cc +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmGroupedSm80(unittest.TestCase): def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x32(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py index 0a76198a..7ddeebbc 100644 --- a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py @@ -5,7 +5,10 @@ import unittest from pycutlass.test.gemm_testbed import test_all_gemm +from pycutlass.utils.device import device_cc + +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") class GemmS8TensorOpF32Sm80(unittest.TestCase): def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self): math_inst = MathInstruction( diff --git a/tools/library/scripts/pycutlass/test/unit/test_sm80.py b/tools/library/scripts/pycutlass/test/unit/test_sm80.py index 0dd685de..446d3720 100644 --- a/tools/library/scripts/pycutlass/test/unit/test_sm80.py +++ b/tools/library/scripts/pycutlass/test/unit/test_sm80.py @@ -35,12 +35,14 @@ import pycutlass from pycutlass import * from pycutlass.test import * +from pycutlass.utils.device import device_cc import unittest # # Create GEMM operation # +@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.") def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False, epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): """