diff --git a/forge/csrc/forge_bindings.cpp b/forge/csrc/forge_bindings.cpp index 84046a5f5..e3e8f3703 100644 --- a/forge/csrc/forge_bindings.cpp +++ b/forge/csrc/forge_bindings.cpp @@ -17,6 +17,7 @@ namespace py = pybind11; #include "autograd/python_bindings.hpp" #include "backend_api/device_config.hpp" #include "forge_passes.hpp" +#include "forge_graph_module.hpp" #include "graph_lib/graph.hpp" #include "graph_lib/python_bindings.hpp" #include "lower_to_forge/common.hpp" @@ -111,6 +112,16 @@ PYBIND11_MODULE(_C, m) { py::module_ m_graph = m.def_submodule("graph", "Submodule defining forge graph functions"); GraphModule(m_graph); + py::enum_(m, "GraphType") + .value("Forward", tt::GraphType::Forward) + .value("Backward", tt::GraphType::Backward) + .value("Optimizer", tt::GraphType::Optimizer) + .export_values(); + + py::class_(m, "ForgeGraphModule") + .def(py::init(), py::arg("name"), py::arg("forward_graph")) + .def("set_graph", &tt::ForgeGraphModule::set_graph); + py::module_ m_autograd = m.def_submodule("autograd", "Submodule defining autograd_engine."); AutogradModule(m_autograd); diff --git a/forge/csrc/forge_graph_module.hpp b/forge/csrc/forge_graph_module.hpp new file mode 100644 index 000000000..45ccfda5b --- /dev/null +++ b/forge/csrc/forge_graph_module.hpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "utils/assert.hpp" + +namespace tt +{ + +namespace graphlib +{ + class Graph; +} + +enum class GraphType : std::uint8_t +{ + Forward = 0, + Backward = 1, + Loss = 2, + Optimizer = 3, + GraphTypeCount = 4, +}; + +template +constexpr std::underlying_type_t to_underlying(T e) noexcept +{ + return static_cast>(e); +} + +constexpr std::uint8_t GRAPH_TYPE_COUNT = to_underlying(GraphType::GraphTypeCount); +using StaticGraphArray = std::array; + +/** + * @brief ForgeGraphModule is a container for all the graphs that are part of a module. + * The graphs are stored in an array by their type (enum GraphType). + * ForgeGraphModule is initialized with a Forward graph, + * while the other graphs can be set later (if the module is compiled for training). + */ +class ForgeGraphModule +{ +public: + ForgeGraphModule(std::string name, graphlib::Graph* forward_graph) : name_(name), graphs_{nullptr} + { + TT_ASSERT(forward_graph != nullptr); + graphs_[to_underlying(GraphType::Forward)] = forward_graph; + } + + void set_graph(GraphType type, graphlib::Graph* graph) + { + TT_ASSERT(graph != nullptr); + graphs_[to_underlying(type)] = graph; + } + + graphlib::Graph* get_graph(GraphType type) const + { + TT_ASSERT(graphs_[to_underlying(type)] != nullptr); + return graphs_[to_underlying(type)]; + } + + /** + * @brief Get all existing graphs in the module. + * @return A vector of pointers to the graphs. + */ + std::vector graphs() const + { + std::vector res; + res.reserve(graphs_.size()); + for (auto graph : graphs_) { + if (graph != nullptr) { + res.push_back(graph); + } + } + return res; + } + + std::string name() const { return name_; } + +private: + std::string name_; + + // Static array of graphs, indexed by GraphType. + StaticGraphArray graphs_; +}; + +} // namespace tt diff --git a/forge/csrc/passes/lower_to_mlir.cpp b/forge/csrc/passes/lower_to_mlir.cpp index 82f816c64..469bdc4bc 100644 --- a/forge/csrc/passes/lower_to_mlir.cpp +++ b/forge/csrc/passes/lower_to_mlir.cpp @@ -9,6 +9,7 @@ #include // TTForge headers +#include "forge_graph_module.hpp" #include "graph_lib/graph.hpp" #include "graph_lib/node.hpp" #include "graph_lib/utils.hpp" @@ -43,7 +44,7 @@ namespace { using namespace tt; /** - * @brief Implementation of TT-MLIR emission from the TTForge graph. + * @brief Implementation of TT-MLIR emission from the Forge module (set of graphs). */ class MLIRGenerator @@ -55,27 +56,32 @@ class MLIRGenerator init_lowering_handler_map(); } - /// Public API: Convert the TTForge graph into an MLIR module operation for TTIR. - mlir::ModuleOp emit_mlir(graphlib::Graph *graph) + /// Public API: Convert the ForgeGraphModule into an MLIR module operation for TTIR. + mlir::ModuleOp emit_mlir(tt::ForgeGraphModule& module) { - graphModule_ = mlir::ModuleOp::create(get_module_location(graph), "tt-forge-graph"); + graphModule_ = mlir::ModuleOp::create(get_module_location(module), module.name()); graphModule_->setAttr(mlir::tt::SystemDescAttr::name, mlir::tt::SystemDescAttr::getDefault(builder_.getContext())); builder_.setInsertionPointToStart(&graphModule_.getBodyRegion().front()); + // Emit MLIR functions for each graph in the module. + for (auto graph : module.graphs()) { - auto traversal_context = graphlib::get_subgraph_traversal_context(graph); - emit_mlir_function(graph); - } + // Currently there is only one graph in the ForgeGraphModule. This will change after completion of issue #100. + // For now, we keep the hack for splitting the single graph into forward and backward subgraphs. + TT_ASSERT(module.graphs().size() == 1, "Expected only one graph in ForgeGraphModule"); - if (graph->training()) - { - auto traversal_context = graphlib::get_subgraph_traversal_context(graph); - emit_mlir_function(graph, "backward"); - } + { + auto traversal_context = graphlib::get_subgraph_traversal_context(graph); + emit_mlir_function(graph); + } - log_info(LogMLIRCompiler, "MLIR module generated successfully."); - graphModule_.dump(); + if (graph->training()) + { + auto traversal_context = graphlib::get_subgraph_traversal_context(graph); + emit_mlir_function(graph, "backward"); + } + } /// Verify the module after we have finished constructing it, this will check /// the structural properties of the IR and invoke any specific verifiers we @@ -86,6 +92,9 @@ class MLIRGenerator throw std::runtime_error("Generated MLIR module failed verification."); } + log_info(LogMLIRCompiler, "MLIR module generated successfully."); + graphModule_.dump(); + #ifdef DEBUG // Create a string to store the output std::string moduleStr; @@ -98,14 +107,15 @@ class MLIRGenerator rso.flush(); - log_trace(LogMLIRCompiler, "MLIR module after lowering TT-Forge graph:\n{}", moduleStr); + log_trace(LogMLIRCompiler, "MLIR module after lowering ForgeGraphModule:\n{}", moduleStr); #endif return graphModule_; } private: - /// A "module" matches a TTForge graph: containing a single function to exectue. + /// A "module" matches the set of graphs contained in ForgeGraphModule. + /// Where each graph will lower into a separate MLIR function inside the module. mlir::ModuleOp graphModule_; /// The builder is a helper class to create IR. The builder @@ -166,6 +176,8 @@ class MLIRGenerator /// A function represents a set of TTForge operations that are executed to produce output results. /// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function. mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph, std::string fn_name = "forward") { + + log_info("Emmiting mlir for function {}", fn_name); // Assemble the function arguments (inputs and parameters) llvm::SmallVector argument_types; llvm::SmallVector argument_nodes; @@ -489,10 +501,10 @@ class MLIRGenerator } /// Get the location for a module. - mlir::Location get_module_location(tt::graphlib::Graph *graph) + mlir::Location get_module_location(tt::ForgeGraphModule& module) { return mlir::FileLineColLoc::get( - builder_.getContext(), graph->name(), graph->id(), 0); + builder_.getContext(), module.name(), 0, 0); } /// Get the simple location for a node in a format "graph_name", (graph_id), (node_id) @@ -545,9 +557,9 @@ class MLIRGenerator } namespace tt::passes { - /// Public API for generating MLIR from the TTForge graph. - mlir::OwningOpRef lower_to_mlir(graphlib::Graph * graph, mlir::MLIRContext& context) + /// Public API for generating MLIR from the Forge module (set of graphs). + mlir::OwningOpRef lower_to_mlir(tt::ForgeGraphModule& module, mlir::MLIRContext& context) { - return MLIRGenerator(context).emit_mlir(graph); + return MLIRGenerator(context).emit_mlir(module); } } diff --git a/forge/csrc/passes/lower_to_mlir.hpp b/forge/csrc/passes/lower_to_mlir.hpp index 25f7fe2ad..2513515db 100644 --- a/forge/csrc/passes/lower_to_mlir.hpp +++ b/forge/csrc/passes/lower_to_mlir.hpp @@ -2,9 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 #pragma once -namespace tt::graphlib + +namespace tt { -class Graph; +class ForgeGraphModule; } namespace mlir { @@ -15,7 +16,7 @@ namespace mlir { namespace tt::passes { - // Public API for generating MLIR from the TT-Forge graph. - mlir::OwningOpRef lower_to_mlir(tt::graphlib::Graph * graph, mlir::MLIRContext& context); + // Public API for generating MLIR from a Forge module (set of graphs). + mlir::OwningOpRef lower_to_mlir(tt::ForgeGraphModule& module, mlir::MLIRContext& context); } // namespace tt:passes diff --git a/forge/csrc/passes/mlir_compiler.cpp b/forge/csrc/passes/mlir_compiler.cpp index cb66b985d..664050ddb 100644 --- a/forge/csrc/passes/mlir_compiler.cpp +++ b/forge/csrc/passes/mlir_compiler.cpp @@ -35,7 +35,7 @@ namespace tt::passes { /// Public API for lowering to MLIR, running MLIR passes and generate runtime binary. - runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph) + runtime::Binary run_mlir_compiler(tt::ForgeGraphModule& module) { // Register all the required dialects. mlir::DialectRegistry registry; @@ -58,9 +58,7 @@ namespace tt::passes context.loadAllAvailableDialects(); // Generate MLIR from the Forge graph. - mlir::OwningOpRef mlir_module = lower_to_mlir(graph, context); - - tt::log_info(LogMLIRCompiler, "MLIR module generated successfully."); + mlir::OwningOpRef mlir_module = lower_to_mlir(module, context); // Run MLIR registered passes. run_mlir_passes(mlir_module); diff --git a/forge/csrc/passes/mlir_compiler.hpp b/forge/csrc/passes/mlir_compiler.hpp index eed44b24a..8c2d81d4c 100644 --- a/forge/csrc/passes/mlir_compiler.hpp +++ b/forge/csrc/passes/mlir_compiler.hpp @@ -2,20 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 #pragma once -#include #include "tt/runtime/types.h" namespace tt { - namespace graphlib - { - class Graph; - } + class ForgeGraphModule; } namespace tt::passes { /// Public API for running MLIR passes and generating binary. - runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph); + runtime::Binary run_mlir_compiler(tt::ForgeGraphModule& module); } diff --git a/forge/forge/compile.py b/forge/forge/compile.py index b29c2f485..8ddfa6638 100644 --- a/forge/forge/compile.py +++ b/forge/forge/compile.py @@ -27,6 +27,7 @@ run_pre_lowering_passes, dump_graph, ) +from forge._C import ForgeGraphModule import forge._C.autograd as pyautograd import forge._C.graph as pygraph from forge._C.graph import Graph @@ -129,6 +130,7 @@ class CompileContext: in_recompile: bool = False recompile_count: int = 0 target_cycles_offset: int = 0 + forge_module: Optional[ForgeGraphModule] = None compiled_binary: Optional[Binary] = None def calculate_grads( @@ -609,6 +611,9 @@ def generate_initial_graph(context: CompileContext) -> CompileDepth: for name, value in module.named_parameters(): context.parameter_dict[name] = value + forge_module = ForgeGraphModule(context.graph_name, context.graph) + context.forge_module = forge_module + return CompileDepth.POST_INITIAL_GRAPH_PASS def run_post_initial_graph_pass(context: CompileContext) -> CompileDepth: @@ -821,9 +826,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth: return CompileDepth.RUN_MLIR_COMPILER def run_mlir_compiler(context: CompileContext) -> CompileDepth: - graph = context.graph + assert context.forge_module is not None - context.compiled_binary = forge._C.run_mlir_compiler(graph) + context.compiled_binary = forge._C.run_mlir_compiler(context.forge_module) return CompileDepth.FINISH_COMPILE