Skip to content

Commit

Permalink
Generate teams reductions inside the distribute operation if the redu…
Browse files Browse the repository at this point in the history
…ction

computation is contained within the distribute op.
  • Loading branch information
jsjodin committed Nov 21, 2024
1 parent 109ac78 commit 85629e8
Showing 1 changed file with 90 additions and 26 deletions.
116 changes: 90 additions & 26 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,19 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
return success();
}

static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
auto iface =
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
// Check that all uses of the reduction block arg has a distribute op parent.
for (auto ra : iface.getReductionBlockArgs())
for (auto &use : ra.getUses()) {
auto useOp = use.getOwner();
if (!useOp->getParentOfType<omp::DistributeOp>())
return false;
}
return true;
}

// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
static LogicalResult
convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
Expand All @@ -1662,32 +1675,39 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*op)))
return failure();

llvm::ArrayRef<bool> isByRef = getIsByRef(op.getReductionByref());
assert(isByRef.size() == op.getNumReductionVars());

DenseMap<Value, llvm::Value *> reductionVariableMap;
unsigned numReductionVars = op.getNumReductionVars();
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(op, reductionDecls);
SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
llvm::ArrayRef<bool> isByRef;
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);

SmallVector<llvm::Value *> privateReductionVariables(
op.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;
// Only do teams reduction if there is no distribute op that captures the
// reduction instead.
bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
if (doTeamsReduction) {
isByRef = getIsByRef(op.getReductionByref());

MutableArrayRef<BlockArgument> reductionArgs =
llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
assert(isByRef.size() == op.getNumReductionVars());

if (failed(allocAndInitializeReductionVars(
op, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();
MutableArrayRef<BlockArgument> reductionArgs =
llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();

// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
// omp.reduce operations in a separate call.
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);
collectReductionDecls(op, reductionDecls);

if (failed(allocAndInitializeReductionVars(
op, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();

// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
// omp.reduce operations in a separate call.
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);
}

auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
Expand Down Expand Up @@ -1723,13 +1743,13 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
return failure();

builder.restoreIP(*afterIP);

// Process the reductions if required.
return createReductionsAndCleanup(
op, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef,
/*isNoWait*/ false, /*isTeamsReduction*/ true);

if (doTeamsReduction) {
// Process the reductions if required.
return createReductionsAndCleanup(
op, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef,
/*isNoWait*/ false, /*isTeamsReduction*/ true);
}
return success();
}

Expand Down Expand Up @@ -3815,6 +3835,43 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(opInst)))
return failure();

/// Process teams op reduction in distribute if the reduction is contained in
/// the distribute op.
omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
bool doDistributeReduction =
teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;

DenseMap<Value, llvm::Value *> reductionVariableMap;
unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
SmallVector<omp::DeclareReductionOp> reductionDecls;
SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
llvm::ArrayRef<bool> isByRef;

if (doDistributeReduction) {
isByRef = getIsByRef(teamsOp.getReductionByref());
assert(isByRef.size() == teamsOp.getNumReductionVars());

collectReductionDecls(teamsOp, reductionDecls);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);

MutableArrayRef<BlockArgument> reductionArgs =
llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
.getReductionBlockArgs();

if (failed(allocAndInitializeReductionVars(
teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();
}

// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
// omp.reduce operations in a separate call.
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);

auto loopOp = cast<omp::LoopNestOp>(distributeOp.getWrappedLoop());

SmallVector<omp::LoopWrapperInterface> loopWrappers;
Expand Down Expand Up @@ -3861,6 +3918,13 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
return opInst.emitError(llvm::toString(afterIP.takeError()));
builder.restoreIP(*afterIP);

if (doDistributeReduction) {
// Process the reductions if required.
return createReductionsAndCleanup(
teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef,
/*isNoWait*/ false, /*isTeamsReduction*/ true);
}
return success();
}

Expand Down

0 comments on commit 85629e8

Please sign in to comment.