Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port refprune pass to NewPassManager #1057

Merged
merged 10 commits into from
Sep 24, 2024
220 changes: 157 additions & 63 deletions ffi/custom_passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,22 @@
using namespace llvm;

namespace llvm {
void initializeRefNormalizePassPass(PassRegistry &Registry);
void initializeRefPrunePassPass(PassRegistry &Registry);
void initializeRefNormalizeLegacyPassPass(PassRegistry &Registry);
void initializeRefPruneLegacyPassPass(PassRegistry &Registry);
Comment on lines +29 to +30
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefix the pass name with "Legacy" to be used with legacy pass manager

} // namespace llvm

namespace llvm {
struct OpaqueModulePassManager;
typedef OpaqueModulePassManager *LLVMModulePassManagerRef;
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ModulePassManager, LLVMModulePassManagerRef)

struct OpaqueFunctionPassManager;
typedef OpaqueFunctionPassManager *LLVMFunctionPassManagerRef;
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(FunctionPassManager,
LLVMFunctionPassManagerRef)
} // namespace llvm

namespace {
/**
* Checks if a call instruction is an incref
*
Expand Down Expand Up @@ -104,13 +116,9 @@ template <class Tstack> struct raiiStack {
* A FunctionPass to reorder incref/decref instructions such that decrefs occur
* logically after increfs. This is a pre-requisite pass to the pruner passes.
*/
struct RefNormalizePass : public FunctionPass {
static char ID;
RefNormalizePass() : FunctionPass(ID) {
initializeRefNormalizePassPass(*PassRegistry::getPassRegistry());
}
struct RefNormalize {

bool runOnFunction(Function &F) override {
bool runOnFunction(Function &F) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of the RefNormalize pass, we will call this function from legacy and new version of RefNormalize pass.

bool mutated = false;
// For each basic block in F
for (BasicBlock &bb : F) {
Expand Down Expand Up @@ -158,7 +166,16 @@ struct RefNormalizePass : public FunctionPass {
}
};

struct RefPrunePass : public FunctionPass {
typedef enum {
None = 0b0000,
PerBasicBlock = 0b0001,
Diamond = 0b0010,
Fanout = 0b0100,
FanoutRaise = 0b1000,
All = PerBasicBlock | Diamond | Fanout | FanoutRaise
} Subpasses;

struct RefPrune {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have moved all the computation related functions from inside struct RefPrunePass : public FunctionPass to struct RefPrune so that we can re-use these functions from both the Legacy pass and the new pass.

static char ID;
static size_t stats_per_bb;
static size_t stats_diamond;
Expand All @@ -175,25 +192,21 @@ struct RefPrunePass : public FunctionPass {
/**
* Enum for setting which subpasses to run, there is no interdependence.
*/
enum Subpasses {
None = 0b0000,
PerBasicBlock = 0b0001,
Diamond = 0b0010,
Fanout = 0b0100,
FanoutRaise = 0b1000,
All = PerBasicBlock | Diamond | Fanout | FanoutRaise
} flags;

RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1)
: FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) {
initializeRefPrunePassPass(*PassRegistry::getPassRegistry());
}
Subpasses flags;

DominatorTree &DT;
PostDominatorTree &PDT;

RefPrune(DominatorTree &DT, PostDominatorTree &PDT,
Subpasses flags = Subpasses::All, size_t subgraph_limit = -1)
: DT(DT), PDT(PDT), flags(flags), subgraph_limit(subgraph_limit) {}

bool isSubpassEnabledFor(Subpasses expected) {
return (flags & expected) == expected;
}

bool runOnFunction(Function &F) override {
bool runOnFunction(Function &F) {
// state for LLVM function pass mutated IR
bool mutated = false;

Expand Down Expand Up @@ -361,11 +374,6 @@ struct RefPrunePass : public FunctionPass {
*/
bool runDiamondPrune(Function &F) {
bool mutated = false;
// gets the dominator tree
auto &domtree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
// gets the post-dominator tree
auto &postdomtree =
getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();

// Find all increfs and decrefs in the Function and store them in
// incref_list and decref_list respectively.
Expand Down Expand Up @@ -394,8 +402,8 @@ struct RefPrunePass : public FunctionPass {
continue;

// incref DOM decref && decref POSTDOM incref
if (domtree.dominates(incref, decref) &&
postdomtree.dominates(decref, incref)) {
if (DT.dominates(incref, decref) &&
PDT.dominates(decref, incref)) {
// check that the decref cannot be executed multiple times
SmallBBSet tail_nodes;
tail_nodes.insert(decref->getParent());
Expand Down Expand Up @@ -1028,14 +1036,6 @@ struct RefPrunePass : public FunctionPass {
return NULL;
}

/**
* getAnalysisUsage() LLVM plumbing for the pass
*/
void getAnalysisUsage(AnalysisUsage &Info) const override {
Info.addRequired<DominatorTreeWrapperPass>();
Info.addRequired<PostDominatorTreeWrapperPass>();
}

/**
* Checks if the first argument to the supplied call_inst is NULL and
* returns true if so, false otherwise.
Expand Down Expand Up @@ -1163,34 +1163,128 @@ struct RefPrunePass : public FunctionPass {
}
}
}
}; // end of struct RefPrunePass
}; // end of struct RefPrune

} // namespace

class RefPrunePass : public PassInfoMixin<RefPrunePass> {

public:
Subpasses flags;
size_t subgraph_limit;
RefPrunePass(Subpasses flags = Subpasses::All, size_t subgraph_limit = -1)
: flags(flags), subgraph_limit(subgraph_limit) {}

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
if (RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F)) {
return PreservedAnalyses::none();
}

return PreservedAnalyses::all();
}
};
Comment on lines +1170 to +1187
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have renamed the pass with "Legacy" keyword, we are free to create another pass with the name RefPrunePass, we will call this pass from NewPassManager . The passes for new pass manager follow this template mixin inheritance syntax that's visible here. There's also a difference in how we get the Analysis results needed for the pass to run, instead of explicitly specifying dependencies in getAnalysisUsage function we can directly query the NPM analysis manager for the result.

Note that wrapped in all this new pass manager syntax we are calling runOnFunction(F) which actually does the computation for this pass and is also used by legacy version of this pass.


char RefNormalizePass::ID = 0;
char RefPrunePass::ID = 0;
class RefNormalizePass : public PassInfoMixin<RefNormalizePass> {

size_t RefPrunePass::stats_per_bb = 0;
size_t RefPrunePass::stats_diamond = 0;
size_t RefPrunePass::stats_fanout = 0;
size_t RefPrunePass::stats_fanout_raise = 0;
public:
RefNormalizePass() = default;

INITIALIZE_PASS(RefNormalizePass, "nrtrefnormalizepass", "Normalize NRT refops",
false, false)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) {
RefNormalize().runOnFunction(F);

INITIALIZE_PASS_BEGIN(RefPrunePass, "nrtrefprunepass", "Prune NRT refops",
false, false)
return PreservedAnalyses::all();
}
};

class RefNormalizeLegacyPass : public FunctionPass {
public:
static char ID;
RefNormalizeLegacyPass() : FunctionPass(ID) {
initializeRefNormalizeLegacyPassPass(*PassRegistry::getPassRegistry());
}

bool runOnFunction(Function &F) override {
return RefNormalize().runOnFunction(F);
};
};

class RefPruneLegacyPass : public FunctionPass {

public:
static char ID; // Pass identification, replacement for typeid
// The maximum number of nodes that the fanout pruners will look at.
size_t subgraph_limit;
Subpasses flags;
RefPruneLegacyPass(Subpasses flags = Subpasses::All,
size_t subgraph_limit = -1)
: FunctionPass(ID), flags(flags), subgraph_limit(subgraph_limit) {
initializeRefPruneLegacyPassPass(*PassRegistry::getPassRegistry());
}

bool runOnFunction(Function &F) override {
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();

auto &PDT =
getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();

return RefPrune(DT, PDT, flags, subgraph_limit).runOnFunction(F);
};

/**
* getAnalysisUsage() LLVM plumbing for the pass
*/
void getAnalysisUsage(AnalysisUsage &Info) const override {
Info.addRequired<DominatorTreeWrapperPass>();
Info.addRequired<PostDominatorTreeWrapperPass>();
}
};

char RefNormalizeLegacyPass::ID = 0;
char RefPruneLegacyPass::ID = 0;

size_t RefPrune::stats_per_bb = 0;
size_t RefPrune::stats_diamond = 0;
size_t RefPrune::stats_fanout = 0;
size_t RefPrune::stats_fanout_raise = 0;

INITIALIZE_PASS(RefNormalizeLegacyPass, "nrtRefNormalize",
"Normalize NRT refops", false, false)

INITIALIZE_PASS_BEGIN(RefPruneLegacyPass, "nrtRefPruneLegacyPass",
"Prune NRT refops", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)

INITIALIZE_PASS_END(RefPrunePass, "refprunepass", "Prune NRT refops", false,
false)
INITIALIZE_PASS_END(RefPruneLegacyPass, "RefPruneLegacyPass",
"Prune NRT refops", false, false)

extern "C" {

API_EXPORT(void)
LLVMPY_AddRefPrunePass(LLVMPassManagerRef PM, int subpasses,
size_t subgraph_limit) {
unwrap(PM)->add(new RefNormalizePass());
LLVMPY_AddLegacyRefPrunePass(LLVMPassManagerRef PM, int subpasses,
size_t subgraph_limit) {
unwrap(PM)->add(new RefNormalizeLegacyPass());
unwrap(PM)->add(
new RefPrunePass((RefPrunePass::Subpasses)subpasses, subgraph_limit));
new RefPruneLegacyPass((Subpasses)subpasses, subgraph_limit));
}

API_EXPORT(void)
LLVMPY_AddRefPrunePass_module(LLVMModulePassManagerRef MPM, int subpasses,
size_t subgraph_limit) {
llvm::unwrap(MPM)->addPass(
createModuleToFunctionPassAdaptor(RefNormalizePass()));
llvm::unwrap(MPM)->addPass(createModuleToFunctionPassAdaptor(
RefPrunePass((Subpasses)subpasses, subgraph_limit)));
}

API_EXPORT(void)
LLVMPY_AddRefPrunePass_function(LLVMFunctionPassManagerRef FPM, int subpasses,
size_t subgraph_limit) {
llvm::unwrap(FPM)->addPass(RefNormalizePass());
llvm::unwrap(FPM)->addPass(
RefPrunePass((Subpasses)subpasses, subgraph_limit));
Comment on lines +1273 to +1287
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API to register new passes in MPM and FPM respectively.

}

/**
Expand All @@ -1207,24 +1301,24 @@ typedef struct PruneStats {
API_EXPORT(void)
LLVMPY_DumpRefPruneStats(PRUNESTATS *buf, bool do_print) {
/* PRUNESTATS is updated with the statistics about what has been pruned from
* the RefPrunePass static state vars. This isn't threadsafe but neither is
* the RefPrune static state vars. This isn't threadsafe but neither is
* the LLVM pass infrastructure so it's all done under a python thread lock.
*
* do_print if set will print the stats to stderr.
*/
if (do_print) {
errs() << "refprune stats "
<< "per-BB " << RefPrunePass::stats_per_bb << " "
<< "diamond " << RefPrunePass::stats_diamond << " "
<< "fanout " << RefPrunePass::stats_fanout << " "
<< "fanout+raise " << RefPrunePass::stats_fanout_raise << " "
<< "per-BB " << RefPrune::stats_per_bb << " "
<< "diamond " << RefPrune::stats_diamond << " "
<< "fanout " << RefPrune::stats_fanout << " "
<< "fanout+raise " << RefPrune::stats_fanout_raise << " "
<< "\n";
};

buf->basicblock = RefPrunePass::stats_per_bb;
buf->diamond = RefPrunePass::stats_diamond;
buf->fanout = RefPrunePass::stats_fanout;
buf->fanout_raise = RefPrunePass::stats_fanout_raise;
buf->basicblock = RefPrune::stats_per_bb;
buf->diamond = RefPrune::stats_diamond;
buf->fanout = RefPrune::stats_fanout;
buf->fanout_raise = RefPrune::stats_fanout_raise;
}

} // extern "C"
45 changes: 45 additions & 0 deletions llvmlite/binding/newpassmanagers.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied the APIs from passmanagers.py

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import c_bool, c_int
from enum import IntFlag
from llvmlite.binding import ffi


Expand All @@ -18,6 +19,14 @@ def create_pipeline_tuning_options(speed_level=2, size_level=0):
return PipelineTuningOptions(speed_level, size_level)


class RefPruneSubpasses(IntFlag):
PER_BB = 0b0001 # noqa: E221
DIAMOND = 0b0010 # noqa: E221
FANOUT = 0b0100 # noqa: E221
FANOUT_RAISE = 0b1000
ALL = PER_BB | DIAMOND | FANOUT | FANOUT_RAISE


class ModulePassManager(ffi.ObjectRef):

def __init__(self, ptr=None):
Expand Down Expand Up @@ -52,6 +61,24 @@ def add_jump_threading_pass(self, threshold=-1):
def _dispose(self):
ffi.lib.LLVMPY_DisposeNewModulePassManger(self)

# Non-standard LLVM passes
def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL,
subgraph_limit=1000):
"""Add Numba specific Reference count pruning pass.

Parameters
----------
subpasses_flags : RefPruneSubpasses
A bitmask to control the subpasses to be enabled.
subgraph_limit : int
Limit the fanout pruners to working on a subgraph no bigger than
this number of basic-blocks to avoid spending too much time in very
large graphs. Default is 1000. Subject to change in future
versions.
"""
iflags = RefPruneSubpasses(subpasses_flags)
ffi.lib.LLVMPY_AddRefPrunePass_module(self, iflags, subgraph_limit)


class FunctionPassManager(ffi.ObjectRef):

Expand Down Expand Up @@ -84,6 +111,24 @@ def add_jump_threading_pass(self, threshold=-1):
def _dispose(self):
ffi.lib.LLVMPY_DisposeNewFunctionPassManger(self)

# Non-standard LLVM passes
def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL,
subgraph_limit=1000):
"""Add Numba specific Reference count pruning pass.

Parameters
----------
subpasses_flags : RefPruneSubpasses
A bitmask to control the subpasses to be enabled.
subgraph_limit : int
Limit the fanout pruners to working on a subgraph no bigger than
this number of basic-blocks to avoid spending too much time in very
large graphs. Default is 1000. Subject to change in future
versions.
"""
iflags = RefPruneSubpasses(subpasses_flags)
ffi.lib.LLVMPY_AddRefPrunePass_function(self, iflags, subgraph_limit)


class PipelineTuningOptions(ffi.ObjectRef):

Expand Down
Loading