Skip to content

Commit

Permalink
Added support for named attrs in DecomposingContext. (#213)
Browse files Browse the repository at this point in the history
Added support for named attrs in DecomposingContext, since they are needed for lowering to MLIR.
  • Loading branch information
vladimirjovanovicTT authored Sep 4, 2024
1 parent ce8d1d5 commit fc211a7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
42 changes: 42 additions & 0 deletions forge/csrc/passes/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "graph_lib/defines.hpp"
#include "graph_lib/graph.hpp"
#include "graph_lib/node_types.hpp"
#include "lower_to_buda/common.hpp"
#include "shared_utils/sparse_matmul_utils.hpp"
#include "python_bindings_common.hpp"

Expand Down Expand Up @@ -208,6 +209,47 @@ void PassesModule(py::module &m_passes)
py::arg("dont_decompose") = false,
py::arg("optimize_hoist") = false,
py::arg("output_df") = DataFormat::Invalid)
.def(
"op_with_named_attrs",
[](tt::DecomposingContext &self,
std::variant<std::string, py::object> const &type,
std::vector<NodeContext> const &operands,
BudaOpAttrs const &named_attrs,
std::vector<graphlib::OpType::Attr> const &attrs = {},
bool copy_tms = true,
bool dont_decompose = false,
bool optimize_hoist = false,
DataFormat output_df = DataFormat::Invalid)
{
if (std::holds_alternative<std::string>(type))
{
TT_LOG_ASSERT(
not has_newstyle_interface(std::get<std::string>(type), false),
"Error decomposing a type with old OpType interface, expects new OpType interface {}",
std::get<std::string>(type));
return self.op(
graphlib::OpType(std::get<std::string>(type), attrs, {}, named_attrs),
operands,
copy_tms,
dont_decompose,
optimize_hoist,
output_df);
}
else
{
TT_ASSERT(attrs.size() == 0, "Illegal mixing of API modes");
auto const &op_type = std::get<py::object>(type).attr("op_type").cast<graphlib::OpType>();
return self.op(op_type, operands, copy_tms, dont_decompose, optimize_hoist, output_df);
}
},
py::arg("type"),
py::arg("operands"),
py::arg("named_attrs"),
py::arg("attrs") = std::vector<int>{},
py::arg("copy_tms") = true,
py::arg("dont_decompose") = false,
py::arg("optimize_hoist") = false,
py::arg("output_df") = DataFormat::Invalid)
.def("fuse", &tt::DecomposingContext::fuse, py::arg("operand"), py::arg("producer_output_port_id") = 0)
.def(
"tensor",
Expand Down
2 changes: 1 addition & 1 deletion forge/forge/op/eval/forge/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def decompose_post_autograd(op_type, attr, dc, inputs):
assert len(out_shape) > dim and dim >= -len(out_shape), "Given dimension is out of the shape"

grad_out = dc.op("multiply", (grad, output), ())
gout_sum = dc.op("reduce_sum", (grad_out, ), (dim, ))
gout_sum = dc.op_with_named_attrs("reduce_sum", (grad_out, ), {"keep_dim": True}, (dim, ))
gout_sub = dc.op("subtract", (grad, gout_sum), ())
result = dc.op("multiply", (gout_sub, output), ())
dc.fuse(result)
Expand Down

0 comments on commit fc211a7

Please sign in to comment.