Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed proto structure & Added virtual register name #6

Merged
merged 9 commits into from
Dec 3, 2023
19 changes: 19 additions & 0 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::ostream& operator<<(std::ostream& os, OperandType operand_type) {
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kFpImmediateValue);
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kAddress);
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kMemory);
GEMATRIA_PRINT_ENUM_VALUE_TO_OS(os, OperandType::kVirtualRegister);
}
return os;
}
Expand Down Expand Up @@ -101,9 +102,20 @@ bool InstructionOperand::operator==(const InstructionOperand& other) const {
return address() == other.address();
case OperandType::kMemory:
return alias_group_id() == other.alias_group_id();
case OperandType::kVirtualRegister:
return register_name() == other.register_name() && size() == other.size();
}
}

InstructionOperand InstructionOperand::VirtualRegister(
const std::string register_name, size_t size) {
InstructionOperand result;
result.type_ = OperandType::kVirtualRegister;
result.register_name_ = std::move(register_name);
result.size_ = size;
return result;
}

InstructionOperand InstructionOperand::Register(
const std::string register_name) {
InstructionOperand result;
Expand Down Expand Up @@ -189,6 +201,9 @@ void InstructionOperand::AddTokensToList(
case OperandType::kMemory:
tokens.emplace_back(kMemoryToken);
break;
case OperandType::kVirtualRegister:
tokens.emplace_back(getVREG_TOKEN(size()));
break;
}
}

Expand Down Expand Up @@ -220,6 +235,10 @@ std::string InstructionOperand::ToString() const {
case OperandType::kMemory:
buffer << ".from_memory(" << alias_group_id() << ")";
break;
case OperandType::kVirtualRegister:
buffer << ".from_virtual_register('" << register_name() << "', " << size()
<< ")";
break;
}
return buffer.str();
}
Expand Down
15 changes: 13 additions & 2 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <string_view>
#include <utility>
#include <vector>

namespace gematria {

// Tokens used for instruction canonicalization in Gematria. The values used
Expand All @@ -42,6 +41,10 @@ inline constexpr std::string_view kAddressToken = "_ADDRESS_";
inline constexpr std::string_view kMemoryToken = "_MEMORY_";
inline constexpr std::string_view kNoRegisterToken = "_NO_REGISTER_";
inline constexpr std::string_view kDisplacementToken = "_DISPLACEMENT_";
inline constexpr std::string_view kVirtualRegisterToken = "_VREG";
inline std::string getVREG_TOKEN(size_t size) {
return std::string(kVirtualRegisterToken) + std::to_string(size) + "_";
}

// The type of an operand of an instruction.
enum class OperandType {
Expand All @@ -68,6 +71,9 @@ enum class OperandType {
// The operand is a memory access. Instructions with this operand often have
// also an operand of type kAddress.
kMemory,

// The operand is a virtual register.
kVirtualRegister,
};

std::ostream& operator<<(std::ostream& os, OperandType operand_type);
Expand Down Expand Up @@ -140,6 +146,8 @@ class InstructionOperand {
InstructionOperand& operator=(InstructionOperand&&) = default;

// The operands must be created through one of the factory functions.
static InstructionOperand VirtualRegister(std::string register_name,
size_t size);
static InstructionOperand Register(std::string register_name);
static InstructionOperand ImmediateValue(uint64_t immediate_value);
static InstructionOperand FpImmediateValue(double fp_immediate_value);
Expand Down Expand Up @@ -172,9 +180,11 @@ class InstructionOperand {
// kUnknown.
OperandType type() const { return type_; }

const size_t size() const { return size_; }
// Returns the name of the register. Valid only when type() is kRegister.
const std::string& register_name() const {
assert(type_ == OperandType::kRegister);
assert(type_ == OperandType::kRegister ||
type_ == OperandType::kVirtualRegister);
return register_name_;
}

Expand Down Expand Up @@ -209,6 +219,7 @@ class InstructionOperand {
private:
OperandType type_ = OperandType::kUnknown;

size_t size_;
std::string register_name_;
uint64_t immediate_value_ = 0;
double fp_immediate_value_ = 0.0;
Expand Down
10 changes: 10 additions & 0 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ InstructionOperand InstructionOperandFromProto(
case CanonicalizedOperandProto::kMemory:
return InstructionOperand::MemoryLocation(
proto.memory().alias_group_id());
case CanonicalizedOperandProto::kVirtualRegister:
return InstructionOperand::VirtualRegister(
proto.virtual_register().name(), proto.virtual_register().size());
}
}

Expand All @@ -85,6 +88,13 @@ CanonicalizedOperandProto ProtoFromInstructionOperand(
case OperandType::kMemory:
proto.mutable_memory()->set_alias_group_id(operand.alias_group_id());
break;
case OperandType::kVirtualRegister: {
CanonicalizedOperandProto::VirtualRegister* virtual_register =
proto.mutable_virtual_register();
virtual_register->set_name(operand.register_name());
virtual_register->set_size(operand.size());
break;
}
case OperandType::kUnknown:
break;
}
Expand Down
24 changes: 24 additions & 0 deletions gematria/basic_block/basic_block_protos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,29 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) {
/* implicit_output_operands = */ {})}));
}

TEST(BasicBlockFromProtoTest, VRegInstructions) {
const BasicBlockProto proto = ParseTextProto(R"pb(
canonicalized_instructions {
mnemonic: "CMP64RI32"
llvm_mnemonic: "CMP64ri32"
input_operands { virtual_register { name: "%60" size: 64 } }
input_operands { immediate_value: 0 }
implicit_output_operands { register_name: "EFLAGS" }
}
)pb");
const BasicBlock block = BasicBlockFromProto(proto);
EXPECT_EQ(block,
BasicBlock({Instruction(
/* mnemonic = */ "CMP64RI32", /* llvm_mnemonic = */ "CMP64ri32",
/* prefixes = */ {},
/* input_operands = */
{InstructionOperand::VirtualRegister("%60", 64),
InstructionOperand::ImmediateValue(0)},
/* implicit_input_operands = */ {},
/* output_operands = */ {},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")})}));
}

} // namespace
} // namespace gematria
32 changes: 28 additions & 4 deletions gematria/basic_block/python/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ InstructionOperandPropertyOrNone(const InstructionOperand& operand) {
return &(operand.*getter_member_ptr)();
}

// Safe version of reading a field of InstructionOperand for cases when the
// getter returns a reference. Preserves const-ness of the returned value.
// - when operand.type() == expected_type1 or operand.type() == expected_type2, returns a pointer to the returned
// object.
// - otherwise, returns nullptr; this is converted by pybind11 to None.
// This wrapper can be used together with
// py::return_value_policy::reference_internal to avoid unnecessary copying of
// the address tuple.
template <OperandType expected_type1, OperandType expected_type2, auto getter_member_ptr,
typename ResultType = decltype((
InstructionOperand::ImmediateValue(0).*getter_member_ptr)())>
std::enable_if_t<std::is_lvalue_reference_v<ResultType>,
std::add_pointer_t<ResultType>>
InstructionOperandVregPropertyOrNone(const InstructionOperand& operand) {
if (operand.type() != expected_type1 || operand.type() != expected_type2 ) {
return nullptr;
}
return &(operand.*getter_member_ptr)();
}

PYBIND11_MODULE(basic_block, m) {
m.doc() = "Data structures representing instructions and basic blocks.";

Expand All @@ -96,7 +116,8 @@ PYBIND11_MODULE(basic_block, m) {
.value("IMMEDIATE_VALUE", OperandType::kImmediateValue)
.value("FP_IMMEDIATE_VALUE", OperandType::kFpImmediateValue)
.value("ADDRESS", OperandType::kAddress)
.value("MEMORY", OperandType::kMemory);
.value("MEMORY", OperandType::kMemory)
.value("VIRTUAL_REGISTER", OperandType::kVirtualRegister);

py::class_<AddressTuple> address_tuple(m, "AddressTuple");
address_tuple
Expand Down Expand Up @@ -133,6 +154,9 @@ PYBIND11_MODULE(basic_block, m) {
.def_static("from_fp_immediate_value",
&InstructionOperand::FpImmediateValue,
py::arg("fp_immediate_value"))
.def_static("from_virtual_register",
&InstructionOperand::VirtualRegister,
py::arg("register_name"), py::arg("size") = 0)
.def_static<InstructionOperand (*)(
std::string /* base_register */, int64_t /* displacement */,
std::string /* index_register */, int /* scaling */,
Expand Down Expand Up @@ -162,9 +186,9 @@ PYBIND11_MODULE(basic_block, m) {
.def("as_token_list", &InstructionOperand::AsTokenList)
.def_property_readonly("type", &InstructionOperand::type)
.def_property_readonly(
"register_name",
InstructionOperandPropertyOrNone<OperandType::kRegister,
&InstructionOperand::register_name>)
"register_name",InstructionOperandVregPropertyOrNone<OperandType::kRegister, OperandType::kVirtualRegister, &InstructionOperand::register_name>)
.def_property_readonly(
"size", &InstructionOperand::size)
.def_property_readonly("immediate_value",
InstructionOperandPropertyOrNone<
OperandType::kImmediateValue,
Expand Down
2 changes: 1 addition & 1 deletion gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG


#ifdef DEBUG
#define LOG(X) \
Expand Down
13 changes: 13 additions & 0 deletions gematria/datasets/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,16 @@ gematria_py_binary(
"//gematria/utils/python:pybind11_abseil_status",
],
)

gematria_py_binary(
name = "gen_tokens",
srcs = ["gen_tokens.py"],
deps = [
"//gematria/basic_block/python:basic_block_protos",
"//gematria/basic_block/python:basic_block",
"//gematria/io/python:tfrecord",
"//gematria/proto:basic_block_py_pb2",
"//gematria/proto:canonicalized_instruction_py_pb2",
"//gematria/proto:throughput_py_pb2",
],
)
64 changes: 64 additions & 0 deletions gematria/datasets/python/gen_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from gematria.basic_block.python import basic_block
from gematria.basic_block.python import basic_block_protos
from gematria.proto import basic_block_pb2
from gematria.proto import throughput_pb2
from gematria.proto import canonicalized_instruction_pb2
from gematria.io.python import tfrecord

from collections.abc import Sequence

from absl import app
from absl import flags
from absl import logging

_CanonicalizedInstructionProto = (
canonicalized_instruction_pb2.CanonicalizedInstructionProto
)

r"""Generates tokens from a Gematria data set.


Usage:
gen_tokens \
--gematria_input_tfrecord=/tmp/bhive/skl.tfrecord \
--gematria_output_tokens=/tmp/bhive/skl_tokens.txt \

"""

_INPUT_TFRECORD_FILE = flags.DEFINE_string(
'gematria_input_tfrecord',
None,
'The name of the TFRecord file to read the tokens from.',
required=True,
)

_OUTPUT_TOKENS_FILE = flags.DEFINE_string(
'gematria_output_tokens',
None,
'The name of the file to write the tokens to.',
required=True,
)

def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
output_blocks = list(
tfrecord.read_protos((_INPUT_TFRECORD_FILE.value,), throughput_pb2.BasicBlockWithThroughputProto)
)
token_set = set()
for block in output_blocks:
for instruction in block.basic_block.canonicalized_instructions:
ginstruction = basic_block_protos.instruction_from_proto(instruction)
for token in ginstruction.as_token_list():
if not token.startswith('%'):
token_set.add(token)
print(token_set)
with open(_OUTPUT_TOKENS_FILE.value, 'w') as f:
for token in token_set:
f.write(token)
f.write('\n')



if __name__ == '__main__':
app.run(main)
Loading
Loading