-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[compile] introduce forge graph module
For the purpose of encapsulating multiple graphs for a given module, `ForgeGraphModule` class is introduced. The class for now simply stores pointers to graphs by the type of the graph. Graph type can be one of the following: `Forward`, `Backward`, `Loss`, `Optimizer`. When lowering to mlir, we will iterate through all the graphs inside the `ForgeGraphModule` and lower them as mlir functions. This is a first change in a set of changes needed for separating forward and backward graphs (for training) - Issue #100. Closes #222
- Loading branch information
Showing
7 changed files
with
152 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <array> | ||
#include <cstdint> | ||
#include <string> | ||
#include <type_traits> | ||
|
||
#include "utils/assert.hpp" | ||
|
||
namespace tt | ||
{ | ||
|
||
namespace graphlib | ||
{ | ||
class Graph; | ||
} | ||
|
||
enum class GraphType : std::uint8_t | ||
{ | ||
Forward = 0, | ||
Backward = 1, | ||
Loss = 2, | ||
Optimizer = 3, | ||
GraphTypeCount = 4, | ||
}; | ||
|
||
template <typename T> | ||
constexpr std::underlying_type_t<T> to_underlying(T e) noexcept | ||
{ | ||
return static_cast<std::underlying_type_t<T>>(e); | ||
} | ||
|
||
constexpr std::uint8_t GRAPH_TYPE_COUNT = to_underlying(GraphType::GraphTypeCount); | ||
using StaticGraphArray = std::array<graphlib::Graph*, GRAPH_TYPE_COUNT>; | ||
|
||
/** | ||
* @brief ForgeGraphModule is a container for all the graphs that are part of a module. | ||
* The graphs are stored in an array by their type (enum GraphType). | ||
* ForgeGraphModule is initialized with a Forward graph, | ||
* while the other graphs can be set later (if the module is compiled for training). | ||
*/ | ||
class ForgeGraphModule | ||
{ | ||
public: | ||
ForgeGraphModule(std::string name, graphlib::Graph* forward_graph) : name_(name), graphs_{nullptr} | ||
{ | ||
TT_ASSERT(forward_graph != nullptr); | ||
graphs_[to_underlying(GraphType::Forward)] = forward_graph; | ||
} | ||
|
||
void set_graph(GraphType type, graphlib::Graph* graph) | ||
{ | ||
TT_ASSERT(graph != nullptr); | ||
graphs_[to_underlying(type)] = graph; | ||
} | ||
|
||
graphlib::Graph* get_graph(GraphType type) const | ||
{ | ||
TT_ASSERT(graphs_[to_underlying(type)] != nullptr); | ||
return graphs_[to_underlying(type)]; | ||
} | ||
|
||
/** | ||
* @brief Get all existing graphs in the module. | ||
* @return A vector of pointers to the graphs. | ||
*/ | ||
std::vector<graphlib::Graph*> graphs() const | ||
{ | ||
std::vector<graphlib::Graph*> res; | ||
res.reserve(graphs_.size()); | ||
for (auto graph : graphs_) { | ||
if (graph != nullptr) { | ||
res.push_back(graph); | ||
} | ||
} | ||
return res; | ||
} | ||
|
||
std::string name() const { return name_; } | ||
|
||
private: | ||
std::string name_; | ||
|
||
// Static array of graphs, indexed by GraphType. | ||
StaticGraphArray graphs_; | ||
}; | ||
|
||
} // namespace tt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters