diff --git a/include/LLVMSPIRVLib.h b/include/LLVMSPIRVLib.h index 48834d8be8..ec12043688 100644 --- a/include/LLVMSPIRVLib.h +++ b/include/LLVMSPIRVLib.h @@ -227,7 +227,7 @@ ModulePass * createSPIRVLowerLLVMIntrinsicLegacy(const SPIRV::TranslatorOpts &Opts); /// Create a pass for regularize LLVM module to be translated to SPIR-V. -ModulePass *createSPIRVRegularizeLLVMLegacy(); +ModulePass *createSPIRVRegularizeLLVMLegacy(const SPIRV::TranslatorOpts &Opts); /// Create a pass for translating SPIR-V Instructions to desired /// representation in LLVM IR (OpenCL built-ins, SPIR-V Friendly IR, etc.) diff --git a/include/LLVMSPIRVOpts.h b/include/LLVMSPIRVOpts.h index ab577d757b..a91193620e 100644 --- a/include/LLVMSPIRVOpts.h +++ b/include/LLVMSPIRVOpts.h @@ -233,6 +233,14 @@ class TranslatorOpts { PreserveOCLKernelArgTypeMetadataThroughString = Value; } + void setGenerateKernelEntryPoints(bool Value) noexcept { + GenerateEntryPoints = Value; + } + + bool getGenerateKernelEntryPoints() const noexcept { + return GenerateEntryPoints; + } + void setBuiltinFormat(BuiltinFormat Value) noexcept { SPIRVBuiltinFormat = Value; } @@ -281,6 +289,8 @@ class TranslatorOpts { // kernel_arg_type_qual metadata through OpString bool PreserveOCLKernelArgTypeMetadataThroughString = false; + bool GenerateEntryPoints = true; + bool PreserveAuxData = false; BuiltinFormat SPIRVBuiltinFormat = BuiltinFormat::Function; diff --git a/lib/SPIRV/SPIRVRegularizeLLVM.cpp b/lib/SPIRV/SPIRVRegularizeLLVM.cpp index f32814217e..f94edec364 100644 --- a/lib/SPIRV/SPIRVRegularizeLLVM.cpp +++ b/lib/SPIRV/SPIRVRegularizeLLVM.cpp @@ -500,7 +500,8 @@ void regularizeWithOverflowInstrinsics(StringRef MangledName, CallInst *Call, /// Remove entities not representable by SPIR-V bool SPIRVRegularizeLLVMBase::regularize() { eraseUselessFunctions(M); - addKernelEntryPoint(M); + if (Opts.getGenerateKernelEntryPoints()) + addKernelEntryPoint(M); expandSYCLTypeUsing(M); cleanupConversionToNonStdIntegers(M); @@ -767,6 +768,6 @@ void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) { INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular", "Regularize LLVM for SPIR-V", false, false) -ModulePass *llvm::createSPIRVRegularizeLLVMLegacy() { - return new SPIRVRegularizeLLVMLegacy(); +ModulePass *llvm::createSPIRVRegularizeLLVMLegacy(const TranslatorOpts &Opts) { + return new SPIRVRegularizeLLVMLegacy(Opts); } diff --git a/lib/SPIRV/SPIRVRegularizeLLVM.h b/lib/SPIRV/SPIRVRegularizeLLVM.h index 90e14433bd..c785568a62 100644 --- a/lib/SPIRV/SPIRVRegularizeLLVM.h +++ b/lib/SPIRV/SPIRVRegularizeLLVM.h @@ -45,7 +45,8 @@ namespace SPIRV { class SPIRVRegularizeLLVMBase { public: - SPIRVRegularizeLLVMBase() : M(nullptr), Ctx(nullptr) {} + SPIRVRegularizeLLVMBase(const TranslatorOpts& TOpts) + : M(nullptr), Ctx(nullptr), Opts(TOpts) {} bool runRegularizeLLVM(llvm::Module &M); // Lower functions @@ -116,12 +117,15 @@ class SPIRVRegularizeLLVMBase { private: llvm::Module *M; llvm::LLVMContext *Ctx; + const TranslatorOpts &Opts; }; class SPIRVRegularizeLLVMPass : public llvm::PassInfoMixin, public SPIRVRegularizeLLVMBase { public: + SPIRVRegularizeLLVMPass(const TranslatorOpts &TOpts) + : SPIRVRegularizeLLVMBase(TOpts) {} llvm::PreservedAnalyses run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { return runRegularizeLLVM(M) ? llvm::PreservedAnalyses::none() @@ -134,7 +138,8 @@ class SPIRVRegularizeLLVMPass class SPIRVRegularizeLLVMLegacy : public llvm::ModulePass, public SPIRVRegularizeLLVMBase { public: - SPIRVRegularizeLLVMLegacy() : ModulePass(ID) { + SPIRVRegularizeLLVMLegacy(const TranslatorOpts &TOpts) + : ModulePass(ID), SPIRVRegularizeLLVMBase(TOpts) { initializeSPIRVRegularizeLLVMLegacyPass(*PassRegistry::getPassRegistry()); } diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index ef684e936e..7decafa378 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -871,7 +871,7 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) { static_cast(mapValue(F, BM->addFunction(BFT))); BF->setFunctionControlMask(transFunctionControlMask(F)); if (F->hasName()) { - if (isKernel(F)) { + if (isKernel(F) && BM->getGenerateKernelEntryPoints()) { /* strip the prefix as the runtime will be looking for this name */ std::string Prefix = kSPIRVName::EntrypointPrefix; std::string Name = F->getName().str(); @@ -5635,7 +5635,9 @@ void LLVMToSPIRVBase::transFunction(Function *I) { if (isKernel(I)) { auto Interface = collectEntryPointInterfaces(BF, I); - BM->addEntryPoint(ExecutionModelKernel, BF->getId(), BF->getName(), + auto Name = + BM->getGenerateKernelEntryPoints() ? BF->getName() : I->getName().str(); + BM->addEntryPoint(ExecutionModelKernel, BF->getId(), Name, Interface); } } @@ -5998,7 +6000,9 @@ bool LLVMToSPIRVBase::transMetadata() { static void transKernelArgTypeMD(SPIRVModule *BM, Function *F, MDNode *MD, std::string MDName) { std::string Prefix = kSPIRVName::EntrypointPrefix; - std::string Name = F->getName().str().substr(Prefix.size()); + std::string Name = F->getName().str(); + if (BM->getGenerateKernelEntryPoints()) + Name = Name.substr(Prefix.size()); std::string KernelArgTypesMDStr = std::string(MDName) + "." + Name + "."; for (const auto &TyOp : MD->operands()) KernelArgTypesMDStr += cast(TyOp)->getString().str() + ","; @@ -6647,7 +6651,7 @@ void addPassesForSPIRV(ModulePassManager &PassMgr, PassMgr.addPass(PreprocessMetadataPass()); PassMgr.addPass(SPIRVLowerOCLBlocksPass()); PassMgr.addPass(OCLToSPIRVPass()); - PassMgr.addPass(SPIRVRegularizeLLVMPass()); + PassMgr.addPass(SPIRVRegularizeLLVMPass(Opts)); PassMgr.addPass(SPIRVLowerConstExprPass()); PassMgr.addPass(SPIRVLowerBoolPass()); PassMgr.addPass(SPIRVLowerMemmovePass()); diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.h b/lib/SPIRV/libSPIRV/SPIRVModule.h index d2879fbcc4..d0dd2e9d6f 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.h +++ b/lib/SPIRV/libSPIRV/SPIRVModule.h @@ -578,6 +578,10 @@ class SPIRVModule { return TranslationOpts.getDesiredBIsRepresentation(); } + bool getGenerateKernelEntryPoints() const { + return TranslationOpts.getGenerateKernelEntryPoints(); + } + // I/O functions friend spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M); friend std::istream &operator>>(std::istream &I, SPIRVModule &M); diff --git a/tools/llvm-spirv/llvm-spirv.cpp b/tools/llvm-spirv/llvm-spirv.cpp index c90f16480f..0c38f03e4a 100644 --- a/tools/llvm-spirv/llvm-spirv.cpp +++ b/tools/llvm-spirv/llvm-spirv.cpp @@ -278,6 +278,10 @@ static cl::opt SPIRVBuiltinFormat( clEnumValN(SPIRV::BuiltinFormat::Global, "global", "Use globals to represent SPIR-V builtin variables"))); +static cl::opt + NoEntry("no-entry", cl::init(false), + cl::desc("Disable kernel entry point generation. This is a workaround for some OpenCL drivers that are incompatible with entry points, leading them to have twice the number of kernel arguments. See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1486 for more details.")); + static std::string removeExt(const std::string &FileName) { size_t Pos = FileName.find_last_of("."); if (Pos != std::string::npos) @@ -820,6 +824,8 @@ int main(int Ac, char **Av) { if (PreserveOCLKernelArgTypeMetadataThroughString.getNumOccurrences() != 0) Opts.setPreserveOCLKernelArgTypeMetadataThroughString(true); + Opts.setGenerateKernelEntryPoints(!NoEntry); + #ifdef _SPIRV_SUPPORT_TEXT_FMT if (ToText && (ToBinary || IsReverse || IsRegularization)) { errs() << "Cannot use -to-text with -to-binary, -r, -s\n";