diff --git a/driver/driver.cc b/driver/driver.cc index c460c71c0..6a93654b7 100644 --- a/driver/driver.cc +++ b/driver/driver.cc @@ -147,7 +147,8 @@ static llvm::cl::opt IpuNum( "ipu-num", llvm::cl::desc("Num of ipus, should consistent with subgraph num"), llvm::cl::init(1), llvm::cl::cat(HaloOptCat)); - +static llvm::cl::opt Shards("shards", llvm::cl::desc("Num of shards"), + llvm::cl::init(-1), llvm::cl::cat(HaloOptCat)); static llvm::cl::opt BatchesPerStep( "batches-per-step", llvm::cl::desc("Specify batches num for each step"), llvm::cl::init(1), llvm::cl::cat(HaloOptCat)); @@ -358,6 +359,7 @@ int main(int argc, char** argv) { cg_opts.enable_ipu_device = EnableIpuDevice; cg_opts.use_ipu_model = UseIpuModel; cg_opts.ipu_num = IpuNum; + cg_opts.num_shards = Shards; cg_opts.batches_per_step = BatchesPerStep; cg_opts.api = Api; cg_opts.disable_broadcasting = DisableBroadcasting; diff --git a/include/halo/halo.h b/include/halo/halo.h index b63a7bfb2..fc72f677f 100644 --- a/include/halo/halo.h +++ b/include/halo/halo.h @@ -91,6 +91,7 @@ struct CXXCodeGenOpts { bool emit_shared_lib = false; const char* linked_odla_lib = nullptr; bool save_temps = false; + int num_shards = -1; }; #define HALO_MODEL_INFO_MAX_OUTPUT_NR 64 diff --git a/include/halo/lib/pass/pass_manager.h b/include/halo/lib/pass/pass_manager.h index 7aee9267e..9f6f75c4b 100644 --- a/include/halo/lib/pass/pass_manager.h +++ b/include/halo/lib/pass/pass_manager.h @@ -110,6 +110,7 @@ class HL_API_EXPORT PassManager final { Pass* AddRISCVLLVMIRCodeGenPass(ConstantDataStorage constant_data_storage, const std::string& rt_lib_name); + Pass* AddShardingPass(int num_shards); Pass* AddSplittingPass(); Pass* AddTFExtensionLegalizerPass(); Pass* AddTFLiteExtensionLegalizerPass(); diff --git a/include/halo/lib/transforms/sharding.h b/include/halo/lib/transforms/sharding.h new file mode 100644 index 000000000..a1e34eb44 --- /dev/null +++ b/include/halo/lib/transforms/sharding.h @@ -0,0 +1,38 @@ +//===- splitting.h --------------------------------------------------------===// +// +// Copyright (C) 2019-2020 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef HALO_LIB_TRANSFORMS_SHARDING_H_ +#define HALO_LIB_TRANSFORMS_SHARDING_H_ + +#include "halo/lib/pass/pass.h" + +namespace halo { + +/// This pass splitting a function into sub functions. +class Sharding final : public FunctionPass { + public: + Sharding(int shards) : FunctionPass("Sharding"), shards_(shards) {} + + bool RunOnFunction(Function* func) override; + + private: + int shards_; +}; + +} // end namespace halo. + +#endif // HALO_LIB_TRANSFORMS_SHARDING_H_ \ No newline at end of file diff --git a/include/halo/utils/passes_helper.h b/include/halo/utils/passes_helper.h index 280f727b6..92af12767 100644 --- a/include/halo/utils/passes_helper.h +++ b/include/halo/utils/passes_helper.h @@ -162,6 +162,7 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target, if (opts.enable_type_cast) { pm->AddTypeCastPass(); } + pm->AddShardingPass(opts.num_shards); } } // namespace halo diff --git a/lib/pass/pass_manager.cc b/lib/pass/pass_manager.cc index 53cf58ac1..4720af765 100644 --- a/lib/pass/pass_manager.cc +++ b/lib/pass/pass_manager.cc @@ -39,6 +39,7 @@ #include "halo/lib/transforms/onnxextension_legalizer.h" #include "halo/lib/transforms/output_rewriter.h" #include "halo/lib/transforms/reorder_channel.h" +#include "halo/lib/transforms/sharding.h" #include "halo/lib/transforms/splitting.h" #include "halo/lib/transforms/tfextension_legalizer.h" #include "halo/lib/transforms/tfliteextension_legalizer.h" @@ -382,6 +383,10 @@ Pass* PassManager::AddRISCVLLVMIRCodeGenPass( return AddPass(constant_data_storage); } +Pass* PassManager::AddShardingPass(int num_shards) { + return AddPass(num_shards); +} + Pass* PassManager::AddSplittingPass() { return AddPass(); } Pass* PassManager::AddTFExtensionLegalizerPass() { diff --git a/lib/transforms/CMakeLists.txt b/lib/transforms/CMakeLists.txt index 942002fe2..43925d177 100644 --- a/lib/transforms/CMakeLists.txt +++ b/lib/transforms/CMakeLists.txt @@ -30,6 +30,7 @@ set(SRCS onnxextension_legalizer.cc output_rewriter.cc reorder_channel.cc + sharding.cc splitting.cc tfextension_legalizer.cc tfliteextension_legalizer.cc diff --git a/lib/transforms/sharding.cc b/lib/transforms/sharding.cc new file mode 100644 index 000000000..d9b2171bd --- /dev/null +++ b/lib/transforms/sharding.cc @@ -0,0 +1,133 @@ +//===- sharding.cc --------------------------------------------------------===// +// +// Copyright (C) 2019-2021 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "halo/lib/transforms/sharding.h" + +#include +#include + +#include "halo/lib/ir/ir_builder.h" + +namespace halo { + +static unsigned CountNonConstOps(const Function& func) { + unsigned ret = func.Args().size(); + for (auto& bb : func) { + for (auto& ir : *bb) { + ret += IsA(*ir) ? 0 : 1; + } + } + return ret; +} + +// Simple sharding scheme: try to get equal shards of op. +static std::unordered_map GetSimpleSharding( + const Function& func, unsigned num_shards) { + std::unordered_map shardings; + auto num_ops = CountNonConstOps(func); + const unsigned threshold = (num_shards == 0) + ? std::numeric_limits::max() + : (num_ops + num_shards - 1) / num_shards; + unsigned curr_shard = 0; + unsigned allocated = 0; + auto is_avail = [&shardings](const IRObject* n) { + for (auto& op : n->GetOperands()) { + if (!IsA(op) && shardings.count(op.GetOwner()) == 0) { + return false; + } + } + return true; + }; + + // BFS visit. + std::unordered_set workset; + for (auto& arg : func.Args()) { + workset.insert(arg.get()); + } + for (auto& bb : func) { + for (auto& ir : *bb) { + if (!IsA(ir.get()) && is_avail(ir.get())) { + workset.insert(ir.get()); + } + } + } + + while (!workset.empty()) { + std::unordered_set next; + HLCHECK(curr_shard < num_shards); + for (auto& node : workset) { + HLCHECK(shardings.count(node) == 0); + shardings[node] = curr_shard; + if (!IsA(node)) { + ++allocated; + } + for (auto& op : node->GetOperands()) { + if (shardings.count(op.GetOwner()) == 0 && IsA(op)) { + shardings[op.GetOwner()] = curr_shard; + } + } + } + // check successors + for (auto& node : workset) { + for (auto& uses : node->GetResultsUses()) { + for (auto& user : uses) { + auto n = user.GetUse(); + if (is_avail(n) && shardings.count(n) == 0) { + next.insert(n); + } + } + } + } + if (allocated >= threshold) { + ++curr_shard; + allocated = 0; + for (auto& n : next) { + std::cout << " ---- cut:" << n->GetName() << std::endl; + } + } + workset.swap(next); + } + + for (auto& bb : func) { + for (auto& ir : *bb) { + auto inst = ir.get(); + HLCHECK(shardings.count(inst) != 0); + if (shardings.count(inst) == 0) { + std::cerr << "[sharding] unallocated: " << inst->GetName() << "\n"; + shardings[inst] = num_shards - 1; + } + } + } + return shardings; +} + +static void ApplySharding(const std::unordered_map& shardings) { + for (const auto& kv : shardings) { + std::cout << kv.second << ":" << kv.first->GetName() << ":" << std::endl; + } +} + +bool Sharding::RunOnFunction(Function* func) { + if (shards_ <= 0) { + return false; + } + auto sharding_scheme = GetSimpleSharding(*func, shards_); + ApplySharding(sharding_scheme); + return false; +} + +} // end namespace halo \ No newline at end of file