Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new flag "--no-entry" #2585

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/LLVMSPIRVLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ ModulePass *
createSPIRVLowerLLVMIntrinsicLegacy(const SPIRV::TranslatorOpts &Opts);

/// Create a pass for regularize LLVM module to be translated to SPIR-V.
ModulePass *createSPIRVRegularizeLLVMLegacy();
ModulePass *createSPIRVRegularizeLLVMLegacy(const SPIRV::TranslatorOpts &Opts);

/// Create a pass for translating SPIR-V Instructions to desired
/// representation in LLVM IR (OpenCL built-ins, SPIR-V Friendly IR, etc.)
Expand Down
10 changes: 10 additions & 0 deletions include/LLVMSPIRVOpts.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@ class TranslatorOpts {
PreserveOCLKernelArgTypeMetadataThroughString = Value;
}

void setGenerateKernelEntryPoints(bool Value) noexcept {
GenerateEntryPoints = Value;
}

bool getGenerateKernelEntryPoints() const noexcept {
return GenerateEntryPoints;
}

void setBuiltinFormat(BuiltinFormat Value) noexcept {
SPIRVBuiltinFormat = Value;
}
Expand Down Expand Up @@ -281,6 +289,8 @@ class TranslatorOpts {
// kernel_arg_type_qual metadata through OpString
bool PreserveOCLKernelArgTypeMetadataThroughString = false;

bool GenerateEntryPoints = true;

bool PreserveAuxData = false;

BuiltinFormat SPIRVBuiltinFormat = BuiltinFormat::Function;
Expand Down
7 changes: 4 additions & 3 deletions lib/SPIRV/SPIRVRegularizeLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ void regularizeWithOverflowInstrinsics(StringRef MangledName, CallInst *Call,
/// Remove entities not representable by SPIR-V
bool SPIRVRegularizeLLVMBase::regularize() {
eraseUselessFunctions(M);
addKernelEntryPoint(M);
if (Opts.getGenerateKernelEntryPoints())
addKernelEntryPoint(M);
expandSYCLTypeUsing(M);
cleanupConversionToNonStdIntegers(M);

Expand Down Expand Up @@ -767,6 +768,6 @@ void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) {
INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular",
"Regularize LLVM for SPIR-V", false, false)

ModulePass *llvm::createSPIRVRegularizeLLVMLegacy() {
return new SPIRVRegularizeLLVMLegacy();
ModulePass *llvm::createSPIRVRegularizeLLVMLegacy(const TranslatorOpts &Opts) {
return new SPIRVRegularizeLLVMLegacy(Opts);
}
9 changes: 7 additions & 2 deletions lib/SPIRV/SPIRVRegularizeLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ namespace SPIRV {

class SPIRVRegularizeLLVMBase {
public:
SPIRVRegularizeLLVMBase() : M(nullptr), Ctx(nullptr) {}
SPIRVRegularizeLLVMBase(const TranslatorOpts& TOpts)
: M(nullptr), Ctx(nullptr), Opts(TOpts) {}

bool runRegularizeLLVM(llvm::Module &M);
// Lower functions
Expand Down Expand Up @@ -116,12 +117,15 @@ class SPIRVRegularizeLLVMBase {
private:
llvm::Module *M;
llvm::LLVMContext *Ctx;
const TranslatorOpts &Opts;
};

class SPIRVRegularizeLLVMPass
: public llvm::PassInfoMixin<SPIRVRegularizeLLVMPass>,
public SPIRVRegularizeLLVMBase {
public:
SPIRVRegularizeLLVMPass(const TranslatorOpts &TOpts)
: SPIRVRegularizeLLVMBase(TOpts) {}
llvm::PreservedAnalyses run(llvm::Module &M,
llvm::ModuleAnalysisManager &MAM) {
return runRegularizeLLVM(M) ? llvm::PreservedAnalyses::none()
Expand All @@ -134,7 +138,8 @@ class SPIRVRegularizeLLVMPass
class SPIRVRegularizeLLVMLegacy : public llvm::ModulePass,
public SPIRVRegularizeLLVMBase {
public:
SPIRVRegularizeLLVMLegacy() : ModulePass(ID) {
SPIRVRegularizeLLVMLegacy(const TranslatorOpts &TOpts)
: ModulePass(ID), SPIRVRegularizeLLVMBase(TOpts) {
initializeSPIRVRegularizeLLVMLegacyPass(*PassRegistry::getPassRegistry());
}

Expand Down
12 changes: 8 additions & 4 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
static_cast<SPIRVFunction *>(mapValue(F, BM->addFunction(BFT)));
BF->setFunctionControlMask(transFunctionControlMask(F));
if (F->hasName()) {
if (isKernel(F)) {
if (isKernel(F) && BM->getGenerateKernelEntryPoints()) {
/* strip the prefix as the runtime will be looking for this name */
std::string Prefix = kSPIRVName::EntrypointPrefix;
std::string Name = F->getName().str();
Expand Down Expand Up @@ -5635,7 +5635,9 @@ void LLVMToSPIRVBase::transFunction(Function *I) {

if (isKernel(I)) {
auto Interface = collectEntryPointInterfaces(BF, I);
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), BF->getName(),
auto Name =
BM->getGenerateKernelEntryPoints() ? BF->getName() : I->getName().str();
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), Name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? If !getGenerateKernelEntryPoints, then BF->getName()==I->getName().str() since BF->getName() will be set by line 885.

Interface);
}
}
Expand Down Expand Up @@ -5998,7 +6000,9 @@ bool LLVMToSPIRVBase::transMetadata() {
static void transKernelArgTypeMD(SPIRVModule *BM, Function *F, MDNode *MD,
std::string MDName) {
std::string Prefix = kSPIRVName::EntrypointPrefix;
std::string Name = F->getName().str().substr(Prefix.size());
std::string Name = F->getName().str();
if (BM->getGenerateKernelEntryPoints())
Name = Name.substr(Prefix.size());
std::string KernelArgTypesMDStr = std::string(MDName) + "." + Name + ".";
for (const auto &TyOp : MD->operands())
KernelArgTypesMDStr += cast<MDString>(TyOp)->getString().str() + ",";
Expand Down Expand Up @@ -6647,7 +6651,7 @@ void addPassesForSPIRV(ModulePassManager &PassMgr,
PassMgr.addPass(PreprocessMetadataPass());
PassMgr.addPass(SPIRVLowerOCLBlocksPass());
PassMgr.addPass(OCLToSPIRVPass());
PassMgr.addPass(SPIRVRegularizeLLVMPass());
PassMgr.addPass(SPIRVRegularizeLLVMPass(Opts));
PassMgr.addPass(SPIRVLowerConstExprPass());
PassMgr.addPass(SPIRVLowerBoolPass());
PassMgr.addPass(SPIRVLowerMemmovePass());
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ class SPIRVModule {
return TranslationOpts.getDesiredBIsRepresentation();
}

bool getGenerateKernelEntryPoints() const {
return TranslationOpts.getGenerateKernelEntryPoints();
}

// I/O functions
friend spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M);
friend std::istream &operator>>(std::istream &I, SPIRVModule &M);
Expand Down
6 changes: 6 additions & 0 deletions tools/llvm-spirv/llvm-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ static cl::opt<SPIRV::BuiltinFormat> SPIRVBuiltinFormat(
clEnumValN(SPIRV::BuiltinFormat::Global, "global",
"Use globals to represent SPIR-V builtin variables")));

static cl::opt<bool>
NoEntry("no-entry", cl::init(false),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other llvm-spirv options:

Suggested change
NoEntry("no-entry", cl::init(false),
NoEntry("spirv-no-entry", cl::init(false),

cl::desc("Disable kernel entry point generation. This is a workaround for some OpenCL drivers that are incompatible with entry points, leading them to have twice the number of kernel arguments. See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1486 for more details."));

static std::string removeExt(const std::string &FileName) {
size_t Pos = FileName.find_last_of(".");
if (Pos != std::string::npos)
Expand Down Expand Up @@ -820,6 +824,8 @@ int main(int Ac, char **Av) {
if (PreserveOCLKernelArgTypeMetadataThroughString.getNumOccurrences() != 0)
Opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);

Opts.setGenerateKernelEntryPoints(!NoEntry);

#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (ToText && (ToBinary || IsReverse || IsRegularization)) {
errs() << "Cannot use -to-text with -to-binary, -r, -s\n";
Expand Down
Loading