Skip to content

Commit

Permalink
Support fuse bn into ConvTranspose.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenyuchi.wyc committed Nov 22, 2022
1 parent 74fdf9c commit 44658d1
Showing 1 changed file with 55 additions and 4 deletions.
59 changes: 55 additions & 4 deletions onnxoptimizer/passes/fuse_bn_into_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,42 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
}
}

bool modify_conv(Node* conv, Node* bn, Graph& graph) {
void scale_by_dim(Tensor& W, Tensor& s, const int axis) {
ONNX_ASSERT(W.sizes().size() > 1 && s.sizes().size() == 1 && s.sizes()[0] == W.sizes()[axis]);
ONNX_ASSERT(s.elem_type() == W.elem_type());
const int64_t inner_size = W.size_from_dim(axis+1);
const int64_t outer_size = axis > 0 ? std::accumulate(W.sizes().begin(), W.sizes().begin() + axis, 1, std::multiplies<int>()) : 1;
const int64_t axis_size = W.sizes()[axis];

#define DO_SCALE(TENSOR_TYPE) \
TENSOR_TYPE* ptr = W.data<TENSOR_TYPE>(); \
const TENSOR_TYPE* s_ptr = s.data<TENSOR_TYPE>(); \
int64_t counter = 0; \
for (int64_t i = 0; i < outer_size; ++i) { \
for (int64_t j = 0; j < axis_size; ++j) { \
for (int64_t k = 0; k < inner_size; ++k) { \
ptr[counter++] *= s_ptr[j]; \
} \
} \
}

switch (s.elem_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
DO_SCALE(float)
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
DO_SCALE(double)
break;
}
default:
TENSOR_ASSERTM(
false, "Operation scale_by_dim not supported for data type %s", to_string(W.elem_type()).c_str());
}
#undef DO_SCALE
}

bool modify_conv(Node* conv, Node* bn, Graph& graph, const bool is_conv) {
const auto& bn_inputs = bn->inputs();
const auto& conv_inputs = conv->inputs();
auto end_iter = graph.initializers().end();
Expand Down Expand Up @@ -136,7 +171,6 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
var.add(eps); \
var.sqrt(); \
s.divide(var); \
W.scale_by_first_dim(s); \
bc.subtract(m); \
bc.multiply(s); \
bc.add(bbn);
Expand All @@ -154,21 +188,38 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
return false;
}
#undef DO_COMPUTATION
if (is_conv) {
scale_by_dim(W, s, 0);
} else {
scale_by_dim(W, s, 1);
}
replace_inputs(W, bc, conv, graph);
return true;
}

bool patternMatchPredicate(Node* node) override {
inline bool matchConvBn(Node *node) {
return node->kind() == kBatchNormalization &&
node->inputs()[0]->node()->kind() == kConv;
}

inline bool matchConvTransposeBn(Node *node) {
return node->kind() == kBatchNormalization &&
node->inputs()[0]->node()->kind() == kConvTranspose;
}

bool patternMatchPredicate(Node *node) override {
return matchConvBn(node) || matchConvTransposeBn(node);
}

bool runTransform(Node* n, Graph& graph,
NodeDestroyType& destroy_current) override {
const bool is_conv = matchConvBn(n);

Node* bn = n;
Node* conv = n->inputs()[0]->node();
auto origInput = bn->inputs()[0];
if (origInput->uses().size() > 1 || bn->outputs().size() > 1 ||
!modify_conv(conv, bn, graph)) {
!modify_conv(conv, bn, graph, is_conv)) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
Expand Down

0 comments on commit 44658d1

Please sign in to comment.