Skip to content

Commit

Permalink
Make annotations more generalized.
Browse files Browse the repository at this point in the history
 * Places `AnnotationProto` up one level so that it can be used with `BasicBlock`s along with `CanonicalizedInstruction`s.
 * Annotations are now stored as `repeated` in protos, and in `vector`s or `list`s rather than each type of annotation having a corresponding data member.
 * Naming has changed accordingly e.g. `RuntimeAnnotation` becomes `Annotation`.
  • Loading branch information
virajbshah committed Nov 27, 2023
1 parent 70b0982 commit 555ef25
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 106 deletions.
28 changes: 14 additions & 14 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,21 +229,21 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand) {
return os;
}

RuntimeAnnotation::RuntimeAnnotation(std::string pmu_event, double value)
: pmu_event(std::move(pmu_event)), value(value) {}
Annotation::Annotation(const std::string &name, double value)
: name(std::move(name)), value(value) {}

bool RuntimeAnnotation::operator==(const RuntimeAnnotation& other) const {
const auto as_tuple = [](const RuntimeAnnotation& annotation) {
return std::tie(annotation.pmu_event, annotation.value);
bool Annotation::operator==(const Annotation& other) const {
const auto as_tuple = [](const Annotation& annotation) {
return std::tie(annotation.name, annotation.value);
};
return as_tuple(*this) == as_tuple(other);
}

std::string RuntimeAnnotation::ToString() const {
std::string Annotation::ToString() const {
std::stringstream buffer;
buffer << "RuntimeAnnotation(";
if (!pmu_event.empty()) {
buffer << "pmu_event='" << pmu_event << "', ";
buffer << "Annotation(";
if (!name.empty()) {
buffer << "name='" << name << "', ";
}
if (value != -1) {
buffer << "value=" << value << ", ";
Expand All @@ -258,7 +258,7 @@ std::string RuntimeAnnotation::ToString() const {
}

std::ostream& operator<<(std::ostream& os,
const RuntimeAnnotation& annotation) {
const Annotation& annotation) {
os << annotation.ToString();
return os;
}
Expand All @@ -270,23 +270,23 @@ Instruction::Instruction(
std::vector<InstructionOperand> implicit_input_operands,
std::vector<InstructionOperand> output_operands,
std::vector<InstructionOperand> implicit_output_operands,
RuntimeAnnotation cache_miss_frequency)
std::vector<Annotation> instruction_annotations)
: mnemonic(std::move(mnemonic)),
llvm_mnemonic(std::move(llvm_mnemonic)),
prefixes(std::move(prefixes)),
input_operands(std::move(input_operands)),
implicit_input_operands(std::move(implicit_input_operands)),
output_operands(std::move(output_operands)),
implicit_output_operands(std::move(implicit_output_operands)),
cache_miss_frequency(std::move(cache_miss_frequency)) {}
instruction_annotations(std::move(instruction_annotations)) {}

bool Instruction::operator==(const Instruction& other) const {
const auto as_tuple = [](const Instruction& instruction) {
return std::tie(
instruction.mnemonic, instruction.llvm_mnemonic, instruction.prefixes,
instruction.input_operands, instruction.implicit_input_operands,
instruction.output_operands, instruction.implicit_output_operands,
instruction.cache_miss_frequency);
instruction.instruction_annotations);
};
return as_tuple(*this) == as_tuple(other);
}
Expand Down Expand Up @@ -329,7 +329,7 @@ std::string Instruction::ToString() const {
add_operand_list("output_operands", output_operands);
add_operand_list("implicit_output_operands", implicit_output_operands);

// TODO(virajbshah): Include cache_miss_frequency at the end of the string.
// TODO(virajbshah): Include instruction annotations at the end of the string.

auto msg = buffer.str();
assert(msg.size() >= 2);
Expand Down
33 changes: 18 additions & 15 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,27 +218,28 @@ class InstructionOperand {

std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand);

// Represents a runtime-related measure/statistic paired with the instruction.
struct RuntimeAnnotation {
RuntimeAnnotation() : value(-1){};
// Represents an annotation holding a value such as some measure/statistic
// paired with the instruction.
struct Annotation {
Annotation() : value(-1){};

// Initializes all fields of the annotation.
RuntimeAnnotation(std::string pmu_event, double value);
Annotation(const std::string &name, double value);

RuntimeAnnotation(const RuntimeAnnotation&) = default;
RuntimeAnnotation(RuntimeAnnotation&&) = default;
Annotation(const Annotation&) = default;
Annotation(Annotation&&) = default;

RuntimeAnnotation& operator=(const RuntimeAnnotation&) = default;
RuntimeAnnotation& operator=(RuntimeAnnotation&&) = default;
Annotation& operator=(const Annotation&) = default;
Annotation& operator=(Annotation&&) = default;

bool operator==(const RuntimeAnnotation& other) const;
bool operator!=(const RuntimeAnnotation& other) const {
bool operator==(const Annotation& other) const;
bool operator!=(const Annotation& other) const {
return !(*this == other);
}

std::string ToString() const;

std::string pmu_event;
std::string name;
double value;
};

Expand All @@ -254,7 +255,8 @@ struct Instruction {
std::vector<InstructionOperand> implicit_input_operands,
std::vector<InstructionOperand> output_operands,
std::vector<InstructionOperand> implicit_output_operands,
RuntimeAnnotation cache_miss_frequency = RuntimeAnnotation());
std::vector<Annotation> instruction_annotations =
std::vector<Annotation>{Annotation()});

Instruction(const Instruction&) = default;
Instruction(Instruction&&) = default;
Expand Down Expand Up @@ -305,9 +307,10 @@ struct Instruction {
// to the ML models explicitly.
std::vector<InstructionOperand> implicit_output_operands;

// The cache miss frequency of the instruction. Used to better model the
// overhead coming from LLC misses.
RuntimeAnnotation cache_miss_frequency;
// The list of instruction level annotations used to supply additional
// information to the model. Currently includes the cache miss frequency of
// the instruction. Used to better model the overhead coming from LLC misses.
std::vector<Annotation> instruction_annotations;

// The address of the instruction.
uint64_t address = 0;
Expand Down
40 changes: 29 additions & 11 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "gematria/basic_block/basic_block.h"
#include "gematria/proto/canonicalized_instruction.pb.h"
#include "gematria/proto/annotation.pb.h"
#include "google/protobuf/repeated_ptr_field.h"

namespace gematria {
Expand Down Expand Up @@ -112,20 +113,35 @@ void ToRepeatedPtrField(
ProtoFromInstructionOperand);
}

std::vector<Annotation> ToVector(
const google::protobuf::RepeatedPtrField<AnnotationProto>& protos) {
std::vector<Annotation> result(protos.size());
std::transform(protos.begin(), protos.end(), result.begin(),
AnnotationFromProto);
return result;
}

void ToRepeatedPtrField(
const std::vector<Annotation>& annotations,
google::protobuf::RepeatedPtrField<AnnotationProto>* repeated_field) {
repeated_field->Reserve(annotations.size());
std::transform(annotations.begin(), annotations.end(),
google::protobuf::RepeatedFieldBackInserter(repeated_field),
ProtoFromAnnotation);
}

} // namespace

RuntimeAnnotation RuntimeAnnotationFromProto(
const CanonicalizedInstructionProto::RuntimeAnnotation& proto) {
return RuntimeAnnotation(
/* pmu_event = */ proto.pmu_event(),
Annotation AnnotationFromProto(const AnnotationProto& proto) {
return Annotation(
/* name = */ proto.name(),
/* value = */ proto.value());
}

CanonicalizedInstructionProto::RuntimeAnnotation ProtoFromRuntimeAnnotation(
const RuntimeAnnotation& runtime_annotation) {
CanonicalizedInstructionProto::RuntimeAnnotation proto;
proto.set_pmu_event(runtime_annotation.pmu_event);
proto.set_value(runtime_annotation.value);
AnnotationProto ProtoFromAnnotation(const Annotation& annotation) {
AnnotationProto proto;
proto.set_name(annotation.name);
proto.set_value(annotation.value);
return proto;
}

Expand All @@ -141,8 +157,8 @@ Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto) {
/* output_operands = */ ToVector(proto.output_operands()),
/* implicit_output_operands = */
ToVector(proto.implicit_output_operands()),
/* cache_miss_frequency = */
RuntimeAnnotationFromProto(proto.cache_miss_frequency()));
/* instruction_annotations = */
ToVector(proto.instruction_annotations()));
}

CanonicalizedInstructionProto ProtoFromInstruction(
Expand All @@ -160,6 +176,8 @@ CanonicalizedInstructionProto ProtoFromInstruction(
proto.mutable_output_operands());
ToRepeatedPtrField(instruction.implicit_output_operands,
proto.mutable_implicit_output_operands());
ToRepeatedPtrField(instruction.instruction_annotations,
proto.mutable_instruction_annotations());
return proto;
}

Expand Down
10 changes: 4 additions & 6 deletions gematria/basic_block/basic_block_protos.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ InstructionOperand InstructionOperandFromProto(
CanonicalizedOperandProto ProtoFromInstructionOperand(
const InstructionOperand& operand);

// Creates a runtime annotation data structure from a proto.
RuntimeAnnotation RuntimeAnnotationFromProto(
const CanonicalizedInstructionProto::RuntimeAnnotation& proto);
// Creates a annotation data structure from a proto.
Annotation AnnotationFromProto(const AnnotationProto& proto);

// Creates a proto representing the given runtime annotation.
CanonicalizedInstructionProto::RuntimeAnnotation ProtoFromRuntimeAnnotation(
const RuntimeAnnotation& runtime_annotation);
// Creates a proto representing the given annotation.
AnnotationProto ProtoFromAnnotation(const Annotation& annotation);

// Creates an instruction data structure from a proto.
Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto);
Expand Down
33 changes: 22 additions & 11 deletions gematria/basic_block/basic_block_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ TEST(InstructionOperandTest, AsTokenList) {
}
}

// TODO(virajbshah): Add tests for RuntimeAnnotation.
// TODO(virajbshah): Add tests for Annotation.

TEST(InstructionTest, Constructor) {
constexpr char kMnemonic[] = "MOV";
Expand All @@ -332,7 +332,8 @@ TEST(InstructionTest, Constructor) {
InstructionOperand::MemoryLocation(3)};
const std::vector<InstructionOperand> kImplicitOutputOperands = {
InstructionOperand::Register("EFLAGS")};
const RuntimeAnnotation kCacheMissFrequency("r20d1", 0.875);
const std::vector<Annotation> kInstructionAnnotations = {
Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)};

const Instruction instruction(
/* mnemonic = */ kMnemonic,
Expand All @@ -342,15 +343,15 @@ TEST(InstructionTest, Constructor) {
/* implicit_input_operands = */ kImplicitInputOperands,
/* output_operands = */ kOutputOperands,
/* implicit_output_operands = */ kImplicitOutputOperands,
/* cache_miss_frequency = */ kCacheMissFrequency);
/* instruction_annotations = */ kInstructionAnnotations);
EXPECT_EQ(instruction.mnemonic, kMnemonic);
EXPECT_EQ(instruction.llvm_mnemonic, kLlvmMnemonic);
EXPECT_EQ(instruction.prefixes, kPrefixes);
EXPECT_EQ(instruction.input_operands, kInputOperands);
EXPECT_EQ(instruction.implicit_input_operands, kImplicitInputOperands);
EXPECT_EQ(instruction.output_operands, kOutputOperands);
EXPECT_EQ(instruction.implicit_output_operands, kImplicitOutputOperands);
EXPECT_EQ(instruction.cache_miss_frequency, kCacheMissFrequency);
EXPECT_EQ(instruction.instruction_annotations, kInstructionAnnotations);
}

TEST(InstructionTest, AsTokenList) {
Expand All @@ -365,7 +366,8 @@ TEST(InstructionTest, AsTokenList) {
InstructionOperand::MemoryLocation(3)};
const std::vector<InstructionOperand> kImplicitOutputOperands = {
InstructionOperand::Register("EFLAGS")};
const RuntimeAnnotation kCacheMissFrequency("r20d1", 0.875);
const std::vector<Annotation> kInstructionAnnotations = {
Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)};

const Instruction instruction(
/* mnemonic = */ kMnemonic,
Expand All @@ -375,7 +377,7 @@ TEST(InstructionTest, AsTokenList) {
/* implicit_input_operands = */ kImplicitInputOperands,
/* output_operands = */ kOutputOperands,
/* implicit_output_operands = */ kImplicitOutputOperands,
/* cache_miss_frequency = */ kCacheMissFrequency);
/* instruction_annotations = */ kInstructionAnnotations);

EXPECT_THAT(instruction.AsTokenList(),
ElementsAre(kPrefixes[0], kPrefixes[1], kMnemonic,
Expand All @@ -395,7 +397,8 @@ TEST(InstructionTest, ToString) {
/* output_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")},
RuntimeAnnotation("r20d1", 0.875));
/* instruction_annotations = */
{Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)});
constexpr char kExpectedString[] =
"Instruction(mnemonic='ADC', llvm_mnemonic='ADC32rr', "
"prefixes=('LOCK',), "
Expand Down Expand Up @@ -501,7 +504,9 @@ TEST(BasicBlockTest, Constructor) {
/* output_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")},
RuntimeAnnotation("r20d1", 0.875));
/* instruction_annotations = */
{Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)});

const BasicBlock block({instruction});
EXPECT_THAT(block.instructions, ElementsAre(instruction));
}
Expand All @@ -525,7 +530,9 @@ TEST(BasicBlockTest, Equality) {
/* output_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")},
RuntimeAnnotation("r20d1", 0.875)));
/* instruction_annotations = */
{Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}));

EXPECT_NE(block_1, block_2);
EXPECT_FALSE(block_1 == block_2);

Expand All @@ -541,7 +548,9 @@ TEST(BasicBlockTest, Equality) {
/* output_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")},
RuntimeAnnotation("r20d1", 0.875)));
/* instruction_annotations = */
{Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}));

EXPECT_EQ(block_1, block_2);
EXPECT_FALSE(block_1 != block_2);

Expand All @@ -563,7 +572,9 @@ TEST(BasicBlockTest, ToString) {
/* output_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")},
RuntimeAnnotation("r20d1", 0.875));
/* instruction_annotations = */
{Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)});

BasicBlock block({instruction});
constexpr char kExpectedString[] =
"BasicBlock(instructions=InstructionList((Instruction(mnemonic='ADC', "
Expand Down
Loading

0 comments on commit 555ef25

Please sign in to comment.