From de2a5adc6a833610c0ecc49e5d68bd2a05b0c8e2 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Tue, 21 May 2024 10:48:37 +0530 Subject: [PATCH] Port refprune to NewPassManager Based on the changes introduced in #1042 by @modiking --- ffi/custom_passes.cpp | 222 ++++++++++++++++++++-------- llvmlite/binding/newpassmanagers.py | 45 ++++++ llvmlite/binding/passmanagers.py | 6 +- llvmlite/tests/test_binding.py | 2 + 4 files changed, 207 insertions(+), 68 deletions(-) diff --git a/ffi/custom_passes.cpp b/ffi/custom_passes.cpp index e3269f2be..816569373 100644 --- a/ffi/custom_passes.cpp +++ b/ffi/custom_passes.cpp @@ -26,10 +26,22 @@ using namespace llvm; namespace llvm { -void initializeRefNormalizePassPass(PassRegistry &Registry); -void initializeRefPrunePassPass(PassRegistry &Registry); +void initializeRefNormalizeLegacyPassPass(PassRegistry &Registry); +void initializeRefPruneLegacyPassPass(PassRegistry &Registry); } // namespace llvm +namespace llvm { +struct OpaqueModulePassManager; +typedef OpaqueModulePassManager *LLVMModulePassManagerRef; +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ModulePassManager, LLVMModulePassManagerRef) + +struct OpaqueFunctionPassManager; +typedef OpaqueFunctionPassManager *LLVMFunctionPassManagerRef; +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(FunctionPassManager, + LLVMFunctionPassManagerRef) +} // namespace llvm + +namespace { /** * Checks if a call instruction is an incref * @@ -104,13 +116,9 @@ template struct raiiStack { * A FunctionPass to reorder incref/decref instructions such that decrefs occur * logically after increfs. This is a pre-requisite pass to the pruner passes. */ -struct RefNormalizePass : public FunctionPass { - static char ID; - RefNormalizePass() : FunctionPass(ID) { - initializeRefNormalizePassPass(*PassRegistry::getPassRegistry()); - } +struct RefNormalize { - bool runOnFunction(Function &F) override { + bool runOnFunction(Function &F) { bool mutated = false; // For each basic block in F for (BasicBlock &bb : F) { @@ -158,7 +166,16 @@ struct RefNormalizePass : public FunctionPass { } }; -struct RefPrunePass : public FunctionPass { +typedef enum { + None = 0b0000, + PerBasicBlock = 0b0001, + Diamond = 0b0010, + Fanout = 0b0100, + FanoutRaise = 0b1000, + All = PerBasicBlock | Diamond | Fanout | FanoutRaise +} Subpasses; + +struct RefPrune { static char ID; static size_t stats_per_bb; static size_t stats_diamond; @@ -175,25 +192,21 @@ struct RefPrunePass : public FunctionPass { /** * Enum for setting which subpasses to run, there is no interdependence. */ - enum Subpasses { - None = 0b0000, - PerBasicBlock = 0b0001, - Diamond = 0b0010, - Fanout = 0b0100, - FanoutRaise = 0b1000, - All = PerBasicBlock | Diamond | Fanout | FanoutRaise - } flags; - RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1) - : FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) { - initializeRefPrunePassPass(*PassRegistry::getPassRegistry()); - } + Subpasses flags; + + DominatorTree &DT; + PostDominatorTree &PDT; + + RefPrune(DominatorTree &DT, PostDominatorTree &PDT, + Subpasses flags = Subpasses::All, size_t subgraph_limit = -1) + : DT(DT), PDT(PDT), flags(flags), subgraph_limit(subgraph_limit) {} bool isSubpassEnabledFor(Subpasses expected) { return (flags & expected) == expected; } - bool runOnFunction(Function &F) override { + bool runOnFunction(Function &F) { // state for LLVM function pass mutated IR bool mutated = false; @@ -361,11 +374,6 @@ struct RefPrunePass : public FunctionPass { */ bool runDiamondPrune(Function &F) { bool mutated = false; - // gets the dominator tree - auto &domtree = getAnalysis().getDomTree(); - // gets the post-dominator tree - auto &postdomtree = - getAnalysis().getPostDomTree(); // Find all increfs and decrefs in the Function and store them in // incref_list and decref_list respectively. @@ -394,8 +402,8 @@ struct RefPrunePass : public FunctionPass { continue; // incref DOM decref && decref POSTDOM incref - if (domtree.dominates(incref, decref) && - postdomtree.dominates(decref, incref)) { + if (DT.dominates(incref, decref) && + PDT.dominates(decref, incref)) { // check that the decref cannot be executed multiple times SmallBBSet tail_nodes; tail_nodes.insert(decref->getParent()); @@ -1028,14 +1036,6 @@ struct RefPrunePass : public FunctionPass { return NULL; } - /** - * getAnalysisUsage() LLVM plumbing for the pass - */ - void getAnalysisUsage(AnalysisUsage &Info) const override { - Info.addRequired(); - Info.addRequired(); - } - /** * Checks if the first argument to the supplied call_inst is NULL and * returns true if so, false otherwise. @@ -1163,34 +1163,128 @@ struct RefPrunePass : public FunctionPass { } } } -}; // end of struct RefPrunePass +}; // end of struct RefPrune + +} // namespace + +class RefPrunePass : public PassInfoMixin { + + public: + Subpasses flags; + size_t subgraph_limit; + RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1) + : flags(flags), subgraph_limit(subgraph_limit) {} + + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) { + auto &DT = AM.getResult(F); + auto &PDT = AM.getResult(F); + if (RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F)) { + return PreservedAnalyses::none(); + } + + return PreservedAnalyses::all(); + } +}; -char RefNormalizePass::ID = 0; -char RefPrunePass::ID = 0; +class RefNormalizePass : public PassInfoMixin { -size_t RefPrunePass::stats_per_bb = 0; -size_t RefPrunePass::stats_diamond = 0; -size_t RefPrunePass::stats_fanout = 0; -size_t RefPrunePass::stats_fanout_raise = 0; + public: + RefNormalizePass() = default; -INITIALIZE_PASS(RefNormalizePass, "nrtrefnormalizepass", "Normalize NRT refops", - false, false) + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) { + RefNormalize().runOnFunction(F); -INITIALIZE_PASS_BEGIN(RefPrunePass, "nrtrefprunepass", "Prune NRT refops", - false, false) + return PreservedAnalyses::all(); + } +}; + +class RefNormalizeLegacyPass : public FunctionPass { + public: + static char ID; + RefNormalizeLegacyPass() : FunctionPass(ID) { + initializeRefNormalizeLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + return RefNormalize().runOnFunction(F); + }; +}; + +class RefPruneLegacyPass : public FunctionPass { + + public: + static char ID; // Pass identification, replacement for typeid + // The maximum number of nodes that the fanout pruners will look at. + size_t subgraph_limit; + Subpasses flags; + RefPruneLegacyPass(Subpasses flags = Subpasses::All, + size_t subgraph_limit = -1) + : FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) { + initializeRefPruneLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis().getDomTree(); + + auto &PDT = + getAnalysis().getPostDomTree(); + + return RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F); + }; + + /** + * getAnalysisUsage() LLVM plumbing for the pass + */ + void getAnalysisUsage(AnalysisUsage &Info) const override { + Info.addRequired(); + Info.addRequired(); + } +}; + +char RefNormalizeLegacyPass::ID = 0; +char RefPruneLegacyPass::ID = 0; + +size_t RefPrune::stats_per_bb = 0; +size_t RefPrune::stats_diamond = 0; +size_t RefPrune::stats_fanout = 0; +size_t RefPrune::stats_fanout_raise = 0; + +INITIALIZE_PASS(RefNormalizeLegacyPass, "nrtRefNormalize", + "Normalize NRT refops", false, false) + +INITIALIZE_PASS_BEGIN(RefPruneLegacyPass, "nrtRefPruneLegacyPass", + "Prune NRT refops", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END(RefPrunePass, "refprunepass", "Prune NRT refops", false, - false) +INITIALIZE_PASS_END(RefPruneLegacyPass, "RefPruneLegacyPass", + "Prune NRT refops", false, false) + extern "C" { API_EXPORT(void) -LLVMPY_AddRefPrunePass(LLVMPassManagerRef PM, int subpasses, - size_t subgraph_limit) { - unwrap(PM)->add(new RefNormalizePass()); +LLVMPY_AddLegacyRefPrunePass(LLVMPassManagerRef PM, int subpasses, + size_t subgraph_limit) { + unwrap(PM)->add(new RefNormalizeLegacyPass()); unwrap(PM)->add( - new RefPrunePass((RefPrunePass::Subpasses)subpasses, subgraph_limit)); + new RefPruneLegacyPass((Subpasses)subpasses, subgraph_limit)); +} + +API_EXPORT(void) +LLVMPY_AddRefPrunePass_module(LLVMModulePassManagerRef MPM, int subpasses, + size_t subgraph_limit) { + llvm::unwrap(MPM)->addPass( + createModuleToFunctionPassAdaptor(RefNormalizePass())); + llvm::unwrap(MPM)->addPass(createModuleToFunctionPassAdaptor( + RefPrunePass((Subpasses)subpasses, subgraph_limit))); +} + +API_EXPORT(void) +LLVMPY_AddRefPrunePass_function(LLVMFunctionPassManagerRef FPM, int subpasses, + size_t subgraph_limit) { + llvm::unwrap(FPM)->addPass(RefNormalizePass()); + llvm::unwrap(FPM)->addPass( + RefPrunePass((Subpasses)subpasses, subgraph_limit)); } /** @@ -1207,24 +1301,22 @@ typedef struct PruneStats { API_EXPORT(void) LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) { /* PRUNESTATS is updated with the statistics about what has been pruned from - * the RefPrunePass static state vars. This isn't threadsafe but neither is + * the RefPrune static state vars. This isn't threadsafe but neither is * the LLVM pass infrastructure so it's all done under a python thread lock. * * do_print if set will print the stats to stderr. */ if (do_print) { - errs() << "refprune stats " - << "per-BB " << RefPrunePass::stats_per_bb << " " - << "diamond " << RefPrunePass::stats_diamond << " " - << "fanout " << RefPrunePass::stats_fanout << " " - << "fanout+raise " << RefPrunePass::stats_fanout_raise << " " - << "\n"; + errs() << "refprune stats " << "per-BB " << RefPrune::stats_per_bb + << " " << "diamond " << RefPrune::stats_diamond << " " + << "fanout " << RefPrune::stats_fanout << " " << "fanout+raise " + << RefPrune::stats_fanout_raise << " " << "\n"; }; - buf->basicblock = RefPrunePass::stats_per_bb; - buf->diamond = RefPrunePass::stats_diamond; - buf->fanout = RefPrunePass::stats_fanout; - buf->fanout_raise = RefPrunePass::stats_fanout_raise; + buf->basicblock = RefPrune::stats_per_bb; + buf->diamond = RefPrune::stats_diamond; + buf->fanout = RefPrune::stats_fanout; + buf->fanout_raise = RefPrune::stats_fanout_raise; } } // extern "C" diff --git a/llvmlite/binding/newpassmanagers.py b/llvmlite/binding/newpassmanagers.py index f1dd20cc7..def42721b 100644 --- a/llvmlite/binding/newpassmanagers.py +++ b/llvmlite/binding/newpassmanagers.py @@ -1,4 +1,5 @@ from ctypes import c_bool, c_int +from enum import IntFlag from llvmlite.binding import ffi @@ -18,6 +19,14 @@ def create_pipeline_tuning_options(speed_level=2, size_level=0): return PipelineTuningOptions(speed_level, size_level) +class RefPruneSubpasses(IntFlag): + PER_BB = 0b0001 # noqa: E221 + DIAMOND = 0b0010 # noqa: E221 + FANOUT = 0b0100 # noqa: E221 + FANOUT_RAISE = 0b1000 + ALL = PER_BB | DIAMOND | FANOUT | FANOUT_RAISE + + class ModulePassManager(ffi.ObjectRef): def __init__(self, ptr=None): @@ -52,6 +61,24 @@ def add_jump_threading_pass(self, threshold=-1): def _dispose(self): ffi.lib.LLVMPY_DisposeNewModulePassManger(self) + # Non-standard LLVM passes + def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, + subgraph_limit=1000): + """Add Numba specific Reference count pruning pass. + + Parameters + ---------- + subpasses_flags : RefPruneSubpasses + A bitmask to control the subpasses to be enabled. + subgraph_limit : int + Limit the fanout pruners to working on a subgraph no bigger than + this number of basic-blocks to avoid spending too much time in very + large graphs. Default is 1000. Subject to change in future + versions. + """ + iflags = RefPruneSubpasses(subpasses_flags) + ffi.lib.LLVMPY_AddRefPrunePass_module(self, iflags, subgraph_limit) + class FunctionPassManager(ffi.ObjectRef): @@ -84,6 +111,24 @@ def add_jump_threading_pass(self, threshold=-1): def _dispose(self): ffi.lib.LLVMPY_DisposeNewFunctionPassManger(self) + # Non-standard LLVM passes + def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, + subgraph_limit=1000): + """Add Numba specific Reference count pruning pass. + + Parameters + ---------- + subpasses_flags : RefPruneSubpasses + A bitmask to control the subpasses to be enabled. + subgraph_limit : int + Limit the fanout pruners to working on a subgraph no bigger than + this number of basic-blocks to avoid spending too much time in very + large graphs. Default is 1000. Subject to change in future + versions. + """ + iflags = RefPruneSubpasses(subpasses_flags) + ffi.lib.LLVMPY_AddRefPrunePass_function(self, iflags, subgraph_limit) + class PipelineTuningOptions(ffi.ObjectRef): diff --git a/llvmlite/binding/passmanagers.py b/llvmlite/binding/passmanagers.py index d983ce9d5..30e6c402b 100644 --- a/llvmlite/binding/passmanagers.py +++ b/llvmlite/binding/passmanagers.py @@ -668,7 +668,7 @@ def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, versions. """ iflags = RefPruneSubpasses(subpasses_flags) - ffi.lib.LLVMPY_AddRefPrunePass(self, iflags, subgraph_limit) + ffi.lib.LLVMPY_AddLegacyRefPrunePass(self, iflags, subgraph_limit) class ModulePassManager(PassManager): @@ -933,7 +933,7 @@ def run_with_remarks(self, function, remarks_format='yaml', c_char_p] ffi.lib.LLVMPY_AddInstructionNamerPass.argtypes = [ffi.LLVMPassManagerRef] -ffi.lib.LLVMPY_AddRefPrunePass.argtypes = [ffi.LLVMPassManagerRef, c_int, - c_size_t] +ffi.lib.LLVMPY_AddLegacyRefPrunePass.argtypes = [ffi.LLVMPassManagerRef, c_int, + c_size_t] ffi.lib.LLVMPY_DumpRefPruneStats.argtypes = [POINTER(_c_PruneStats), c_bool] diff --git a/llvmlite/tests/test_binding.py b/llvmlite/tests/test_binding.py index c1ce3d08d..717641322 100644 --- a/llvmlite/tests/test_binding.py +++ b/llvmlite/tests/test_binding.py @@ -2745,6 +2745,7 @@ def test_add_passes(self): mpm.add_loop_rotate_pass() mpm.add_instruction_combine_pass() mpm.add_jump_threading_pass() + mpm.add_refprune_pass() class TestNewFunctionPassManager(BaseTest, NewPassManagerMixin): @@ -2785,6 +2786,7 @@ def test_add_passes(self): fpm.add_loop_rotate_pass() fpm.add_instruction_combine_pass() fpm.add_jump_threading_pass() + fpm.add_refprune_pass() if __name__ == "__main__":