Skip to content

Commit

Permalink
Merge pull request #6 from GranLte/lukezhuz/training_wo_itf
Browse files Browse the repository at this point in the history
Milestone1: MIR model without live intervals
  • Loading branch information
9Tempest authored Dec 3, 2023
2 parents 8d644f6 + ad8aaf8 commit b404fb6
Show file tree
Hide file tree
Showing 16 changed files with 350 additions and 71 deletions.
19 changes: 19 additions & 0 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::ostream& operator<<(std::ostream& os, OperandType operand_type) {
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kFpImmediateValue);
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kAddress);
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kMemory);
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kVirtualRegister);
}
return os;
}
Expand Down Expand Up @@ -101,9 +102,20 @@ bool InstructionOperand::operator==(const InstructionOperand& other) const {
return address() == other.address();
case OperandType::kMemory:
return alias_group_id() == other.alias_group_id();
case OperandType::kVirtualRegister:
return register_name() == other.register_name() && size() == other.size();
}
}

InstructionOperand InstructionOperand::VirtualRegister(
const std::string register_name, size_t size) {
InstructionOperand result;
result.type_ = OperandType::kVirtualRegister;
result.register_name_ = std::move(register_name);
result.size_ = size;
return result;
}

InstructionOperand InstructionOperand::Register(
const std::string register_name) {
InstructionOperand result;
Expand Down Expand Up @@ -189,6 +201,9 @@ void InstructionOperand::AddTokensToList(
case OperandType::kMemory:
tokens.emplace_back(kMemoryToken);
break;
case OperandType::kVirtualRegister:
tokens.emplace_back(getVREG_TOKEN(size()));
break;
}
}

Expand Down Expand Up @@ -220,6 +235,10 @@ std::string InstructionOperand::ToString() const {
case OperandType::kMemory:
buffer << ".from_memory(" << alias_group_id() << ")";
break;
case OperandType::kVirtualRegister:
buffer << ".from_virtual_register('" << register_name() << "', " << size()
<< ")";
break;
}
return buffer.str();
}
Expand Down
15 changes: 13 additions & 2 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <string_view>
#include <utility>
#include <vector>

namespace gematria {

// Tokens used for instruction canonicalization in Gematria. The values used
Expand All @@ -42,6 +41,10 @@ inline constexpr std::string_view kAddressToken = "_ADDRESS_";
inline constexpr std::string_view kMemoryToken = "_MEMORY_";
inline constexpr std::string_view kNoRegisterToken = "_NO_REGISTER_";
inline constexpr std::string_view kDisplacementToken = "_DISPLACEMENT_";
inline constexpr std::string_view kVirtualRegisterToken = "_VREG";
inline std::string getVREG_TOKEN(size_t size) {
return std::string(kVirtualRegisterToken) + std::to_string(size) + "_";
}

// The type of an operand of an instruction.
enum class OperandType {
Expand All @@ -68,6 +71,9 @@ enum class OperandType {
// The operand is a memory access. Instructions with this operand often have
// also an operand of type kAddress.
kMemory,

// The operand is a virtual register.
kVirtualRegister,
};

std::ostream& operator<<(std::ostream& os, OperandType operand_type);
Expand Down Expand Up @@ -140,6 +146,8 @@ class InstructionOperand {
InstructionOperand& operator=(InstructionOperand&&) = default;

// The operands must be created through one of the factory functions.
static InstructionOperand VirtualRegister(std::string register_name,
size_t size);
static InstructionOperand Register(std::string register_name);
static InstructionOperand ImmediateValue(uint64_t immediate_value);
static InstructionOperand FpImmediateValue(double fp_immediate_value);
Expand Down Expand Up @@ -172,9 +180,11 @@ class InstructionOperand {
// kUnknown.
OperandType type() const { return type_; }

const size_t size() const { return size_; }
// Returns the name of the register. Valid only when type() is kRegister.
const std::string& register_name() const {
assert(type_ == OperandType::kRegister);
assert(type_ == OperandType::kRegister ||
type_ == OperandType::kVirtualRegister);
return register_name_;
}

Expand Down Expand Up @@ -209,6 +219,7 @@ class InstructionOperand {
private:
OperandType type_ = OperandType::kUnknown;

size_t size_;
std::string register_name_;
uint64_t immediate_value_ = 0;
double fp_immediate_value_ = 0.0;
Expand Down
10 changes: 10 additions & 0 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ InstructionOperand InstructionOperandFromProto(
case CanonicalizedOperandProto::kMemory:
return InstructionOperand::MemoryLocation(
proto.memory().alias_group_id());
case CanonicalizedOperandProto::kVirtualRegister:
return InstructionOperand::VirtualRegister(
proto.virtual_register().name(), proto.virtual_register().size());
}
}

Expand All @@ -85,6 +88,13 @@ CanonicalizedOperandProto ProtoFromInstructionOperand(
case OperandType::kMemory:
proto.mutable_memory()->set_alias_group_id(operand.alias_group_id());
break;
case OperandType::kVirtualRegister: {
CanonicalizedOperandProto::VirtualRegister* virtual_register =
proto.mutable_virtual_register();
virtual_register->set_name(operand.register_name());
virtual_register->set_size(operand.size());
break;
}
case OperandType::kUnknown:
break;
}
Expand Down
24 changes: 24 additions & 0 deletions gematria/basic_block/basic_block_protos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,29 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) {
/* implicit_output_operands = */ {})}));
}

TEST(BasicBlockFromProtoTest, VRegInstructions) {
const BasicBlockProto proto = ParseTextProto(R"pb(
canonicalized_instructions {
mnemonic: "CMP64RI32"
llvm_mnemonic: "CMP64ri32"
input_operands { virtual_register { name: "%60" size: 64 } }
input_operands { immediate_value: 0 }
implicit_output_operands { register_name: "EFLAGS" }
}
)pb");
const BasicBlock block = BasicBlockFromProto(proto);
EXPECT_EQ(block,
BasicBlock({Instruction(
/* mnemonic = */ "CMP64RI32", /* llvm_mnemonic = */ "CMP64ri32",
/* prefixes = */ {},
/* input_operands = */
{InstructionOperand::VirtualRegister("%60", 64),
InstructionOperand::ImmediateValue(0)},
/* implicit_input_operands = */ {},
/* output_operands = */ {},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")})}));
}

} // namespace
} // namespace gematria
32 changes: 28 additions & 4 deletions gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ InstructionOperandPropertyOrNone(const InstructionOperand& operand) {
return &(operand.*getter_member_ptr)();
}

// Safe version of reading a field of InstructionOperand for cases when the
// getter returns a reference. Preserves const-ness of the returned value.
// - when operand.type() == expected_type1 or operand.type() == expected_type2, returns a pointer to the returned
// object.
// - otherwise, returns nullptr; this is converted by pybind11 to None.
// This wrapper can be used together with
// py::return_value_policy::reference_internal to avoid unnecessary copying of
// the address tuple.
template <OperandType expected_type1, OperandType expected_type2, auto getter_member_ptr,
typename ResultType = decltype((
InstructionOperand::ImmediateValue(0).*getter_member_ptr)())>
std::enable_if_t<std::is_lvalue_reference_v<ResultType>,
std::add_pointer_t<ResultType>>
InstructionOperandVregPropertyOrNone(const InstructionOperand& operand) {
if (operand.type() != expected_type1 || operand.type() != expected_type2 ) {
return nullptr;
}
return &(operand.*getter_member_ptr)();
}

PYBIND11_MODULE(basic_block, m) {
m.doc() = "Data structures representing instructions and basic blocks.";

Expand All @@ -96,7 +116,8 @@ PYBIND11_MODULE(basic_block, m) {
.value("IMMEDIATE_VALUE", OperandType::kImmediateValue)
.value("FP_IMMEDIATE_VALUE", OperandType::kFpImmediateValue)
.value("ADDRESS", OperandType::kAddress)
.value("MEMORY", OperandType::kMemory);
.value("MEMORY", OperandType::kMemory)
.value("VIRTUAL_REGISTER", OperandType::kVirtualRegister);

py::class_<AddressTuple> address_tuple(m, "AddressTuple");
address_tuple
Expand Down Expand Up @@ -133,6 +154,9 @@ PYBIND11_MODULE(basic_block, m) {
.def_static("from_fp_immediate_value",
&InstructionOperand::FpImmediateValue,
py::arg("fp_immediate_value"))
.def_static("from_virtual_register",
&InstructionOperand::VirtualRegister,
py::arg("register_name"), py::arg("size") = 0)
.def_static<InstructionOperand (*)(
std::string /* base_register */, int64_t /* displacement */,
std::string /* index_register */, int /* scaling */,
Expand Down Expand Up @@ -162,9 +186,9 @@ PYBIND11_MODULE(basic_block, m) {
.def("as_token_list", &InstructionOperand::AsTokenList)
.def_property_readonly("type", &InstructionOperand::type)
.def_property_readonly(
"register_name",
InstructionOperandPropertyOrNone<OperandType::kRegister,
&InstructionOperand::register_name>)
"register_name",InstructionOperandVregPropertyOrNone<OperandType::kRegister, OperandType::kVirtualRegister, &InstructionOperand::register_name>)
.def_property_readonly(
"size", &InstructionOperand::size)
.def_property_readonly("immediate_value",
InstructionOperandPropertyOrNone<
OperandType::kImmediateValue,
Expand Down
2 changes: 1 addition & 1 deletion gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG


#ifdef DEBUG
#define LOG(X) \
Expand Down
13 changes: 13 additions & 0 deletions gematria/datasets/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,16 @@ gematria_py_binary(
"//gematria/utils/python:pybind11_abseil_status",
],
)

gematria_py_binary(
name = "gen_tokens",
srcs = ["gen_tokens.py"],
deps = [
"//gematria/basic_block/python:basic_block_protos",
"//gematria/basic_block/python:basic_block",
"//gematria/io/python:tfrecord",
"//gematria/proto:basic_block_py_pb2",
"//gematria/proto:canonicalized_instruction_py_pb2",
"//gematria/proto:throughput_py_pb2",
],
)
64 changes: 64 additions & 0 deletions gematria/datasets/python/gen_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from gematria.basic_block.python import basic_block
from gematria.basic_block.python import basic_block_protos
from gematria.proto import basic_block_pb2
from gematria.proto import throughput_pb2
from gematria.proto import canonicalized_instruction_pb2
from gematria.io.python import tfrecord

from collections.abc import Sequence

from absl import app
from absl import flags
from absl import logging

_CanonicalizedInstructionProto = (
canonicalized_instruction_pb2.CanonicalizedInstructionProto
)

r"""Generates tokens from a Gematria data set.
Usage:
gen_tokens \
--gematria_input_tfrecord=/tmp/bhive/skl.tfrecord \
--gematria_output_tokens=/tmp/bhive/skl_tokens.txt \
"""

_INPUT_TFRECORD_FILE = flags.DEFINE_string(
'gematria_input_tfrecord',
None,
'The name of the TFRecord file to read the tokens from.',
required=True,
)

_OUTPUT_TOKENS_FILE = flags.DEFINE_string(
'gematria_output_tokens',
None,
'The name of the file to write the tokens to.',
required=True,
)

def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
output_blocks = list(
tfrecord.read_protos((_INPUT_TFRECORD_FILE.value,), throughput_pb2.BasicBlockWithThroughputProto)
)
token_set = set()
for block in output_blocks:
for instruction in block.basic_block.canonicalized_instructions:
ginstruction = basic_block_protos.instruction_from_proto(instruction)
for token in ginstruction.as_token_list():
if not token.startswith('%'):
token_set.add(token)
print(token_set)
with open(_OUTPUT_TOKENS_FILE.value, 'w') as f:
for token in token_set:
f.write(token)
f.write('\n')



if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit b404fb6

Please sign in to comment.