forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgraph_executor.h
144 lines (112 loc) · 4.45 KB
/
graph_executor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once
#include <atomic>
#include <memory>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/update_graph_executor_opt.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/variable_tensor_list.h>
C10_DECLARE_bool(torch_jit_enable_new_executor);
namespace torch {
namespace jit {
struct GraphExecutorState;
struct Code;
enum ExecutorExecutionMode {
SIMPLE,
PROFILING,
};
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name)
: code(graph, std::move(function_name)), graph(std::move(graph)) {}
operator bool() const {
return static_cast<bool>(graph);
}
Code code;
std::shared_ptr<Graph> graph;
};
// Notice that those structs don't manage lifetime of their members.
// They are only valid only right after you call getDebugState() and should
// never be used again once another GraphExecutor function is called.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct GraphExecutorState {
const Graph* graph = nullptr;
ExecutionPlan fallback; // XXX: members of this field are optional
std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans;
};
struct TORCH_API EnableProfilingGuard {
EnableProfilingGuard();
~EnableProfilingGuard();
private:
bool old_executor_mode = false;
bool old_get_optimize = false;
};
struct GraphExecutorImplBase;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name);
GraphExecutor(
const std::shared_ptr<Graph>& graph,
std::string function_name,
ExecutorExecutionMode executor_mode);
void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);
// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
// remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
// profiling and specialization. This is also equivalent to the
// SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
// `GraphExecutor` will profile and specialize its input graph based on the
// profiled information whenever a bailout check is failed/triggered, a new
// `GraphExecutor` will be created. This new `GraphExecutor`'s
// remaining_bailout_depth will be reduced by 1.
// If no bailout depth is passed, the depth will be initialized from the
// current global fusion strategy settings.
const ExecutionPlan& getPlanFor(
Stack& inputs,
c10::optional<size_t> remaining_bailout_depth = c10::nullopt);
GraphExecutorState getDebugState();
void debugFlushCompilationCache();
bool isOptimized() const;
private:
std::shared_ptr<GraphExecutorImplBase> pImpl;
};
TORCH_API Node* replaceBlockWithFallbackGraph(
Block* b,
ArrayRef<Value*> inputs);
// These passes need to run before it is valid to pass to the interpreter
// regardless of whether sizes have been specialized or not.
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
TORCH_API void debugSetFusionGroupInlining(bool state);
TORCH_API bool getFusionGroupInlining();
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
TORCH_API std::atomic<bool>& getProfilingMode();
TORCH_API std::atomic<bool>& getExecutorMode();
TORCH_API std::atomic<size_t>& getNumProfiledRuns();
TORCH_API size_t getBailoutDepth();
TORCH_API bool IsNewExecutorEnabled();
struct TORCH_API GraphOptimizerEnabledGuard {
GraphOptimizerEnabledGuard(bool state)
: old_state_(getGraphExecutorOptimize()) {
setGraphExecutorOptimize(state);
}
~GraphOptimizerEnabledGuard() {
setGraphExecutorOptimize(old_state_);
}
bool old_state_;
};
namespace detail {
GraphExecutor* getGradExecutor(Operation& op);
GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op);
// for debugging information we expose a way to get the last actually
// run graph. Previous approaches allowed querying the GraphExecutor
// for what graph it would run in certain circumstances (graphFor), but
// this is fragile because we sometimes change how these decisions are made.
// This interface still allows our tests to look at optimized graphs, but
// with less plumbing.
} // namespace detail
} // namespace jit
} // namespace torch