From 6b21b2ae5f8fe8b0fbf7059ff80bfd5300c968cf Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Tue, 21 May 2024 10:48:37 +0530 Subject: [PATCH 1/8] Port refprune to NewPassManager Based on the changes introduced in #1042 by @modiking --- ffi/custom_passes.cpp | 222 ++++++++++++++++++++-------- llvmlite/binding/newpassmanagers.py | 45 ++++++ llvmlite/binding/passmanagers.py | 6 +- llvmlite/tests/test_binding.py | 2 + 4 files changed, 207 insertions(+), 68 deletions(-) diff --git a/ffi/custom_passes.cpp b/ffi/custom_passes.cpp index e3269f2be..816569373 100644 --- a/ffi/custom_passes.cpp +++ b/ffi/custom_passes.cpp @@ -26,10 +26,22 @@ using namespace llvm; namespace llvm { -void initializeRefNormalizePassPass(PassRegistry &Registry); -void initializeRefPrunePassPass(PassRegistry &Registry); +void initializeRefNormalizeLegacyPassPass(PassRegistry &Registry); +void initializeRefPruneLegacyPassPass(PassRegistry &Registry); } // namespace llvm +namespace llvm { +struct OpaqueModulePassManager; +typedef OpaqueModulePassManager *LLVMModulePassManagerRef; +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ModulePassManager, LLVMModulePassManagerRef) + +struct OpaqueFunctionPassManager; +typedef OpaqueFunctionPassManager *LLVMFunctionPassManagerRef; +DEFINE_SIMPLE_CONVERSION_FUNCTIONS(FunctionPassManager, + LLVMFunctionPassManagerRef) +} // namespace llvm + +namespace { /** * Checks if a call instruction is an incref * @@ -104,13 +116,9 @@ template struct raiiStack { * A FunctionPass to reorder incref/decref instructions such that decrefs occur * logically after increfs. This is a pre-requisite pass to the pruner passes. */ -struct RefNormalizePass : public FunctionPass { - static char ID; - RefNormalizePass() : FunctionPass(ID) { - initializeRefNormalizePassPass(*PassRegistry::getPassRegistry()); - } +struct RefNormalize { - bool runOnFunction(Function &F) override { + bool runOnFunction(Function &F) { bool mutated = false; // For each basic block in F for (BasicBlock &bb : F) { @@ -158,7 +166,16 @@ struct RefNormalizePass : public FunctionPass { } }; -struct RefPrunePass : public FunctionPass { +typedef enum { + None = 0b0000, + PerBasicBlock = 0b0001, + Diamond = 0b0010, + Fanout = 0b0100, + FanoutRaise = 0b1000, + All = PerBasicBlock | Diamond | Fanout | FanoutRaise +} Subpasses; + +struct RefPrune { static char ID; static size_t stats_per_bb; static size_t stats_diamond; @@ -175,25 +192,21 @@ struct RefPrunePass : public FunctionPass { /** * Enum for setting which subpasses to run, there is no interdependence. */ - enum Subpasses { - None = 0b0000, - PerBasicBlock = 0b0001, - Diamond = 0b0010, - Fanout = 0b0100, - FanoutRaise = 0b1000, - All = PerBasicBlock | Diamond | Fanout | FanoutRaise - } flags; - RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1) - : FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) { - initializeRefPrunePassPass(*PassRegistry::getPassRegistry()); - } + Subpasses flags; + + DominatorTree &DT; + PostDominatorTree &PDT; + + RefPrune(DominatorTree &DT, PostDominatorTree &PDT, + Subpasses flags = Subpasses::All, size_t subgraph_limit = -1) + : DT(DT), PDT(PDT), flags(flags), subgraph_limit(subgraph_limit) {} bool isSubpassEnabledFor(Subpasses expected) { return (flags & expected) == expected; } - bool runOnFunction(Function &F) override { + bool runOnFunction(Function &F) { // state for LLVM function pass mutated IR bool mutated = false; @@ -361,11 +374,6 @@ struct RefPrunePass : public FunctionPass { */ bool runDiamondPrune(Function &F) { bool mutated = false; - // gets the dominator tree - auto &domtree = getAnalysis().getDomTree(); - // gets the post-dominator tree - auto &postdomtree = - getAnalysis().getPostDomTree(); // Find all increfs and decrefs in the Function and store them in // incref_list and decref_list respectively. @@ -394,8 +402,8 @@ struct RefPrunePass : public FunctionPass { continue; // incref DOM decref && decref POSTDOM incref - if (domtree.dominates(incref, decref) && - postdomtree.dominates(decref, incref)) { + if (DT.dominates(incref, decref) && + PDT.dominates(decref, incref)) { // check that the decref cannot be executed multiple times SmallBBSet tail_nodes; tail_nodes.insert(decref->getParent()); @@ -1028,14 +1036,6 @@ struct RefPrunePass : public FunctionPass { return NULL; } - /** - * getAnalysisUsage() LLVM plumbing for the pass - */ - void getAnalysisUsage(AnalysisUsage &Info) const override { - Info.addRequired(); - Info.addRequired(); - } - /** * Checks if the first argument to the supplied call_inst is NULL and * returns true if so, false otherwise. @@ -1163,34 +1163,128 @@ struct RefPrunePass : public FunctionPass { } } } -}; // end of struct RefPrunePass +}; // end of struct RefPrune + +} // namespace + +class RefPrunePass : public PassInfoMixin { + + public: + Subpasses flags; + size_t subgraph_limit; + RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1) + : flags(flags), subgraph_limit(subgraph_limit) {} + + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) { + auto &DT = AM.getResult(F); + auto &PDT = AM.getResult(F); + if (RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F)) { + return PreservedAnalyses::none(); + } + + return PreservedAnalyses::all(); + } +}; -char RefNormalizePass::ID = 0; -char RefPrunePass::ID = 0; +class RefNormalizePass : public PassInfoMixin { -size_t RefPrunePass::stats_per_bb = 0; -size_t RefPrunePass::stats_diamond = 0; -size_t RefPrunePass::stats_fanout = 0; -size_t RefPrunePass::stats_fanout_raise = 0; + public: + RefNormalizePass() = default; -INITIALIZE_PASS(RefNormalizePass, "nrtrefnormalizepass", "Normalize NRT refops", - false, false) + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) { + RefNormalize().runOnFunction(F); -INITIALIZE_PASS_BEGIN(RefPrunePass, "nrtrefprunepass", "Prune NRT refops", - false, false) + return PreservedAnalyses::all(); + } +}; + +class RefNormalizeLegacyPass : public FunctionPass { + public: + static char ID; + RefNormalizeLegacyPass() : FunctionPass(ID) { + initializeRefNormalizeLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + return RefNormalize().runOnFunction(F); + }; +}; + +class RefPruneLegacyPass : public FunctionPass { + + public: + static char ID; // Pass identification, replacement for typeid + // The maximum number of nodes that the fanout pruners will look at. + size_t subgraph_limit; + Subpasses flags; + RefPruneLegacyPass(Subpasses flags = Subpasses::All, + size_t subgraph_limit = -1) + : FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) { + initializeRefPruneLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis().getDomTree(); + + auto &PDT = + getAnalysis().getPostDomTree(); + + return RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F); + }; + + /** + * getAnalysisUsage() LLVM plumbing for the pass + */ + void getAnalysisUsage(AnalysisUsage &Info) const override { + Info.addRequired(); + Info.addRequired(); + } +}; + +char RefNormalizeLegacyPass::ID = 0; +char RefPruneLegacyPass::ID = 0; + +size_t RefPrune::stats_per_bb = 0; +size_t RefPrune::stats_diamond = 0; +size_t RefPrune::stats_fanout = 0; +size_t RefPrune::stats_fanout_raise = 0; + +INITIALIZE_PASS(RefNormalizeLegacyPass, "nrtRefNormalize", + "Normalize NRT refops", false, false) + +INITIALIZE_PASS_BEGIN(RefPruneLegacyPass, "nrtRefPruneLegacyPass", + "Prune NRT refops", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END(RefPrunePass, "refprunepass", "Prune NRT refops", false, - false) +INITIALIZE_PASS_END(RefPruneLegacyPass, "RefPruneLegacyPass", + "Prune NRT refops", false, false) + extern "C" { API_EXPORT(void) -LLVMPY_AddRefPrunePass(LLVMPassManagerRef PM, int subpasses, - size_t subgraph_limit) { - unwrap(PM)->add(new RefNormalizePass()); +LLVMPY_AddLegacyRefPrunePass(LLVMPassManagerRef PM, int subpasses, + size_t subgraph_limit) { + unwrap(PM)->add(new RefNormalizeLegacyPass()); unwrap(PM)->add( - new RefPrunePass((RefPrunePass::Subpasses)subpasses, subgraph_limit)); + new RefPruneLegacyPass((Subpasses)subpasses, subgraph_limit)); +} + +API_EXPORT(void) +LLVMPY_AddRefPrunePass_module(LLVMModulePassManagerRef MPM, int subpasses, + size_t subgraph_limit) { + llvm::unwrap(MPM)->addPass( + createModuleToFunctionPassAdaptor(RefNormalizePass())); + llvm::unwrap(MPM)->addPass(createModuleToFunctionPassAdaptor( + RefPrunePass((Subpasses)subpasses, subgraph_limit))); +} + +API_EXPORT(void) +LLVMPY_AddRefPrunePass_function(LLVMFunctionPassManagerRef FPM, int subpasses, + size_t subgraph_limit) { + llvm::unwrap(FPM)->addPass(RefNormalizePass()); + llvm::unwrap(FPM)->addPass( + RefPrunePass((Subpasses)subpasses, subgraph_limit)); } /** @@ -1207,24 +1301,22 @@ typedef struct PruneStats { API_EXPORT(void) LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) { /* PRUNESTATS is updated with the statistics about what has been pruned from - * the RefPrunePass static state vars. This isn't threadsafe but neither is + * the RefPrune static state vars. This isn't threadsafe but neither is * the LLVM pass infrastructure so it's all done under a python thread lock. * * do_print if set will print the stats to stderr. */ if (do_print) { - errs() << "refprune stats " - << "per-BB " << RefPrunePass::stats_per_bb << " " - << "diamond " << RefPrunePass::stats_diamond << " " - << "fanout " << RefPrunePass::stats_fanout << " " - << "fanout+raise " << RefPrunePass::stats_fanout_raise << " " - << "\n"; + errs() << "refprune stats " << "per-BB " << RefPrune::stats_per_bb + << " " << "diamond " << RefPrune::stats_diamond << " " + << "fanout " << RefPrune::stats_fanout << " " << "fanout+raise " + << RefPrune::stats_fanout_raise << " " << "\n"; }; - buf->basicblock = RefPrunePass::stats_per_bb; - buf->diamond = RefPrunePass::stats_diamond; - buf->fanout = RefPrunePass::stats_fanout; - buf->fanout_raise = RefPrunePass::stats_fanout_raise; + buf->basicblock = RefPrune::stats_per_bb; + buf->diamond = RefPrune::stats_diamond; + buf->fanout = RefPrune::stats_fanout; + buf->fanout_raise = RefPrune::stats_fanout_raise; } } // extern "C" diff --git a/llvmlite/binding/newpassmanagers.py b/llvmlite/binding/newpassmanagers.py index f1dd20cc7..def42721b 100644 --- a/llvmlite/binding/newpassmanagers.py +++ b/llvmlite/binding/newpassmanagers.py @@ -1,4 +1,5 @@ from ctypes import c_bool, c_int +from enum import IntFlag from llvmlite.binding import ffi @@ -18,6 +19,14 @@ def create_pipeline_tuning_options(speed_level=2, size_level=0): return PipelineTuningOptions(speed_level, size_level) +class RefPruneSubpasses(IntFlag): + PER_BB = 0b0001 # noqa: E221 + DIAMOND = 0b0010 # noqa: E221 + FANOUT = 0b0100 # noqa: E221 + FANOUT_RAISE = 0b1000 + ALL = PER_BB | DIAMOND | FANOUT | FANOUT_RAISE + + class ModulePassManager(ffi.ObjectRef): def __init__(self, ptr=None): @@ -52,6 +61,24 @@ def add_jump_threading_pass(self, threshold=-1): def _dispose(self): ffi.lib.LLVMPY_DisposeNewModulePassManger(self) + # Non-standard LLVM passes + def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, + subgraph_limit=1000): + """Add Numba specific Reference count pruning pass. + + Parameters + ---------- + subpasses_flags : RefPruneSubpasses + A bitmask to control the subpasses to be enabled. + subgraph_limit : int + Limit the fanout pruners to working on a subgraph no bigger than + this number of basic-blocks to avoid spending too much time in very + large graphs. Default is 1000. Subject to change in future + versions. + """ + iflags = RefPruneSubpasses(subpasses_flags) + ffi.lib.LLVMPY_AddRefPrunePass_module(self, iflags, subgraph_limit) + class FunctionPassManager(ffi.ObjectRef): @@ -84,6 +111,24 @@ def add_jump_threading_pass(self, threshold=-1): def _dispose(self): ffi.lib.LLVMPY_DisposeNewFunctionPassManger(self) + # Non-standard LLVM passes + def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, + subgraph_limit=1000): + """Add Numba specific Reference count pruning pass. + + Parameters + ---------- + subpasses_flags : RefPruneSubpasses + A bitmask to control the subpasses to be enabled. + subgraph_limit : int + Limit the fanout pruners to working on a subgraph no bigger than + this number of basic-blocks to avoid spending too much time in very + large graphs. Default is 1000. Subject to change in future + versions. + """ + iflags = RefPruneSubpasses(subpasses_flags) + ffi.lib.LLVMPY_AddRefPrunePass_function(self, iflags, subgraph_limit) + class PipelineTuningOptions(ffi.ObjectRef): diff --git a/llvmlite/binding/passmanagers.py b/llvmlite/binding/passmanagers.py index 3fecbc721..0ce61aefc 100644 --- a/llvmlite/binding/passmanagers.py +++ b/llvmlite/binding/passmanagers.py @@ -672,7 +672,7 @@ def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, versions. """ iflags = RefPruneSubpasses(subpasses_flags) - ffi.lib.LLVMPY_AddRefPrunePass(self, iflags, subgraph_limit) + ffi.lib.LLVMPY_AddLegacyRefPrunePass(self, iflags, subgraph_limit) class ModulePassManager(PassManager): @@ -940,7 +940,7 @@ def run_with_remarks(self, function, remarks_format='yaml', c_char_p] ffi.lib.LLVMPY_AddInstructionNamerPass.argtypes = [ffi.LLVMPassManagerRef] -ffi.lib.LLVMPY_AddRefPrunePass.argtypes = [ffi.LLVMPassManagerRef, c_int, - c_size_t] +ffi.lib.LLVMPY_AddLegacyRefPrunePass.argtypes = [ffi.LLVMPassManagerRef, c_int, + c_size_t] ffi.lib.LLVMPY_DumpRefPruneStats.argtypes = [POINTER(_c_PruneStats), c_bool] diff --git a/llvmlite/tests/test_binding.py b/llvmlite/tests/test_binding.py index 265682f2d..043b41f09 100644 --- a/llvmlite/tests/test_binding.py +++ b/llvmlite/tests/test_binding.py @@ -3054,6 +3054,7 @@ def test_add_passes(self): mpm.add_loop_rotate_pass() mpm.add_instruction_combine_pass() mpm.add_jump_threading_pass() + mpm.add_refprune_pass() class TestNewFunctionPassManager(BaseTest, NewPassManagerMixin): @@ -3094,6 +3095,7 @@ def test_add_passes(self): fpm.add_loop_rotate_pass() fpm.add_instruction_combine_pass() fpm.add_jump_threading_pass() + fpm.add_refprune_pass() if __name__ == "__main__": From 2d9f8e609d4cf1103e25981bb5af74c7b1096619 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Wed, 10 Jul 2024 10:01:18 +0530 Subject: [PATCH 2/8] clang format --- ffi/custom_passes.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ffi/custom_passes.cpp b/ffi/custom_passes.cpp index 816569373..b675cd516 100644 --- a/ffi/custom_passes.cpp +++ b/ffi/custom_passes.cpp @@ -1299,7 +1299,8 @@ typedef struct PruneStats { } PRUNESTATS; API_EXPORT(void) -LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) { +LLVMPY_DumpRefPruneStats(PRUNESTATS* buf, bool do_print) +{ /* PRUNESTATS is updated with the statistics about what has been pruned from * the RefPrune static state vars. This isn't threadsafe but neither is * the LLVM pass infrastructure so it's all done under a python thread lock. @@ -1307,10 +1308,14 @@ LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) { * do_print if set will print the stats to stderr. */ if (do_print) { - errs() << "refprune stats " << "per-BB " << RefPrune::stats_per_bb - << " " << "diamond " << RefPrune::stats_diamond << " " - << "fanout " << RefPrune::stats_fanout << " " << "fanout+raise " - << RefPrune::stats_fanout_raise << " " << "\n"; + errs() << "refprune stats " + << "per-BB " << RefPrune::stats_per_bb + << " " + << "diamond " << RefPrune::stats_diamond << " " + << "fanout " << RefPrune::stats_fanout << " " + << "fanout+raise " + << RefPrune::stats_fanout_raise << " " + << "\n"; }; buf->basicblock = RefPrune::stats_per_bb; From 4554776ee7e4ba67a3982d478defe41e4163ae47 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Thu, 18 Jul 2024 10:26:01 +0530 Subject: [PATCH 3/8] Revert "clang format" This reverts commit 2d9f8e609d4cf1103e25981bb5af74c7b1096619. --- ffi/custom_passes.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/ffi/custom_passes.cpp b/ffi/custom_passes.cpp index b675cd516..816569373 100644 --- a/ffi/custom_passes.cpp +++ b/ffi/custom_passes.cpp @@ -1299,8 +1299,7 @@ typedef struct PruneStats { } PRUNESTATS; API_EXPORT(void) -LLVMPY_DumpRefPruneStats(PRUNESTATS* buf, bool do_print) -{ +LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) { /* PRUNESTATS is updated with the statistics about what has been pruned from * the RefPrune static state vars. This isn't threadsafe but neither is * the LLVM pass infrastructure so it's all done under a python thread lock. @@ -1308,14 +1307,10 @@ LLVMPY_DumpRefPruneStats(PRUNESTATS* buf, bool do_print) * do_print if set will print the stats to stderr. */ if (do_print) { - errs() << "refprune stats " - << "per-BB " << RefPrune::stats_per_bb - << " " - << "diamond " << RefPrune::stats_diamond << " " - << "fanout " << RefPrune::stats_fanout << " " - << "fanout+raise " - << RefPrune::stats_fanout_raise << " " - << "\n"; + errs() << "refprune stats " << "per-BB " << RefPrune::stats_per_bb + << " " << "diamond " << RefPrune::stats_diamond << " " + << "fanout " << RefPrune::stats_fanout << " " << "fanout+raise " + << RefPrune::stats_fanout_raise << " " << "\n"; }; buf->basicblock = RefPrune::stats_per_bb; From 1e85a8b323bf60e9cedaea60508bce9b9054f317 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Fri, 19 Jul 2024 12:07:01 +0530 Subject: [PATCH 4/8] Port refprune pass tests to npm --- llvmlite/tests/test_refprune.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/llvmlite/tests/test_refprune.py b/llvmlite/tests/test_refprune.py index 0c4208a26..e2b90c745 100644 --- a/llvmlite/tests/test_refprune.py +++ b/llvmlite/tests/test_refprune.py @@ -3,7 +3,7 @@ from llvmlite import binding as llvm from llvmlite.tests import TestCase -from . import refprune_proto as proto +import llvmlite.tests.refprune_proto as proto def _iterate_cases(generate_test): @@ -18,6 +18,15 @@ def wrapped(self): yield f'test_{k}', wrap(case_fn) +class PassManagerMixin(): + + def pb(self): + llvm.initialize_native_target() + tm = llvm.Target.from_default_triple().create_target_machine() + pto = llvm.create_pipeline_tuning_options(speed_level=0, size_level=0) + return llvm.create_pass_builder(tm, pto) + + class TestRefPrunePrototype(TestCase): """ Test that the prototype is working. @@ -35,7 +44,7 @@ def generate_test(self, case_gen): ptr_ty = ir.IntType(8).as_pointer() -class TestRefPrunePass(TestCase): +class TestRefPrunePass(TestCase, PassManagerMixin): """ Test that the C++ implementation matches the expected behavior as for the prototype. @@ -116,9 +125,10 @@ def generate_ir(self, nodes, edges): def apply_refprune(self, irmod): mod = llvm.parse_assembly(str(irmod)) - pm = llvm.ModulePassManager() + pb = self.pb() + pm = pb.getModulePassManager() pm.add_refprune_pass() - pm.run(mod) + pm.run(mod, pb) return mod def check(self, mod, expected, nodes): @@ -158,7 +168,7 @@ def generate_test(self, case_gen): locals()[name] = case -class BaseTestByIR(TestCase): +class BaseTestByIR(TestCase, PassManagerMixin): refprune_bitmask = 0 prologue = r""" @@ -168,14 +178,15 @@ class BaseTestByIR(TestCase): def check(self, irmod, subgraph_limit=None): mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}") - pm = llvm.ModulePassManager() + pb = self.pb() + pm = pb.getModulePassManager() if subgraph_limit is None: pm.add_refprune_pass(self.refprune_bitmask) else: pm.add_refprune_pass(self.refprune_bitmask, subgraph_limit=subgraph_limit) before = llvm.dump_refprune_stats() - pm.run(mod) + pm.run(mod, pb) after = llvm.dump_refprune_stats() return mod, after - before From 2203726fc7610b049c2188ebb40f793c2a890804 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Fri, 19 Jul 2024 12:11:13 +0530 Subject: [PATCH 5/8] clang-format 14 based formatting --- ffi/custom_passes.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ffi/custom_passes.cpp b/ffi/custom_passes.cpp index 816569373..a7906617b 100644 --- a/ffi/custom_passes.cpp +++ b/ffi/custom_passes.cpp @@ -1307,10 +1307,12 @@ LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) { * do_print if set will print the stats to stderr. */ if (do_print) { - errs() << "refprune stats " << "per-BB " << RefPrune::stats_per_bb - << " " << "diamond " << RefPrune::stats_diamond << " " - << "fanout " << RefPrune::stats_fanout << " " << "fanout+raise " - << RefPrune::stats_fanout_raise << " " << "\n"; + errs() << "refprune stats " + << "per-BB " << RefPrune::stats_per_bb << " " + << "diamond " << RefPrune::stats_diamond << " " + << "fanout " << RefPrune::stats_fanout << " " + << "fanout+raise " << RefPrune::stats_fanout_raise << " " + << "\n"; }; buf->basicblock = RefPrune::stats_per_bb; From a619eb829e6fc1604cabe0adf60d530b5b1677e6 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Mon, 19 Aug 2024 11:41:26 +0530 Subject: [PATCH 6/8] Doc for add_refprune_pass() --- docs/source/user-guide/binding/optimization-passes.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/user-guide/binding/optimization-passes.rst b/docs/source/user-guide/binding/optimization-passes.rst index 4dd129e9a..a0a6f5eaa 100644 --- a/docs/source/user-guide/binding/optimization-passes.rst +++ b/docs/source/user-guide/binding/optimization-passes.rst @@ -169,6 +169,11 @@ The ``add_*`` methods supported by both pass manager classes are: Add the `Simplify CFG `_ pass. +.. method:: add_refprune_pass() + + Add the `Reference pruning + `_ pass. + .. currentmodule:: llvmlite.binding Legacy Pass Manager APIs @@ -301,6 +306,11 @@ create and configure a :class:`PassManagerBuilder`. See `instnamer pass documentation `_. + * .. function:: add_refprune_pass() + + Add the `Reference pruning + `_ pass. + .. class:: ModulePassManager() :no-index: From 5e5b5c61cf5f2dbf55bac5da1cfb21a0cb422bd4 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Mon, 19 Aug 2024 12:14:47 +0530 Subject: [PATCH 7/8] Add back tests for legacy version of refprune_pass --- llvmlite/tests/test_refprune.py | 117 ++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/llvmlite/tests/test_refprune.py b/llvmlite/tests/test_refprune.py index e2b90c745..3fb93e9a4 100644 --- a/llvmlite/tests/test_refprune.py +++ b/llvmlite/tests/test_refprune.py @@ -5,6 +5,8 @@ import llvmlite.tests.refprune_proto as proto +# TODO:: Get rid of Legacy tests once completely transitioned to NewPassManager + def _iterate_cases(generate_test): def wrap(fn): @@ -131,6 +133,13 @@ def apply_refprune(self, irmod): pm.run(mod, pb) return mod + def apply_refprune_legacy(self, irmod): + mod = llvm.parse_assembly(str(irmod)) + pm = llvm.ModulePassManager() + pm.add_refprune_pass() + pm.run(mod) + return mod + def check(self, mod, expected, nodes): # preprocess incref/decref locations d = {} @@ -163,10 +172,19 @@ def generate_test(self, case_gen): outmod = self.apply_refprune(irmod) self.check(outmod, expected, nodes) + def generate_test_legacy(self, case_gen): + nodes, edges, expected = case_gen() + irmod = self.generate_ir(nodes, edges) + outmod = self.apply_refprune_legacy(irmod) + self.check(outmod, expected, nodes) + # Generate tests for name, case in _iterate_cases(generate_test): locals()[name] = case + for name, case in _iterate_cases(generate_test_legacy): + locals()[name + "_legacy"] = case + class BaseTestByIR(TestCase, PassManagerMixin): refprune_bitmask = 0 @@ -190,6 +208,19 @@ def check(self, irmod, subgraph_limit=None): after = llvm.dump_refprune_stats() return mod, after - before + def check_legacy(self, irmod, subgraph_limit=None): + mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}") + pm = llvm.ModulePassManager() + if subgraph_limit is None: + pm.add_refprune_pass(self.refprune_bitmask) + else: + pm.add_refprune_pass(self.refprune_bitmask, + subgraph_limit=subgraph_limit) + before = llvm.dump_refprune_stats() + pm.run(mod) + after = llvm.dump_refprune_stats() + return mod, after - before + class TestPerBB(BaseTestByIR): refprune_bitmask = llvm.RefPruneSubpasses.PER_BB @@ -206,6 +237,10 @@ def test_per_bb_1(self): mod, stats = self.check(self.per_bb_ir_1) self.assertEqual(stats.basicblock, 2) + def test_per_bb_1_legacy(self): + mod, stats = self.check_legacy(self.per_bb_ir_1) + self.assertEqual(stats.basicblock, 2) + per_bb_ir_2 = r""" define void @main(i8* %ptr) { call void @NRT_incref(i8* %ptr) @@ -223,6 +258,12 @@ def test_per_bb_2(self): # not pruned self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod)) + def test_per_bb_2_legacy(self): + mod, stats = self.check_legacy(self.per_bb_ir_2) + self.assertEqual(stats.basicblock, 4) + # not pruned + self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod)) + per_bb_ir_3 = r""" define void @main(i8* %ptr, i8* %other) { call void @NRT_incref(i8* %ptr) @@ -239,6 +280,12 @@ def test_per_bb_3(self): # not pruned self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) + def test_per_bb_3_legacy(self): + mod, stats = self.check_legacy(self.per_bb_ir_3) + self.assertEqual(stats.basicblock, 2) + # not pruned + self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) + per_bb_ir_4 = r""" ; reordered define void @main(i8* %ptr, i8* %other) { @@ -257,6 +304,12 @@ def test_per_bb_4(self): # not pruned self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) + def test_per_bb_4_legacy(self): + mod, stats = self.check_legacy(self.per_bb_ir_4) + self.assertEqual(stats.basicblock, 4) + # not pruned + self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) + class TestDiamond(BaseTestByIR): refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND @@ -276,6 +329,10 @@ def test_per_diamond_1(self): mod, stats = self.check(self.per_diamond_1) self.assertEqual(stats.diamond, 2) + def test_per_diamond_1_legacy(self): + mod, stats = self.check_legacy(self.per_diamond_1) + self.assertEqual(stats.diamond, 2) + per_diamond_2 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: @@ -295,6 +352,10 @@ def test_per_diamond_2(self): mod, stats = self.check(self.per_diamond_2) self.assertEqual(stats.diamond, 2) + def test_per_diamond_2_legacy(self): + mod, stats = self.check_legacy(self.per_diamond_2) + self.assertEqual(stats.diamond, 2) + per_diamond_3 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: @@ -315,6 +376,10 @@ def test_per_diamond_3(self): mod, stats = self.check(self.per_diamond_3) self.assertEqual(stats.diamond, 0) + def test_per_diamond_3_legacy(self): + mod, stats = self.check_legacy(self.per_diamond_3) + self.assertEqual(stats.diamond, 0) + per_diamond_4 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: @@ -335,6 +400,10 @@ def test_per_diamond_4(self): mod, stats = self.check(self.per_diamond_4) self.assertEqual(stats.diamond, 2) + def test_per_diamond_4_legacy(self): + mod, stats = self.check_legacy(self.per_diamond_4) + self.assertEqual(stats.diamond, 2) + per_diamond_5 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: @@ -356,6 +425,10 @@ def test_per_diamond_5(self): mod, stats = self.check(self.per_diamond_5) self.assertEqual(stats.diamond, 4) + def test_per_diamond_5_legacy(self): + mod, stats = self.check_legacy(self.per_diamond_5) + self.assertEqual(stats.diamond, 4) + class TestFanout(BaseTestByIR): """More complex cases are tested in TestRefPrunePass @@ -381,6 +454,10 @@ def test_fanout_1(self): mod, stats = self.check(self.fanout_1) self.assertEqual(stats.fanout, 3) + def test_fanout_1_legacy(self): + mod, stats = self.check_legacy(self.fanout_1) + self.assertEqual(stats.fanout, 3) + fanout_2 = r""" define void @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: @@ -399,6 +476,10 @@ def test_fanout_2(self): mod, stats = self.check(self.fanout_2) self.assertEqual(stats.fanout, 0) + def test_fanout_2_legacy(self): + mod, stats = self.check_legacy(self.fanout_2) + self.assertEqual(stats.fanout, 0) + fanout_3 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: @@ -427,6 +508,16 @@ def test_fanout_3_limited(self): mod, stats = self.check(self.fanout_3, subgraph_limit=1) self.assertEqual(stats.fanout, 0) + def test_fanout_3_legacy(self): + mod, stats = self.check_legacy(self.fanout_3) + self.assertEqual(stats.fanout, 6) + + def test_fanout_3_limited_legacy(self): + # With subgraph limit at 1, it is essentially turning off the fanout + # pruner. + mod, stats = self.check_legacy(self.fanout_3, subgraph_limit=1) + self.assertEqual(stats.fanout, 0) + class TestFanoutRaise(BaseTestByIR): refprune_bitmask = llvm.RefPruneSubpasses.FANOUT_RAISE @@ -450,6 +541,10 @@ def test_fanout_raise_1(self): mod, stats = self.check(self.fanout_raise_1) self.assertEqual(stats.fanout_raise, 2) + def test_fanout_raise_1_legacy(self): + mod, stats = self.check_legacy(self.fanout_raise_1) + self.assertEqual(stats.fanout_raise, 2) + fanout_raise_2 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: @@ -472,6 +567,12 @@ def test_fanout_raise_2(self): mod, stats = self.check(self.fanout_raise_2) self.assertEqual(stats.fanout_raise, 0) + def test_fanout_raise_2_legacy(self): + # This is ensuring that fanout_raise is not pruning when the metadata + # is incorrectly named. + mod, stats = self.check_legacy(self.fanout_raise_2) + self.assertEqual(stats.fanout_raise, 0) + fanout_raise_3 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: @@ -492,6 +593,10 @@ def test_fanout_raise_3(self): mod, stats = self.check(self.fanout_raise_3) self.assertEqual(stats.fanout_raise, 2) + def test_fanout_raise_3_legacy(self): + mod, stats = self.check_legacy(self.fanout_raise_3) + self.assertEqual(stats.fanout_raise, 2) + fanout_raise_4 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: @@ -510,6 +615,10 @@ def test_fanout_raise_4(self): mod, stats = self.check(self.fanout_raise_4) self.assertEqual(stats.fanout_raise, 0) + def test_fanout_raise_4_legacy(self): + mod, stats = self.check_legacy(self.fanout_raise_4) + self.assertEqual(stats.fanout_raise, 0) + fanout_raise_5 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: @@ -532,6 +641,10 @@ def test_fanout_raise_5(self): mod, stats = self.check(self.fanout_raise_5) self.assertEqual(stats.fanout_raise, 2) + def test_fanout_raise_5_legacy(self): + mod, stats = self.check_legacy(self.fanout_raise_5) + self.assertEqual(stats.fanout_raise, 2) + # test case 6 is from https://github.com/numba/llvmlite/issues/1023 fanout_raise_6 = r""" define i32 @main(i8* %ptr, i1 %cond1, i1 %cond2, i1 %cond3, i8** %excinfo) { @@ -563,6 +676,10 @@ def test_fanout_raise_6(self): mod, stats = self.check(self.fanout_raise_6) self.assertEqual(stats.fanout_raise, 7) + def test_fanout_raise_6_legacy(self): + mod, stats = self.check_legacy(self.fanout_raise_6) + self.assertEqual(stats.fanout_raise, 7) + if __name__ == '__main__': unittest.main() From 0b99743e914f5810ed382e1a4bdcb9fc65b1ce57 Mon Sep 17 00:00:00 2001 From: Yashwant Singh Date: Thu, 19 Sep 2024 11:25:44 +0530 Subject: [PATCH 8/8] Fix corrupted test after merge --- llvmlite/tests/test_refprune.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llvmlite/tests/test_refprune.py b/llvmlite/tests/test_refprune.py index a8dff2940..299852001 100644 --- a/llvmlite/tests/test_refprune.py +++ b/llvmlite/tests/test_refprune.py @@ -349,7 +349,11 @@ def test_per_bb_4_legacy(self): mod, stats = self.check_legacy(self.per_bb_ir_4) self.assertEqual(stats.basicblock, 4) # not pruned - self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) + # FIXME: Remove `else' once TP are no longer supported. + if opaque_pointers_enabled: + self.assertIn("call void @NRT_decref(ptr %other)", str(mod)) + else: + self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) class TestDiamond(BaseTestByIR):