diff --git a/.gitignore b/.gitignore index 070f28e689..2feb324b11 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,6 @@ train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz train-images-idx3-ubyte train-labels-idx1-ubyte + +# Logs +logs/ diff --git a/include/flexflow/config.h b/include/flexflow/config.h index 288a119ba7..d82b1377c7 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -117,6 +117,7 @@ class FFConfig { int epochs, batchSize, printFreq; // int inputHeight, inputWidth; int numNodes, cpusPerNode, workersPerNode; + float device_mem; // The device (GPU) memory threshold; given by -ll:fsize float learningRate, weightDecay; size_t workSpaceSize; Legion::Context lg_ctx; @@ -155,6 +156,7 @@ class FFConfig { int base_optimize_threshold; bool enable_control_replication; int python_data_loader_type; + bool perform_memory_search{false}; }; class FFIterationConfig { diff --git a/include/flexflow/graph.h b/include/flexflow/graph.h index e52911853e..2e0cf1ca4b 100644 --- a/include/flexflow/graph.h +++ b/include/flexflow/graph.h @@ -17,6 +17,7 @@ #define _FLEXFLOW_GRAPH_H_ #include "flexflow/basic_graph.h" #include "flexflow/graph_structures.h" +#include "flexflow/memory_optimization.h" #include "flexflow/model.h" #include "flexflow/utils/dot/dot_file.h" #include "flexflow/utils/recursive_logger.h" @@ -114,6 +115,31 @@ struct GraphCostResult { friend std::ostream &operator<<(std::ostream &, GraphCostResult const &); }; +/** + * @brief Holds the cost information of a PCG. + */ +struct GraphCostResultWithMemory { + float cost; ///< Run time cost + MemoryUsage mem_cost; ///< Memory usage + ///< Corresponding machine views (device placement views) + std::unordered_map views; + + /** + * @brief Get the multi-objective cost that combines the run time and memory + * cost. + * + * @return float Numerical value to represent the overall cost + */ + float get_multi_obj_cost() const; + + static GraphCostResultWithMemory invalid(); + + bool operator<(GraphCostResultWithMemory const &other) const; + + friend std::ostream &operator<<(std::ostream &, + GraphCostResultWithMemory const &); +}; + template T sequence_cost(T const &first, T const &second); @@ -157,6 +183,12 @@ class SearchHelper { NodeAssignment const &source, NodeAssignment const &sink, MachineResource const &resources) const; + /** + * @brief Starting point to get parallel split time cost. + * + * @tparam T float or GraphCostResult (or GraphCostResultWithMemory in memory + * optimization) + */ template T find_optimal_nonsequence_graph_time(Graph const *g, NodeAssignment const &source, @@ -200,6 +232,20 @@ class SearchHelper { template void add_operator_cost(NodeAssignment const &, float, T *) const; + template + void add_sink_node_costs(NodeAssignment const &sink, + CostMetrics metrics, + T *result) const; + + /** + * @brief Add run time cost and memory cost of the operator to the graph cost. + * This is a temp workaround and should be refactored eventually. + */ + void add_operator_cost_with_memory(NodeAssignment const &node, + float node_run_time_cost, + MemoryUsage node_mem_cost, + GraphCostResultWithMemory *cost) const; + template float get_cost(T const &) const; @@ -209,6 +255,8 @@ class SearchHelper { public: mutable std::unique_ptr logger; + void clear_cache(); + private: template T execute_nonsequence_split(std::unique_ptr const &first_graph, @@ -260,6 +308,7 @@ class Graph { Graph subgraph(std::unordered_set const &nodes) const; void contract_out_node(Node const &); float optimal_cost() const; + float optimal_cost_with_memory(float run_time_cost_factor) const; std::unordered_map optimal_views() const; void remove_input_nodes(); void duplicate_input_node(Node const &); @@ -335,6 +384,20 @@ struct GraphOptimizeResult { friend std::ostream &operator<<(std::ostream &, GraphOptimizeResult const &); }; +/** + * @brief Hold the optimization results with memory information. + */ +struct GraphOptimizeResultWithMemory { + tl::optional graph; ///< Optimized PCG + float cost; ///< Run time cost + MemoryUsage mem_cost; ///< Memory usage + ///< Corresponding machine views (device placement views) + std::unordered_map views; + + friend std::ostream &operator<<(std::ostream &, + GraphOptimizeResultWithMemory const &); +}; + namespace Utils { template <> struct GraphStructure { diff --git a/include/flexflow/memory_optimization.h b/include/flexflow/memory_optimization.h new file mode 100644 index 0000000000..4cb0d670b9 --- /dev/null +++ b/include/flexflow/memory_optimization.h @@ -0,0 +1,107 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _FLEXFLOW_MEMORY_OPTIMIZATION_H_ +#define _FLEXFLOW_MEMORY_OPTIMIZATION_H_ + +#include +#include + +namespace FlexFlow { + +enum class MemoryUsageType { + // Use global memory of a PCG as the measure of memory usage. No device + // mapping consideration. + GLOBAL, + + // Use the max of peak per-device memory usage among devices as the measure. + // Need associated device mapping views. + PER_DEVICE_MAX, +}; + +enum class MemorySearchAlgo { + // Multiple objective DP search. Combine memory cost and run time cost into + // one single cost function and add a factor to balance them. + MULTI_OBJECTIVE, +}; + +/** + * @brief Config class to control memory optimizations. This should be put into + * config.h and be stored in FFConfig. But for easy turnaround, put this here + * for now. + */ +class MemoryOptimConfig { +public: + MemoryUsageType mem_usage_type; ///< How to represent memory cost + MemorySearchAlgo mem_search_algo; ///< How to search for the optimal schedule + float run_time_cost_factor; ///< The weight factor of run time cost in the + ///< overall cost function; used in + ///< MULTI_OBJECTIVE algorithm + ///< Valid between and including 0 and 1 + + MemoryOptimConfig() + : mem_usage_type{MemoryUsageType::GLOBAL}, + mem_search_algo{MemorySearchAlgo::MULTI_OBJECTIVE}, + run_time_cost_factor{0.5} {} + MemoryOptimConfig(float factor) + : mem_usage_type{MemoryUsageType::GLOBAL}, + mem_search_algo{MemorySearchAlgo::MULTI_OBJECTIVE}, + run_time_cost_factor{factor} {} +}; + +/** + * @brief Hold the result (including memory information) of a graph_optimize on + * a PCG. + */ +class MemorySearchResult { +public: + float run_time_cost{}; + float memory_cost{}; + float search_time{}; + ///< The max of per-device memory usage among all devices + float max_per_device_mem_all_deivces = 0.0; +}; + +namespace PCG { + +/** + * @brief Class to hold memory usage information of a (sub-)PCG. + */ +class MemoryUsage { +public: + MemoryUsageType usage_type; ///< What "num" means + float num; ///< The numerical number of memory usage + + MemoryUsage() : usage_type{MemoryUsageType::GLOBAL}, num{0.0} {} + MemoryUsage(MemoryUsageType _usage_type, float _num) + : usage_type{_usage_type}, num{_num} {} + + std::string to_string() const; + + MemoryUsage &operator+=(MemoryUsage const &rhs); + + /** + * @brief Combine the memory usage of two PCGs flexibly based on + * MemoryUsageType. + */ + friend MemoryUsage operator+(MemoryUsage lhs, MemoryUsage const &rhs); + + friend std::ostream &operator<<(std::ostream &s, MemoryUsage const &usage); +}; + +} // namespace PCG +} // namespace FlexFlow + +#endif // _FLEXFLOW_MEMORY_OPTIMIZATION_H_ diff --git a/include/flexflow/model.h b/include/flexflow/model.h index c6bc6929ad..01733228ba 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -17,6 +17,7 @@ #include "accessor.h" #include "config.h" #include "device.h" +#include "flexflow/memory_optimization.h" #include "flexflow/node.h" #include "flexflow/operator_params.h" #include "flexflow/utils/hash_utils.h" @@ -784,6 +785,13 @@ class FFModel { bool only_data_parallel, std::unique_ptr &best_graph, std::unordered_map &optimal_view); + void graph_optimize(size_t budget, + bool only_data_parallel, + std::unique_ptr &best_graph, + std::unordered_map &optimal_view, + bool perform_memory_search, + MemoryOptimConfig new_config, + MemorySearchResult &search_result); void mcmc_optimize(std::map &best, size_t budget, float alpha, @@ -821,6 +829,11 @@ class FFModel { public: void set_iteration_config_sequence_length(int seq_length); + /** + * @brief Clear the cache of the GraphSearchHelper and SearchHelper. + */ + void clear_graph_search_cache(); + public: size_t op_global_guid, layer_global_guid; size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid; diff --git a/include/flexflow/parallel_tensor.h b/include/flexflow/parallel_tensor.h index aed99c8204..db77b49030 100644 --- a/include/flexflow/parallel_tensor.h +++ b/include/flexflow/parallel_tensor.h @@ -63,9 +63,10 @@ struct ParallelDim { return false; } - int size = 0; - int degree = UNKNOWN_DEGREE; - int parallel_idx = UNKNOWN_INDEX; + int size = 0; // Actual size of tensor + int degree = UNKNOWN_DEGREE; // Degree of sharding + int parallel_idx = UNKNOWN_INDEX; // Runtime information, unique id of each + // degree of sharding bool is_replica_dim = false; }; diff --git a/include/flexflow/simulator.h b/include/flexflow/simulator.h index 990dee28e0..9ee1b1eb09 100644 --- a/include/flexflow/simulator.h +++ b/include/flexflow/simulator.h @@ -53,10 +53,17 @@ class FFModel; */ struct CostMetrics { /** - * @brief Return the sum of the memory usage recorded in this CostMetrics. + * @brief Return the sum of inputs_memory, outputs_memory, and weights_memory + * recorded in this CostMetrics. */ size_t total_memory() const; + /** + * @brief Return the sum of memory recorded in this CostMetrics, but in MB, + * instead of Bytes. + */ + float total_memory_in_mb() const; + /** * @brief Get the incremental difference between the total memory in * CostMetrics and sim->offset. @@ -76,6 +83,8 @@ struct CostMetrics { // 2. we call Simulator::free_all before measuring an operator // Therefore, the current memory usage of an operator is (size_t)sim->offset size_t inputs_memory = 0, outputs_memory = 0, weights_memory = 0; + ///< Memory usage of Op* considering parallelization over devices + size_t op_total_mem = 0; }; class Device { diff --git a/include/flexflow/substitution.h b/include/flexflow/substitution.h index f78b70822a..669044825d 100644 --- a/include/flexflow/substitution.h +++ b/include/flexflow/substitution.h @@ -128,6 +128,18 @@ class GraphCompare { } }; +class GraphCompareWithMemory { +public: + GraphCompareWithMemory(float factor) : run_time_cost_factor{factor} {} + bool operator()(Graph *lhs, Graph *rhs) { + return lhs->optimal_cost_with_memory(run_time_cost_factor) > + rhs->optimal_cost_with_memory(run_time_cost_factor); + } + +private: + float run_time_cost_factor; +}; + class GraphXferMatch { public: GraphXferMatch(GraphXfer const *); @@ -203,15 +215,17 @@ class GraphXfer { std::string get_name() const; - void run(int depth, - Graph *graph, - std::priority_queue, GraphCompare> &, - std::unordered_set &, - float threshold, - int maxNumOps, - SimplificationSettings const &simplification_settings, - int &num_matches_found, - int &num_matches_rejected); + template + void + run(int depth, + Graph *graph, + std::priority_queue, GraphComparator> &, + std::unordered_set &, + float threshold, + int maxNumOps, + SimplificationSettings const &simplification_settings, + int &num_matches_found, + int &num_matches_rejected); void find_matches(Graph const *, std::vector &matches); GraphXferMatch get_match_record(Graph const *) const; @@ -239,11 +253,26 @@ class GraphSearchHelper { bool only_data_parallel, std::unique_ptr &best_graph, std::unordered_map &optimal_views); + void graph_optimize_with_memory( + size_t budget, + bool only_data_parallel, + std::unique_ptr &best_graph, + std::unordered_map &optimal_views, + MemorySearchResult &search_result); void graph_optimize_no_split( size_t budget, bool only_data_parallel, std::unique_ptr &best_graph, std::unordered_map &optimal_views); + /** + * @brief Substitute the mem_config with new_config. + */ + void update_mem_optim_config(MemoryOptimConfig const &new_config); + + /** + * @brief Clear the optimized graph cache of this helper. + */ + void clear_cache(); private: template @@ -253,6 +282,13 @@ class GraphSearchHelper { tl::optional const &output_shape, tl::optional const &input_shape); + template + T generic_sequence_optimize_with_memory( + Graph const *graph, + Node const &sink_node, + tl::optional const &output_shape, + tl::optional const &input_shape); + float sequence_optimize(Graph const *graph, Node const &sink_node, tl::optional const &output_shape, @@ -267,6 +303,16 @@ class GraphSearchHelper { Node const &sink_node, Node const &bottleneck, ParallelTensorShape const &bottleneck_output_shape); + template + T execute_sequence_split_with_memory( + std::unique_ptr const &pre_graph, + std::unique_ptr const &post_graph, + tl::optional const &output_shape, + tl::optional const &input_shape, + Node const &sink_node, + Node const &bottleneck, + ParallelTensorShape const &bottleneck_output_shape); + void generate_all_pcg_xfers(); void load_graph_substitutions(std::vector &xfers) const; Graph *construct_graph(); @@ -276,6 +322,9 @@ class GraphSearchHelper { base_optimize(Graph const *, SimplificationSettings const &simplification_settings); + std::unique_ptr base_optimize_with_memory( + Graph const *, SimplificationSettings const &simplification_settings); + std::vector possible_split_output_tensor_shapes(Node const &) const; @@ -298,6 +347,7 @@ class GraphSearchHelper { std::vector all_pcg_xfers; FFModel *model; FFConfig const &config; + MemoryOptimConfig mem_config; std::unique_ptr logger; }; diff --git a/src/loss_functions/loss_functions.cu b/src/loss_functions/loss_functions.cu index 01766347b0..f78311980c 100644 --- a/src/loss_functions/loss_functions.cu +++ b/src/loss_functions/loss_functions.cu @@ -122,19 +122,17 @@ void Loss::mean_squared_error_avg_loss_backward_kernel_wrapper( logit_grad_ptr, logit_grad_volume, 0, scale_factor); } -void Loss::identity_loss_backward_kernel_wrapper( - float *loss_grad_ptr, - float const *loss_ptr, - size_t loss_volume, - size_t loss_grad_volume, - float scale_factor) { +void Loss::identity_loss_backward_kernel_wrapper(float *loss_grad_ptr, + float const *loss_ptr, + size_t loss_volume, + size_t loss_grad_volume, + float scale_factor) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); identity_loss_backward<<>>( - loss_grad_ptr, loss_ptr, loss_volume); + stream>>>(loss_grad_ptr, loss_ptr, loss_volume); // Scale logit gradients by loss->scale_factor scale_kernel<<>>( loss_grad_ptr, loss_grad_volume, 0, scale_factor); diff --git a/src/parallel_ops/fused_parallel_op.cc b/src/parallel_ops/fused_parallel_op.cc index c90e292f11..c0a97bdda1 100644 --- a/src/parallel_ops/fused_parallel_op.cc +++ b/src/parallel_ops/fused_parallel_op.cc @@ -252,6 +252,11 @@ bool FusedParallelOp::measure_operator_cost(Simulator *sim, cost_metrics = CostMetrics(); cost_metrics.forward_time = 0.1f; cost_metrics.backward_time = 0.1f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; return true; } diff --git a/src/parallel_ops/partition.cc b/src/parallel_ops/partition.cc index 3ff02db766..727ffd3264 100644 --- a/src/parallel_ops/partition.cc +++ b/src/parallel_ops/partition.cc @@ -215,6 +215,11 @@ bool Repartition::measure_operator_cost(Simulator *sim, cost_metrics = CostMetrics(); cost_metrics.forward_time = 0.0f; cost_metrics.backward_time = 0.0f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; return true; } diff --git a/src/parallel_ops/reduction.cc b/src/parallel_ops/reduction.cc index 61b4d4626d..737f86239c 100644 --- a/src/parallel_ops/reduction.cc +++ b/src/parallel_ops/reduction.cc @@ -173,6 +173,11 @@ bool Reduction::measure_operator_cost(Simulator *sim, cost_metrics = CostMetrics(); cost_metrics.forward_time = 0.0f; cost_metrics.backward_time = 0.0f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; return true; } diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index 031166e63e..fee78043bd 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -196,6 +196,11 @@ bool Replicate::measure_operator_cost(Simulator *sim, cost_metrics = CostMetrics(); cost_metrics.forward_time = 0.0f; cost_metrics.backward_time = 0.0f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; return true; } diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index ad298f5c93..c6c6aeb9a0 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -90,6 +90,9 @@ SearchHelper::SearchHelper(FFModel *model) : model(model) { this->logger = std::unique_ptr(new RecursiveLogger("DP")); } +/** + * @brief Combine results from sequential sub-problems. + */ template T SearchHelper::execute_sequence_split(std::unique_ptr const &pre_graph, std::unique_ptr const &post_graph, @@ -102,6 +105,12 @@ T SearchHelper::execute_sequence_split(std::unique_ptr const &pre_graph, this->graph_cost(post_graph.get(), bn, sink, resources, false)); } +/** + * @brief Starting point to get sequential split time cost. + * + * @tparam T float or GraphCostResult (or GraphCostResultWithMemory in memory + * optimization) + */ template T SearchHelper::find_optimal_sequence_graph_time( Graph const *g, @@ -170,6 +179,11 @@ T SearchHelper::find_optimal_sequence_graph_time( return optimal; } +void SearchHelper::clear_cache() { + cached_graph_costs.clear(); + cached_operator_valid_views.clear(); +} + template T SearchHelper::execute_nonsequence_split( std::unique_ptr const &first_graph, @@ -1134,20 +1148,46 @@ GraphCostResult GraphCostResult::invalid() { return {std::numeric_limits::infinity(), {}}; } +GraphCostResultWithMemory GraphCostResultWithMemory::invalid() { + return {std::numeric_limits::infinity(), MemoryUsage{}, {}}; +} + +float GraphCostResultWithMemory::get_multi_obj_cost() const { + return this->cost; +} + bool GraphCostResult::operator<(GraphCostResult const &other) const { return this->cost < other.cost; } +bool GraphCostResultWithMemory::operator<( + GraphCostResultWithMemory const &other) const { + return this->get_multi_obj_cost() < other.get_multi_obj_cost(); +} + std::ostream &operator<<(std::ostream &s, GraphCostResult const &r) { s << "GraphCostResult{cost=" << r.cost << "}"; return s; } +std::ostream &operator<<(std::ostream &s, GraphCostResultWithMemory const &r) { + s << "GraphCostResultWithMemory{run_time_cost=" << r.cost + << ", memory_cost=" << r.mem_cost << "}"; + return s; +} + std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) { s << "GraphOptimizeResult{cost=" << r.cost << "}"; return s; } +std::ostream &operator<<(std::ostream &s, + GraphOptimizeResultWithMemory const &r) { + s << "GraphOptimizeResultWithMemory{run_time_cost=" << r.cost + << ", memory_cost=" << r.mem_cost << "}"; + return s; +} + template <> GraphCostResult sequence_cost(GraphCostResult const &first, GraphCostResult const &second) { @@ -1157,6 +1197,17 @@ GraphCostResult sequence_cost(GraphCostResult const &first, return result; } +template <> +GraphCostResultWithMemory sequence_cost( + GraphCostResultWithMemory const &first, + GraphCostResultWithMemory const &second) { + GraphCostResultWithMemory result{first}; + result.cost += second.cost; + result.mem_cost += second.mem_cost; + result.views.insert(second.views.cbegin(), second.views.cend()); + return result; +} + template <> float sequence_cost(float const &first, float const &second) { return first + second; @@ -1177,6 +1228,30 @@ GraphOptimizeResult return result; } +/** + * @brief Specialization of sequence_cost to combine two + * GraphOptimizeResultWithMemory. This reuses the parts of combining run time + * costs. This should be merged with other versions of sequence_cost. + */ +template <> +GraphOptimizeResultWithMemory sequence_cost( + GraphOptimizeResultWithMemory const &first, + GraphOptimizeResultWithMemory const &second) { + GraphOptimizeResultWithMemory result; + result.cost = first.cost + second.cost; + result.views.insert(first.views.cbegin(), first.views.cend()); + result.views.insert(second.views.cbegin(), second.views.cend()); + + result.graph = second.graph; + Node second_src = result.graph.value().find_source_node(); + result.graph.value().replace_subgraph({second_src}, first.graph.value()); + + // New: Combine memory cost + result.mem_cost = first.mem_cost + second.mem_cost; + + return result; +} + template <> GraphCostResult parallel_cost(GraphCostResult const &first, GraphCostResult const &second) { @@ -1188,6 +1263,19 @@ GraphCostResult parallel_cost(GraphCostResult const &first, return result; } +template <> +GraphCostResultWithMemory parallel_cost( + GraphCostResultWithMemory const &first, + GraphCostResultWithMemory const &second) { + GraphCostResultWithMemory result; + result.cost = std::max(first.cost, second.cost); + result.mem_cost = first.mem_cost + second.mem_cost; + result.views.insert(first.views.cbegin(), first.views.cend()); + result.views.insert(second.views.cbegin(), second.views.cend()); + + return result; +} + template <> float parallel_cost(float const &first, float const &second) { return std::max(first, second); @@ -1204,6 +1292,12 @@ bool SearchHelper::is_invalid( return cost.cost == std::numeric_limits::infinity(); } +template <> +bool SearchHelper::is_invalid( + GraphCostResultWithMemory const &cost) const { + return cost.cost == std::numeric_limits::infinity(); +} + /** * @brief Asserts that the results of graph optimization are valid for the graph * @@ -1232,6 +1326,28 @@ void SearchHelper::check_matches_graph( assert(g_nodes == r_nodes); } +template <> +void SearchHelper::check_matches_graph( + Graph const *g, + GraphCostResultWithMemory const &r, + Node const &sink) const { + using FlexFlow::PCG::Utils::nodes; + + if (this->is_invalid(r)) { + return; + } + + std::unordered_set g_nodes = nodes(*g); + g_nodes.erase(sink); + + std::unordered_set r_nodes; + for (auto const &kv : r.views) { + r_nodes.insert(kv.first); + } + + assert(g_nodes == r_nodes); +} + template <> void SearchHelper::check_matches_graph(Graph const *g, float const &r, @@ -1253,6 +1369,13 @@ std::pair return {false, GraphCostResult::invalid()}; } +template <> +std::pair + SearchHelper::try_get_cost_from_cache( + size_t hash) const { + return {false, GraphCostResultWithMemory::invalid()}; +} + template <> void SearchHelper::try_cache_result(size_t hash, float const &value) const { @@ -1268,6 +1391,14 @@ void SearchHelper::try_cache_result( this->cached_graph_costs[hash] = value.cost; } +template <> +void SearchHelper::try_cache_result( + size_t hash, GraphCostResultWithMemory const &value) const { + this->logger->debug() << "cached_graph_costs[" << hash << "=" + << value.get_multi_obj_cost() << "]"; + this->cached_graph_costs[hash] = value.get_multi_obj_cost(); +} + template <> float SearchHelper::infinity() const { return std::numeric_limits::infinity(); @@ -1278,6 +1409,15 @@ GraphCostResult SearchHelper::infinity() const { return {std::numeric_limits::infinity(), {}}; } +template <> +GraphCostResultWithMemory + SearchHelper::infinity() const { + return {std::numeric_limits::infinity(), + MemoryUsage(MemoryUsageType::GLOBAL, + std::numeric_limits::infinity()), + {}}; +} + template <> float SearchHelper::empty() const { return 0.0f; @@ -1288,6 +1428,12 @@ GraphCostResult SearchHelper::empty() const { return {0.0f, {}}; } +template <> +GraphCostResultWithMemory + SearchHelper::empty() const { + return {0.0f, MemoryUsage{}, {}}; +} + template T SearchHelper::estimate_xfer_cost(Graph const *graph, NodeAssignment const &source, @@ -1318,6 +1464,46 @@ T SearchHelper::estimate_xfer_cost(Graph const *graph, return result; } +/** + * @brief Specialization to avoid changing many calls. + * @details Note that this function is only called when the graph has no more + * than 2 nodes + */ +template <> +GraphCostResultWithMemory + SearchHelper::estimate_xfer_cost( + Graph const *graph, + NodeAssignment const &source, + NodeAssignment const &sink) const { + GraphCostResultWithMemory result = this->empty(); + + if (source.node != Node::INVALID_NODE) { + // Get the in-edges of the sink node + auto const &inList = graph->inEdges.find(sink.node)->second; + float op_cost = 0.0f; // run time cost + for (auto const &it2 : inList) { + // For all edges between source node and sink node + assert(it2.srcOp == source.node); + assert(sink.node.ptr->inputs[it2.dstIdx]->is_valid_machine_view( + source.view)); + + float estimated_xfer_cost = this->model->simulator->estimate_xfer_cost( + sink.node.ptr, it2.dstIdx, source.view, sink.view); + op_cost += estimated_xfer_cost; + } + this->add_operator_cost_with_memory( + source, op_cost, MemoryUsage{}, &result); + } else { + // The real source must be an input operator + Node real_source = graph->find_source_node(); + assert(real_source.ptr->op_type == OP_INPUT); + this->add_operator_cost_with_memory( + {real_source, MachineView::NO_VIEW}, 0.0f, MemoryUsage{}, &result); + } + + return result; +} + template <> void SearchHelper::add_operator_cost(NodeAssignment const &node, float node_cost, @@ -1332,6 +1518,20 @@ void SearchHelper::add_operator_cost( cost->views[node.node] = node.view; } +/** + * @brief Add an operator's run time and memory cost to the graph cost. + * "cost" is updated within this function. + */ +void SearchHelper::add_operator_cost_with_memory( + NodeAssignment const &node, + float node_run_time_cost, + MemoryUsage node_mem_cost, + GraphCostResultWithMemory *cost) const { + cost->cost += node_run_time_cost; + cost->mem_cost += node_mem_cost; + cost->views[node.node] = node.view; +} + template <> float SearchHelper::get_cost(float const &f) const { return f; @@ -1343,6 +1543,45 @@ float SearchHelper::get_cost( return gcr.cost; } +template <> +float SearchHelper::get_cost( + GraphCostResultWithMemory const &gcr) const { + return gcr.get_multi_obj_cost(); +} + +template +void SearchHelper::add_sink_node_costs(NodeAssignment const &sink, + CostMetrics metrics, + T *result) const { + this->add_operator_cost(sink, + metrics.forward_time + metrics.backward_time + + metrics.sync_time, + result); +} + +/** + * @brief Specialization of add_sink_node_costs to handle + * GraphCostResultWithMemory + */ +template <> +void SearchHelper::add_sink_node_costs( + NodeAssignment const &sink, + CostMetrics metrics, + GraphCostResultWithMemory *result) const { + float op_total_mem_mb = ((float)(metrics.op_total_mem / 1e4)) / 1e2; + this->add_operator_cost_with_memory( + sink, + metrics.forward_time + metrics.backward_time + metrics.sync_time, + MemoryUsage{MemoryUsageType::GLOBAL, op_total_mem_mb}, + result); +} + +/** + * @brief Core function to analyze the cost of a graph. + * + * @tparam T float or GraphCostResult (or GraphCostResultWithMemory in memory + * optimization) + */ template T SearchHelper::graph_cost(Graph const *graph, NodeAssignment const &source, @@ -1350,7 +1589,8 @@ T SearchHelper::graph_cost(Graph const *graph, MachineResource const &resources, bool include_sink_compute_time) const { TAG_ENTER(this->logger); - this->logger->debug() << "sink(" << sink.node.guid << ") " + this->logger->debug() << "PCG::SearchHelper::graph_cost: sink(" + << sink.node.guid << ") " << "sink.view(" << sink.view.ndims << " " << sink.view.start_device_id << " " << sink.view.dim[0] << ") " @@ -1382,9 +1622,11 @@ T SearchHelper::graph_cost(Graph const *graph, result = from_cache.second; } else { if (graph->inEdges.size() <= 2) { + // When there are no more than 2 nodes in the graph result = this->estimate_xfer_cost(graph, source, sink); this->logger->debug() - << "Estimated xfer cost is " << this->get_cost(result); + << "[PCG::SearchHelper::graph_cost] Estimated xfer cost is " + << this->get_cost(result); } else { Node bn_node = graph->find_bottleneck_node(sink.node, source.node); if (bn_node != Node::INVALID_NODE) { @@ -1415,26 +1657,114 @@ T SearchHelper::graph_cost(Graph const *graph, check_matches_graph(graph, result, sink.node); + // This is where we really add the costs of an operator if (include_sink_compute_time) { + // Sink node costs CostMetrics metrics = this->model->simulator->measure_operator_cost(sink.node.ptr, sink.view); - this->logger->debug() << "Sink node cost: " + + // Adjust operator memory usage + this->logger->spew() + << "[PCG::SearchHelper::graph_cost] Analyzing sink op memory cost [" + << sink.node.to_string() << "]:"; + + int input_num_parts = 0; + int output_num_parts = 0; + int weight_num_parts = 0; + auto op = sink.node.ptr; + this->logger->spew() << " input ParallelTensor shape|num_replicas:"; + for (int i = 0; i < op->numInputs; i++) { + auto shape = op->inputs[i]->get_shape(); + this->logger->spew() << shape << "|" << shape.get_num_replicas() << "; "; + if (input_num_parts == 0) { + input_num_parts = op->inputs[i]->get_total_num_parts(); + } + } + this->logger->spew() << " output ParallelTensor shape|num_replicas:"; + for (int i = 0; i < op->numOutputs; i++) { + auto shape = op->outputs[i]->get_shape(); + this->logger->spew() << shape << "|" << shape.get_num_replicas() << "; "; + if (output_num_parts == 0) { + output_num_parts = op->outputs[i]->get_total_num_parts(); + } + } + this->logger->spew() << " weight ParallelTensor shape|num_replicas:"; + for (int i = 0; i < op->numWeights; i++) { + if (op->weights[i] == nullptr) { + continue; + } + auto shape = op->weights[i]->get_shape(); + this->logger->spew() << shape << "|" << shape.get_num_replicas() << "; "; + if (weight_num_parts == 0) { + weight_num_parts = op->weights[i]->get_total_num_parts(); + } + } + input_num_parts = std::max(input_num_parts, 1); + output_num_parts = std::max(output_num_parts, 1); + weight_num_parts = std::max(weight_num_parts, 1); + if (input_num_parts > weight_num_parts) { + weight_num_parts = input_num_parts; + } + this->logger->spew() + << " Total number of parts of inputs|outputs|weights: " + << input_num_parts << "|" << output_num_parts << "|" + << weight_num_parts; + + // Real memory usage of this Op* considering parallelization over devices + this->logger->spew() << " cost_metrics input|output|weight memory: " + << metrics.inputs_memory << "|" + << metrics.outputs_memory << "|" + << metrics.weights_memory; + + metrics.op_total_mem = input_num_parts * metrics.inputs_memory + + output_num_parts * metrics.outputs_memory + + weight_num_parts * metrics.weights_memory; + + this->logger->spew() << " op_total_mem: " << metrics.op_total_mem; + float op_total_mem_mb = (float)((metrics.op_total_mem) / 1e4) / 1e2; + this->logger->debug() << "[PCG::SearchHelper::graph_cost] Sink node cost [" + << sink.node.to_string() << "]: " << "forward(" << metrics.forward_time << ") " << "backward(" << metrics.backward_time << ") " - << "sync(" << metrics.sync_time << ")"; - this->add_operator_cost(sink, - metrics.forward_time + metrics.backward_time + - metrics.sync_time, - &result); + << "sync(" << metrics.sync_time << ") " + << "memory(" << op_total_mem_mb << " MB)"; + this->add_sink_node_costs(sink, metrics, &result); } return result; } +/** + * @brief Get the optimal run time cost of a PCG. + * @details This is the current metric used to decide which PCG is better + * in Unity's search algorithm. + */ float Graph::optimal_cost() const { return this->generic_optimal_cost(); } +/** + * @brief Get a single number to represent the multi-objective cost of a PCG. + */ +float Graph::optimal_cost_with_memory(float run_time_cost_factor) const { + auto optimal = this->generic_optimal_cost(); + float run_time_cost = optimal.cost; + float mem_cost = optimal.mem_cost.num; + // This is where we combine two costs to get the multi-objective cost + auto combined_cost = (run_time_cost_factor * run_time_cost + + (1 - run_time_cost_factor) * mem_cost); + std::string output_str = + "Multi-objective cost in Graph::optimal_cost_with_memory:" + "run time cost: " + + std::to_string(run_time_cost) + + ", mem cost: " + std::to_string(mem_cost) + + ", combined cost: " + std::to_string(combined_cost) + + " (with run time cost factor: " + std::to_string(run_time_cost_factor) + + ")"; + this->search->logger->spew() << output_str; + return combined_cost; +} + std::unordered_map Graph::optimal_views() const { return this->generic_optimal_cost().views; } @@ -1466,9 +1796,8 @@ Graph Graph::reduced() const { * two versions is almost identical. By using a few template specializations we * can avoid duplicating all this code. * - * @tparam T the result type (can be either float or GraphCostResult) - * @return T the cost of the graph (along with any additional data in the return - * type) + * @tparam T Result type (float, GraphCostResult, or GraphCostResultWithMemory) + * @return T Cost of the graph (along with any additional data) */ template T Graph::generic_optimal_cost() const { @@ -1484,7 +1813,9 @@ T Graph::generic_optimal_cost() const { // } Node sink_node = reduced_graph.find_sink_node(); - this->search->logger->info() << "Found sink node: " << sink_node.to_string(); + this->search->logger->info() + << "Graph::generic_optimal_cost: Found sink node: " + << sink_node.to_string(); MachineResource resource(model->config); @@ -1543,12 +1874,21 @@ size_t dp_state_hash(Graph const *graph, return key; } -GraphOptimalViewSerialized - Graph::graph_optimize_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { +namespace { + +/** + * @brief Given a lambda value, perform the search and return the optimized PCG + * and corresponding MachineView. + */ +std::pair, std::unordered_map> + try_one_lambda(std::pair &lambda, + Task const *task, + std::shared_ptr &cached_simulator, + bool perform_memory_search) { + // Create a new fresh model FFModel *model = *((FFModel **)task->args); + model->clear_graph_search_cache(); + if (model->config.search_num_nodes.has_value()) { model->config.numNodes = model->config.search_num_nodes.value(); } @@ -1581,11 +1921,21 @@ GraphOptimalViewSerialized "machine-model-file should not be empty."); } // Assume this task is running on GPU0 - std::shared_ptr simulator( - new Simulator(model, model->handlers[0], gpu_mem, machine)); - model->simulator = simulator.get(); - std::unique_ptr best_graph; - std::unordered_map optimal_views; + if (!cached_simulator) { + cached_simulator = std::make_shared( + model, model->handlers[0], gpu_mem, machine); + } else { + // Update simulator with the new stuff + cached_simulator->handler = model->handlers[0]; + cached_simulator->memory = gpu_mem; + cached_simulator->machine = machine; + } + model->simulator = cached_simulator.get(); + + // Perform the search + std::unique_ptr curr_best_graph; + std::unordered_map curr_optimal_views; + if (model->config.only_data_parallel) { Graph *graph = new Graph(model); std::unordered_map op_to_node_map; @@ -1601,7 +1951,7 @@ GraphOptimalViewSerialized graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); } } - best_graph = std::unique_ptr(graph); + curr_best_graph = std::unique_ptr(graph); MachineView data_parallel_view; data_parallel_view.device_type = MachineView::GPU; data_parallel_view.ndims = 1; @@ -1609,15 +1959,208 @@ GraphOptimalViewSerialized model->config.numNodes * model->config.workersPerNode; data_parallel_view.stride[0] = 1; data_parallel_view.start_device_id = 0; - for (auto const &node : best_graph->inEdges) { - optimal_views[node.first] = data_parallel_view; + for (auto const &node : curr_best_graph->inEdges) { + curr_optimal_views[node.first] = data_parallel_view; } } else { + // Main step to optimize the PCG of an FFModel model->graph_optimize(model->config.search_budget, model->config.only_data_parallel, - best_graph, - optimal_views); + curr_best_graph, + curr_optimal_views, + perform_memory_search, + MemoryOptimConfig{lambda.first}, + lambda.second); + } + // Return the best result of the current search + return std::make_pair(std::move(curr_best_graph), curr_optimal_views); +}; + +/** + * @brief Analyze the per-device memory cost and compare with the memory + * threshold of each device. + */ +bool is_valid_strategy( + std::vector> &lambdas_results, + Graph *curr_graph, + std::unordered_map &curr_views, + std::shared_ptr const cached_simulator, + float memory_threshold) { + std::cout << "try to check valid for lambda " << lambdas_results.back().first + << std::endl; + assert(cached_simulator.get() != nullptr && + "cached_simulator cannot be nullptr"); + + // Analyze the strategy and update max_per_device_mem_all_deivces in the + // lambda_result. + std::unordered_map device_to_mem{}; + for (auto const &view : curr_views) { + CostMetrics op_cost = + cached_simulator->measure_operator_cost(view.first.ptr, view.second); + float node_mem_as_mb = op_cost.total_memory_in_mb(); + + for (auto const d_id : view.second.device_ids()) { + if (device_to_mem.find(d_id) == device_to_mem.end()) { + device_to_mem.emplace(std::make_pair(d_id, node_mem_as_mb)); + } else { + device_to_mem[d_id] += node_mem_as_mb; + } + } + } + + float max_per_device_mem = 0.0; + float total_device_mem = 0.0; + for (auto const &d : device_to_mem) { + std::cout << "d_id: " << d.first << ", mem: " << d.second << std::endl; + total_device_mem += d.second; + if (d.second > max_per_device_mem) { + max_per_device_mem = d.second; + } + } + + lambdas_results.back().second.max_per_device_mem_all_deivces = + max_per_device_mem; + + std::cout << "max_per_device_mem: " + << lambdas_results.back().second.max_per_device_mem_all_deivces + << ", total_device_mem: " << total_device_mem << std::endl; + + if (max_per_device_mem >= memory_threshold) { + return false; + } + return true; +}; + +}; // namespace + +/** + * @brief Starting point of Unity search procedure. Registered on Legion + * runtime. Legion task to launch as one step of model.compile(). + * + * @param task Legion task to get FFModel and other configs + * @param regions Not used + * @param ctx Not used + * @param runtime Not used + * @return GraphOptimalViewSerialized Serialized optimal PCG + */ +GraphOptimalViewSerialized + Graph::graph_optimize_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + auto model_config = (*((FFModel **)task->args))->config; + bool perform_memory_search = model_config.perform_memory_search; + float memory_threshold = model_config.device_mem; + bool only_data_parallel = model_config.only_data_parallel; + + std::vector> lambdas{}; + + std::shared_ptr cached_simulator{}; + + // Optimized graph from the search + std::unique_ptr best_graph; + std::unordered_map optimal_views; + + // Be optimistic + lambdas.emplace_back(std::make_pair(1.0, MemorySearchResult{})); + auto try_result = try_one_lambda( + lambdas.back(), task, cached_simulator, perform_memory_search); + best_graph = std::move(try_result.first); + optimal_views = try_result.second; + + bool has_valid_strategy = false; + int best_lambda_index = -1; + int binary_search_budget = 10; + + if (perform_memory_search && !is_valid_strategy(lambdas, + best_graph.get(), + optimal_views, + cached_simulator, + memory_threshold)) { + // Not found the strategy; need to do binary search + lambdas.emplace_back(std::make_pair(0.0, MemorySearchResult{})); + try_result = try_one_lambda( + lambdas.back(), task, cached_simulator, perform_memory_search); + best_graph = std::move(try_result.first); + optimal_views = try_result.second; + + if (!is_valid_strategy(lambdas, + best_graph.get(), + optimal_views, + cached_simulator, + memory_threshold)) { + // Cannot find a valid strategy + has_valid_strategy = false; + } else { + has_valid_strategy = true; + best_lambda_index = 1; + + // Do a binary search between 0 and 1 for the best lambda + int bianry_search_num = 0; + float lower = 0.0; + float upper = 1.0; + + while (bianry_search_num < binary_search_budget) { + bianry_search_num++; + + float mid = (lower + upper) * 0.5; + + lambdas.emplace_back(std::make_pair(mid, MemorySearchResult{})); + try_result = try_one_lambda( + lambdas.back(), task, cached_simulator, perform_memory_search); + + if (!is_valid_strategy(lambdas, + try_result.first.get(), + try_result.second, + cached_simulator, + memory_threshold)) { + upper = mid; + } else { + // Found a better and valid strategy + best_graph = std::move(try_result.first); + optimal_views = try_result.second; + + lower = mid; + best_lambda_index = 1 + bianry_search_num; + } + } + } + } else { + has_valid_strategy = true; + best_lambda_index = 0; + } + + // Print out the results + if (perform_memory_search) { + if (has_valid_strategy) { + auto &best_l = lambdas[best_lambda_index]; + std::cout << "Found valid strategy with memory_threshold: " + << memory_threshold << " | lambda index: " << best_lambda_index + << ", lambda value: " << best_l.first + << ", result: run time cost: " << best_l.second.run_time_cost + << ", memory cost: " << best_l.second.memory_cost + << ", search time: " << best_l.second.search_time + << ", per-device max memory: " + << best_l.second.max_per_device_mem_all_deivces << std::endl; + } else { + std::cout << "Failed to find a valid strategy" << std::endl; + } + + std::cout << "All lambda results:" << std::endl; + for (auto l : lambdas) { + std::cout << "lambda: " << l.first + << ", run time cost: " << l.second.run_time_cost + << ", memory cost: " << l.second.memory_cost + << ", search time: " << l.second.search_time + << ", per-device max memory: " + << l.second.max_per_device_mem_all_deivces << std::endl; + } + } else if (!only_data_parallel) { + std::cout << "\nNot doing memory search" << std::endl; } + + // Following lines are to serialize the optimized PCG. + // Only need best_graph and optimal_views below. Serializer sez; // First serialize graph sez.serialize(best_graph->inEdges.size()); @@ -1758,7 +2301,7 @@ GraphOptimalViewSerialized } assert(node_idx == best_graph->inEdges.size()); // Second, serialize optimal machine view - printf("opotimal_views.size = %zu\n", optimal_views.size()); + printf("optimal_views.size = %zu\n", optimal_views.size()); sez.serialize(optimal_views.size()); for (auto const &it : optimal_views) { sez.serialize((size_t)98765432); // safe guard diff --git a/src/runtime/memory_optimization.cc b/src/runtime/memory_optimization.cc new file mode 100644 index 0000000000..b9d6d3bb55 --- /dev/null +++ b/src/runtime/memory_optimization.cc @@ -0,0 +1,64 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/memory_optimization.h" + +namespace FlexFlow { + +namespace PCG { + +std::string MemoryUsage::to_string() const { + std::string type_name; + switch (usage_type) { + case MemoryUsageType::GLOBAL: + type_name = "GLOBAL"; + break; + case MemoryUsageType::PER_DEVICE_MAX: + type_name = "PER_DEVICE_MAX"; + break; + } + return "(MemoryUsageType:" + type_name + ", Usage:" + std::to_string(num) + + ")"; +} + +MemoryUsage &MemoryUsage::operator+=(MemoryUsage const &rhs) { + assert(usage_type == rhs.usage_type); + + // Handle the merge of memory usage differently here. + switch (usage_type) { + case MemoryUsageType::GLOBAL: + num += rhs.num; + break; + case MemoryUsageType::PER_DEVICE_MAX: + num = std::max(num, rhs.num); + break; + } + + return *this; +} + +MemoryUsage operator+(MemoryUsage lhs, MemoryUsage const &rhs) { + lhs += rhs; + return lhs; +} + +std::ostream &operator<<(std::ostream &s, MemoryUsage const &usage) { + s << usage.to_string(); + return s; +} + +} // namespace PCG + +} // namespace FlexFlow diff --git a/src/runtime/model.cc b/src/runtime/model.cc index ca1ab33343..a3ebc87b8d 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1221,6 +1221,11 @@ FFModel::FFModel(FFConfig &_config) } } +void FFModel::clear_graph_search_cache() { + this->graph_search->clear_cache(); + this->search->clear_cache(); +} + #ifdef FF_USE_NCCL ncclComm_t *FFModel::find_nccl_comms(MachineView const &view) const { auto const &it = view_hash_to_nccl_comms.find(view.hash()); @@ -3529,6 +3534,7 @@ FFConfig::FFConfig() { syntheticInput = false; perform_fusion = false; base_optimize_threshold = DefaultConfig::base_optimize_threshold; + perform_memory_search = false; // Parse input arguments { @@ -3615,6 +3621,10 @@ void FFConfig::parse_args(char **argv, int argc) { workersPerNode = atoi(argv[++i]); continue; } + if (!strcmp(argv[i], "-ll:fsize")) { + device_mem = atoi(argv[++i]); + continue; + } if (!strcmp(argv[i], "--nodes")) { fprintf(stderr, "[Warning] --nodes is deprecated. " @@ -3701,6 +3711,10 @@ void FFConfig::parse_args(char **argv, int argc) { substitution_json_path = std::string(argv[++i]); continue; } + if (!strcmp(argv[i], "--memory-search")) { + perform_memory_search = true; + continue; + } } } diff --git a/src/runtime/simulator.cc b/src/runtime/simulator.cc index f50ab0a514..c363cdd296 100644 --- a/src/runtime/simulator.cc +++ b/src/runtime/simulator.cc @@ -42,6 +42,11 @@ size_t CostMetrics::total_memory() const { return inputs_memory + outputs_memory + weights_memory; } +float CostMetrics::total_memory_in_mb() const { + float mem_mb = (float)((total_memory()) / 1e4) / 1e2; + return mem_mb; +} + size_t CostMetrics::total_mem_diff_from(off_t sim_offset) const { return static_cast(sim_offset) - total_memory(); } @@ -537,7 +542,7 @@ CostMetrics Simulator::measure_operator_cost(Op const *op, ProfilingRecordKey key{params, mv}; if (this->strict_hash_to_operator_cost.find(key) == this->strict_hash_to_operator_cost.end()) { - CostMetrics cost_metrics; + CostMetrics cost_metrics{}; bool is_implemented = op->measure_operator_cost(this, mv, cost_metrics); if (!is_implemented) { handle_measure_operator_cost_unimplemented(op); @@ -558,7 +563,7 @@ CostMetrics Simulator::measure_operator_cost(Op const *op, hash_to_operator_cost.find(hash); if (iter == hash_to_operator_cost.end()) { - CostMetrics cost_metrics; + CostMetrics cost_metrics{}; bool is_implemented = op->measure_operator_cost(this, mv, cost_metrics); if (!is_implemented) { handle_measure_operator_cost_unimplemented(op); diff --git a/src/runtime/substitution.cc b/src/runtime/substitution.cc index 9d3d6057f1..f852acaa6b 100644 --- a/src/runtime/substitution.cc +++ b/src/runtime/substitution.cc @@ -592,10 +592,11 @@ void GraphXfer::find_matches(int depth, } } +template void GraphXfer::run( int depth, Graph *graph, - std::priority_queue, GraphCompare> + std::priority_queue, GraphComparator> &candidates, std::unordered_set &hashmap, float threshold, @@ -1173,11 +1174,6 @@ OpX *GraphXfer::create_combine(TensorX const &input, return part; } -/* std::vector MachineView::get_devices() const { */ -/* std::vector devices; */ - -/* } */ - void Graph::print_strategy_computation_graph( std::unordered_map const &strategy) const { DotFile dot(std::cout); @@ -1200,8 +1196,8 @@ void Graph::export_strategy_computation_graph( GraphStructure s; for (auto const &node : s.get_nodes(*this)) { + // Add node if (strategy.find(node) == strategy.end()) { - dot.add_node(node, {{"label", node.to_string()}}); // Check FusedParallel node here and print out the detailed information if (node.ptr->op_type == OperatorType::OP_FUSED_PARALLEL) { RecordFormatter rf; @@ -1266,6 +1262,8 @@ void Graph::export_strategy_computation_graph( } } } + + // Fetch machine view information for (int device_id : mv.device_ids()) { machine_view_row << std::to_string(device_id); } @@ -1305,6 +1303,7 @@ void Graph::export_strategy_computation_graph( dot.add_record_node(node, rf); } + // Add edges for (auto const &edge : s.get_incoming_edges(*this, node)) { dot.add_edge(s.get_src(*this, edge), s.get_dst(*this, edge)); } @@ -1710,11 +1709,15 @@ std::vector create_xfers(FFModel *model, } GraphSearchHelper::GraphSearchHelper(FFModel *model) - : model(model), config(model->config) { + : model(model), config(model->config), mem_config(1.0) { this->logger = std::unique_ptr(new RecursiveLogger("gs")); generate_all_pcg_xfers(); } +void GraphSearchHelper::clear_cache() { + cached_optimized_graphs.clear(); +} + void GraphSearchHelper::load_graph_substitutions( std::vector &xfers) const { xfers = all_pcg_xfers; @@ -1883,6 +1886,15 @@ Graph *GraphSearchHelper::construct_graph() { return graph; } +/** + * @brief Unity search algorithm main entrance. + * + * @param[in] budget Not used + * @param[in] only_data_parallel Not used + * @param[out] best_graph The best possible PCG after optimization + * @param[out] optimal_views The corresponding device placement views of the + * best graph + */ void GraphSearchHelper::graph_optimize( size_t budget, bool only_data_parallel, @@ -1932,6 +1944,97 @@ void GraphSearchHelper::graph_optimize( optimal_views = real_optimal_views; } +/** + * @brief Experimental DP algorithm to optimize PCG with the consideration of + * memory usage. This is to avoid polluting the current Unity search algorithm + * above. And this should be merged to GraphSearchHelper::graph_optimize + * eventually. + * + * @param[in] budget Not used + * @param[in] only_data_parallel Not used + * @param[out] best_graph The best possible PCG after optimization + * @param[out] optimal_views The corresponding device placement views of the + * best graph + * @param[out] search_result The performance result of the search + */ +void GraphSearchHelper::graph_optimize_with_memory( + size_t budget, + bool only_data_parallel, + std::unique_ptr &best_graph, + std::unordered_map &optimal_views, + MemorySearchResult &search_result) { + this->logger->debug() + << "Starting graph optimization with memory consideration"; + + // Construct graph structure + Graph *graph = this->construct_graph(); + + // The input nodes may need to be duplicated because the PCG was constructed + // to have one input node for one input, but the actual execution graph should + // have the distributed version of inputs (i.e. multiple nodes). + graph->duplicate_input_nodes(); + + // Export an empty schedule if needed. + std::unordered_map empty_strategy; + if (!this->config.export_strategy_computation_graph_file.empty()) { + graph->export_strategy_computation_graph( + empty_strategy, this->config.export_strategy_computation_graph_file); + } + + Node sink_node = graph->find_sink_node(); + + auto const start = std::chrono::system_clock::now(); + GraphOptimizeResultWithMemory optimal = + this->generic_sequence_optimize_with_memory< + GraphOptimizeResultWithMemory>( + graph, sink_node, tl::nullopt, tl::nullopt); + auto const end = std::chrono::system_clock::now(); + + this->logger->debug() << "Total cache size: " + << this->cached_optimized_graphs.size(); + std::cout << "Optimal run time cost: " << optimal.cost + << ", Memory usage: " << optimal.mem_cost + << " | run_time_cost_factor: " + << this->mem_config.run_time_cost_factor << std::endl; + + // Save the search performance results to the output argument + search_result.run_time_cost = optimal.cost; + search_result.memory_cost = optimal.mem_cost.num; + search_result.search_time = + std::chrono::duration_cast(end - start) + .count(); + + // Further simplify the "optimal" graph/schedule to have a more efficient + // graph and more accurate cost. + best_graph = std::unique_ptr(new Graph(optimal.graph.value())); + SimplificationSettings settings; + // Simplify to consider parallel op fusion + settings.fuse_parallel_ops = true; + settings.remove_noops = true; + settings.remove_trailing_parallel_ops = true; + settings.simplify_parallel_ops = true; + best_graph->simplify(settings); + + // Get the real optimal machine views. + std::unordered_map duplicated_optimal_views = + best_graph->optimal_views(); + std::unordered_map deduplication_map = + best_graph->deduplicate_input_nodes(); + std::unordered_map real_optimal_views; + for (auto const &kv : duplicated_optimal_views) { + if (deduplication_map.find(kv.first) != deduplication_map.end()) { + real_optimal_views[deduplication_map.at(kv.first)] = kv.second; + } else { + real_optimal_views[kv.first] = kv.second; + } + } + std::cout << "Dot graph of searched strategy:" << std::endl; + best_graph->print_strategy_computation_graph(optimal.views); + std::cout << std::endl; + + optimal_views = real_optimal_views; +} + void GraphSearchHelper::graph_optimize_no_split( size_t budget, bool only_data_parallel, @@ -1969,6 +2072,11 @@ static void graph_log_representation(Graph const *graph, } } +void GraphSearchHelper::update_mem_optim_config( + MemoryOptimConfig const &new_config) { + mem_config = new_config; +} + void GraphSearchHelper::find_rewrite_matches( Graph const *graph, std::vector &matches) const { std::vector xfers; @@ -2111,6 +2219,13 @@ tl::optional return best; } +/** + * @brief Base case of Unity's DP search algorithm. + * + * @param r_graph Graph to be optimized + * @param simplification_settings Settings to simplify the PCG + * @return std::unique_ptr Optimized PCG + */ std::unique_ptr GraphSearchHelper::base_optimize( Graph const *r_graph, SimplificationSettings const &simplification_settings) { @@ -2195,6 +2310,112 @@ std::unique_ptr GraphSearchHelper::base_optimize( return std::unique_ptr(best_graph); } +/** + * @brief Experimental. Base case of Unity's DP search algorithm with + * memory consideration. + * + * @param r_graph Graph to be optimized + * @param simplification_settings Settings to simplify the resulting PCG + * @return std::unique_ptr Optimized PCG + */ +std::unique_ptr GraphSearchHelper::base_optimize_with_memory( + Graph const *r_graph, + SimplificationSettings const &simplification_settings) { + TAG_ENTER(this->logger); + this->logger->debug() << "Optimizing base graph with memory: "; + { + TAG_ENTER(this->logger); + /* graph_log_representation(r_graph, *this->logger); */ + // r_graph->print_dot(); + } + this->logger->debug() << "Starting cost: " + << r_graph->optimal_cost_with_memory( + mem_config.run_time_cost_factor); + + // Construct graph substitutions + std::vector xfers; + this->load_graph_substitutions(xfers); + + // Prepare for the search + std::priority_queue, GraphCompareWithMemory> + candidates(GraphCompareWithMemory{mem_config.run_time_cost_factor}); + std::unordered_set hashmap; + + Graph *graph = new Graph(*r_graph); + candidates.push(graph); + hashmap.insert(graph->hash()); + + Graph *best_graph = new Graph(*graph); + float best_cost = + best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); + + int counter = 0; + float const alpha = this->model->config.search_alpha; + int budget = model->config.search_budget; + if (budget == 0) { + log_xfers.warning() + << "Base search budget is set to 0. This is probably not what you want " + "(use the --budget flag to set the base search budget)"; + } + + // Actual exploration + for (int iter = 0; iter < budget || budget == -1; iter++) { + log_xfers.spew() << "Considering " << candidates.size() + << " candidates in base_optimize_with_memory"; + if (candidates.empty()) { + break; + } + + Graph *cur_graph = candidates.top(); + candidates.pop(); + if (cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor) < + best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor)) { + delete best_graph; + best_graph = cur_graph; + best_cost = + cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); + } else if (cur_graph->optimal_cost_with_memory( + mem_config.run_time_cost_factor) > best_cost * alpha) { + continue; + } + + log_xfers.info( + "[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", + counter, + cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor), + best_cost, + candidates.size()); + + log_xfers.debug() << "Considering " << xfers.size() + << " possible xfers in base_optimize_with_memory"; + for (size_t i = 0; i < xfers.size(); i++) { + int num_matches_found = 0, num_matches_rejected = 0; + log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); + xfers[i]->run(0, + cur_graph, + candidates, + hashmap, + best_cost * alpha, + 1000, + simplification_settings, + num_matches_found, + num_matches_rejected); + log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " + << num_matches_found << " ] matches"; + } + + if (best_graph != cur_graph) { + delete cur_graph; + } + } + + this->logger->debug() + << "Optimized cost at the end of base_optimize_with_memory: " + << best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); + + return std::unique_ptr(best_graph); +} + size_t gs_dp_state_hash(Graph const *graph, Node const &sink_node, tl::optional const &output_shape, @@ -2249,6 +2470,20 @@ GraphOptimizeResult GraphSearchHelper::get_optimal_cost( return result; } +template <> +GraphOptimizeResultWithMemory + GraphSearchHelper::get_optimal_cost( + std::unique_ptr optimized) const { + GraphOptimizeResultWithMemory result; + result.graph = *optimized; + GraphCostResultWithMemory gcr = + optimized->generic_optimal_cost(); + result.cost = gcr.cost; + result.views = gcr.views; + result.mem_cost = gcr.mem_cost; + return result; +} + template <> tl::optional GraphSearchHelper::try_get_cost_from_cache( @@ -2263,6 +2498,13 @@ tl::optional return tl::nullopt; } +template <> +tl::optional + GraphSearchHelper::try_get_cost_from_cache( + size_t hash) const { + return tl::nullopt; +} + template <> void GraphSearchHelper::try_cache_result(size_t hash, float const &value) { @@ -2277,6 +2519,16 @@ template <> void GraphSearchHelper::try_cache_result( size_t hash, GraphOptimizeResult const &value) {} +template <> +void GraphSearchHelper::try_cache_result( + size_t hash, GraphOptimizeResultWithMemory const &value) {} + +/** + * @brief Get the cost/result of PCG if sequentially split it. + * + * @details This function is to combine the search results from DP sub-problems. + * The sub-problems are solved by generic_sequence_optimize(). + */ template T GraphSearchHelper::execute_sequence_split( std::unique_ptr const &pre_graph, @@ -2293,6 +2545,29 @@ T GraphSearchHelper::execute_sequence_split( post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); } +/** + * @brief Experimental. Consider memory usage when spliting the PCG during the + * DP search. This should be merged with execute_sequence_split(). + */ +template +T GraphSearchHelper::execute_sequence_split_with_memory( + std::unique_ptr const &pre_graph, + std::unique_ptr const &post_graph, + tl::optional const &output_shape, + tl::optional const &input_shape, + Node const &sink_node, + Node const &bottleneck, + ParallelTensorShape const &bottleneck_output_shape) { + return sequence_cost( + this->generic_sequence_optimize_with_memory( + pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), + this->generic_sequence_optimize_with_memory( + post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); +} + +/** + * @brief Top level DP search procedure for Unity. + */ template T GraphSearchHelper::generic_sequence_optimize( Graph const *graph, @@ -2475,6 +2750,194 @@ T GraphSearchHelper::generic_sequence_optimize( return return_value; } +/** + * @brief Top level DP search procedure for Unity with the consideration of + * memory usage. + * + * @tparam T Returned type + * @param graph Pre-optimization PCG + * @param sink_node Sink node of the PCG + * @param output_shape ??? + * @param input_shape ??? + * @return T Optimal result + */ +template +T GraphSearchHelper::generic_sequence_optimize_with_memory( + Graph const *graph, + Node const &sink_node, + tl::optional const &output_shape, + tl::optional const &input_shape) { + TAG_ENTER(this->logger); + + // Try to find the result from cache first. But this will only get the cached + // result if the returned type is float. The float number means the best run + // time cost with only machine quantity (without distinguishing machine + // identities). + size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); + tl::optional cached = this->try_get_cost_from_cache(hash); + if (cached.has_value()) { + this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() + << " nodes"; + { + TAG_ENTER(this->logger); + this->logger->spew() << "Nodes: "; + { + TAG_ENTER(this->logger); + graph_log_representation(graph, *this->logger); + } + this->logger->spew() << "Retrieved value from cache: " << cached.value(); + } + return cached.value(); + } + + // Couldn't find the result from cache. Try to optimize and get one. + this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() + << " nodes"; + T return_value; + { + // Print out debug information + TAG_ENTER(this->logger); + this->logger->spew() << "Nodes: "; + { + TAG_ENTER(this->logger); + graph_log_representation(graph, *this->logger); + } + this->logger->debug() << "Graph hash: " << std::setw(32) + << std::setfill('0') << graph->hash(); + if (input_shape.has_value()) { + this->logger->debug() << "Input shape: " << input_shape.value(); + } else { + this->logger->debug() << "Input shape: "; + } + if (output_shape.has_value()) { + this->logger->debug() << "Output shape: " << output_shape.value(); + } else { + this->logger->debug() << "Output shape: "; + } + + // Find the node to sequentially split the PCG. + // Decide if the search reaches the base condition by this. + tl::optional bottleneck = + this->find_split_node(graph, this->config.base_optimize_threshold); + + if (!bottleneck.has_value()) { + this->logger->debug() << "Applying base case"; + + // Construct the PCG to optimize based on input_shape and output_shape + // information. + Graph to_optimize(*graph); + if (input_shape.has_value()) { + Node input_node = + this->model->get_or_create_input_node(input_shape.value()); + Node noop_node = + this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); + Graph input_graph(this->model); + Edge e(input_node, noop_node, 0, 0); + input_graph.add_edge(e); + + Node old_source_node = graph->find_source_node(); + ParallelTensorShape old_source_output_shape = + old_source_node.ptr->outputs[0]->get_shape(); + input_graph.reshape_output_tensor(old_source_output_shape); + + Node new_sink_node = input_graph.find_sink_node(); + assert(new_sink_node.ptr->numOutputs == 1); + assert(new_sink_node.ptr->outputs[0]->get_shape() == + old_source_output_shape); + + to_optimize.replace_subgraph({old_source_node}, input_graph); + } + SimplificationSettings settings; + if (output_shape.has_value()) { + to_optimize.reshape_output_tensor(output_shape.value()); + Node sink_node = to_optimize.find_sink_node(); + Node noop_node = + this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); + to_optimize.add_edge(sink_node, noop_node, 0, 0); + } else { + settings.remove_trailing_parallel_ops = true; + } + settings.simplify_parallel_ops = true; + + // Call base optimization to perform graph substitution. + std::unique_ptr optimized = + this->base_optimize_with_memory(&to_optimize, settings); + return_value = get_optimal_cost(std::move(optimized)); + } else { + this->logger->debug() << "Applying recursive case on bottleneck " + << bottleneck.value().guid; + + std::unique_ptr pre_graph, post_graph; + std::tie(pre_graph, post_graph) = + graph->split_at_node(bottleneck.value()); + + MachineResource resources(this->model->config); + std::vector valid_machine_views = + this->model->search->get_valid_machine_views(bottleneck.value().ptr, + resources); + + // Try to find the best cost and corresponding best bottleneck shape. + // This search process is based on the float version of + // execute_sequence_split_with_memory(). + float best_cost = std::numeric_limits::infinity(); + tl::optional best_shape = tl::nullopt; + { + TAG_ENTER(this->logger); + for (auto const &bottleneck_output_shape : + this->possible_split_output_tensor_shapes(bottleneck.value())) { + this->logger->debug() + << "Considering boundary shape " << bottleneck_output_shape; + float current_cost; + { + TAG_ENTER(this->logger); + // Get the cost from execute_sequence_split_with_memory by + // only changing bottleneck_output_shape. + current_cost = this->execute_sequence_split_with_memory( + pre_graph, + post_graph, + output_shape, + input_shape, + sink_node, + bottleneck.value(), + bottleneck_output_shape); + + if (current_cost < best_cost) { + best_cost = current_cost; + best_shape = bottleneck_output_shape; + } + } + this->logger->debug() << "Boundary shape " << bottleneck_output_shape + << " has cost: " << current_cost; + } + } + + if (best_shape.has_value()) { + this->logger->debug() + << "Best intermediate shape found: " << best_shape.value(); + } else { + this->logger->debug() << "No valid intermediate shapes found"; + } + + // ? What if best_cost is infinity ? + if (best_cost != std::numeric_limits::infinity()) { + // Get the return value of correct type with previously found + // best_shape. + return_value = + this->execute_sequence_split_with_memory(pre_graph, + post_graph, + output_shape, + input_shape, + sink_node, + bottleneck.value(), + best_shape.value()); + } + } + // Try to cache the float result + this->try_cache_result(hash, return_value); + } + return return_value; +} + std::vector GraphSearchHelper::possible_split_output_tensor_shapes( Node const &source_node) const { @@ -3110,13 +3573,35 @@ using PCG::Edge; using PCG::Graph; using PCG::Node; +/** + * @brief Optimize the graph stored in FFModel. + * + * @param[in] budget The search budget + * @param[in] only_data_parallel True if only doing data parallel training + * @param[out] best_graph The searched best graph + * @param[out] optimal_views The corresponding machine view of the best_graph + * @param[in] perform_memory_search True if we want to consider memory during + * the search + * @param[in] new_config Memory optimization config to use if this is a memory + * search + * @param[out] search_result The performance result of this search + */ void FFModel::graph_optimize( size_t budget, bool only_data_parallel, std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - this->graph_search->graph_optimize( - budget, only_data_parallel, best_graph, optimal_views); + std::unordered_map &optimal_views, + bool perform_memory_search, + MemoryOptimConfig new_config, + MemorySearchResult &search_result) { + if (perform_memory_search) { + this->graph_search->update_mem_optim_config(new_config); + this->graph_search->graph_optimize_with_memory( + budget, only_data_parallel, best_graph, optimal_views, search_result); + } else { + this->graph_search->graph_optimize( + budget, only_data_parallel, best_graph, optimal_views); + } } bool FFModel::convert_graph_to_operators(