Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP][OMPIRBuilder] Support omp.target 'if' translation to LLVM IR #157

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2871,6 +2871,7 @@ class OpenMPIRBuilder {
/// \param Loc where the target data construct was encountered.
/// \param IsSPMD whether this is an SPMD target launch.
/// \param IsOffloadEntry whether it is an offload entry.
/// \param IfCond value of the IF clause for the TARGET construct or nullptr.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
Expand All @@ -2884,7 +2885,7 @@ class OpenMPIRBuilder {
/// \param Dependencies A vector of DependData objects that carry
// dependency information as passed in the depend clause.
InsertPointTy createTarget(const LocationDescription &Loc, bool IsSPMD,
bool IsOffloadEntry,
bool IsOffloadEntry, Value *IfCond,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
Expand Down
195 changes: 115 additions & 80 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7266,7 +7266,7 @@ static void emitTargetCall(
const OpenMPIRBuilder::TargetKernelDefaultBounds &DefaultBounds,
const OpenMPIRBuilder::TargetKernelRuntimeBounds &RuntimeBounds,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
SmallVectorImpl<Value *> &Args, Value *IfCond,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
// Generate a function call to the host fallback implementation of the target
Expand All @@ -7283,9 +7283,7 @@ static void emitTargetCall(
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
auto &&EmitTargetCallElse = [&]() {
if (RequiresOuterTargetTask) {
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
Expand All @@ -7298,96 +7296,132 @@ static void emitTargetCall(
} else {
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
return;
}

OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefNumTeams, RtNumTeams] :
llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
NumTeamsC.push_back(RtNumTeams ? RtNumTeams
: Builder.getInt32(DefNumTeams));
}

// Calculate number of threads: 0 if no clauses specified, otherwise it is the
// minimum between optional THREAD_LIMIT and MAX_THREADS clauses. Perform a
// type cast to uint32.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
};

auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result = Result
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
auto &&EmitTargetCallThen = [&]() {
OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
/*SeparateBeginEndCalls=*/true);

OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
RTArgs, MapInfo,
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

SmallVector<Value *, 3> NumTeamsC;
for (auto [DefNumTeams, RtNumTeams] :
llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
NumTeamsC.push_back(RtNumTeams ? RtNumTeams
: Builder.getInt32(DefNumTeams));
}

// Calculate number of threads: 0 if no clauses specified, otherwise it is
// the minimum between optional THREAD_LIMIT and MAX_THREADS clauses.
// Perform a type cast to uint32.
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
if (Clause)
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
/*isSigned=*/false);
return Clause;
};

auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
if (Clause)
Result =
Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
Result, Clause)
: Clause;
};
};

// TODO: Check if this is the correct handling for multi-dim thread_limit.
SmallVector<Value *, 3> NumThreadsC;
Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);

// TODO: Check if this is the correct handling for multi-dim thread_limit.
SmallVector<Value *, 3> NumThreadsC;
Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);
for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);

for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);

CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
}

NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);

Value *TripCount = RuntimeBounds.LoopTripCount
? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);

// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC,
DynCGGroupMem, HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
}
};

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
EmitTargetCallElse();
return;
}

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);

Value *TripCount = RuntimeBounds.LoopTripCount
? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
Builder.getInt64Ty(),
/*isSigned=*/false)
: Builder.getInt64(0);

// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
NumTeamsC, NumThreadsC, DynCGGroupMem,
HasNoWait);

// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
if (RequiresOuterTargetTask) {
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
RTLoc, AllocaIP, Dependencies, HasNoWait));
} else {
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
DeviceID, RTLoc, AllocaIP));
// If there's no IF clause, only generate the kernel launch code path.
if (!IfCond) {
EmitTargetCallThen();
return;
}

// Create if-else to handle IF clause.
llvm::BasicBlock *ThenBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.then");
llvm::BasicBlock *ElseBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.else");
llvm::BasicBlock *ContBlock =
BasicBlock::Create(Builder.getContext(), "omp_if.end");
Builder.CreateCondBr(IfCond, ThenBlock, ElseBlock);

Function *CurFn = Builder.GetInsertBlock()->getParent();

// Emit the 'then' code.
OMPBuilder.emitBlock(ThenBlock, CurFn);
EmitTargetCallThen();
OMPBuilder.emitBranch(ContBlock);
// Emit the 'else' code.
OMPBuilder.emitBlock(ElseBlock, CurFn);
EmitTargetCallElse();
OMPBuilder.emitBranch(ContBlock);
// Emit the continuation block.
OMPBuilder.emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsSPMD, bool IsOffloadEntry,
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value *IfCond, InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultBounds &DefaultBounds,
const TargetKernelRuntimeBounds &RuntimeBounds,
Expand Down Expand Up @@ -7415,7 +7449,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, DefaultBounds, RuntimeBounds,
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies);
OutlinedFn, OutlinedFnID, Args, IfCond, GenMapInfoCB,
Dependencies);
return Builder.saveIP();
}

Expand Down
18 changes: 9 additions & 9 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6017,9 +6017,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
RuntimeBounds.MaxTeams.push_back(nullptr);
Builder.restoreIP(OMPBuilder.createTarget(
OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultBounds, RuntimeBounds, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultBounds,
RuntimeBounds, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();

Expand Down Expand Up @@ -6134,9 +6134,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
RuntimeBounds.MaxTeams.push_back(nullptr);
Builder.restoreIP(OMPBuilder.createTarget(
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down Expand Up @@ -6290,9 +6290,9 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
RuntimeBounds.MaxTeams.push_back(nullptr);
Builder.restoreIP(OMPBuilder.createTarget(
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));

Builder.CreateRetVoid();
OMPBuilder.finalize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3476,10 +3476,6 @@ static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,

static bool targetOpSupported(Operation &opInst) {
auto targetOp = cast<omp::TargetOp>(opInst);
if (targetOp.getIfExpr()) {
opInst.emitError("If clause not yet supported");
return false;
}

if (targetOp.getDevice()) {
opInst.emitError("Device clause not yet supported");
Expand Down Expand Up @@ -3955,8 +3951,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (Value targetThreadLimit = targetOp.getThreadLimit())
llvmTargetThreadLimit = moduleTranslation.lookupValue(targetThreadLimit);

llvm::Value *ifCond = nullptr;
if (Value targetIfCond = targetOp.getIfExpr())
ifCond = moduleTranslation.lookupValue(targetIfCond);

builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, allocaIP,
ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, ifCond, allocaIP,
builder.saveIP(), entryInfo, defaultBounds, runtimeBounds, kernelInput,
genMapInfoCB, bodyCB, argAccessorCB, dds));

Expand Down