Skip to content

Commit

Permalink
If: WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Weiming Zhao committed Dec 3, 2021
1 parent 6aff283 commit b588a43
Show file tree
Hide file tree
Showing 15 changed files with 127 additions and 7 deletions.
3 changes: 3 additions & 0 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ static odla_value CreateValue(T* t, const odla_value_type& type,
auto v = std::make_unique<_odla_value>(t, type, name);
auto ret = v.get();
g_comp->vals.push_back(std::move(v));
if (!g_comp->branchs.empty()) {
g_comp->branchs.top().branch->addInput(*ret);
}
return ret;
}

Expand Down
15 changes: 14 additions & 1 deletion include/halo/lib/ir/values.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <iostream>
#include <list>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -248,7 +249,16 @@ class IRObject {
}
}

/// Reset `idx`-th operand to null. The number of operands remains unchanged.
/// Append an IRObject that depends on this object.
void AddDependant(IRObject* dependent) { dependents_.insert(dependent); }

/// Return dependents
const std::set<IRObject*>& GetDependents() const noexcept {
return dependents_;
}

/// Reset `idx`-th operand to null. The number of operands remains
/// unchanged.
void ResetOperand(size_t idx);

/// Drop all operands and reset the operand counter.
Expand Down Expand Up @@ -417,6 +427,9 @@ class IRObject {
// Results' use_list
std::vector<UseList> results_uses_;

// Results' dependents
std::set<IRObject*> dependents_;

// Attribute list
std::vector<std::unique_ptr<Attribute>> attributes_;
};
Expand Down
1 change: 1 addition & 0 deletions include/halo/lib/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class HL_API_EXPORT PassManager final {
Pass* AddCodeFormatterPass(std::ostringstream& code,
std::ostringstream& header,
const CXXCodeGenOpts& opts);
Pass* AddConvertTFCFGPass();
Pass* AddDCEPass();
Pass* AddDevicePlacementPass();
Pass* AddFusionPass(const FusionOptions& opts);
Expand Down
3 changes: 3 additions & 0 deletions include/halo/utils/passes_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target,
if (opts.enable_type_cast) {
pm->AddTypeCastPass();
}
if (format == ModelFormat::TENSORFLOW) {
pm->AddConvertTFCFGPass();
}
if (opts.constant_decombine) {
pm->AddConstantDecombinePass();
}
Expand Down
17 changes: 17 additions & 0 deletions lib/parser/tensorflow/tf_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ Status TFParser::ConvertToHaloIR(const tensorflow::GraphDef& graph_def) {
HLCHECK(graph_def.node_size() == i);
ConvertReturnNodes(ir_builder_.get(), ret_vals);

// Add control dependents.
for (const auto& kv : control_edges_) {
auto it = inst_name_to_ptr_.find(kv.first);
HLCHECK(it != inst_name_to_ptr_.end());
for (const auto& dep_name : kv.second) {
auto it_d = inst_name_to_ptr_.find(dep_name);
HLCHECK(it_d != inst_name_to_ptr_.end());
it->second->AddDependant(it_d->second);
}
}
return Status::SUCCESS;
}

Expand Down Expand Up @@ -898,6 +908,13 @@ Status TFParser::ConvertConstNode(IRBuilder* ir_builder,
const tensorflow::NodeDef& node_def) {
TFAttrs attrs(node_def);
DataType data_type = DataType::INVALID;
// Check for control deps
for (int i = 0, e = node_def.input_size(); i < e; ++i) {
const auto& dep = node_def.input(i);
HLCHECK(!dep.empty() && dep.front() == '^');
control_edges_[dep.substr(1)].push_back(node_def.name());
}

if (attrs.Process<DataType>("dtype", &data_type)) {
switch (data_type) {
case DataType::BOOL: {
Expand Down
1 change: 1 addition & 0 deletions lib/parser/tensorflow/tf_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class TFParser : public Parser {
using CallBack =
std::function<Status(IRBuilder*, const tensorflow::NodeDef&)>;
std::unordered_map<std::string, CallBack> func_lists_;
std::unordered_map<std::string, std::vector<std::string>> control_edges_;
};

/// Convert pb to ipu graphdef
Expand Down
3 changes: 3 additions & 0 deletions lib/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "halo/lib/transforms/analyzer.h"
#include "halo/lib/transforms/caffeextension_legalizer.h"
#include "halo/lib/transforms/constant_decombine.h"
#include "halo/lib/transforms/convert_tf_cfg.h"
#include "halo/lib/transforms/dce.h"
#include "halo/lib/transforms/device_placement.h"
#include "halo/lib/transforms/fusion.h"
Expand Down Expand Up @@ -265,6 +266,8 @@ Pass* PassManager::AddCodeFormatterPass(std::ostringstream& buf_code,
return AddPass<CodeFormatter>(buf_code, buf_header, opts);
}

Pass* PassManager::AddConvertTFCFGPass() { return AddPass<ConvertTFCFG>(); }

Pass* PassManager::AddDCEPass() { return AddPass<DCE>(); }

Pass* PassManager::AddDevicePlacementPass() {
Expand Down
7 changes: 7 additions & 0 deletions lib/target/generic_cpp/generic_cxx_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ bool GenericCXXCodeGen::RunOnModule(Module* module) {
}

if (entry_func != nullptr) {
entry_func->Dump();
if (module->Functions().size() > 1) {
RunOnHostFunction(*entry_func);
} else {
Expand Down Expand Up @@ -268,6 +269,9 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
auto nr_outputs = ret_inst.GetNumOfOperands();
model_info.num_outputs = nr_outputs;
for (const auto& out : ret_inst.GetOperands()) {
if (out.IsNull()) {
continue;
}
const auto& type = out.GetType();
if (ir_mapping_.find(out) == ir_mapping_.end()) {
CXXValue cv(out.GetDef()->GetName(), TensorTypeToCXXType(type, false));
Expand Down Expand Up @@ -836,6 +840,9 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
index = 0;
// Pre-launch binding.
for (auto& op : return_inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
auto& cv = ir_mapping_[op];
std::string arg_name = (opts_.emit_inference_func_sig || is_sub)
? (is_sub ? "outputs.values[" : "outputs[") +
Expand Down
3 changes: 3 additions & 0 deletions lib/target/generic_cpp/return.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace halo {
void GenericCXXCodeGen::RunOnInstruction(ReturnInst* inst) {
bool is_compile_mode = opts_.exec_mode == ExecMode::Compile;
for (auto& op : inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
const CXXValue& val = ir_mapping_[op];
if (is_compile_mode) {
bool is_entry_with_calls =
Expand Down
1 change: 1 addition & 0 deletions lib/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(SRCS
analyzer.cc
caffeextension_legalizer.cc
constant_decombine.cc
convert_tf_cfg.cc
dce.cc
device_placement.cc
fusion.cc
Expand Down
20 changes: 17 additions & 3 deletions lib/transforms/dce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,28 @@

namespace halo {

static void RemoveLoopBody(LoopInst* loop_inst) {
auto body = loop_inst->GetBody();
auto return_inst = body->GetReturnInst();
static void RemoveBody(BasicBlock* bb) {
auto return_inst = bb->GetReturnInst();
if (return_inst != nullptr) {
// Drop all the operands of the return instruction so the rest of the body
// loop will be DCE'ed automatically.
// Note that the return inst cannot be erased because the current legalizer
// will try to append one if no return inst exists for a block.
return_inst->DropAllOperands();
if (bb->Instructions().size() == 1) {
bb->Instructions().clear();
return;
}
}
}
static void RemoveLoopBody(LoopInst* loop_inst) {
RemoveBody(loop_inst->GetBody());
}

static void RemoveIfBody(IfInst* if_inst) {
RemoveBody(if_inst->GetThenBranch());
RemoveBody(if_inst->GetElseBranch());
}

// For instructions with `undef` operands, they are unreachable except for
// `tf_merge` and optional operands.
Expand Down Expand Up @@ -85,6 +96,9 @@ bool DCE::RunOnBasicBlock(BasicBlock* bb) {
if (inst->GetOpCode() == OpCode::LOOP) {
RemoveLoopBody(DynCast<LoopInst>(inst));
}
if (inst->GetOpCode() == OpCode::IF) {
RemoveIfBody(DynCast<IfInst>(inst));
}
it = bb->Instructions().erase(it);
} else {
it = std::next(it);
Expand Down
1 change: 1 addition & 0 deletions lib/transforms/input_legalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ bool InputLegalizer::RunOnFunction(Function* func) {
: it->second.GetDimSizes();
arg->GetResultsTypes()[0] = halo::Type(dt, dims);
specified_shapes.erase(it);
changed = true;
}

auto dims = ty.GetDimSizes();
Expand Down
30 changes: 29 additions & 1 deletion lib/transforms/tfextension_legalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,43 @@ static std::vector<Def> ConvertStridedSlice(const TFExtensionInst* ext,
static std::vector<Def> ConvertSwitch(const TFExtensionInst* ext,
IRBuilder* builder) {
const auto& data = ext->GetOperand(0);
if (const Constant* pred = DynCast<Constant>(ext->GetOperand(1));
const auto& cond = ext->GetOperand(1);
#if 0
if (const Constant* pred = DynCast<Constant>(cond);
pred != nullptr) {
HLCHECK(pred->GetResultType().GetTotalNumOfElements() == 1);
bool cond = pred->GetDataAsInt64(0) != 0;
std::vector<Def> ret_true{Def::GetUndefined(), data};
std::vector<Def> ret_false{data, Def::GetUndefined()};
return cond ? ret_true : ret_false;
}
#endif
// TODO(unknown): move to separate pass?
#if 1
builder->SetInsertAfter(ext);
BasicBlockBuilder bb_builder(ext->GetParent()->GetParent());
const auto& name = ext->GetName();
auto if_inst = builder->CreateIf(ext->GetName(), cond);
if_inst->AddOneOperand(data);

BasicBlock* bb_t = bb_builder.CreateBasicBlock(name + "_true");
if_inst->SetThenBranch(bb_t);
IRBuilder builder_t(bb_t);
auto arg_builder_t = std::make_unique<ArgumentBuilder>(bb_t);
auto arg_t = arg_builder_t->CreateArgument(name + "_t", data.GetType());
builder_t.CreateReturn(name + "ret_t", *arg_t);

BasicBlock* bb_f = bb_builder.CreateBasicBlock(name + "_false");
IRBuilder builder_f(bb_f);
if_inst->SetElseBranch(bb_f);
auto arg_builder_f = std::make_unique<ArgumentBuilder>(bb_f);
auto arg_f = arg_builder_f->CreateArgument(name + "_f", data.GetType());
builder_f.CreateReturn(name + "ret_f", *arg_f);
if_inst->SetNumOfResults(2);
return {Def(if_inst, 0), Def(if_inst, 1)};
#else
return {};
#endif
}

static std::vector<Def> ConvertMerge(const TFExtensionInst* ext,
Expand Down
3 changes: 2 additions & 1 deletion lib/transforms/transforms_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ bool AppendReturnInst(BasicBlock* bb) {
std::vector<Def> outputs;
for (auto& inst_t : *bb) {
Instruction* inst = inst_t.get();
if (inst->GetNumberOfUses() == 0 && inst->GetOpCode() != OpCode::RETURN) {
if (inst->GetNumberOfUses() == 0 && inst->GetOpCode() != OpCode::RETURN &&
inst->GetDependents().empty()) {
outputs.push_back(*inst);
}
}
Expand Down
26 changes: 25 additions & 1 deletion lib/transforms/type_legalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,7 @@ static void RunOnInstruction(TFIDFVectorizeInst* inst) {

static void RunOnInstruction(ReturnInst* inst) {
std::vector<Type> types;
types.reserve(inst->GetNumOfOperands());
types.reserve(inst->GetNumOfResults());
for (auto& op : inst->GetOperands()) {
types.push_back(op.GetType());
}
Expand Down Expand Up @@ -1512,6 +1512,25 @@ static void RunOnInstruction(BitcastInst* inst) {
inst->GetResultsTypes()[0] = result_type;
}

static void RunOnInstruction(TFExtensionInst* inst) {
if (inst->GetExtOpCode() == TFExtOpCode::MERGE) {
for (auto& op : inst->GetOperands()) {
if (op.GetType().IsValid()) {
inst->GetResultsTypes()[0] = op.GetType();
return;
}
}
return;
}
if (inst->GetExtOpCode() == TFExtOpCode::SWITCH) {
const auto& ty = inst->GetOperand(0).GetType();
if (ty.IsValid()) {
inst->GetResultsTypes() = {ty, ty};
}
return;
}
}

static void RunOnInstruction(UniqueInst* inst) {
const auto& type0 = inst->GetOperand(0).GetType();
if (!type0.IsValid()) {
Expand Down Expand Up @@ -1550,6 +1569,11 @@ bool TypeLegalizer::RunOnBasicBlock(BasicBlock* bb) {
#define GET_INST_DOWNCAST_SWITCH
#include "halo/lib/ir/instructions_info.def"
#undef GET_INST_DOWNCAST_SWITCH
case OpCode::EXTENSION: {
TFExtensionInst* ext = DynCast<TFExtensionInst>(inst);
RunOnInstruction(ext);
break;
}
default: {
if (!relaxed_) {
// HLCHECK(0 && "Unreachable");
Expand Down

0 comments on commit b588a43

Please sign in to comment.