diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index 199608d64..23b5c65c3 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -184,6 +184,24 @@ def triton_ops_matmul(self, a, b, bias) -> Callable: return lambda: kernels.matmul(a, b) return lambda: kernels.matmul(a, b) + bias + @register_benchmark(enabled=False, ci=False) + def triton_ops_with_tf32x23(self, a, b, bias) -> Callable: + input_precision = None + if self.dtype == torch.float32: + input_precision = "tf32x3" + if bias is None: + return lambda: kernels.matmul(a, b, None, input_precision) + return lambda: kernels.matmul(a, b, None, input_precision) + bias + + @register_benchmark(enabled=False, ci=False) + def triton_ops_with_fp32_strict(self, a, b, bias) -> Callable: + input_precision = None + if self.dtype == torch.float32: + input_precision = "ieee" + if bias is None: + return lambda: kernels.matmul(a, b, self.dtype, input_precision) + return lambda: kernels.matmul(a, b, self.dtype, input_precision) + bias + @register_benchmark(baseline=True) def aten_matmul(self, a, b, bias) -> Callable: if not bias == None: