From fc211a70a8c0f929a491123f3e3031b2e58fe8d4 Mon Sep 17 00:00:00 2001 From: Vladimir Jovanovic Date: Wed, 4 Sep 2024 12:27:01 +0200 Subject: [PATCH] Added support for named attrs in DecomposingContext. (#213) Added support for named attrs in DecomposingContext, since they are needed for lowering to MLIR. --- forge/csrc/passes/python_bindings.cpp | 42 +++++++++++++++++++++++++++ forge/forge/op/eval/forge/nn.py | 2 +- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/forge/csrc/passes/python_bindings.cpp b/forge/csrc/passes/python_bindings.cpp index 197e21cb..102d2af3 100644 --- a/forge/csrc/passes/python_bindings.cpp +++ b/forge/csrc/passes/python_bindings.cpp @@ -4,6 +4,7 @@ #include "graph_lib/defines.hpp" #include "graph_lib/graph.hpp" #include "graph_lib/node_types.hpp" +#include "lower_to_buda/common.hpp" #include "shared_utils/sparse_matmul_utils.hpp" #include "python_bindings_common.hpp" @@ -208,6 +209,47 @@ void PassesModule(py::module &m_passes) py::arg("dont_decompose") = false, py::arg("optimize_hoist") = false, py::arg("output_df") = DataFormat::Invalid) + .def( + "op_with_named_attrs", + [](tt::DecomposingContext &self, + std::variant const &type, + std::vector const &operands, + BudaOpAttrs const &named_attrs, + std::vector const &attrs = {}, + bool copy_tms = true, + bool dont_decompose = false, + bool optimize_hoist = false, + DataFormat output_df = DataFormat::Invalid) + { + if (std::holds_alternative(type)) + { + TT_LOG_ASSERT( + not has_newstyle_interface(std::get(type), false), + "Error decomposing a type with old OpType interface, expects new OpType interface {}", + std::get(type)); + return self.op( + graphlib::OpType(std::get(type), attrs, {}, named_attrs), + operands, + copy_tms, + dont_decompose, + optimize_hoist, + output_df); + } + else + { + TT_ASSERT(attrs.size() == 0, "Illegal mixing of API modes"); + auto const &op_type = std::get(type).attr("op_type").cast(); + return self.op(op_type, operands, copy_tms, dont_decompose, optimize_hoist, output_df); + } + }, + py::arg("type"), + py::arg("operands"), + py::arg("named_attrs"), + py::arg("attrs") = std::vector{}, + py::arg("copy_tms") = true, + py::arg("dont_decompose") = false, + py::arg("optimize_hoist") = false, + py::arg("output_df") = DataFormat::Invalid) .def("fuse", &tt::DecomposingContext::fuse, py::arg("operand"), py::arg("producer_output_port_id") = 0) .def( "tensor", diff --git a/forge/forge/op/eval/forge/nn.py b/forge/forge/op/eval/forge/nn.py index e7cda4d4..ef223634 100644 --- a/forge/forge/op/eval/forge/nn.py +++ b/forge/forge/op/eval/forge/nn.py @@ -476,7 +476,7 @@ def decompose_post_autograd(op_type, attr, dc, inputs): assert len(out_shape) > dim and dim >= -len(out_shape), "Given dimension is out of the shape" grad_out = dc.op("multiply", (grad, output), ()) - gout_sum = dc.op("reduce_sum", (grad_out, ), (dim, )) + gout_sum = dc.op_with_named_attrs("reduce_sum", (grad_out, ), {"keep_dim": True}, (dim, )) gout_sub = dc.op("subtract", (grad, gout_sum), ()) result = dc.op("multiply", (gout_sub, output), ()) dc.fuse(result)