diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc index b25f55062e109..5eb05edefdcfc 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc @@ -4,6 +4,8 @@ #include #include #include +#include + #include "python/tools/kernel_explorer/device_array.h" #include "python/tools/kernel_explorer/kernel_explorer_interface.h" @@ -13,6 +15,10 @@ namespace onnxruntime { static py::module::module_def _kernel_explorer_module_def; +bool TuningInfo::collect_enabled_{false}; +std::vector TuningInfo::collected_tuning_results_ = {}; +std::optional TuningInfo::max_tuning_duration_ms_ = {}; + py::module GetKernelExplorerModule() { static pybind11::module_ m = []() { auto tmp = pybind11::module_::create_extension_module( @@ -36,11 +42,29 @@ KE_REGISTER(m) { .def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray) .def("UpdateDeviceArray", &DeviceArray::UpdateDeviceArray); + m.def("enable_collect_tuning_results", TuningInfo::EnableCollect, pybind11::arg("enable") = true); + + m.def("max_tuning_duration_ms", TuningInfo::SetMaxTuningDurationMs); + + m.def("get_collected_tuning_results", []() { + py::list ret; + for (const auto& trs : TuningInfo::GetCollectedTuningResults()) { + py::dict py_trs; + py_trs["ep"] = trs.ep; + py_trs["results"] = trs.results; + py_trs["validators"] = trs.validators; + ret.append(std::move(py_trs)); + } + return ret; + }); + + // clang-format ill-format the following code below version 18 + // clang-format off m.def("is_composable_kernel_available", []() { #ifdef USE_COMPOSABLE_KERNEL return true; #else - return false; + return false; #endif }); @@ -48,7 +72,7 @@ KE_REGISTER(m) { #ifdef USE_HIPBLASLT return true; #else - return false; + return false; #endif }); @@ -56,9 +80,10 @@ KE_REGISTER(m) { #ifndef DISABLE_FLOAT8_TYPES return true; #else - return false; + return false; #endif }); + // clang-format on } } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h index 9eb0adcede04b..1c7232e6a5cd0 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer_interface.h @@ -36,6 +36,24 @@ using TuningContextT = onnxruntime::rocm::tunable::RocmTuningContext; namespace onnxruntime { +struct TuningInfo { + static void EnableCollect(bool b) { + collect_enabled_ = b; + } + + static std::vector GetCollectedTuningResults() { + return collected_tuning_results_; + } + + static void SetMaxTuningDurationMs(int milliseconds) { + max_tuning_duration_ms_ = milliseconds; + } + + static bool collect_enabled_; + static std::vector collected_tuning_results_; + static std::optional max_tuning_duration_ms_; +}; + /// Wrapping around Op and TunableOp class IKernelExplorer { public: @@ -59,7 +77,11 @@ class IKernelExplorer { return timer.Duration() / repeats_; } - virtual ~IKernelExplorer() = default; + virtual ~IKernelExplorer() { + if (TuningInfo::collect_enabled_) { + TuningInfo::collected_tuning_results_.emplace_back(this->ep_->GetTuningContext()->GetTuningResults()); + } + } protected: ExecutionProvider* GetEp() { @@ -73,6 +95,15 @@ class IKernelExplorer { auto tuning_ctx = this->ep_->GetTuningContext(); if (nullptr != tuning_ctx) { tuning_ctx->RegisterAllocatorsView(&this->allocators_); + for (const auto& tr : TuningInfo::collected_tuning_results_) { + auto status = tuning_ctx->LoadTuningResults(tr); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << status; + } + } + if (TuningInfo::max_tuning_duration_ms_.has_value()) { + tuning_ctx->SetMaxTuningDurationMs(*TuningInfo::max_tuning_duration_ms_); + } } stream_ = std::make_unique(nullptr, this->ep_->GetOrtDeviceByMemType(OrtMemTypeDefault)); }); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi b/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi index 94213aceed08c..4682f7135d7a3 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi +++ b/onnxruntime/python/tools/kernel_explorer/kernels/_kernel_explorer.pyi @@ -14,3 +14,7 @@ class qkv_format: # noqa: N801 Q_KV_BSNH_BSN2H: int def is_composable_kernel_available(*args, **kwargs): ... +def is_hipblaslt_available(*args, **kwargs): ... + +def enable_collect_tuning_results(*args, **kwargs): ... +def get_collected_tuning_results(*args, **kwargs): ... diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py index cc5a918735536..5f24867901570 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/batched_gemm_test.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import os -import sys from dataclasses import dataclass from itertools import product @@ -23,6 +22,7 @@ def dtype_to_suffix(dtype): }[dtype] +@ke.dispatchable def _test_batched_gemm( func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int, alpha=1.0, beta=0.0 ): @@ -148,6 +148,7 @@ def report(self): return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common +@ke.dispatchable(pattern_arg=0) def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int): a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -177,12 +178,13 @@ def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, ke.report(BatchedGemmMetric(impl, dtype, duration_ms, flops, transa, transb, m, n, k, batch)) -def profile_with_args(dtype, transa, transb, m, n, k, batch, sort): +@ke.dispatchable +def profile_with_args(dtype, transa, transb, m, n, k, batch): dtype_suffix = "_" + dtype_to_suffix(dtype) transab_suffix = "_" + transab_to_suffix((transa, transb)) fn_rocblas = getattr(ke, "RocBlasBatchedGemm" + dtype_suffix) fn_tunable = getattr(ke, "BatchedGemmTunable" + dtype_suffix + transab_suffix) - with ke.benchmark(sort): + with ke.benchmark(): profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch) profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch) print() @@ -192,14 +194,12 @@ def profile(): for dtype in dtypes: for m, n, k in get_gemm_bert_sizes(full=False): for batch in [1, 32, 64]: - profile_with_args(dtype, False, False, m, n, k, batch, True) + profile_with_args(dtype, False, False, m, n, k, batch) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("dtype", choices=dtypes) group.add_argument("transa", choices="NT") group.add_argument("transb", choices="NT") @@ -207,12 +207,9 @@ def profile(): group.add_argument("n", type=int) group.add_argument("k", type=int) group.add_argument("batch", type=int) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args( - args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch, args.sort - ) + args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py index 7088039f9e531..ba049fad773aa 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import sys from dataclasses import dataclass import kernel_explorer as ke @@ -31,6 +30,7 @@ def report(self): return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} n={self.n} k={self.k} {self.name}" +@ke.dispatchable(pattern_arg=3) def profile_dequantize_int4_func(n, k, dtype, func): np.random.seed(0) output = np.random.rand(n, k).astype(dtype) @@ -48,8 +48,9 @@ def profile_dequantize_int4_func(n, k, dtype, func): ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k)) -def profile_with_args(n, k, dtype, sort): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(n, k, dtype): + with ke.benchmark(): for func in dtype_to_funcs(dtype): profile_dequantize_int4_func(n, k, dtype, func) @@ -57,22 +58,19 @@ def profile_with_args(n, k, dtype, sort): def profile(): for dt in dtypes: for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): - profile_with_args(n, k, dt, True) + profile_with_args(n, k, dt) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("n", type=int) group.add_argument("k", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.n, args.k, args.dtype, args.sort) + args.dispatch(args.n, args.k, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py index 913a2c31a5f10..425d8843814c3 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/elementwise_test.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import re -import sys from dataclasses import dataclass from itertools import product @@ -90,6 +89,7 @@ def report(self): return "not supported " + common +@ke.dispatchable(pattern_arg=4) def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func): x_size = [batch_size, seq_len, hidden_size] bias_size = hidden_size @@ -112,8 +112,9 @@ def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func): ke.report(ElementwiseMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size)) -def profile_with_args(batch_size, seq_len, hidden_size, fn_name, dtype, sort): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(batch_size, seq_len, hidden_size, fn_name, dtype): + with ke.benchmark(): for func in dtype_to_funcs(fn_name, dtype): profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func) @@ -121,24 +122,21 @@ def profile_with_args(batch_size, seq_len, hidden_size, fn_name, dtype, sort): def profile(): for dtype in dtypes: for bert_size in get_bert_sizes(): - profile_with_args(*bert_size, "FastGeLU", dtype, True) + profile_with_args(*bert_size, "FastGeLU", dtype) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("batch_size", type=int) group.add_argument("seq_len", type=int) group.add_argument("hidden_size", type=int) group.add_argument("fn_name", choices=fn_names) group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.fn_name, args.dtype, args.sort) + args.dispatch(args.batch_size, args.seq_len, args.hidden_size, args.fn_name, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py index 9b308c09811d1..8ee9c6bc0f040 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_fast_gelu_test.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import sys from dataclasses import dataclass from itertools import product @@ -120,6 +119,7 @@ def report(self): return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common +@ke.dispatchable(pattern_arg=0) def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool): a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -153,10 +153,11 @@ def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, trans ke.report(GemmFastGeluMetric(impl, dtype, duration_ms, floating_point_operations, transa, transb, m, n, k)) -def profile_with_args(transa, transb, dtype, m, n, k, sort): +@ke.dispatchable +def profile_with_args(transa, transb, dtype, m, n, k): dtype_suffix = "_" + dtype_to_suffix(dtype) transab_suffix = "_" + transab_to_suffix((transa, transb)) - with ke.benchmark(sort): + with ke.benchmark(): profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluUnfused" + dtype_suffix), dtype, m, n, k, transa, transb) profile_gemmfastgelu_func( getattr(ke, "CKGemmFastGelu" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb @@ -173,24 +174,22 @@ def profile_with_args(transa, transb, dtype, m, n, k, sort): def profile(): for dtype in dtypes: for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(False, False, dtype, m, n, k, True) + profile_with_args(False, False, dtype, m, n, k) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("transa", choices="NT") group.add_argument("transb", choices="NT") group.add_argument("dtype", choices=dtypes) group.add_argument("m", type=int) group.add_argument("n", type=int) group.add_argument("k", type=int) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k, args.sort) + args.dispatch(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py index 19a1008b3947a..76d0b2a3138bc 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_float8_test.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import sys from dataclasses import dataclass import kernel_explorer as ke @@ -43,6 +42,7 @@ def cast_and_scale(a, dtype: str): raise ValueError(dtype) +@ke.dispatchable(pattern_arg=0) def _test_gemm( func, dta: str, dtb: str, dtc: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0 ): @@ -154,6 +154,7 @@ def _test_gemm( ) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dta, dtb, dtc", dtypes) +@ke.dispatchable def test_ck_gemm(dta, dtb, dtc, transa, transb, m, n, k): if dtb == "float16" and transb: pytest.skip("Only supports transb when b is fp8") @@ -206,6 +207,7 @@ def report(self): return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops {self.gbps:5.2f} GB/s " + common +@ke.dispatchable(pattern_arg=0) def profile_gemm_func( func, dta: str, dtb: str, dtc: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0 ): @@ -264,10 +266,11 @@ def profile_gemm_func( ke.report(GemmMetric(impl, f"{dta}_{dtb}_{dtc}", duration_ms, FLOPs, total_bytes, transa, transb, m, n, k)) -def profile_with_args(dta, dtb, dtc, transa, transb, m, n, k, sort): +@ke.dispatchable +def profile_with_args(dta, dtb, dtc, transa, transb, m, n, k): dtype_suffix = "_" + dtype_to_suffix(dta) + "_" + dtype_to_suffix(dtb) + "_" + dtype_to_suffix(dtc) transab_suffix = "_" + transab_to_suffix((transa, transb)) - with ke.benchmark(sort): + with ke.benchmark(): profile_gemm_func( getattr(ke, "GemmFloat8CK" + dtype_suffix + transab_suffix), dta, dtb, dtc, transa, transb, m, n, k ) @@ -280,14 +283,12 @@ def profile_with_args(dta, dtb, dtc, transa, transb, m, n, k, sort): def profile(): for dta, dtb, dtc in dtypes: for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(dta, dtb, dtc, False, False, m, n, k, True) + profile_with_args(dta, dtb, dtc, False, False, m, n, k) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("dta", choices=["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) group.add_argument("dtb", choices=["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) group.add_argument("dtc", choices=["float8_e4m3fn", "float8_e4m3fnuz", "float16"]) @@ -296,12 +297,9 @@ def profile(): group.add_argument("m", type=int) group.add_argument("n", type=int) group.add_argument("k", type=int) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args( - args.dta, args.dtb, args.dtc, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.sort - ) + args.dispatch(args.dta, args.dtb, args.dtc, args.transa == "T", args.transb == "T", args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 802d924c27b62..8a6713f6e03a1 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -5,7 +5,6 @@ import os -import sys from dataclasses import dataclass from itertools import product @@ -432,6 +431,7 @@ def report(self): return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common +@ke.dispatchable(pattern_arg=0) def profile_gemm_softmax_gemm_permute_func( f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ): @@ -519,10 +519,21 @@ def profile_gemm_softmax_gemm_permute_func( ) +@ke.dispatchable def profile_with_args( - dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, causal, mask_dim, scale, qkv_format, *, sort=False + dtype, + batch, + seqlen, + total_seqlen, + num_heads, + head_size, + biased, + causal, + mask_dim, + scale, + qkv_format, ): - with ke.benchmark(sort): + with ke.benchmark(): args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format) if qkv_format == ke.qkv_format.Q_K_V_BNSH: profile_gemm_softmax_gemm_permute_func( @@ -551,21 +562,17 @@ def profile(): mask_dim=0, qkv_format=getattr(ke.qkv_format, qkv_format_name), scale=0.125, - sort=True, ) print() for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, causals, mask_dims): - profile_with_args(*args, qkv_format=ke.qkv_format.Q_K_V_BNSH, scale=0.125, sort=True) + profile_with_args(*args, qkv_format=ke.qkv_format.Q_K_V_BNSH, scale=0.125) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") - group.add_argument("--sort", action="store_true") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("dtype", choices=dtypes) group.add_argument("batch", type=int) group.add_argument("seqlen", type=int) @@ -587,12 +594,11 @@ def profile(): ], ) - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - print(args) - profile_with_args( + args.dispatch( args.dtype, args.batch, args.seqlen, @@ -604,5 +610,4 @@ def profile(): args.mask_dim, args.scale, getattr(ke.qkv_format, args.qkv_format), - sort=args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index 8182cdb17567c..23ffa5735d2c1 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import sys from dataclasses import dataclass from itertools import product @@ -13,8 +12,9 @@ from utils import dtype_to_suffix, get_gemm_basic_sizes, get_gemm_bert_sizes, get_gemm_bound, matmul, transab_to_suffix +@ke.dispatchable def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, alpha=1.0, beta=0.0): - assert dtype in ["float32", "float16"] + assert dtype in ["float32", "float16", "float8_e4m3"] a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -76,6 +76,7 @@ def _test_gemm(func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=True) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k): _test_gemm(getattr(ke, "RocBlasGemm_" + dtype_to_suffix(dtype)), dtype, transa, transb, m, n, k) @@ -84,6 +85,7 @@ def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k): @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_ck_gemm_bert_cases(dtype, transa, transb, m, n, k): wrapper_name = f"CKGemm_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k) @@ -93,6 +95,7 @@ def test_ck_gemm_bert_cases(dtype, transa, transb, m, n, k): @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_gemm_tunable_bert_cases(dtype, transa, transb, m, n, k): wrapper_name = f"GemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" _test_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k) @@ -142,6 +145,7 @@ def report(self): return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common +@ke.dispatchable(pattern_arg=0) def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int): a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -172,14 +176,17 @@ def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, ke.report(GemmMetric(impl, dtype, duration_ms, FLOPs, transa, transb, m, n, k)) -def profile_with_args(dtype, transa, transb, m, n, k, sort): +@ke.dispatchable +def profile_with_args(dtype, transa, transb, m, n, k): dtype_suffix = "_" + dtype_to_suffix(dtype) transab_suffix = "_" + transab_to_suffix((transa, transb)) - with ke.benchmark(sort): - profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k) - profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) + with ke.benchmark(): + if ke.is_rocm_available(): + profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k) + profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) - profile_gemm_func(getattr(ke, "GemmBenchmark" + dtype_suffix), dtype, transa, transb, m, n, k) + if ke.is_cuda_available(): + profile_gemm_func(getattr(ke, "GemmBenchmark" + dtype_suffix), dtype, transa, transb, m, n, k) if ke.is_hipblaslt_available(): profile_gemm_func( getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k @@ -190,24 +197,21 @@ def profile_with_args(dtype, transa, transb, m, n, k, sort): def profile(): for dtype in dtypes: for m, n, k in get_gemm_bert_sizes(full=True): - profile_with_args(dtype, False, False, m, n, k, True) + profile_with_args(dtype, False, False, m, n, k) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("dtype", choices=dtypes) group.add_argument("transa", choices="NT") group.add_argument("transb", choices="NT") group.add_argument("m", type=int) group.add_argument("n", type=int) group.add_argument("k", type=int) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.sort) + args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 400a9d8a7a187..a45b9e80500cc 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import re -import sys from dataclasses import dataclass from itertools import product @@ -181,6 +180,7 @@ def report(self): return "not supported " + common +@ke.dispatchable(pattern_arg=8) def profile_group_norm_func( batch_size: int, height: int, @@ -257,8 +257,9 @@ def profile_group_norm_func( ) -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True): + with ke.benchmark(): for func in dtype_to_funcs(dtype): profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func) # ck function @@ -293,10 +294,8 @@ def profile(): if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("batch_size", type=int) group.add_argument("height", type=int) group.add_argument("width", type=int) @@ -305,13 +304,12 @@ def profile(): group.add_argument("dtype", choices=dtypes) group.add_argument("--silu", action="store_true") group.add_argument("--has_skip", action="store_true") - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args( + args.dispatch( args.batch_size, args.height, args.width, @@ -320,5 +318,4 @@ def profile(): args.dtype, args.silu, args.has_skip, - args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py b/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py index 289dad22379bc..66e1a8052ce84 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/kernel_explorer.py @@ -5,12 +5,19 @@ """This file provides wrapper for native _kernel_explorer.so library and benchmark reporter for operator""" +from __future__ import annotations + import ctypes +import json import os import sys from abc import abstractmethod +from argparse import Action, ArgumentParser from contextlib import contextmanager from dataclasses import dataclass +from fnmatch import fnmatch +from functools import wraps +from typing import Callable build_dir = os.environ.get("KERNEL_EXPLORER_BUILD_DIR", None) if build_dir is None: @@ -38,10 +45,14 @@ "onnxruntime_pybind11_state.so", "libonnxruntime_providers_shared.so", ] +_is_cuda_available = False +_is_rocm_available = False if "CUDAExecutionProvider" in available_providers: library_files_to_load.append("libonnxruntime_providers_cuda.so") + _is_cuda_available = True if "ROCMExecutionProvider" in available_providers: library_files_to_load.append("libonnxruntime_providers_rocm.so") + _is_rocm_available = True library_to_load = [] @@ -56,15 +67,37 @@ # use RTLD_GLOBAL to bring all symbols to global name space -libraries = [ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) for lib_path in library_to_load] +_libraries = [ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) for lib_path in library_to_load] +del library_files_to_load, library_to_load # pylint: disable=wrong-import-position, disable=unused-import -import _kernel_explorer # noqa: E402, F401 +import _kernel_explorer # noqa: E402 # pylint: disable=wrong-import-position, disable=unused-import, disable=wildcard-import from _kernel_explorer import * # noqa: F403, E402 +@dataclass +class _KeContext: + sort: bool = False + + pattern = "*" + + # mapping the module to dispatch to + dispatchable: dict | None = None + instance_dispatchable: dict | None = None # can be filtered with pattern + + dispatch_depth = 0 + + save_tuning_results: str | None = None + return_tuning_results: bool = False + + +_ke_context = _KeContext() +_ke_context.dispatchable = {} +_ke_context.instance_dispatchable = {} + + # Benchmark Reporter @dataclass class MetricBase: @@ -114,30 +147,34 @@ class ComputeAndBandwidthMetric(ComputeMetric, BandwidthMetric): class InstanceBenchmarkReporter: def __init__(self): - self.sort = False + self.best = float("inf") self.reporters = [] - def set_sort(self, sort): - self.sort = sort - def make_report(self): self.reporters.sort() for item in self.reporters: - print(item.report()) + if not _ke_context.sort and item.milliseconds_duration > 0 and item.milliseconds_duration < self.best: + self.best = item.milliseconds_duration + print(item.report(), "*") + else: + print(item.report()) self.reporters.clear() def receive(self, status): self.reporters.append(status) - if not self.sort: + if not _ke_context.sort: self.make_report() + def _reset_best(self): + self.best = float("inf") + _reporter = InstanceBenchmarkReporter() @contextmanager -def benchmark(sort): - _reporter.set_sort(sort) +def benchmark(): + _reporter._reset_best() try: yield finally: @@ -146,3 +183,182 @@ def benchmark(sort): def report(status): _reporter.receive(status) + + +def set_ort_severity(v): + v = int(v) + onnxruntime_pybind11_state.set_default_logger_severity(v) + return v + + +def set_ort_verbosity(v): + v = int(v) + onnxruntime_pybind11_state.set_default_logger_verbosity(v) + return v + + +def register_common_arguments(parser: ArgumentParser): + class SortAction(Action): + def __init__(self, option_strings, dest, default=False, help=None): + super().__init__(option_strings=option_strings, dest=dest, nargs=0, default=default, help=help) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, True) + _ke_context.sort = True + + def set_dispatch(name): + if name in _ke_context.dispatchable: + dispatch = _ke_context.dispatchable[name] + _ke_context.dispatch = dispatch + return dispatch + + if name in _ke_context.instance_dispatchable: + msg = f"'{name}' needs an instance to dispatch, thus it is not dispatchable from commandline." + print(msg) + raise ValueError(msg) + + from difflib import SequenceMatcher as Matcher + + valid_names = list(_ke_context.dispatchable.keys()) + scored_names = list(reversed(sorted([(Matcher(None, name, a).ratio(), a) for a in valid_names]))) + top10 = "\n ".join([a for _, a in scored_names[:10]]) + msg = f"'{name}' is not registered for dispatch. Top 10 matches are:\n {top10}" + print(msg) + raise ValueError(msg) + + def set_pattern(pattern): + pattern = str(pattern) + _ke_context.pattern = pattern + + def set_save_tuning_results(path): + _ke_context.save_tuning_results = path + return path + + group = parser.add_argument_group("kernel explorer args", "Common arguments for kernel explorer") + group.add_argument( + "--sort", + action=SortAction, + help="control the sort of ke benchmark results based on timing", + ) + group.add_argument( + "--ort_default_logger_severity", + default=2, + choices=[0, 1, 2, 3, 4], + type=set_ort_severity, + help="0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal", + ) + group.add_argument("--ort_default_logger_verbosity", default=0, type=set_ort_verbosity) + group.add_argument( + "--dispatch", + default="profile_with_args", + help="dispatch a registered dispatchable.", + type=set_dispatch, + ) + group.add_argument( + "--pattern", + default="*", + help="filter the register instanced dispatchables, only matched pattern will be run.", + type=set_pattern, + ) + group.add_argument( + "--save_tuning_results", + default=None, + type=set_save_tuning_results, + help="patch the dispatch function to save tuning results to the specified path.", + ) + + return parser + + +def get_argument_parser(): + parser = ArgumentParser() + return register_common_arguments(parser) + + +def has_args(): + if "--help" in sys.argv or "-h" in sys.argv or "--func" in sys.argv: + return True + + # parse the KE args group + parser = get_argument_parser() + _, remainder = parser.parse_known_args(sys.argv) + return len(remainder) > 1 # the file path is always the remainder + + +def is_cuda_available(): + return _is_cuda_available + + +def is_rocm_available(): + return _is_rocm_available + + +def dispatchable(f: Callable | None = None, *, pattern_arg: int | None = None): + def wrap_dispatch(f, *args, **kwargs): + if _ke_context.dispatch_depth == 0: + if _ke_context.save_tuning_results is not None: + _kernel_explorer.enable_collect_tuning_results() + _ke_context.dispatch_depth += 1 + ret = f(*args, **kwargs) + _ke_context.dispatch_depth -= 1 + if _ke_context.dispatch_depth == 0: + if _ke_context.save_tuning_results is not None: + try: + trs = _kernel_explorer.get_collected_tuning_results() + with open(_ke_context.save_tuning_results, "x") as f: + json.dump(trs, f) + finally: + pass + + if _ke_context.return_tuning_results: + if ret is not None: + print( + f"WARNING: kernel explorer wants to override the return value of {f.__name__},", + "but original return value is not None!", + ) + return ret + try: + trs = _kernel_explorer.get_collected_tuning_results() + return trs + finally: + pass + + return ret + + if f is None: # Used with ke.dispatchable(...) + assert pattern_arg is not None + + def decorator(f): + _ke_context.instance_dispatchable[f.__name__] = f + + @wraps(f) + def wrapper(*args, **kwargs): + func_name = args[pattern_arg] if isinstance(args[pattern_arg], str) else args[pattern_arg].__name__ + if not fnmatch(func_name, _ke_context.pattern): + print( + f"Trying to run {func_name},", + f"does not match allowed function name pattern '{_ke_context.pattern}', skip...", + ) + return + return wrap_dispatch(f, *args, **kwargs) + + return wrapper + + return decorator + + else: # Used with @ke.dispatchable + _ke_context.dispatchable[f.__name__] = f + + @wraps(f) + def wrapper(*args, **kwargs): + return wrap_dispatch(f, *args, **kwargs) + + return wrapper + + +def set_dispatchable_pattern(p: str = "*"): + _ke_context.pattern = p + + +def set_return_tuning_results(b: bool = True): + _ke_context.return_tuning_results = b diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py index 9b8a261f728d2..df35fa4e6c411 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import sys from dataclasses import dataclass import kernel_explorer as ke @@ -50,6 +49,7 @@ def report(self): return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} is_symmetric={self.is_symmetric} {self.name}" +@ke.dispatchable(pattern_arg=4) def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): np.random.seed(0) output = np.random.rand(m, n).astype(dtype) @@ -76,6 +76,7 @@ def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric)) +@ke.dispatchable(pattern_arg=4) def profile_gemm_func(m, n, k, dtype, func): np.random.seed(0) output = np.random.rand(m, n).astype(dtype) @@ -93,8 +94,9 @@ def profile_gemm_func(m, n, k, dtype, func): ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) -def profile_with_args(m, n, k, dtype, sort): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(m, n, k, dtype): + with ke.benchmark(): for func in dtype_to_funcs(dtype): profile_matmul_fp_int4_func(m, n, k, dtype, func, True) @@ -117,23 +119,20 @@ def profile(): (11008, 4096), (2 * 11008, 4096), ): - profile_with_args(m, n, k, dt, False) + profile_with_args(m, n, k, dt) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("m", type=int) group.add_argument("n", type=int) group.add_argument("k", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.m, args.n, args.k, args.dtype, args.sort) + args.dispatch(args.m, args.n, args.k, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py index a31e8b851fa36..bfe13fac2a148 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import re -import sys from dataclasses import dataclass from itertools import product @@ -51,6 +50,7 @@ def simplified_skip_layer_norm(input_x, skip, bias, gamma, epsilon): return output, val +@ke.dispatchable(pattern_arg=4) def run_skip_layer_norm( batch_size: int, seq_len: int, hidden_size: int, dtype: str, func, simplified=False, has_optional_output=False ): @@ -130,6 +130,7 @@ def report(self): return "not supported " + common +@ke.dispatchable(pattern_arg=4) def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output): np.random.seed(0) input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) @@ -175,8 +176,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, ke.report(SkipLayerNormMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size)) -def profile_with_args(batch_size, seq_len, hidden_size, dtype, sort=True, has_optional_output=False, simplified=False): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(batch_size, seq_len, hidden_size, dtype, has_optional_output=False, simplified=False): + with ke.benchmark(): for func in dtype_to_funcs(dtype, simplified): profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func, has_optional_output) @@ -189,28 +191,24 @@ def profile(): if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("batch_size", type=int) group.add_argument("seq_len", type=int) group.add_argument("hidden_size", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") group.add_argument("--has_optional_output", "-o", action="store_true") group.add_argument("--simplified", "-s", action="store_true", default=False) - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args( + args.dispatch( args.batch_size, args.seq_len, args.hidden_size, args.dtype, - args.sort, args.has_optional_output, args.simplified, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py index c8de619fe96d1..3a7e4442108f5 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/softmax_test.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import re -import sys from dataclasses import dataclass from itertools import product @@ -57,6 +56,7 @@ def _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, func): @pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes()) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_softmax(batch_count, softmax_elements, is_log_softmax, dtype): for f in dtype_to_funcs(dtype): _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, f) @@ -64,6 +64,7 @@ def test_softmax(batch_count, softmax_elements, is_log_softmax, dtype): @pytest.mark.parametrize("batch_count, softmax_elements, is_log_softmax", get_test_sizes()) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_ck_softmax(batch_count, softmax_elements, is_log_softmax, dtype): ck_f_name = "CKSoftmax_" + dtype_to_suffix(dtype) _test_softmax(batch_count, softmax_elements, is_log_softmax, dtype, ck_f_name) @@ -82,6 +83,7 @@ def report(self): return "not supported " + common +@ke.dispatchable(pattern_arg=4) def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func): np.random.seed(0) x = np.random.rand(batch_count, softmax_elements).astype(dtype) @@ -104,8 +106,9 @@ def profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, f ke.report(SoftmaxMetric(impl, dtype, duration_ms, total_bytes, batch_count, softmax_elements, is_log_softmax)) -def profile_with_args(batch_count, softmax_elements, is_log_softmax, dtype, sort): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(batch_count, softmax_elements, is_log_softmax, dtype): + with ke.benchmark(): for func in dtype_to_funcs(dtype): profile_softmax_func(batch_count, softmax_elements, is_log_softmax, dtype, func) # ck function @@ -119,23 +122,20 @@ def profile_with_args(batch_count, softmax_elements, is_log_softmax, dtype, sort def profile(): for dtype in dtypes: for batch_count, softmax_elements in profile_size: - profile_with_args(batch_count, softmax_elements, False, dtype, True) + profile_with_args(batch_count, softmax_elements, False, dtype) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("batch_count", type=int) group.add_argument("softmax_elements", type=int) group.add_argument("is_log_softmax", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.batch_count, args.softmax_elements, args.is_log_softmax, args.dtype, args.sort) + args.dispatch(args.batch_count, args.softmax_elements, args.is_log_softmax, args.dtype) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py index b5504cbd4944d..b8c9c6f6a4ab6 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/strided_batched_gemm_test.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import os -import sys from dataclasses import dataclass from itertools import product @@ -23,6 +22,7 @@ def dtype_to_suffix(dtype): }[dtype] +@ke.dispatchable def _test_strided_batched_gemm( func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int, alpha=1.0, beta=0.0 ): @@ -101,6 +101,7 @@ def _test_strided_batched_gemm( @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch): wrapper_name = "RocBlasStridedBatchedGemm_" + dtype_to_suffix(dtype) _test_strided_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) @@ -111,6 +112,7 @@ def test_rocblas_gemm_all_cases(dtype, transa, transb, m, n, k, batch): @pytest.mark.parametrize("m, n, k", get_gemm_basic_sizes(full=False) + get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_ck_gemm_all_cases(dtype, transa, transb, m, n, k, batch): wrapper_name = f"CKStridedBatchedGemm_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" _test_strided_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) @@ -121,6 +123,7 @@ def test_ck_gemm_all_cases(dtype, transa, transb, m, n, k, batch): @pytest.mark.parametrize("m, n, k", get_gemm_bert_sizes(full=False)) @pytest.mark.parametrize("transa, transb", all_transabs) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_gemm_tunable_bert_cases(dtype, transa, transb, m, n, k, batch): wrapper_name = f"StridedBatchedGemmTunable_{dtype_to_suffix(dtype)}_{transab_to_suffix((transa, transb))}" _test_strided_batched_gemm(getattr(ke, wrapper_name), dtype, transa, transb, m, n, k, batch) @@ -177,6 +180,7 @@ def report(self): return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common +@ke.dispatchable(pattern_arg=0) def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int): a_shape = (k, m) if transa else (m, k) b_shape = (n, k) if transb else (k, n) @@ -209,7 +213,8 @@ def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, ke.report(StridedBatchedGemmMetric(impl, dtype, duration_ms, FLOPs, transa, transb, m, n, k, batch)) -def profile_with_args(dtype, transa, transb, m, n, k, batch, sort): +@ke.dispatchable +def profile_with_args(dtype, transa, transb, m, n, k, batch): dtype_suffix = "_" + dtype_to_suffix(dtype) transab_suffix = "_" + transab_to_suffix((transa, transb)) fn_rocblas = getattr(ke, "RocBlasStridedBatchedGemm" + dtype_suffix) @@ -217,7 +222,7 @@ def profile_with_args(dtype, transa, transb, m, n, k, batch, sort): fn_tunable = getattr(ke, "StridedBatchedGemmTunable" + dtype_suffix + transab_suffix) if ke.is_hipblaslt_available(): fn_hipblaslt = getattr(ke, "StridedBatchedGemmHipBlasLt" + dtype_suffix + transab_suffix) - with ke.benchmark(sort): + with ke.benchmark(): profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch) profile_gemm_func(fn_ck, dtype, transa, transb, m, n, k, batch) profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch) @@ -230,14 +235,12 @@ def profile(): for dtype in dtypes: for m, n, k in get_gemm_bert_sizes(full=False): for batch in [1, 32, 64]: - profile_with_args(dtype, False, False, m, n, k, batch, True) + profile_with_args(dtype, False, False, m, n, k, batch) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("dtype", choices=dtypes) group.add_argument("transa", choices="NT") group.add_argument("transb", choices="NT") @@ -245,12 +248,9 @@ def profile(): group.add_argument("n", type=int) group.add_argument("k", type=int) group.add_argument("batch", type=int) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args( - args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch, args.sort - ) + args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py index dcb51f1db145c..8edf55f68c11f 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/vector_add_test.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import sys from dataclasses import dataclass import kernel_explorer as ke @@ -43,6 +42,7 @@ def run_vector_add(size, dtype, func): @pytest.mark.parametrize("size", [1, 3, 4, 16, 124, 125, 126, 127, 128, 129, 130, 131, 132, 1024]) @pytest.mark.parametrize("dtype", dtypes) +@ke.dispatchable def test_vector_add(size, dtype): for dtype in dtypes: for f in dtype_to_funcs(dtype): @@ -57,6 +57,7 @@ def report(self): return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} size={self.size:<4} {self.name}" +@ke.dispatchable(pattern_arg=2) def profile_vector_add_func(size, dtype, func): np.random.seed(0) x = np.random.rand(size).astype(dtype) @@ -74,8 +75,9 @@ def profile_vector_add_func(size, dtype, func): ke.report(VectorAddMetric(func, dtype, duration_ms, total_bytes, size)) -def profile_with_args(size, dtype, sort): - with ke.benchmark(sort): +@ke.dispatchable +def profile_with_args(size, dtype): + with ke.benchmark(): for func in dtype_to_funcs(dtype): profile_vector_add_func(size, dtype, func) @@ -84,21 +86,18 @@ def profile(): sizes = [10000, 100000, 1000000, 10000000] for dt in dtypes: for s in sizes: - profile_with_args(s, dt, True) + profile_with_args(s, dt) print() if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - group = parser.add_argument_group("profile with args") + parser = ke.get_argument_parser() + group = parser.add_argument_group() group.add_argument("size", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--sort", action="store_true") - if len(sys.argv) == 1: + if not ke.has_args(): profile() else: args = parser.parse_args() - profile_with_args(args.size, args.dtype, args.sort) + args.dispatch(args.size, args.dtype)