diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass/backend/evt/frontend/python_ast.py index faffce65df..3f33485456 100644 --- a/python/cutlass/backend/evt/frontend/python_ast.py +++ b/python/cutlass/backend/evt/frontend/python_ast.py @@ -70,6 +70,8 @@ def ast_op_to_bindings(op): ast.Sub: FunctionalOp.Minus, ast.Mult: FunctionalOp.Multiplies, ast.Div: FunctionalOp.Divides, + "maximum": FunctionalOp.Maximum, + "minimum": FunctionalOp.Minimum, "relu": relu.binding_type, "multiply_add": FunctionalOp.MultiplyAdd, "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd), diff --git a/python/cutlass/epilogue/__init__.py b/python/cutlass/epilogue/__init__.py index 2b22b5f582..423deccebc 100644 --- a/python/cutlass/epilogue/__init__.py +++ b/python/cutlass/epilogue/__init__.py @@ -49,5 +49,7 @@ multiply_add, sum, permute, - reshape + reshape, + maximum, + minimum, ) diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index a9b9b5bf4b..575767d03f 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -59,6 +59,17 @@ def max(x, dim): elif is_torch_tensor(x): return torch.amax(x, dim) +def maximum(x, y): + if is_numpy_tensor(x): + return np.maximum(x, y) + elif is_torch_tensor(x): + return torch.maximum(x, torch.tensor(y)) + +def minimum(x, y): + if is_numpy_tensor(x): + return np.minimum(x, y) + elif is_torch_tensor(x): + return torch.minimum(x, torch.tensor(y)) ############################################################################## # Layout manipulate nodes diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index 36cee7878c..3f9996cfcf 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -95,6 +95,29 @@ def evt_func_call(accum, C, alpha, beta, gamma): result_keys = ["D"] launcher.verify((m, n, k), input_keys, result_keys, l) + def test_func_call2(self): + """ + Test Function call + """ + + def evt_func_call2(accum, C, alpha, beta): + D = maximum(alpha * accum + beta * C, 0.0) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call2, example_inputs) + input_keys = ["C", "alpha", "beta"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + if __name__ == '__main__': unittest.main()