Skip to content

Commit

Permalink
Done graph builder support for interference edges
Browse files Browse the repository at this point in the history
  • Loading branch information
9Tempest committed Dec 8, 2023
1 parent a51a79f commit 9514b4c
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 28 deletions.
4 changes: 2 additions & 2 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ class InstructionOperand {
// Returns the list of tokens representing this instruction.
std::vector<std::string> AsTokenList() const;

std::vector<std::string> getInterferedRegisters() const {
const std::vector<std::string>& getInterferedRegisters() const {
assert(type_ == OperandType::kVirtualRegister);
assert(interfered_registers_.size() == interfered_registers_size_.size());
return interfered_registers_;
}

std::vector<int> getInterferedRegistersSize() const {
const std::vector<int>& getInterferedRegistersSize() const {
assert(type_ == OperandType::kVirtualRegister);
assert(interfered_registers_.size() == interfered_registers_size_.size());
return interfered_registers_size_;
Expand Down
180 changes: 155 additions & 25 deletions gematria/granite/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
#include "gematria/basic_block/basic_block.h"
#include "gematria/model/oov_token_behavior.h"

#define DEBUG

#ifdef DEBUG
#define LOG(X) \
std::cerr << X << "\n"
#define LOG(X) std::cerr << X << "\n"
#else
#define LOG(X)
#endif

#define IS_VREG(X) (X[0] == '%')

namespace gematria {
namespace {

Expand Down Expand Up @@ -101,6 +104,7 @@ std::ostream& operator<<(std::ostream& os, EdgeType edge_type) {
EXEGESIS_ENUM_CASE(os, EdgeType::kAddressDisplacement);
EXEGESIS_ENUM_CASE(os, EdgeType::kReverseStructuralDependency);
EXEGESIS_ENUM_CASE(os, EdgeType::kInstructionPrefix);
EXEGESIS_ENUM_CASE(os, EdgeType::kInterference);
}
return os;
}
Expand All @@ -119,7 +123,8 @@ BasicBlockGraphBuilder::AddBasicBlockTransaction::AddBasicBlockTransaction(
prev_edge_senders_size_(graph_builder->edge_senders_.size()),
prev_edge_receivers_size_(graph_builder->edge_receivers_.size()),
prev_edge_types_size_(graph_builder->edge_receivers_.size()),
prev_global_features_size_(graph_builder->global_features_.size()) {}
prev_global_features_size_(graph_builder->global_features_.size()),
prev_interference_groups_(graph_builder->interference_groups_) {}

BasicBlockGraphBuilder::AddBasicBlockTransaction::~AddBasicBlockTransaction() {
if (!is_committed_) Rollback();
Expand Down Expand Up @@ -148,6 +153,7 @@ void BasicBlockGraphBuilder::AddBasicBlockTransaction::Rollback() {
GEMATRIA_CHECK_AND_RESIZE(edge_receivers_);
GEMATRIA_CHECK_AND_RESIZE(edge_types_);
GEMATRIA_CHECK_AND_RESIZE(global_features_);
graph_builder_.interference_groups_ = std::move(prev_interference_groups_);
}

#undef GEMATRIA_CHECK_AND_RESIZE
Expand Down Expand Up @@ -185,6 +191,7 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions(
// Clear the maps that are maintained per basic block.
register_nodes_.clear();
alias_group_nodes_.clear();
interference_groups_.clear();

const int prev_num_nodes = num_nodes();
const int prev_num_edges = num_edges();
Expand Down Expand Up @@ -260,6 +267,65 @@ void BasicBlockGraphBuilder::Reset() {
edge_types_.clear();

global_features_.clear();
interference_groups_.clear();
}

bool BasicBlockGraphBuilder::AddInterference(
const std::string& src_name, const std::string& src_token,
const std::vector<std::string>& dest_names,
const std::vector<int>& dest_sizes) {
assert(dest_names.size() == dest_sizes.size() &&
"dest_names and dest_sizes should have the same size");
assert(
register_nodes_.find(src_name) != register_nodes_.end() &&
"Register not found; It should be added before calling AddInterference");
LOG("Adding Interference: " << src_name << " " << src_token);
for (int i = 0; i < dest_names.size(); ++i) {
LOG(" " << dest_names[i] << " " << dest_sizes[i]);
}
LOG("=====================================");
NodeIndex& operand_node =
LookupOrInsert(register_nodes_, src_name, kInvalidNode);
if (operand_node == kInvalidNode) {
// since we already checked the existence of the register, this should never
// happen
return false;
}
std::vector<std::string> dest_tokens(dest_names.size());
for (int i = 0; i < dest_names.size(); ++i) {
if (IS_VREG(dest_names[i]))
dest_tokens[i] = getVREG_TOKEN(dest_sizes[i]);
else
dest_tokens[i] = dest_names[i];
}
for (int i = 0; i < dest_names.size(); ++i) {
if (interference_groups_[src_name].count(dest_names[i]) > 0) {
continue;
}
auto added = AddDependencyOnRegister(
operand_node, dest_names[i], dest_tokens[i], EdgeType::kInterference);
if (!added) return false;
added = AddDependencyToRegister(operand_node, dest_names[i], dest_tokens[i],
EdgeType::kInterference);
if (!added) return false;
interference_groups_[src_name].insert(dest_names[i]);
interference_groups_[dest_names[i]].insert(src_name);
}
LOG("Done adding interference");
LOG("=====================================");
LOG("Current interference groups: ");
for (auto& [key, value] : interference_groups_) {
LOG(key << " -> ");
for (auto& v : value) {
LOG(" " << v);
}
}
LOG("Current register nodes: ");
for (auto& [key, value] : register_nodes_) {
LOG(key << " -> " << value);
}
LOG("=====================================");
return true;
}

bool BasicBlockGraphBuilder::AddInputOperand(
Expand All @@ -270,7 +336,8 @@ bool BasicBlockGraphBuilder::AddInputOperand(
switch (operand.type()) {
case OperandType::kRegister: {
if (!AddDependencyOnRegister(instruction_node, operand.register_name(),
operand.register_name(), EdgeType::kInputOperands)) {
operand.register_name(),
EdgeType::kInputOperands)) {
return false;
}
} break;
Expand All @@ -280,6 +347,11 @@ bool BasicBlockGraphBuilder::AddInputOperand(
vreg_name, EdgeType::kInputOperands)) {
return false;
}
if (!AddInterference(operand.register_name(), vreg_name,
operand.getInterferedRegisters(),
operand.getInterferedRegistersSize())) {
return false;
}
} break;
case OperandType::kImmediateValue: {
AddEdge(EdgeType::kInputOperands,
Expand All @@ -296,33 +368,52 @@ bool BasicBlockGraphBuilder::AddInputOperand(
AddNode(NodeType::kAddressOperand, address_token_);
const AddressTuple& address_tuple = operand.address();
if (!address_tuple.base_register.empty()) {
bool is_virtual_reg = address_tuple.base_register[0] == '%';
bool is_virtual_reg = IS_VREG(address_tuple.base_register);
std::string vreg_token = getVREG_TOKEN(64);
bool result = AddDependencyOnRegister(address_node, address_tuple.base_register,
is_virtual_reg ? vreg_token : address_tuple.base_register,
EdgeType::kAddressBaseRegister);
bool result = AddDependencyOnRegister(
address_node, address_tuple.base_register,
is_virtual_reg ? vreg_token : address_tuple.base_register,
EdgeType::kAddressBaseRegister);
if (is_virtual_reg) {
result &= AddInterference(
address_tuple.base_register, vreg_token,
address_tuple.base_register_intefered_register,
address_tuple.base_register_intefered_register_sizes);
}
if (result == false) {
return false;
}
}
if (!address_tuple.index_register.empty()) {
bool is_virtual_reg = address_tuple.index_register[0] == '%';
bool is_virtual_reg = IS_VREG(address_tuple.index_register);
std::string vreg_token = getVREG_TOKEN(64);
bool result = AddDependencyOnRegister(address_node,
address_tuple.index_register,
is_virtual_reg ? vreg_token : address_tuple.index_register,
EdgeType::kAddressIndexRegister);
bool result = AddDependencyOnRegister(
address_node, address_tuple.index_register,
is_virtual_reg ? vreg_token : address_tuple.index_register,
EdgeType::kAddressIndexRegister);
if (is_virtual_reg) {
result &= AddInterference(
address_tuple.index_register, vreg_token,
address_tuple.index_register_intefered_register,
address_tuple.index_register_intefered_register_sizes);
}
if (result == false) {
return false;
}
}
if (!address_tuple.segment_register.empty()) {
bool is_virtual_reg = address_tuple.segment_register[0] == '%';
bool is_virtual_reg = IS_VREG(address_tuple.segment_register);
std::string vreg_token = getVREG_TOKEN(64);
bool result = AddDependencyOnRegister(address_node,
address_tuple.segment_register,
is_virtual_reg ? vreg_token : address_tuple.segment_register,
EdgeType::kAddressSegmentRegister);
bool result = AddDependencyOnRegister(
address_node, address_tuple.segment_register,
is_virtual_reg ? vreg_token : address_tuple.segment_register,
EdgeType::kAddressSegmentRegister);
if (is_virtual_reg) {
result &= AddInterference(
address_tuple.segment_register, vreg_token,
address_tuple.segment_register_intefered_register,
address_tuple.segment_register_intefered_register_sizes);
}
if (result == false) {
return false;
}
Expand Down Expand Up @@ -364,12 +455,16 @@ bool BasicBlockGraphBuilder::AddOutputOperand(
register_nodes_[operand.register_name()] = register_node;
} break;
case OperandType::kVirtualRegister: {
std::string vreg_name = getVREG_TOKEN(operand.size());
const NodeIndex register_node =
AddNode(NodeType::kRegister, vreg_name);
if (register_node == kInvalidNode) return false;
AddEdge(EdgeType::kOutputOperands, instruction_node, register_node);
register_nodes_[operand.register_name()] = register_node;
std::string vreg_token = getVREG_TOKEN(operand.size());
bool result =
AddDependencyToRegister(instruction_node, operand.register_name(),
vreg_token, EdgeType::kOutputOperands);
result &= AddInterference(operand.register_name(), vreg_token,
operand.getInterferedRegisters(),
operand.getInterferedRegistersSize());
if (result == false) {
return false;
}
} break;
case OperandType::kImmediateValue:
case OperandType::kFpImmediateValue:
Expand Down Expand Up @@ -407,6 +502,21 @@ bool BasicBlockGraphBuilder::AddDependencyOnRegister(
return true;
}

bool BasicBlockGraphBuilder::AddDependencyToRegister(
NodeIndex dependent_node, const std::string& register_name,
const std::string& register_token, EdgeType edge_type) {
NodeIndex& operand_node =
LookupOrInsert(register_nodes_, register_name, kInvalidNode);
if (operand_node == kInvalidNode) {
// Add a node for the register if it doesn't exist. This also updates the
// node index in `node_by_register`.
operand_node = AddNode(NodeType::kRegister, register_token);
}
if (operand_node == kInvalidNode) return false;
AddEdge(edge_type, dependent_node, operand_node);
return true;
}

BasicBlockGraphBuilder::NodeIndex BasicBlockGraphBuilder::AddNode(
NodeType node_type, TokenIndex token_index) {
const NodeIndex new_node_index = num_nodes();
Expand Down Expand Up @@ -493,7 +603,7 @@ void StrAppendList(std::stringstream& buffer, std::string_view list_name,
buffer << ",";
first = false;
}
buffer << item;
buffer << item << ",";
}
buffer << "]\n";
}
Expand All @@ -508,6 +618,26 @@ std::string BasicBlockGraphBuilder::DebugString() const {
buffer << "num_node_tokens = " << num_node_tokens() << "\n";
StrAppendList(buffer, "num_nodes_per_block", num_nodes_per_block());
StrAppendList(buffer, "num_edges_per_block", num_edges_per_block());
buffer << "register_nodes :"
<< "\n";
for (const auto& [register_name, node_index] : register_nodes_) {
buffer << " " << register_name << " -> " << node_index << "\n";
}
buffer << "alias_group_nodes :"
<< "\n";
for (const auto& [alias_group_id, node_index] : alias_group_nodes_) {
buffer << " " << alias_group_id << " -> " << node_index << "\n";
}
buffer << "interference_groups :"
<< "\n";
for (const auto& [register_name, interfered_registers] :
interference_groups_) {
buffer << " " << register_name << " -> [";
for (const auto& interfered_register : interfered_registers) {
buffer << " " << interfered_register;
}
buffer << " ]\n";
}
StrAppendList(buffer, "node_types", node_types());
StrAppendList(buffer, "edge_senders", edge_senders());
StrAppendList(buffer, "edge_receivers", edge_receivers());
Expand Down
18 changes: 17 additions & 1 deletion gematria/granite/graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "gematria/basic_block/basic_block.h"
Expand Down Expand Up @@ -131,6 +132,7 @@ enum class EdgeType {
// it would invalidate existing checkpoints.
kReverseStructuralDependency = 7,
kInstructionPrefix = 8,
kInterference = 9,
};

std::ostream& operator<<(std::ostream& os, NodeType node_type);
Expand Down Expand Up @@ -339,6 +341,8 @@ class BasicBlockGraphBuilder {
size_t prev_edge_receivers_size_;
size_t prev_edge_types_size_;
size_t prev_global_features_size_;
std::unordered_map<std::string_view, std::unordered_set<std::string>>
prev_interference_groups_;
};

// Adds nodes and edges for a single input operand of an instruction.
Expand All @@ -352,9 +356,19 @@ class BasicBlockGraphBuilder {
// a register. Adds the register node if it doesn't exist in the graph.
bool AddDependencyOnRegister(NodeIndex dependent_node,
const std::string& register_name,
const std::string& register_token,
const std::string& register_token,
EdgeType edge_type);

bool AddDependencyToRegister(NodeIndex dependent_node,
const std::string& register_name,
const std::string& register_token,
EdgeType edge_type);

bool AddInterference(const std::string& src_name,
const std::string& src_token,
const std::vector<std::string>& dest_names,
const std::vector<int>& dest_sizes);

// Adds a new node to the batch; the feature of the node is given directly by
// the caller.
NodeIndex AddNode(NodeType node_type, TokenIndex token_index);
Expand Down Expand Up @@ -392,6 +406,8 @@ class BasicBlockGraphBuilder {

std::unordered_map<std::string_view, NodeIndex> register_nodes_;
std::unordered_map<int, NodeIndex> alias_group_nodes_;
std::unordered_map<std::string_view, std::unordered_set<std::string>>
interference_groups_;
};

} // namespace gematria
Expand Down
Loading

0 comments on commit 9514b4c

Please sign in to comment.