Skip to content

Commit

Permalink
Merge pull request numba#1069 from yashssh/pic-callback-bug
Browse files Browse the repository at this point in the history
Add instrumentation callback hook for new pass managers
  • Loading branch information
kc611 authored Jul 24, 2024
2 parents b15213f + 664d2a2 commit 78ebf9b
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 21 deletions.
47 changes: 42 additions & 5 deletions ffi/newpassmanagers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/JumpThreading.h"
#include "llvm/Transforms/Scalar/LoopRotation.h"
Expand Down Expand Up @@ -53,16 +54,31 @@ LLVMPY_CreateNewModulePassManager() {

API_EXPORT(void)
LLVMPY_RunNewModulePassManager(LLVMModulePassManagerRef MPMRef,
LLVMPassBuilderRef PBRef, LLVMModuleRef mod) {
LLVMModuleRef mod, LLVMPassBuilderRef PBRef) {

ModulePassManager *MPM = llvm::unwrap(MPMRef);
PassBuilder *PB = llvm::unwrap(PBRef);
Module *M = llvm::unwrap(mod);
PassBuilder *PB = llvm::unwrap(PBRef);

// TODO: Make these set(able) by user
bool DebugLogging = false;
bool VerifyEach = false;

LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;

PrintPassOptions PrintPassOpts;

#if LLVM_VERSION_MAJOR < 16
StandardInstrumentations SI(DebugLogging, VerifyEach, PrintPassOpts);
#else
StandardInstrumentations SI(M->getContext(), DebugLogging, VerifyEach,
PrintPassOpts);
#endif
SI.registerCallbacks(*PB->getPassInstrumentationCallbacks(), &FAM);

PB->registerLoopAnalyses(LAM);
PB->registerFunctionAnalyses(FAM);
PB->registerCGSCCAnalyses(CGAM);
Expand Down Expand Up @@ -126,20 +142,36 @@ LLVMPY_CreateNewFunctionPassManager() {

API_EXPORT(void)
LLVMPY_RunNewFunctionPassManager(LLVMFunctionPassManagerRef FPMRef,
LLVMPassBuilderRef PBRef, LLVMValueRef FRef) {
LLVMValueRef FRef, LLVMPassBuilderRef PBRef) {

FunctionPassManager *FPM = llvm::unwrap(FPMRef);
PassBuilder *PB = llvm::unwrap(PBRef);
Function *F = reinterpret_cast<Function *>(FRef);
PassBuilder *PB = llvm::unwrap(PBRef);

// Don't try to optimize function declarations
if (F->isDeclaration())
return;

// TODO: Make these set(able) by user
bool DebugLogging = false;
bool VerifyEach = false;

LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;

// TODO: Can expose this in ffi layer
PrintPassOptions PrintPassOpts;

#if LLVM_VERSION_MAJOR < 16
StandardInstrumentations SI(DebugLogging, VerifyEach, PrintPassOpts);
#else
StandardInstrumentations SI(F->getContext(), DebugLogging, VerifyEach,
PrintPassOpts);
#endif
SI.registerCallbacks(*PB->getPassInstrumentationCallbacks(), &FAM);

PB->registerLoopAnalyses(LAM);
PB->registerFunctionAnalyses(FAM);
PB->registerCGSCCAnalyses(CGAM);
Expand Down Expand Up @@ -254,7 +286,12 @@ LLVMPY_CreatePassBuilder(LLVMTargetMachineRef TM,
LLVMPipelineTuningOptionsRef PTO) {
TargetMachine *target = llvm::unwrap(TM);
PipelineTuningOptions *pt = llvm::unwrap(PTO);
return llvm::wrap(new PassBuilder(target, *pt));
PassInstrumentationCallbacks *PIC = new PassInstrumentationCallbacks();
#if LLVM_VERSION_MAJOR < 16
return llvm::wrap(new PassBuilder(target, *pt, None, PIC));
#else
return llvm::wrap(new PassBuilder(target, *pt, std::nullopt, PIC));
#endif
}

API_EXPORT(void)
Expand Down
15 changes: 7 additions & 8 deletions llvmlite/binding/newpassmanagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, ptr=None):
super().__init__(ptr)

def run(self, module, pb):
ffi.lib.LLVMPY_RunNewModulePassManager(self, pb, module)
ffi.lib.LLVMPY_RunNewModulePassManager(self, module, pb)

def add_verifier(self):
ffi.lib.LLVMPY_AddVerifierPass(self)
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, ptr=None):
super().__init__(ptr)

def run(self, fun, pb):
ffi.lib.LLVMPY_RunNewFunctionPassManager(self, pb, fun)
ffi.lib.LLVMPY_RunNewFunctionPassManager(self, fun, pb)

def add_aa_eval_pass(self):
ffi.lib.LLVMPY_AddAAEvalPass_function(self)
Expand Down Expand Up @@ -193,9 +193,9 @@ def _dispose(self):

ffi.lib.LLVMPY_CreateNewModulePassManager.restype = ffi.LLVMModulePassManagerRef

ffi.lib.LLVMPY_RunNewModulePassManager.argtypes = [ffi.LLVMModulePassManagerRef,
ffi.LLVMPassBuilderRef,
ffi.LLVMModuleRef,]
ffi.lib.LLVMPY_RunNewModulePassManager.argtypes = [
ffi.LLVMModulePassManagerRef, ffi.LLVMModuleRef,
ffi.LLVMPassBuilderRef,]

ffi.lib.LLVMPY_AddVerifierPass.argtypes = [ffi.LLVMModulePassManagerRef,]
ffi.lib.LLVMPY_AddAAEvalPass_module.argtypes = [ffi.LLVMModulePassManagerRef,]
Expand Down Expand Up @@ -223,9 +223,8 @@ def _dispose(self):
ffi.LLVMFunctionPassManagerRef

ffi.lib.LLVMPY_RunNewFunctionPassManager.argtypes = [
ffi.LLVMFunctionPassManagerRef,
ffi.LLVMPassBuilderRef,
ffi.LLVMValueRef,]
ffi.LLVMFunctionPassManagerRef, ffi.LLVMValueRef,
ffi.LLVMPassBuilderRef,]

ffi.lib.LLVMPY_AddAAEvalPass_function.argtypes = [
ffi.LLVMFunctionPassManagerRef,]
Expand Down
89 changes: 81 additions & 8 deletions llvmlite/tests/test_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,23 @@ def no_de_locale():
declare i8* @a_arg0_return_func(i8* returned, i32*)
"""

asm_alloca_optnone = r"""
define double @foo(i32 %i, double %j) optnone noinline {
%I = alloca i32 ; <i32*> [#uses=4]
%J = alloca double ; <double*> [#uses=2]
store i32 %i, i32* %I
store double %j, double* %J
%t1 = load i32, i32* %I ; <i32> [#uses=1]
%t2 = add i32 %t1, 1 ; <i32> [#uses=1]
store i32 %t2, i32* %I
%t3 = load i32, i32* %I ; <i32> [#uses=1]
%t4 = sitofp i32 %t3 to double ; <double> [#uses=1]
%t5 = load double, double* %J ; <double> [#uses=1]
%t6 = fmul double %t4, %t5 ; <double> [#uses=1]
ret double %t6
}
"""

asm_declaration = r"""
declare void @test_declare(i32* )
"""
Expand Down Expand Up @@ -3023,20 +3040,29 @@ class TestNewModulePassManager(BaseTest, NewPassManagerMixin):
def pm(self):
return llvm.create_new_module_pass_manager()

def test_close(self):
mpm = self.pm()
mpm.close()

def test_run(self):
pb = self.pb(speed_level=3, size_level=0)
def run_o_n(self, level):
mod = self.module()
orig_asm = str(mod)
pb = self.pb(speed_level=level, size_level=0)
mpm = pb.getModulePassManager()
mpm.run(mod, pb)
optimized_asm = str(mod)
return orig_asm, optimized_asm

def test_close(self):
mpm = self.pm()
mpm.close()

def test_run_o3(self):
orig_asm, optimized_asm = self.run_o_n(3)
self.assertIn("%.4", orig_asm)
self.assertNotIn("%.4", optimized_asm)

def test_run_o0(self):
orig_asm, optimized_asm = self.run_o_n(0)
self.assertIn("%.4", orig_asm)
self.assertIn("%.4", optimized_asm)

def test_instcombine(self):
pb = self.pb()
mpm = self.pm()
Expand All @@ -3048,6 +3074,25 @@ def test_instcombine(self):
self.assertIn("%.3", orig_asm)
self.assertNotIn("%.3", optimized_asm)

def test_optnone(self):
pb = self.pb(speed_level=3, size_level=0)
orig_asm = str(asm_alloca_optnone.replace("optnone ", ""))
mod = llvm.parse_assembly(orig_asm)
mpm = pb.getModulePassManager()
mpm.run(mod, pb)
optimized_asm = str(mod)
self.assertIn("alloca", orig_asm)
self.assertNotIn("alloca", optimized_asm)

# Module shouldn't be optimized if the function has `optnone` attached
orig_asm_optnone = str(asm_alloca_optnone)
mpm = pb.getModulePassManager()
mod = llvm.parse_assembly(orig_asm_optnone)
mpm.run(mod, pb)
optimized_asm_optnone = str(mod)
self.assertIn("alloca", orig_asm_optnone)
self.assertIn("alloca", optimized_asm_optnone)

def test_add_passes(self):
mpm = self.pm()
mpm.add_verifier()
Expand All @@ -3067,17 +3112,45 @@ def test_close(self):
fpm = self.pm()
fpm.close()

def test_run(self):
pb = self.pb(3)
def run_o_n(self, level):
mod = self.module()
fun = mod.get_function("sum")
orig_asm = str(fun)
pb = self.pb(speed_level=level, size_level=0)
fpm = pb.getFunctionPassManager()
fpm.run(fun, pb)
optimized_asm = str(fun)
return orig_asm, optimized_asm

def test_run_o3(self):
orig_asm, optimized_asm = self.run_o_n(3)
self.assertIn("%.4", orig_asm)
self.assertNotIn("%.4", optimized_asm)

def test_run_o0(self):
orig_asm, optimized_asm = self.run_o_n(0)
self.assertIn("%.4", orig_asm)
self.assertIn("%.4", optimized_asm)

def test_optnone(self):
pb = self.pb(speed_level=3, size_level=0)
orig_asm = str(asm_alloca_optnone.replace("optnone ", ""))
fun = llvm.parse_assembly(orig_asm).get_function("foo")
fpm = pb.getFunctionPassManager()
fpm.run(fun, pb)
optimized_asm = str(fun)
self.assertIn("alloca", orig_asm)
self.assertNotIn("alloca", optimized_asm)

# Function shouldn't be optimized if the function has `optnone` attached
orig_asm_optnone = str(asm_alloca_optnone)
fun = llvm.parse_assembly(orig_asm_optnone).get_function("foo")
fpm = pb.getFunctionPassManager()
fpm.run(fun, pb)
optimized_asm_optnone = str(fun)
self.assertIn("alloca", orig_asm_optnone)
self.assertIn("alloca", optimized_asm_optnone)

def test_instcombine(self):
pb = self.pb()
fpm = self.pm()
Expand Down

0 comments on commit 78ebf9b

Please sign in to comment.