Skip to content

Commit

Permalink
consider op fusion in mem analysis.
Browse files Browse the repository at this point in the history
implemented missing ops to support SSD model.
  • Loading branch information
alishenli authored and weimingzha0 committed Jul 29, 2021
1 parent 857a35e commit 7505ead
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
15 changes: 11 additions & 4 deletions include/halo/lib/transforms/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class Analyzer final : public ModulePass {
std::vector<std::vector<int64_t>> input_shape;
std::vector<int64_t> output_shape;

float io_mem = 0;
size_t io_mem = 0;
// op fusion estimate
size_t op_fs_mem = 0;
float weight_mem = 0;

// Note that FLOPS and FLOPs are different:
Expand All @@ -61,8 +63,11 @@ class Analyzer final : public ModulePass {
};

struct TensorInfo {
size_t Liveness = 0;
size_t Size = 0;
size_t liveness = 0;
size_t op_size = 0;
size_t ip_size = 0;
size_t knl_sz = 0;
halo::OpCode op;
};

Analyzer(std::ostream* os, const AnalyzerOpts& opts)
Expand All @@ -82,6 +87,7 @@ class Analyzer final : public ModulePass {
void RunOnInstruction(Conv2DInst* inst);
void RunOnInstruction(GemmInst* inst);
void RunOnInstruction(MatMulInst* inst);
void RunOnInstruction(NonMaxSuppressionInst* inst);
void RunOnInstruction(PoolingMaxInst* inst);
void RunOnInstruction(PoolingAvgInst* inst);
void RunOnInstruction(BatchNormInst* inst);
Expand Down Expand Up @@ -110,7 +116,8 @@ class Analyzer final : public ModulePass {
std::ostream* os_;
std::vector<Analyzer::NodeInfo> node_infos_;
AnalyzerOpts opts_;
std::unordered_map<std::string, TensorInfo> AliveTensor; // live tensor buffer
// alive tensor buffer
std::unordered_map<std::string, TensorInfo> AliveTensor;
};

} // namespace halo
Expand Down
54 changes: 46 additions & 8 deletions lib/transforms/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "halo/lib/transforms/analyzer.h"

#include <algorithm>
#include <cmath>
#include <iostream>

namespace halo {
Expand Down Expand Up @@ -49,6 +50,7 @@ Analyzer::NodeInfo& Analyzer::GenerateCommonInfo(const Instruction* inst) {

// input shape
auto ip_num = inst->GetNumOfOperands();
size_t knl_sz = 0;
for (size_t i = 0; i < ip_num; ++i) {
const auto& ip_type = inst->GetOperand(i).GetType();
if (ip_type.IsScalar()) {
Expand All @@ -65,31 +67,50 @@ Analyzer::NodeInfo& Analyzer::GenerateCommonInfo(const Instruction* inst) {
size_t size = Dl->Bytes(ip_type);
if (IsA<Constant>(inst->GetOperand(i))) {
node_info.weight_mem += size;

// for conv2d op fusion, save kernel size
if (node_info.type == OpCode::CONV2D && i == 1) {
const auto& wt = inst->GetOperand(i).GetType();
const size_t dims = wt.GetNumOfDims();
knl_sz = wt.GetNumOfElementsInDim(dims - 1);
knl_sz *= wt.GetNumOfElementsInDim(dims - 2);
}

} else {
node_info.io_mem += size;
std::string name = inst->GetOperand(i).GetOwner()->GetName();
if (AliveTensor.find(name) != AliveTensor.end()) {
AliveTensor[name].Liveness--;
AliveTensor[name].liveness--;
}
if (AliveTensor[name].liveness == 0) {
node_info.io_mem += size;
}
}
}

const size_t op_tgts = inst->GetResultsUses()[0].GetUses().size();
// todo: only output[0] is processed here
const auto& op_type = inst->GetResultType();
if (op_type.IsScalar()) {
node_info.output_shape.push_back(1);
} else {
node_info.output_shape = op_type.GetDimSizes();
}
TensorInfo tif = {inst->GetResultsUses()[0].GetUses().size(),
Dl->Bytes(op_type)};
TensorInfo tif = {op_tgts, op_tgts * Dl->Bytes(op_type), node_info.io_mem,
knl_sz, node_info.type};
AliveTensor[node_info.name] = tif;

for (auto iter = AliveTensor.begin(); iter != AliveTensor.end();) {
if (iter->second.Liveness == 0) {
if (iter->second.liveness == 0) {
iter = AliveTensor.erase(iter);
} else {
node_info.io_mem += iter->second.Size;
node_info.io_mem += iter->second.op_size;

// conv2d kernel fusion
if (node_info.name != iter->first && node_info.type == OpCode::CONV2D &&
iter->second.op == OpCode::CONV2D && iter->second.knl_sz == knl_sz) {
node_info.op_fs_mem += iter->second.ip_size;
}

iter++;
}
}
Expand Down Expand Up @@ -149,15 +170,21 @@ void Analyzer::RunOnInstruction(Instruction* inst) {
case OpCode::RSQRT:
case OpCode::FLOOR:
case OpCode::SITOFP:
case OpCode::SIGMOID: {
case OpCode::SIGMOID:
case OpCode::EXP:
case OpCode::TOPK: {
auto& node_info = GenerateCommonInfo(inst);
node_info.flops = inst->GetResultType().GetTotalNumOfElements();
if (op_code == OpCode::SIGMOID) {
node_info.flops *= 3;
}
if (op_code == OpCode::TOPK) {
node_info.flops *= std::log2f(node_info.flops);
}
break;
}
default: {
std::cout << "Error OP: " << static_cast<int>(inst->GetOpCode()) << "\n";
HLCHECK(0 && "Unimplemented");
}
}
Expand Down Expand Up @@ -281,6 +308,17 @@ void Analyzer::RunOnInstruction(MatMulInst* inst) {
node_info.flops = static_cast<float>(row * (2 * matb_size - col));
}

void Analyzer::RunOnInstruction(NonMaxSuppressionInst* inst) {
auto& node_info = GenerateCommonInfo(inst);

// IOU operation: 2x min, 4x max, 4x +-, 2x */
const int iou_op = 12;
size_t num_box = inst->GetOperand(0).GetType().GetNumOfElementsInDim(1);
float sort_op = num_box * std::log2f(num_box);

node_info.flops = static_cast<float>(iou_op * num_box) + sort_op;
}

void Analyzer::RunOnInstruction(PoolingMaxInst* inst) {
CalPoolingInst<PoolingMaxInst>(inst);
}
Expand Down Expand Up @@ -399,7 +437,7 @@ void Analyzer::WriteCSVReport(std::ostream& os) {

float total_flops = 0;
float total_weights = 0;
float max_io = 0;
size_t max_io = 0;
for (const auto& it : node_infos_) {
total_flops += it.flops;
total_weights += it.weight_mem;
Expand Down

0 comments on commit 7505ead

Please sign in to comment.