Skip to content

Commit

Permalink
Replace __backward attribute by isForwardNode
Browse files Browse the repository at this point in the history
  • Loading branch information
mtavenrath committed Jan 25, 2024
1 parent 1f7eea4 commit a1a68ed
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 55 deletions.
3 changes: 0 additions & 3 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,4 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";

// For Priority based graph topology sorting.
constexpr const char* kBackwardNodeAttributeName = "__backwardpass";

} // namespace onnxruntime
12 changes: 7 additions & 5 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ class Node {
const NodeAttributes& GetAttributes() const noexcept { return attributes_; }

/** @returns true if the Node is a forward node, false otherwise. **/
bool isForwardNode() const noexcept { return isForwardNode_; }
bool isForwardNode() const noexcept { return is_forward_node_; }

/* Sets the forward node status */
void setForwardNode(bool is_forward_node) noexcept { is_forward_node_ = is_forward_node; }

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Remove the specified attribute from this Node */
Expand Down Expand Up @@ -464,7 +467,7 @@ class Node {
std::unordered_map<std::string, gsl::not_null<const Graph*>> GetAttributeNameToSubgraphMap() const;

/** Gets the execution ProviderType that this node will be executed by. */
ProviderType const& GetExecutionProviderType() const noexcept { return execution_provider_type_; }
ProviderType GetExecutionProviderType() const noexcept { return execution_provider_type_; }

/** Sets the execution ProviderType that this Node will be executed by. */
void SetExecutionProviderType(ProviderType execution_provider_type) {
Expand Down Expand Up @@ -633,9 +636,8 @@ class Node {
// Execution priority, lower value for higher priority
int priority_ = 0;

// True is Node is a forwardNode and thus doesn't contain a attribute
// named kBackwardNodeAttributeName. False otherwise.
bool isForwardNode_;
// This node is a forward node if value, otherwise it is a backward node.
bool is_forward_node_;

// set from op_->SinceVersion() or via deserialization when OpSchema is not available
int since_version_ = -1;
Expand Down
28 changes: 6 additions & 22 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ void Node::Init(std::string_view name,
gsl::span<NodeArg* const> output_args,
const NodeAttributes* attributes,
std::string_view domain) {
isForwardNode_ = true;
is_forward_node_ = true;
name_ = name;
op_type_ = op_type;
description_ = description;
Expand All @@ -871,12 +871,8 @@ void Node::Init(std::string_view name,
if (attributes) {
attributes_ = *attributes;

isForwardNode_ = true;
is_forward_node_ = true;
for (auto& name_to_attr : attributes_) {
if (!isForwardNode_ && name_to_attr.first == kBackwardNodeAttributeName) {
isForwardNode_ = false;
}

if (utils::HasGraph(name_to_attr.second)) {
#if !defined(ORT_MINIMAL_BUILD)
CreateSubgraph(name_to_attr.first);
Expand Down Expand Up @@ -920,9 +916,6 @@ void Node::CreateSubgraph(const std::string& attr_name) {
#endif // !defined(ORT_MINIMAL_BUILD)

void Node::AddAttributeProto(AttributeProto value) {
if (value.name() == kBackwardNodeAttributeName) {
isForwardNode_ = false;
}
utils::SetNodeAttribute(std::move(value), attributes_);
if (graph_) {
graph_->SetGraphResolveNeeded();
Expand Down Expand Up @@ -959,7 +952,6 @@ ADD_ATTR_IMPLS(TypeProto)
#undef ADD_ATTR_LIST_IMPL
#undef ADD_ATTR_IMPLS

// TODO why isn't attr_name a const&?
void Node::AddAttribute(std::string attr_name, GraphProto value) {
// Do not move attr_name as it is needed below
AttributeProto a = utils::MakeAttribute(attr_name, std::move(value));
Expand All @@ -975,11 +967,7 @@ void Node::AddAttribute(std::string attr_name, GraphProto value) {
bool Node::ClearAttribute(const std::string& attr_name) {
graph_->SetGraphResolveNeeded();
graph_->SetGraphProtoSyncNeeded();
size_t erased = attributes_.erase(attr_name);
if (erased && attr_name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
return erased > 0;
return attributes_.erase(attr_name) > 0;
}

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand All @@ -989,11 +977,7 @@ int Node::PruneRemovableAttributes(gsl::span<const std::string> removable_attrib
graph_->SetGraphProtoSyncNeeded();
int n_removed = 0;
for (const auto& name : removable_attributes) {
bool erased = attributes_.erase(name);
if (erased && name == kBackwardNodeAttributeName) {
isForwardNode_ = true;
}
n_removed += static_cast<int>(erased);
n_removed += static_cast<int>(attributes_.erase(name));
}
can_be_saved_ = can_be_saved_ && n_removed == 0;
return n_removed;
Expand Down Expand Up @@ -1813,7 +1797,7 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::vector<size_t> in_degree(MaxNodeIndex(), 0);
std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
std::vector<NodeIndex> topo_order;
InlinedVector<NodeIndex> topo_order;

for (auto& node : Nodes()) {
size_t input_edge_count = node.GetInputEdgesCount();
Expand Down Expand Up @@ -2034,7 +2018,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
}
}

std::vector<TypeProto> const& InferredOutputTypes() const { return node_output_types_; }
std::vector<TypeProto> const& InferredOutputTypes() const noexcept { return node_output_types_; }

const AttributeProto* getAttribute(const std::string& name) const override {
auto& attribute_value_map = node_.GetAttributes();
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@ struct PriorityNodeCompare {
}

// nodes of forward pass will be output first
auto const& n1_attrs = n1->GetAttributes();
auto const& n2_attrs = n2->GetAttributes();
int64_t n1_is_forward = static_cast<int64_t>(n1->isForwardNode()) ||
(n1_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
int64_t n2_is_forward = static_cast<int64_t>(n2->isForwardNode()) ||
(n2_attrs.at(kBackwardNodeAttributeName).i() + 1) % 2;
int64_t n1_is_forward = n1->isForwardNode();
int64_t n2_is_forward = n2->isForwardNode();
if (n1_is_forward != n2_is_forward) {
return n2_is_forward > n1_is_forward;
}
Expand Down
6 changes: 1 addition & 5 deletions onnxruntime/core/optimizer/matmul_scale_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,7 @@ Status ProcessNode(

matmul_scale_node.SetExecutionProviderType(node.GetExecutionProviderType());
#ifdef USE_ROCM
// forward the __backwardpass, if present
auto& attrs = node.GetAttributes();
if (attrs.count("__backwardpass")) {
matmul_scale_node.AddAttribute("__backwardpass", static_cast<int64_t>(attrs.at("__backwardpass").i()));
}
matmul_scale_node.setForwardNode(node.GetForwardNode());
#endif

{
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/core/optimizer/matmul_transpose_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
matmul_node.SetExecutionProviderType(node.GetExecutionProviderType());
#ifdef USE_ROCM
// forward the __backwardpass, if present
auto& attrs = node.GetAttributes();
if (attrs.count("__backwardpass")) {
matmul_node.AddAttribute("__backwardpass", static_cast<int64_t>(attrs.at("__backwardpass").i()));
}
malmul_node.setForwardPass(node.getForwardPass());
#endif

graph_utils::FinalizeNodeFusion(graph, matmul_node, node);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/rocm_blas_alt_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Status RocmBlasAltImpl::ApplyImpl(Graph& graph, bool& modified, int graph_level,
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

if (is_backward_pass) {
node.AddAttribute(std::string("__backwardpass"), static_cast<int64_t>(1));
node.setForwardNode(false);
modified = true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/rocm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RocmKernel : public OpKernel {

Status Compute(OpKernelContext* p_op_kernel_context) const override {
Status s;
auto is_backward_pass = Info().GetAttrOrDefault<int64_t>("__backwardpass", 0);
auto is_backward_pass = !Node().isForwardNode();
if (is_backward_pass) {
BackwardPassGuard guard;
s = ComputeInternal(p_op_kernel_context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,13 @@ Status ResetNodeBackwardPassAttribute(Graph& graph, bool& modified) {
// Set the attribute to true for all backward nodes.
for (auto& node : graph.Nodes()) {
if (std::find(fw_nodes.begin(), fw_nodes.end(), &node) == fw_nodes.end()) {
auto& attrs = node.GetAttributes();
if (attrs.count(kBackwardNodeAttributeName)) {
continue;
if (node.isForwardNode()) {
node.setForwardNode(false);
modified = true;
}
node.AddAttribute(kBackwardNodeAttributeName, static_cast<int64_t>(1));
modified = true;
} else {
auto& attrs = node.GetAttributes();
if (attrs.count(kBackwardNodeAttributeName)) {
node.ClearAttribute(kBackwardNodeAttributeName);
if (!node.isForwardNode()) {
node.setForwardNode(true);
modified = true;
}
}
Expand Down

0 comments on commit a1a68ed

Please sign in to comment.