diff --git a/include/halo/lib/transforms/analyzer.h b/include/halo/lib/transforms/analyzer.h index 3dac2c1cb..dcb30dcd9 100644 --- a/include/halo/lib/transforms/analyzer.h +++ b/include/halo/lib/transforms/analyzer.h @@ -48,7 +48,9 @@ class Analyzer final : public ModulePass { std::vector> input_shape; std::vector output_shape; - float io_mem = 0; + size_t io_mem = 0; + // op fusion estimate + size_t op_fs_mem = 0; float weight_mem = 0; // Note that FLOPS and FLOPs are different: @@ -61,8 +63,11 @@ class Analyzer final : public ModulePass { }; struct TensorInfo { - size_t Liveness = 0; - size_t Size = 0; + size_t liveness = 0; + size_t op_size = 0; + size_t ip_size = 0; + size_t knl_sz = 0; + halo::OpCode op; }; Analyzer(std::ostream* os, const AnalyzerOpts& opts) @@ -82,6 +87,7 @@ class Analyzer final : public ModulePass { void RunOnInstruction(Conv2DInst* inst); void RunOnInstruction(GemmInst* inst); void RunOnInstruction(MatMulInst* inst); + void RunOnInstruction(NonMaxSuppressionInst* inst); void RunOnInstruction(PoolingMaxInst* inst); void RunOnInstruction(PoolingAvgInst* inst); void RunOnInstruction(BatchNormInst* inst); @@ -110,7 +116,8 @@ class Analyzer final : public ModulePass { std::ostream* os_; std::vector node_infos_; AnalyzerOpts opts_; - std::unordered_map AliveTensor; // live tensor buffer + // alive tensor buffer + std::unordered_map AliveTensor; }; } // namespace halo diff --git a/lib/transforms/analyzer.cc b/lib/transforms/analyzer.cc index 3ce06149c..213a11c76 100644 --- a/lib/transforms/analyzer.cc +++ b/lib/transforms/analyzer.cc @@ -18,6 +18,7 @@ #include "halo/lib/transforms/analyzer.h" #include +#include #include namespace halo { @@ -49,6 +50,7 @@ Analyzer::NodeInfo& Analyzer::GenerateCommonInfo(const Instruction* inst) { // input shape auto ip_num = inst->GetNumOfOperands(); + size_t knl_sz = 0; for (size_t i = 0; i < ip_num; ++i) { const auto& ip_type = inst->GetOperand(i).GetType(); if (ip_type.IsScalar()) { @@ -65,15 +67,27 @@ Analyzer::NodeInfo& Analyzer::GenerateCommonInfo(const Instruction* inst) { size_t size = Dl->Bytes(ip_type); if (IsA(inst->GetOperand(i))) { node_info.weight_mem += size; + + // for conv2d op fusion, save kernel size + if (node_info.type == OpCode::CONV2D && i == 1) { + const auto& wt = inst->GetOperand(i).GetType(); + const size_t dims = wt.GetNumOfDims(); + knl_sz = wt.GetNumOfElementsInDim(dims - 1); + knl_sz *= wt.GetNumOfElementsInDim(dims - 2); + } + } else { - node_info.io_mem += size; std::string name = inst->GetOperand(i).GetOwner()->GetName(); if (AliveTensor.find(name) != AliveTensor.end()) { - AliveTensor[name].Liveness--; + AliveTensor[name].liveness--; + } + if (AliveTensor[name].liveness == 0) { + node_info.io_mem += size; } } } + const size_t op_tgts = inst->GetResultsUses()[0].GetUses().size(); // todo: only output[0] is processed here const auto& op_type = inst->GetResultType(); if (op_type.IsScalar()) { @@ -81,15 +95,22 @@ Analyzer::NodeInfo& Analyzer::GenerateCommonInfo(const Instruction* inst) { } else { node_info.output_shape = op_type.GetDimSizes(); } - TensorInfo tif = {inst->GetResultsUses()[0].GetUses().size(), - Dl->Bytes(op_type)}; + TensorInfo tif = {op_tgts, op_tgts * Dl->Bytes(op_type), node_info.io_mem, + knl_sz, node_info.type}; AliveTensor[node_info.name] = tif; for (auto iter = AliveTensor.begin(); iter != AliveTensor.end();) { - if (iter->second.Liveness == 0) { + if (iter->second.liveness == 0) { iter = AliveTensor.erase(iter); } else { - node_info.io_mem += iter->second.Size; + node_info.io_mem += iter->second.op_size; + + // conv2d kernel fusion + if (node_info.name != iter->first && node_info.type == OpCode::CONV2D && + iter->second.op == OpCode::CONV2D && iter->second.knl_sz == knl_sz) { + node_info.op_fs_mem += iter->second.ip_size; + } + iter++; } } @@ -149,15 +170,21 @@ void Analyzer::RunOnInstruction(Instruction* inst) { case OpCode::RSQRT: case OpCode::FLOOR: case OpCode::SITOFP: - case OpCode::SIGMOID: { + case OpCode::SIGMOID: + case OpCode::EXP: + case OpCode::TOPK: { auto& node_info = GenerateCommonInfo(inst); node_info.flops = inst->GetResultType().GetTotalNumOfElements(); if (op_code == OpCode::SIGMOID) { node_info.flops *= 3; } + if (op_code == OpCode::TOPK) { + node_info.flops *= std::log2f(node_info.flops); + } break; } default: { + std::cout << "Error OP: " << static_cast(inst->GetOpCode()) << "\n"; HLCHECK(0 && "Unimplemented"); } } @@ -281,6 +308,17 @@ void Analyzer::RunOnInstruction(MatMulInst* inst) { node_info.flops = static_cast(row * (2 * matb_size - col)); } +void Analyzer::RunOnInstruction(NonMaxSuppressionInst* inst) { + auto& node_info = GenerateCommonInfo(inst); + + // IOU operation: 2x min, 4x max, 4x +-, 2x */ + const int iou_op = 12; + size_t num_box = inst->GetOperand(0).GetType().GetNumOfElementsInDim(1); + float sort_op = num_box * std::log2f(num_box); + + node_info.flops = static_cast(iou_op * num_box) + sort_op; +} + void Analyzer::RunOnInstruction(PoolingMaxInst* inst) { CalPoolingInst(inst); } @@ -399,7 +437,7 @@ void Analyzer::WriteCSVReport(std::ostream& os) { float total_flops = 0; float total_weights = 0; - float max_io = 0; + size_t max_io = 0; for (const auto& it : node_infos_) { total_flops += it.flops; total_weights += it.weight_mem;