Skip to content

Commit

Permalink
[compile] introduce forge graph module
Browse files Browse the repository at this point in the history
For the purpose of encapsulating multiple graphs for a given module,
`ForgeGraphModule` class is introduced.

The class for now simply stores pointers to graphs by the type of the
graph. Graph type can be one of the following: `Forward`, `Backward`,
`Loss`, `Optimizer`.

When lowering to mlir, we will iterate through all the graphs inside
the `ForgeGraphModule` and lower them as mlir functions.

This is a first change in a set of changes needed for separating
forward and backward graphs (for training) - Issue #100.

Closes #222
  • Loading branch information
pilkicTT committed Sep 9, 2024
1 parent 66ace96 commit e7d8eba
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 37 deletions.
11 changes: 11 additions & 0 deletions forge/csrc/forge_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace py = pybind11;
#include "autograd/python_bindings.hpp"
#include "backend_api/device_config.hpp"
#include "forge_passes.hpp"
#include "forge_graph_module.hpp"
#include "graph_lib/graph.hpp"
#include "graph_lib/python_bindings.hpp"
#include "lower_to_forge/common.hpp"
Expand Down Expand Up @@ -111,6 +112,16 @@ PYBIND11_MODULE(_C, m) {
py::module_ m_graph = m.def_submodule("graph", "Submodule defining forge graph functions");
GraphModule(m_graph);

py::enum_<tt::GraphType>(m, "GraphType")
.value("Forward", tt::GraphType::Forward)
.value("Backward", tt::GraphType::Backward)
.value("Optimizer", tt::GraphType::Optimizer)
.export_values();

py::class_<tt::ForgeGraphModule>(m, "ForgeGraphModule")
.def(py::init<std::string, tt::graphlib::Graph *>(), py::arg("name"), py::arg("forward_graph"))
.def("set_graph", &tt::ForgeGraphModule::set_graph);

py::module_ m_autograd = m.def_submodule("autograd", "Submodule defining autograd_engine.");
AutogradModule(m_autograd);

Expand Down
92 changes: 92 additions & 0 deletions forge/csrc/forge_graph_module.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <array>
#include <cstdint>
#include <string>
#include <type_traits>

#include "utils/assert.hpp"

namespace tt
{

namespace graphlib
{
class Graph;
}

enum class GraphType : std::uint8_t
{
Forward = 0,
Backward = 1,
Loss = 2,
Optimizer = 3,
GraphTypeCount = 4,
};

template <typename T>
constexpr std::underlying_type_t<T> to_underlying(T e) noexcept
{
return static_cast<std::underlying_type_t<T>>(e);
}

constexpr std::uint8_t GRAPH_TYPE_COUNT = to_underlying(GraphType::GraphTypeCount);
using StaticGraphArray = std::array<graphlib::Graph*, GRAPH_TYPE_COUNT>;

/**
* @brief ForgeGraphModule is a container for all the graphs that are part of a module.
* The graphs are stored in an array by their type (enum GraphType).
* ForgeGraphModule is initialized with a Forward graph,
* while the other graphs can be set later (if the module is compiled for training).
*/
class ForgeGraphModule
{
public:
ForgeGraphModule(std::string name, graphlib::Graph* forward_graph) : name_(name), graphs_{nullptr}
{
TT_ASSERT(forward_graph != nullptr);
graphs_[to_underlying(GraphType::Forward)] = forward_graph;
}

void set_graph(GraphType type, graphlib::Graph* graph)
{
TT_ASSERT(graph != nullptr);
graphs_[to_underlying(type)] = graph;
}

graphlib::Graph* get_graph(GraphType type) const
{
TT_ASSERT(graphs_[to_underlying(type)] != nullptr);
return graphs_[to_underlying(type)];
}

/**
* @brief Get all existing graphs in the module.
* @return A vector of pointers to the graphs.
*/
std::vector<graphlib::Graph*> graphs() const
{
std::vector<graphlib::Graph*> res;
res.reserve(graphs_.size());
for (auto graph : graphs_) {
if (graph != nullptr) {
res.push_back(graph);
}
}
return res;
}

std::string name() const { return name_; }

private:
std::string name_;

// Static array of graphs, indexed by GraphType.
StaticGraphArray graphs_;
};

} // namespace tt
54 changes: 33 additions & 21 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <string>

// TTForge headers
#include "forge_graph_module.hpp"
#include "graph_lib/graph.hpp"
#include "graph_lib/node.hpp"
#include "graph_lib/utils.hpp"
Expand Down Expand Up @@ -43,7 +44,7 @@ namespace
{
using namespace tt;
/**
* @brief Implementation of TT-MLIR emission from the TTForge graph.
* @brief Implementation of TT-MLIR emission from the Forge module (set of graphs).
*/

class MLIRGenerator
Expand All @@ -55,27 +56,32 @@ class MLIRGenerator
init_lowering_handler_map();
}

/// Public API: Convert the TTForge graph into an MLIR module operation for TTIR.
mlir::ModuleOp emit_mlir(graphlib::Graph *graph)
/// Public API: Convert the ForgeGraphModule into an MLIR module operation for TTIR.
mlir::ModuleOp emit_mlir(tt::ForgeGraphModule& module)
{
graphModule_ = mlir::ModuleOp::create(get_module_location(graph), "tt-forge-graph");
graphModule_ = mlir::ModuleOp::create(get_module_location(module), module.name());
graphModule_->setAttr(mlir::tt::SystemDescAttr::name,
mlir::tt::SystemDescAttr::getDefault(builder_.getContext()));
builder_.setInsertionPointToStart(&graphModule_.getBodyRegion().front());

// Emit MLIR functions for each graph in the module.
for (auto graph : module.graphs())
{
auto traversal_context = graphlib::get_subgraph_traversal_context<graphlib::SubgraphType::Forward>(graph);
emit_mlir_function(graph);
}
// Currently there is only one graph in the ForgeGraphModule. This will change after completion of issue #100.
// For now, we keep the hack for splitting the single graph into forward and backward subgraphs.
TT_ASSERT(module.graphs().size() == 1, "Expected only one graph in ForgeGraphModule");

if (graph->training())
{
auto traversal_context = graphlib::get_subgraph_traversal_context<graphlib::SubgraphType::Backward>(graph);
emit_mlir_function(graph, "backward");
}
{
auto traversal_context = graphlib::get_subgraph_traversal_context<graphlib::SubgraphType::Forward>(graph);
emit_mlir_function(graph);
}

log_info(LogMLIRCompiler, "MLIR module generated successfully.");
graphModule_.dump();
if (graph->training())
{
auto traversal_context = graphlib::get_subgraph_traversal_context<graphlib::SubgraphType::Backward>(graph);
emit_mlir_function(graph, "backward");
}
}

/// Verify the module after we have finished constructing it, this will check
/// the structural properties of the IR and invoke any specific verifiers we
Expand All @@ -86,6 +92,9 @@ class MLIRGenerator
throw std::runtime_error("Generated MLIR module failed verification.");
}

log_info(LogMLIRCompiler, "MLIR module generated successfully.");
graphModule_.dump();

#ifdef DEBUG
// Create a string to store the output
std::string moduleStr;
Expand All @@ -98,14 +107,15 @@ class MLIRGenerator

rso.flush();

log_trace(LogMLIRCompiler, "MLIR module after lowering TT-Forge graph:\n{}", moduleStr);
log_trace(LogMLIRCompiler, "MLIR module after lowering ForgeGraphModule:\n{}", moduleStr);
#endif

return graphModule_;
}

private:
/// A "module" matches a TTForge graph: containing a single function to exectue.
/// A "module" matches the set of graphs contained in ForgeGraphModule.
/// Where each graph will lower into a separate MLIR function inside the module.
mlir::ModuleOp graphModule_;

/// The builder is a helper class to create IR. The builder
Expand Down Expand Up @@ -166,6 +176,8 @@ class MLIRGenerator
/// A function represents a set of TTForge operations that are executed to produce output results.
/// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function.
mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph, std::string fn_name = "forward") {

log_info("Emmiting mlir for function {}", fn_name);
// Assemble the function arguments (inputs and parameters)
llvm::SmallVector<mlir::Type> argument_types;
llvm::SmallVector<graphlib::Node *> argument_nodes;
Expand Down Expand Up @@ -489,10 +501,10 @@ class MLIRGenerator
}

/// Get the location for a module.
mlir::Location get_module_location(tt::graphlib::Graph *graph)
mlir::Location get_module_location(tt::ForgeGraphModule& module)
{
return mlir::FileLineColLoc::get(
builder_.getContext(), graph->name(), graph->id(), 0);
builder_.getContext(), module.name(), 0, 0);
}

/// Get the simple location for a node in a format "graph_name", (graph_id), (node_id)
Expand Down Expand Up @@ -545,9 +557,9 @@ class MLIRGenerator
}
namespace tt::passes
{
/// Public API for generating MLIR from the TTForge graph.
mlir::OwningOpRef<mlir::ModuleOp> lower_to_mlir(graphlib::Graph * graph, mlir::MLIRContext& context)
/// Public API for generating MLIR from the Forge module (set of graphs).
mlir::OwningOpRef<mlir::ModuleOp> lower_to_mlir(tt::ForgeGraphModule& module, mlir::MLIRContext& context)
{
return MLIRGenerator(context).emit_mlir(graph);
return MLIRGenerator(context).emit_mlir(module);
}
}
9 changes: 5 additions & 4 deletions forge/csrc/passes/lower_to_mlir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
//
// SPDX-License-Identifier: Apache-2.0
#pragma once
namespace tt::graphlib

namespace tt
{
class Graph;
class ForgeGraphModule;
}

namespace mlir {
Expand All @@ -15,7 +16,7 @@ namespace mlir {

namespace tt::passes
{
// Public API for generating MLIR from the TT-Forge graph.
mlir::OwningOpRef<mlir::ModuleOp> lower_to_mlir(tt::graphlib::Graph * graph, mlir::MLIRContext& context);
// Public API for generating MLIR from a Forge module (set of graphs).
mlir::OwningOpRef<mlir::ModuleOp> lower_to_mlir(tt::ForgeGraphModule& module, mlir::MLIRContext& context);
} // namespace tt:passes

6 changes: 2 additions & 4 deletions forge/csrc/passes/mlir_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
namespace tt::passes
{
/// Public API for lowering to MLIR, running MLIR passes and generate runtime binary.
runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph)
runtime::Binary run_mlir_compiler(tt::ForgeGraphModule& module)
{
// Register all the required dialects.
mlir::DialectRegistry registry;
Expand All @@ -58,9 +58,7 @@ namespace tt::passes
context.loadAllAvailableDialects();

// Generate MLIR from the Forge graph.
mlir::OwningOpRef<mlir::ModuleOp> mlir_module = lower_to_mlir(graph, context);

tt::log_info(LogMLIRCompiler, "MLIR module generated successfully.");
mlir::OwningOpRef<mlir::ModuleOp> mlir_module = lower_to_mlir(module, context);

// Run MLIR registered passes.
run_mlir_passes(mlir_module);
Expand Down
8 changes: 2 additions & 6 deletions forge/csrc/passes/mlir_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
//
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <memory>

#include "tt/runtime/types.h"

namespace tt
{
namespace graphlib
{
class Graph;
}
class ForgeGraphModule;
}

namespace tt::passes
{
/// Public API for running MLIR passes and generating binary.
runtime::Binary run_mlir_compiler(tt::graphlib::Graph *graph);
runtime::Binary run_mlir_compiler(tt::ForgeGraphModule& module);
}
9 changes: 7 additions & 2 deletions forge/forge/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
run_pre_lowering_passes,
dump_graph,
)
from forge._C import ForgeGraphModule
import forge._C.autograd as pyautograd
import forge._C.graph as pygraph
from forge._C.graph import Graph
Expand Down Expand Up @@ -129,6 +130,7 @@ class CompileContext:
in_recompile: bool = False
recompile_count: int = 0
target_cycles_offset: int = 0
forge_module: Optional[ForgeGraphModule] = None
compiled_binary: Optional[Binary] = None

def calculate_grads(
Expand Down Expand Up @@ -609,6 +611,9 @@ def generate_initial_graph(context: CompileContext) -> CompileDepth:
for name, value in module.named_parameters():
context.parameter_dict[name] = value

forge_module = ForgeGraphModule(context.graph_name, context.graph)
context.forge_module = forge_module

return CompileDepth.POST_INITIAL_GRAPH_PASS

def run_post_initial_graph_pass(context: CompileContext) -> CompileDepth:
Expand Down Expand Up @@ -821,9 +826,9 @@ def run_pre_lowering_pass(context: CompileContext) -> CompileDepth:
return CompileDepth.RUN_MLIR_COMPILER

def run_mlir_compiler(context: CompileContext) -> CompileDepth:
graph = context.graph
assert context.forge_module is not None

context.compiled_binary = forge._C.run_mlir_compiler(graph)
context.compiled_binary = forge._C.run_mlir_compiler(context.forge_module)

return CompileDepth.FINISH_COMPILE

Expand Down

0 comments on commit e7d8eba

Please sign in to comment.