forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fuse_relu.cpp
71 lines (60 loc) · 2.16 KB
/
fuse_relu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <torch/csrc/jit/passes/fuse_relu.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
namespace {
void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
std::string add_relu_0 = R"(
graph(%a, %b, %alpha):
%add_res = aten::add(%a, %b, %alpha)
%res = aten::relu(%add_res)
return (%res))";
std::string add_relu_fused = R"(
graph(%a, %b, %alpha):
%res = aten::_add_relu(%a, %b, %alpha)
return (%res))";
rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
std::string add_relu_1 = R"(
graph(%a, %b, %alpha):
%add_res = aten::add(%a, %b, %alpha)
%res = aten::relu_(%add_res)
return (%res))";
rewriter.RegisterRewritePattern(add_relu_1, add_relu_fused);
std::string add_inplace_relu_1 = R"(
graph(%a, %b, %alpha):
%add_res = aten::add_(%a, %b, %alpha)
%res = aten::relu_(%add_res)
return (%res))";
std::string add_inplace_relu_fused = R"(
graph(%a, %b, %alpha):
%res = aten::_add_relu_(%a, %b, %alpha)
return (%res))";
rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
std::string add_out_relu = R"(
graph(%a, %b, %alpha, %out):
%add_res = aten::add(%a, %b, %alpha, %out)
%res = aten::relu_(%add_res)
return (%res))";
std::string add_out_relu_fused = R"(
graph(%a, %b, %alpha, %out):
%res = aten::_add_relu(%a, %b, %alpha, %out)
return (%res))";
rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);
rewriter.runOnGraph(graph);
// NB: Patterns that are left out are add_ + relu and add_out + relu
// This is because inplace mutation of the testor done by add_ will be lost if
// inplace mutatation of the same tensor actually does add+relu
}
} // namespace
void FuseAddRelu(script::Module& module) {
auto graph = module.get_method("forward").graph();
fuseAddReluImpl(graph);
}
void FuseAddRelu(std::shared_ptr<Graph>& graph) {
fuseAddReluImpl(graph);
}
} // namespace jit
} // namespace torch