From 529c34d8642071914dfd725476e8b4d87b863c18 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Thu, 2 Dec 2021 23:45:33 -0800 Subject: [PATCH] add missing file --- include/halo/lib/transforms/convert_tf_cfg.h | 37 +++ lib/transforms/convert_tf_cfg.cc | 329 +++++++++++++++++++ 2 files changed, 366 insertions(+) create mode 100644 include/halo/lib/transforms/convert_tf_cfg.h create mode 100644 lib/transforms/convert_tf_cfg.cc diff --git a/include/halo/lib/transforms/convert_tf_cfg.h b/include/halo/lib/transforms/convert_tf_cfg.h new file mode 100644 index 000000000..6264dd738 --- /dev/null +++ b/include/halo/lib/transforms/convert_tf_cfg.h @@ -0,0 +1,37 @@ +//===- convert_tf_cfg.h ---------------------------------------------------===// +// +// Copyright (C) 2019-2020 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_ +#define HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_ + +#include "halo/lib/pass/pass.h" + +namespace halo { + +/// This pass eliminates dead IRs. +class ConvertTFCFG final : public FunctionPass { + public: + ConvertTFCFG() : FunctionPass("Convert TF CFG"), converted_(false) {} + bool RunOnFunction(Function* func) override; + + private: + bool converted_; +}; + +} // end namespace halo. + +#endif // HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_ \ No newline at end of file diff --git a/lib/transforms/convert_tf_cfg.cc b/lib/transforms/convert_tf_cfg.cc new file mode 100644 index 000000000..960882e05 --- /dev/null +++ b/lib/transforms/convert_tf_cfg.cc @@ -0,0 +1,329 @@ +//===- convert_tf_cfg.cc --------------------------------------------------===// +// +// Copyright (C) 2019-2021 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "halo/lib/transforms/convert_tf_cfg.h" + +#include +#include + +#include "halo/api/halo_data.h" +#include "halo/lib/ir/controlflow_instructions.h" +#include "halo/lib/ir/extension_instructions.h" +#include "halo/lib/ir/instruction.h" +#include "halo/lib/ir/ir_builder.h" + +namespace halo { + +static bool MergeIfs(BasicBlock* bb) { + bool changed = false; + std::unordered_map ifs; + for (const auto& it : *bb) { + IfInst* inst = DynCast(it.get()); + if (inst == nullptr) { + continue; + } + const Def& cond = inst->GetOperand(0); + if (ifs.count(cond) == 0) { + ifs[cond] = inst; + continue; + } + + IfInst* dst = ifs[cond]; + for (int i = 1, e = inst->GetNumOfOperands(); i < e; ++i) { + dst->AddOneOperand(inst->GetOperand(i)); + } + inst->DropAllOperands(); + for (unsigned i = 0, dst_idx = dst->GetNumOfResults(); + i < inst->GetNumOfResults(); ++i) { + const auto& ty = inst->GetResultsTypes()[i]; + dst->GetResultsTypes().push_back(ty); + inst->ReplaceAllUsesWith(i, Def{dst, static_cast(dst_idx++)}); + } + inst->GetThenBranch()->MoveTo(dst->GetThenBranch()); + inst->GetElseBranch()->MoveTo(dst->GetElseBranch()); + // merge arguments + auto merge_args = [](BasicBlock* bb) { + if (bb->GetNumOfOperands() == 0) { // who removed args? + return; + } + HLCHECK(bb->GetNumOfOperands() == 2); + Def arg0{bb->arg_front(), 0}; + bb->arg_back()->ReplaceAllUsesWith({arg0}); + bb->Args().pop_back(); + }; + merge_args(inst->GetThenBranch()); + merge_args(inst->GetElseBranch()); + + changed = true; + } + return changed; +} + +static void RewriteOutput(IfInst* if_inst, const std::vector& ops, + bool is_taken) { + auto bb = is_taken ? if_inst->GetThenBranch() : if_inst->GetElseBranch(); + ReturnInst* ret = bb->GetReturnInst(); + HLCHECK(ret != nullptr); + ret->DropAllOperands(); + HLCHECK(if_inst->GetNumOfResults() == (if_inst->GetNumOfOperands() - 1) * 2); + for (auto op : ops) { + // if's output: [v1_f, v1_t, v2_f, v2_t, ...], inputs: [cond, v1, v2, v3] + // Branch bb's args: [arg1, arg2, arg3] + if (op.GetOwner() == if_inst) { + op = Def{std::next(bb->arg_begin(), op.GetIdx() / 2)->get(), 0}; + } + ret->AddOneOperand(op); + } +} + +static bool RunOnBasicBlock(BasicBlock* bb) { + // run on main bb only. Fixme: need to deal with nested if. + if (bb != bb->GetParent()->begin()->get()) { + return false; + } + bool changed = false; + changed |= MergeIfs(bb); + std::unordered_map branch_bbs; + for (const auto& it : *bb) { + IfInst* inst = DynCast(it.get()); + if (inst != nullptr) { + branch_bbs[inst->GetThenBranch()] = inst; + branch_bbs[inst->GetElseBranch()] = inst; + } + } + + for (const auto& it : *bb) { + Instruction* inst = it.get(); + // tf_merge will be handled later. + if (auto ext = DynCast(inst); + ext != nullptr && ext->GetExtOpCode() == TFExtOpCode::MERGE) { + continue; + } + + /* + if (auto ext = DynCast(inst); + ext != nullptr && ext->GetExtOpCode() == TFExtOpCode::MERGE) { + // Handle tf_Merge: all the operands should come from if's output or + some + // branch's output. + IfInst* if_inst = nullptr; + int idx = 0; + for (auto& op : ext->GetOperands()) { + Instruction* op_inst = DynCast(op); + if (op_inst->GetOpCode() == OpCode::IF) { + // some branch is empty. nested if? + if (if_inst != nullptr && if_inst != op_inst) { + HLCHECK(0); + if_inst = nullptr; // merge inputs are from different "if" + break; + } + if_inst = DynCast(op_inst); + idx = op.GetIdx(); + } else { + BasicBlock* bb = op_inst->GetParent(); + auto it = branch_bbs.find(bb); + HLCHECK(it != branch_bbs.end()); + HLCHECK(if_inst == nullptr || if_inst == it->second); + if_inst = it->second; + } + } + if (if_inst != nullptr) { + // FIXME: + std::cout << "Replace with " << idx << "\n"; + // ext->ReplaceAllUsesWith(0, Def{if_inst, idx});// work as a barrier + } + continue; + } + */ + BasicBlock* new_parent = nullptr; + for (int i = 0, e = inst->GetNumOfOperands(); i < e; ++i) { + const auto& op = inst->GetOperand(i); + auto if_inst = DynCast(op); + if (if_inst != nullptr) { + int idx = op.GetIdx(); + auto bb = (idx & 1) == 0 ? if_inst->GetElseBranch() + : if_inst->GetThenBranch(); + if (new_parent == nullptr) { + new_parent = bb; + } else { + HLCHECK(new_parent == bb); + } + } else { + Instruction* op_inst = DynCast(op); + BasicBlock* op_bb = op_inst == nullptr ? nullptr : op_inst->GetParent(); + if (branch_bbs.count(op_bb) > 0) { + if (new_parent == nullptr) { + new_parent = op_bb; + } else { + HLCHECK(new_parent == op_bb); + } + } + } + } + if (new_parent != nullptr) { + IfInst* if_inst = branch_bbs[new_parent]; + HLCHECK(if_inst != nullptr); + IRBuilder new_builder(new_parent); + new_builder.SetInsertBefore(new_parent->GetReturnInst()); + std::vector operands = inst->GetOperands(); + for (auto& op : operands) { + if (op.GetOwner() == if_inst) { + op = Def{std::next(new_parent->arg_begin(), op.GetIdx() / 2)->get(), + 0}; + } + } + auto new_inst = new_builder.Clone(*inst, operands); + new_inst->GetResultsTypes() = inst->GetResultsTypes(); + HLCHECK(new_inst->GetOpCode() != OpCode::RETURN); + for (int i = 0, e = inst->GetNumOfResults(); i < e; ++i) { + inst->ReplaceAllUsesWith(i, Def{new_inst, i}); + } + changed |= true; + } /*else { + std::vector operands = inst->GetOperands(); + BasicBlock* new_parent = nullptr; + for (auto& op : operands) { + Instruction* op_inst = DynCast(op); + BasicBlock* op_bb = op_inst == nullptr ? nullptr : op_inst->GetParent(); + if (op_bb != nullptr && op_bb != bb) { + if (new_parent != nullptr && new_parent != op_bb) { + std::cerr << "unexpected parent\n"; + } + new_parent = op_bb; + } + } + if (new_parent != nullptr) { + IRBuilder new_builder(new_parent); + new_builder.SetInsertBefore(new_parent->GetReturnInst()); + auto new_inst = new_builder.Clone(*inst, operands); + new_inst->GetResultsTypes() = inst->GetResultsTypes(); + + HLCHECK(new_inst->GetOpCode() != OpCode::RETURN); + // inst->Dump(); + for (int i = 0, e = inst->GetNumOfResults(); i < e; ++i) { + inst->ReplaceAllUsesWith(i, Def{new_inst, i}); + } + changed |= true; + } + } + }*/ + } + + // Merge multiple tf_merge that associates with same if. + // All the inputs of tf_merge should associate with the same if. + std::unordered_map> if2merge; + + for (const auto& it : *bb) { + TFExtensionInst* inst = DynCast(it.get()); + if (inst == nullptr || inst->GetExtOpCode() != TFExtOpCode::MERGE) { + continue; + } + IfInst* if_inst = nullptr; + // int idx = 0; + for (auto& op : inst->GetOperands()) { + Instruction* op_inst = DynCast(op); + if (op_inst->GetOpCode() == OpCode::IF) { + // some branch is empty. nested if? + HLCHECK(if_inst == nullptr || if_inst == op_inst); + if_inst = DynCast(op_inst); + // idx = op.GetIdx(); + } else { + BasicBlock* bb = op_inst->GetParent(); + auto it = branch_bbs.find(bb); + HLCHECK(it != branch_bbs.end()); + HLCHECK(if_inst == nullptr || if_inst == it->second); + if_inst = it->second; + // Make it be the output of if. + } + } + HLCHECK(if_inst != nullptr); + if2merge[if_inst].push_back(inst); + } + for (auto& if_merge : if2merge) { + std::vector true_ops; + std::vector false_ops; + IfInst* if_inst = if_merge.first; + std::set op_indices; + for (Instruction* merge : if_merge.second) { + for (auto& op : merge->GetOperands()) { + bool is_taken = false; + if (op.GetOwner() == if_inst) { + is_taken = (op.GetIdx() & 1) == 1; + } else { + const Instruction* inst = DynCast(op.GetOwner()); + HLCHECK(inst != nullptr); + const auto& bb = inst->GetParent(); + auto it = branch_bbs.find(bb); + HLCHECK(branch_bbs.end() != it); + HLCHECK(bb == it->second->GetThenBranch() || + bb == it->second->GetElseBranch()); + is_taken = bb == it->second->GetThenBranch(); + } + if (is_taken) { + true_ops.push_back(op); + } else { + false_ops.push_back(op); + } + } + } + HLCHECK(true_ops.size() == false_ops.size()); + + RewriteOutput(if_inst, true_ops, true); + RewriteOutput(if_inst, false_ops, false); + std::vector rets; + rets.reserve(true_ops.size()); + for (int i = 0, e = true_ops.size(); i < e; ++i) { + const auto& true_ty = true_ops[i].GetType(); + const auto& false_ty = false_ops[i].GetType(); + // The output type is dynamic. Here we just pick a valid one. + rets.push_back(true_ty.IsValid() ? true_ty : false_ty); + } + if_inst->GetResultsTypes() = rets; + for (int i = 0, e = if_merge.second.size(); i < e; ++i) { + if_merge.second[i]->ReplaceAllUsesWith({Def{if_inst, i}}); + } + } + // Modify TF_Merge and associated "if": + // Before: + // if_results(true_val, false_val) = if (...) + // out = merge(if_results) + // After: + // if_result(val) = if(...) + // out = val + + return changed; +} // namespace halo + +bool ConvertTFCFG::RunOnFunction(Function* func) { + bool changed = false; + if (converted_) { + return false; + } + for (auto it = func->begin(), e = func->end(); it != e;) { + BasicBlock* bb = it->get(); + if (bb->Instructions().empty()) { + it = func->BasicBlocks().erase(it); + continue; + } + changed |= RunOnBasicBlock(bb); + it = std::next(it); + } + converted_ = true; + return changed; +} + +} // end namespace halo