Skip to content

Commit

Permalink
[Util] Add support for casting to ExtensionInst
Browse files Browse the repository at this point in the history
Add support for DynCast / IsA for extension instructions.
  • Loading branch information
Weiming Zhao authored and weimingzha0 committed Dec 3, 2021
1 parent 09a77e5 commit 9212669
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions include/halo/lib/ir/extension_instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class ExtensionInst : public Instruction {
void SetOpname(const std::string& name) noexcept { opname_ = name; }
void PrintOpcode(std::ostream& os) const final { os << opname_; }
virtual ExtensionKind GetExtensionKind() const noexcept = 0;
static inline bool Classof(const IRObject* obj) {
if (!Instruction::Classof(obj)) {
return false;
}
const Instruction* inst = Downcast<const Instruction>(obj);
return inst->GetOpCode() == OpCode::EXTENSION;
}

private:
std::string opname_;
Expand All @@ -71,6 +78,15 @@ class TFExtensionInst final : public ExtensionInst {
}
std::unique_ptr<Instruction> Clone() const override;

static inline bool Classof(const IRObject* obj) {
if (!ExtensionInst::Classof(obj)) {
return false;
}
const ExtensionInst* inst = Downcast<const ExtensionInst>(obj);
return inst->GetExtensionKind() ==
ExtensionInst::ExtensionKind::kExtension_TENSORFLOW;
}

private:
// A string name to TF extension opcode map.
static const NameToTFOpMap TFMap;
Expand Down Expand Up @@ -130,6 +146,15 @@ class ONNXExtensionInst final : public ExtensionInst {
}
std::unique_ptr<Instruction> Clone() const override;

static inline bool Classof(const IRObject* obj) {
if (!ExtensionInst::Classof(obj)) {
return false;
}
const ExtensionInst* inst = Downcast<const ExtensionInst>(obj);
return inst->GetExtensionKind() ==
ExtensionInst::ExtensionKind::kExtension_ONNX;
}

private:
// A string name to extension opcode map.
static const NameToOpMap ONNXMap;
Expand All @@ -153,6 +178,15 @@ class TFLITEExtensionInst final : public ExtensionInst {
}
std::unique_ptr<Instruction> Clone() const override;

static inline bool Classof(const IRObject* obj) {
if (!ExtensionInst::Classof(obj)) {
return false;
}
const ExtensionInst* inst = Downcast<const ExtensionInst>(obj);
return inst->GetExtensionKind() ==
ExtensionInst::ExtensionKind::kExtension_TFLITE;
}

private:
// A string name to extension opcode map.
static const NameToOpMap TFLITEMap;
Expand All @@ -176,6 +210,15 @@ class CAFFEExtensionInst final : public ExtensionInst {
}
std::unique_ptr<Instruction> Clone() const override;

static inline bool Classof(const IRObject* obj) {
if (!ExtensionInst::Classof(obj)) {
return false;
}
const ExtensionInst* inst = Downcast<const ExtensionInst>(obj);
return inst->GetExtensionKind() ==
ExtensionInst::ExtensionKind::kExtension_CAFFE;
}

private:
// A string name to extension opcode map.
static const NameToOpMap CAFFEMap;
Expand Down

0 comments on commit 9212669

Please sign in to comment.