diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0dd406003e5e4c..fff54eddf174af 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -33,7 +33,6 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Transforms/RegionUtils.h" @@ -46,24 +45,18 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// -/// Add to the given target operation a host_eval argument, which must be -/// defined outside. -/// -/// \return the entry block argument to represent \c hostVar inside of the -/// target region. -static mlir::Value addHostEvalVar(mlir::omp::TargetOp targetOp, - mlir::Value hostVar) { - assert(!targetOp.getRegion().isAncestor(hostVar.getParentRegion()) && - "variable must be defined outside of the target region"); - - auto argIface = llvm::cast(*targetOp); - unsigned insertIndex = - argIface.getHostEvalBlockArgsStart() + argIface.numHostEvalBlockArgs(); +static void genOMPDispatch(lower::AbstractConverter &converter, + lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, + ConstructQueue::const_iterator item); - targetOp.getHostEvalVarsMutable().append(hostVar); - return targetOp.getRegion().insertArgument(insertIndex, hostVar.getType(), - hostVar.getLoc()); -} +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc); namespace { /// Structure holding the information needed to create and bind entry block @@ -83,7 +76,7 @@ struct EntryBlockArgsEntry { /// Structure holding the information needed to create and bind entry block /// arguments associated to all clauses that can define them. struct EntryBlockArgs { - EntryBlockArgsEntry hostEval; + llvm::ArrayRef hostEvalVars; EntryBlockArgsEntry inReduction; EntryBlockArgsEntry map; EntryBlockArgsEntry priv; @@ -93,8 +86,8 @@ struct EntryBlockArgs { EntryBlockArgsEntry useDevicePtr; bool isValid() const { - return hostEval.isValid() && inReduction.isValid() && map.isValid() && - priv.isValid() && reduction.isValid() && taskReduction.isValid() && + return inReduction.isValid() && map.isValid() && priv.isValid() && + reduction.isValid() && taskReduction.isValid() && useDeviceAddr.isValid() && useDevicePtr.isValid(); } @@ -106,201 +99,146 @@ struct EntryBlockArgs { auto getVars() const { return llvm::concat( - inReduction.vars, map.vars, priv.vars, reduction.vars, + hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars); } }; -} // namespace - -/// Get the directive enumeration value corresponding to the given OpenMP -/// construct PFT node. -llvm::omp::Directive -extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { - return common::visit( - common::visitors{ - [](const parser::OpenMPAllocatorsConstruct &c) { - return llvm::omp::OMPD_allocators; - }, - [](const parser::OpenMPAtomicConstruct &c) { - return llvm::omp::OMPD_atomic; - }, - [](const parser::OpenMPBlockConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [](const parser::OpenMPCriticalConstruct &c) { - return llvm::omp::OMPD_critical; - }, - [](const parser::OpenMPDeclarativeAllocate &c) { - return llvm::omp::OMPD_allocate; - }, - [](const parser::OpenMPExecutableAllocate &c) { - return llvm::omp::OMPD_allocate; - }, - [](const parser::OpenMPLoopConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [](const parser::OpenMPSectionConstruct &c) { - return llvm::omp::OMPD_section; - }, - [](const parser::OpenMPSectionsConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [](const parser::OpenMPStandaloneConstruct &c) { - return common::visit( - common::visitors{ - [](const parser::OpenMPSimpleStandaloneConstruct &c) { - return std::get(c.t) - .v; - }, - [](const parser::OpenMPFlushConstruct &c) { - return llvm::omp::OMPD_flush; - }, - [](const parser::OpenMPCancelConstruct &c) { - return llvm::omp::OMPD_cancel; - }, - [](const parser::OpenMPCancellationPointConstruct &c) { - return llvm::omp::OMPD_cancellation_point; - }, - [](const parser::OpenMPDepobjConstruct &c) { - return llvm::omp::OMPD_depobj; - }}, - c.u); - }}, - ompConstruct.u); -} - -/// Check whether the parent of the given evaluation contains other evaluations. -static bool evalHasSiblings(const lower::pft::Evaluation &eval) { - auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) { - for (auto &sibling : siblings) - if (&sibling != &eval && !sibling.isEndStmt()) - return true; - - return false; - }; - - return eval.parent.visit(common::visitors{ - [&](const lower::pft::Program &parent) { - return parent.getUnits().size() + parent.getCommonBlocks().size() > 1; - }, - [&](const lower::pft::Evaluation &parent) { - return checkSiblings(*parent.evaluationList); - }, - [&](const auto &parent) { - return checkSiblings(parent.evaluationList); - }}); -} - -/// Check whether the given omp.target operation exists and we're compiling for -/// the host device. -static bool isHostTarget(mlir::omp::TargetOp targetOp) { - if (!targetOp) - return false; - - auto offloadModOp = llvm::cast( - *targetOp->getParentOfType()); - - return !offloadModOp.getIsTargetDevice(); -} - -/// Check whether a given evaluation points to an OpenMP loop construct that -/// represents a target SPMD kernel. For this to be true, it must be a `target -/// teams distribute parallel do [simd]` or equivalent construct. -/// -/// Currently, this is limited to cases where all relevant OpenMP constructs are -/// either combined or directly nested within the same function. Also, the -/// composite `distribute parallel do` is not identified if split into two -/// explicit nested loops (i.e. a `distribute` loop and a `parallel do` loop). -static bool isTargetSPMDLoop(const lower::pft::Evaluation &eval) { - using namespace llvm::omp; - - const auto *ompEval = eval.getIf(); - if (!ompEval) - return false; - - switch (extractOmpDirective(*ompEval)) { - case OMPD_distribute_parallel_do: - case OMPD_distribute_parallel_do_simd: { - // It will return true only if one of these are true: - // - It has a 'target teams' parent and no siblings. - // - It has a 'teams' parent and no siblings, and the 'teams' has a - // 'target' parent and no siblings. - if (evalHasSiblings(eval)) - return false; - - const auto *parentEval = eval.parent.getIf(); - if (!parentEval) - return false; - const auto *parentOmpEval = parentEval->getIf(); - if (!parentOmpEval) - return false; +/// Structure holding information that is needed to pass host-evaluated +/// information to later lowering stages. +class HostEvalInfo { +public: + // Allow this function access to private members in order to initialize them. + friend void ::processHostEvalClauses(lower::AbstractConverter &, + semantics::SemanticsContext &, + lower::StatementContext &, + lower::pft::Evaluation &, + mlir::Location); + + /// Fill \c vars with values stored in \c ops. + /// + /// The order in which values are stored matches the one expected by \see + /// bindOperands(). + void collectValues(llvm::SmallVectorImpl &vars) const { + vars.append(ops.loopLowerBounds); + vars.append(ops.loopUpperBounds); + vars.append(ops.loopSteps); + + if (ops.numTeamsLower) + vars.push_back(ops.numTeamsLower); + + if (ops.numTeamsUpper) + vars.push_back(ops.numTeamsUpper); + + if (ops.numThreads) + vars.push_back(ops.numThreads); + + if (ops.threadLimit) + vars.push_back(ops.threadLimit); + } - auto parentDir = extractOmpDirective(*parentOmpEval); - if (parentDir == OMPD_target_teams) - return true; + /// Update \c ops, replacing all values with the corresponding block argument + /// in \c args. + /// + /// The order in which values are stored in \c args is the same as the one + /// used by \see collectValues(). + void bindOperands(llvm::ArrayRef args) { + assert(args.size() == + ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.threadLimit ? 1 : 0) && + "invalid block argument list"); + int argIndex = 0; + for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) + ops.loopLowerBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i) + ops.loopUpperBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopSteps.size(); ++i) + ops.loopSteps[i] = args[argIndex++]; + + if (ops.numTeamsLower) + ops.numTeamsLower = args[argIndex++]; + + if (ops.numTeamsUpper) + ops.numTeamsUpper = args[argIndex++]; + + if (ops.numThreads) + ops.numThreads = args[argIndex++]; + + if (ops.threadLimit) + ops.threadLimit = args[argIndex++]; + } - if (parentDir != OMPD_teams) + /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated + /// values and Fortran symbols, respectively, if they have already been + /// initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::LoopNestOperands &clauseOps, + llvm::SmallVectorImpl &ivOut) { + if (iv.empty() || loopNestApplied) { + loopNestApplied = true; return false; + } - if (evalHasSiblings(*parentEval)) - return false; + loopNestApplied = true; + clauseOps.loopLowerBounds = ops.loopLowerBounds; + clauseOps.loopUpperBounds = ops.loopUpperBounds; + clauseOps.loopSteps = ops.loopSteps; + ivOut.append(iv); + return true; + } - const auto *parentOfParentEval = - parentEval->parent.getIf(); - if (!parentOfParentEval) + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::ParallelOperands &clauseOps) { + if (!ops.numThreads || parallelApplied) { + parallelApplied = true; return false; + } - const auto *parentOfParentOmpEval = - parentOfParentEval->getIf(); - return parentOfParentOmpEval && - extractOmpDirective(*parentOfParentOmpEval) == OMPD_target; + parallelApplied = true; + clauseOps.numThreads = ops.numThreads; + return true; } - case OMPD_teams_distribute_parallel_do: - case OMPD_teams_distribute_parallel_do_simd: { - // Check there's a 'target' parent and no siblings. - if (evalHasSiblings(eval)) - return false; - const auto *parentEval = eval.parent.getIf(); - if (!parentEval) + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::TeamsOperands &clauseOps) { + if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit) return false; - const auto *parentOmpEval = parentEval->getIf(); - return parentOmpEval && extractOmpDirective(*parentOmpEval) == OMPD_target; - } - case OMPD_target_teams_distribute_parallel_do: - case OMPD_target_teams_distribute_parallel_do_simd: + clauseOps.numTeamsLower = ops.numTeamsLower; + clauseOps.numTeamsUpper = ops.numTeamsUpper; + clauseOps.threadLimit = ops.threadLimit; return true; - default: - return false; } -} - -static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) { - mlir::Operation *parentOp = builder.getBlock()->getParentOp(); - if (!parentOp) - return nullptr; - auto targetOp = llvm::dyn_cast(parentOp); - if (!targetOp) - targetOp = parentOp->getParentOfType(); - - return targetOp; -} +private: + mlir::omp::HostEvaluatedOperands ops; + llvm::SmallVector iv; + bool loopNestApplied = false, parallelApplied = false; +}; +} // namespace -static void genOMPDispatch(lower::AbstractConverter &converter, - lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, - ConstructQueue::const_iterator item); +/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target +/// operations being created. +/// +/// The current implementation prevents nested 'target' regions from breaking +/// the handling of the outer region by keeping a stack of information +/// structures, but it will probably still require some further work to support +/// reverse offloading. +static llvm::SmallVector hostEvalInfo; /// Bind symbols to their corresponding entry block arguments. /// @@ -423,8 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter, }; // Process in clause name alphabetical order to match block arguments order. - bindPrivateLike(args.hostEval.syms, args.hostEval.vars, - op.getHostEvalBlockArgs()); + // Do not bind host_eval variables because they cannot be used inside of the + // corresponding region, except for very specific cases handled separately. bindPrivateLike(args.inReduction.syms, args.inReduction.vars, op.getInReductionBlockArgs()); bindMapLike(args.map.syms, op.getMapBlockArgs()); @@ -462,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef vars, }); } +/// Get the directive enumeration value corresponding to the given OpenMP +/// construct PFT node. +llvm::omp::Directive +extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { + return common::visit( + common::visitors{ + [](const parser::OpenMPAllocatorsConstruct &c) { + return llvm::omp::OMPD_allocators; + }, + [](const parser::OpenMPAtomicConstruct &c) { + return llvm::omp::OMPD_atomic; + }, + [](const parser::OpenMPBlockConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [](const parser::OpenMPCriticalConstruct &c) { + return llvm::omp::OMPD_critical; + }, + [](const parser::OpenMPDeclarativeAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPExecutableAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPLoopConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [](const parser::OpenMPSectionConstruct &c) { + return llvm::omp::OMPD_section; + }, + [](const parser::OpenMPSectionsConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [](const parser::OpenMPStandaloneConstruct &c) { + return common::visit( + common::visitors{ + [](const parser::OpenMPSimpleStandaloneConstruct &c) { + return std::get(c.t) + .v; + }, + [](const parser::OpenMPFlushConstruct &c) { + return llvm::omp::OMPD_flush; + }, + [](const parser::OpenMPCancelConstruct &c) { + return llvm::omp::OMPD_cancel; + }, + [](const parser::OpenMPCancellationPointConstruct &c) { + return llvm::omp::OMPD_cancellation_point; + }, + [](const parser::OpenMPDepobjConstruct &c) { + return llvm::omp::OMPD_depobj; + }}, + c.u); + }}, + ompConstruct.u); +} + +/// Populate the global \see hostEvalInfo after processing clauses for the given +/// \p eval OpenMP target construct, or nested constructs, if these must be +/// evaluated outside of the target region per the spec. +/// +/// In particular, this will ensure that in 'target teams' and equivalent nested +/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated +/// in the host. Additionally, loop bounds, steps and the \c num_threads clause +/// will also be evaluated in the host if a target SPMD construct is detected +/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting). +/// +/// The result, stored as a global, is intended to be used to populate the \c +/// host_eval operands of the associated \c omp.target operation, and also to be +/// checked and used by later lowering steps to populate the corresponding +/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest +/// operations. +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc) { + // Obtain the list of clauses of the given OpenMP block or loop construct + // evaluation. Other evaluations passed to this lambda keep `clauses` + // unchanged. + auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval, + List &clauses) { + const auto *ompEval = eval.getIf(); + if (!ompEval) + return; + + const parser::OmpClauseList *beginClauseList = nullptr; + const parser::OmpClauseList *endClauseList = nullptr; + common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &ompConstruct) { + const auto &beginDirective = + std::get(ompConstruct.t); + beginClauseList = + &std::get(beginDirective.t); + endClauseList = &std::get( + std::get(ompConstruct.t).t); + }, + [&](const parser::OpenMPLoopConstruct &ompConstruct) { + const auto &beginDirective = + std::get(ompConstruct.t); + beginClauseList = + &std::get(beginDirective.t); + + if (auto &endDirective = + std::get>( + ompConstruct.t)) + endClauseList = + &std::get(endDirective->t); + }, + [&](const auto &) {}}, + ompEval->u); + + assert(beginClauseList && "expected begin directive"); + clauses.append(makeClauses(*beginClauseList, semaCtx)); + + if (endClauseList) + clauses.append(makeClauses(*endClauseList, semaCtx)); + }; + + // Return the directive that is immediately nested inside of the given + // `parent` evaluation, if it is its only non-end-statement nested evaluation + // and it represents an OpenMP construct. + auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent) + -> std::optional { + if (!parent.hasNestedEvaluations()) + return std::nullopt; + + llvm::omp::Directive dir; + auto &nested = parent.getFirstNestedEvaluation(); + if (const auto *ompEval = nested.getIf()) + dir = extractOmpDirective(*ompEval); + else + return std::nullopt; + + for (auto &sibling : parent.getNestedEvaluations()) + if (&sibling != &nested && !sibling.isEndStmt()) + return std::nullopt; + + return dir; + }; + + // Process the given evaluation assuming it's part of a 'target' construct or + // captured by one, and store results in the global `hostEvalInfo`. + std::function &)> + processEval; + processEval = [&](lower::pft::Evaluation &eval, const List &clauses) { + using namespace llvm::omp; + ClauseProcessor cp(converter, semaCtx, clauses); + + // Call `processEval` recursively with the immediately nested evaluation and + // its corresponding clauses if there is a single nested evaluation + // representing an OpenMP directive that passes the given test. + auto processSingleNestedIf = [&](llvm::function_ref test) { + std::optional nestedDir = extractOnlyOmpNestedDir(eval); + if (!nestedDir || !test(*nestedDir)) + return; + + lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation(); + List nestedClauses; + extractClauses(nestedEval, nestedClauses); + processEval(nestedEval, nestedClauses); + }; + + const auto *ompEval = eval.getIf(); + if (!ompEval) + return; + + HostEvalInfo &hostInfo = hostEvalInfo.back(); + + switch (extractOmpDirective(*ompEval)) { + // Cases where 'teams' and target SPMD clauses might be present. + case OMPD_teams_distribute_parallel_do: + case OMPD_teams_distribute_parallel_do_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute_parallel_do: + case OMPD_target_teams_distribute_parallel_do_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_distribute_parallel_do: + case OMPD_distribute_parallel_do_simd: + cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processNumThreads(stmtCtx, hostInfo.ops); + break; + + // Cases where 'teams' clauses might be present, and target SPMD is + // possible by looking at nested evaluations. + case OMPD_teams: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return nestedDir == OMPD_distribute_parallel_do || + nestedDir == OMPD_distribute_parallel_do_simd; + }); + break; + + // Cases where only 'teams' host-evaluated clauses might be present. + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + break; + + // Standalone 'target' case. + case OMPD_target: { + processSingleNestedIf( + [](Directive nestedDir) { return topTeamsSet.test(nestedDir); }); + break; + } + default: + break; + } + }; + + assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); + + const auto *ompEval = eval.getIf(); + assert(ompEval && + llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + "expected TARGET construct evaluation"); + + // Use the whole list of clauses passed to the construct here, rather than the + // ones only applied to omp.target. + List clauses; + extractClauses(eval, clauses); + processEval(eval, clauses); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -489,214 +667,6 @@ static void genNestedEvaluations(lower::AbstractConverter &converter, converter.genEval(e); } -static bool mustEvalTeamsOutsideTarget(const lower::pft::Evaluation &eval, - mlir::omp::TargetOp targetOp) { - if (!isHostTarget(targetOp)) - return false; - - llvm::omp::Directive dir = - extractOmpDirective(eval.get()); - - assert(llvm::omp::allTeamsSet.test(dir) && "expected a teams construct"); - return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval); -} - -static bool mustEvalTargetSPMDOutsideTarget(const lower::pft::Evaluation &eval, - mlir::omp::TargetOp targetOp) { - if (!isHostTarget(targetOp)) - return false; - - return isTargetSPMDLoop(eval); -} - -//===----------------------------------------------------------------------===// -// HostClausesInsertionGuard -//===----------------------------------------------------------------------===// - -/// If the insertion point of the builder is located inside of an omp.target -/// region, this RAII guard moves the insertion point to just before that -/// omp.target operation and then restores the original insertion point when -/// destroyed. If not currently inserting inside an omp.target, it remains -/// unchanged. -class HostClausesInsertionGuard { -public: - HostClausesInsertionGuard(mlir::OpBuilder &builder) : builder(builder) { - targetOp = findParentTargetOp(builder); - if (targetOp) { - ip = builder.saveInsertionPoint(); - builder.setInsertionPoint(targetOp); - } - } - - ~HostClausesInsertionGuard() { - if (ip.isSet()) { - fixupExtractedHostOps(); - builder.restoreInsertionPoint(ip); - } - } - - mlir::omp::TargetOp getTargetOp() const { return targetOp; } - -private: - mlir::OpBuilder &builder; - mlir::OpBuilder::InsertPoint ip; - mlir::omp::TargetOp targetOp; - - // Finds the list of op operands that escape the target op's region; that is: - // the operands that are used outside the target op but defined inside it. - void - findEscapingOpOperands(llvm::DenseSet &escapingOperands) { - if (!targetOp) - return; - - mlir::Region *targetParentRegion = targetOp->getParentRegion(); - assert(targetParentRegion != nullptr && - "Expected omp.target op to be nested in a parent region."); - - llvm::DenseSet visitedOps; - - // Walk the parent region in pre-order to make sure we visit `targetOp` - // before its nested ops. - targetParentRegion->walk( - [&](mlir::Operation *op) { - // Once we come across `targetOp`, we interrupt the walk since we - // already visited all the ops that come before it in the region. - if (op == targetOp) - return mlir::WalkResult::interrupt(); - - for (mlir::OpOperand &operand : op->getOpOperands()) { - mlir::Operation *operandDefiningOp = operand.get().getDefiningOp(); - - if (operandDefiningOp == nullptr) - continue; - - if (visitedOps.contains(operandDefiningOp)) - continue; - - visitedOps.insert(operandDefiningOp); - auto parentTargetOp = - operandDefiningOp->getParentOfType(); - - if (parentTargetOp != targetOp) - continue; - - escapingOperands.insert(&operand); - } - - return mlir::WalkResult::advance(); - }); - } - - // For an escaping operand, clone its use-def chain (i.e. its backward slice) - // outside the target region. - // - // \return the last op in the chain (this is the op that defines the escaping - // operand). - mlir::Operation * - cloneOperandSliceOutsideTargetOp(mlir::OpOperand *escapingOperand) { - mlir::Operation *operandDefiningOp = escapingOperand->get().getDefiningOp(); - llvm::SetVector backwardSlice; - mlir::BackwardSliceOptions sliceOptions; - sliceOptions.inclusive = true; - mlir::getBackwardSlice(operandDefiningOp, &backwardSlice, sliceOptions); - - auto ip = builder.saveInsertionPoint(); - - mlir::IRMapping mapper; - builder.setInsertionPoint(escapingOperand->getOwner()); - - mlir::Operation *lastSliceOp = nullptr; - llvm::SetVector opsToClone; - - for (auto *op : backwardSlice) { - // DeclareOps need special handling by searching for the corresponding ops - // in the host. Therefore, do not clone them since this special handling - // is done later in the fix-up process. - // - // TODO this might need a more elaborate handling in the future but for - // now this seems sufficient for our purposes. - if (llvm::isa(op)) { - opsToClone.clear(); - break; - } - - opsToClone.insert(op); - } - - for (mlir::Operation *op : opsToClone) - lastSliceOp = builder.clone(*op, mapper); - - builder.restoreInsertionPoint(ip); - return lastSliceOp; - } - - /// Fixup any uses of target region block arguments that we have just created - /// outside of the target region, and replace them by their host values. - void fixupExtractedHostOps() { - llvm::DenseSet escapingOperands; - findEscapingOpOperands(escapingOperands); - - for (mlir::OpOperand *operand : escapingOperands) { - mlir::Operation *operandDefiningOp = operand->get().getDefiningOp(); - assert(operandDefiningOp != nullptr && - "Expected escaping operand to have a defining op (i.e. not to be " - "a block argument)"); - mlir::Operation *lastSliceOp = cloneOperandSliceOutsideTargetOp(operand); - - if (lastSliceOp == nullptr) - continue; - - // Find the index of the operand in the list of results produced by its - // defining op. - unsigned operandResultIdx = 0; - for (auto [idx, res] : llvm::enumerate(operandDefiningOp->getResults())) { - if (res == operand->get()) { - operandResultIdx = idx; - break; - } - } - - // Replace the escaping operand with the corresponding value from the - // op that we cloned outside the target op. - operand->getOwner()->setOperand(operand->getOperandNumber(), - lastSliceOp->getResult(operandResultIdx)); - } - - auto useOutsideTargetRegion = [](mlir::OpOperand &operand) { - if (mlir::Operation *owner = operand.getOwner()) - return !owner->getParentOfType(); - return false; - }; - - auto argIface = llvm::cast(*targetOp); - for (auto [map, arg] : - llvm::zip_equal(targetOp.getMapVars(), argIface.getMapBlockArgs())) { - mlir::Value hostVal = - map.getDefiningOp().getVarPtr(); - - // Replace instances of omp.target block arguments used outside with their - // corresponding host value. - arg.replaceUsesWithIf(hostVal, [&](mlir::OpOperand &operand) -> bool { - // If the use is an hlfir.declare, we need to search for the matching - // one within host code. - if (auto declareOp = llvm::dyn_cast_if_present( - operand.getOwner())) { - if (auto hostDeclareOp = hostVal.getDefiningOp()) { - declareOp->replaceUsesWithIf(hostDeclareOp.getResults(), - useOutsideTargetRegion); - } else if (auto hostBoxOp = hostVal.getDefiningOp()) { - declareOp->replaceUsesWithIf(hostBoxOp.getVal() - .getDefiningOp() - .getResults(), - useOutsideTargetRegion); - } - } - return useOutsideTargetRegion(operand); - }); - } - } -}; - static fir::GlobalOp globalInitialization(lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder, const semantics::Symbol &sym, @@ -1053,7 +1023,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, llvm::SmallVector types; llvm::SmallVector locs; unsigned numVars = - args.hostEval.vars.size() + args.inReduction.vars.size() + + args.hostEvalVars.size() + args.inReduction.vars.size() + args.map.vars.size() + args.priv.vars.size() + args.reduction.vars.size() + args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size(); @@ -1069,7 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, // Populate block arguments in clause name alphabetical order to match // expected order by the BlockArgOpenMPOpInterface. - extractTypeLoc(args.hostEval.vars); + extractTypeLoc(args.hostEvalVars); extractTypeLoc(args.inReduction.vars); extractTypeLoc(args.map.vars); extractTypeLoc(args.priv.vars); @@ -1418,6 +1388,8 @@ static void genBodyOfTargetOp( dsp.processStep2(); bindEntryBlockArgs(converter, targetOp, args); + if (!hostEvalInfo.empty()) + hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1592,29 +1564,13 @@ static void genLoopNestClauses(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::LoopNestOperands &clauseOps, + mlir::Location loc, mlir::omp::LoopNestOperands &clauseOps, llvm::SmallVectorImpl &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - // Evaluate loop bounds on the host device, if the operation is defining part - // of a target SPMD kernel. - if (evalOutsideTarget) { - HostClausesInsertionGuard guard(converter.getFirOpBuilder()); + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv)) cp.processCollapse(loc, eval, clauseOps, iv); - for (unsigned i = 0; i < clauseOps.loopLowerBounds.size(); ++i) { - clauseOps.loopLowerBounds[i] = - addHostEvalVar(guard.getTargetOp(), clauseOps.loopLowerBounds[i]); - clauseOps.loopUpperBounds[i] = - addHostEvalVar(guard.getTargetOp(), clauseOps.loopUpperBounds[i]); - clauseOps.loopSteps[i] = - addHostEvalVar(guard.getTargetOp(), clauseOps.loopSteps[i]); - } - } else { - cp.processCollapse(loc, eval, clauseOps, iv); - } - clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); } @@ -1639,23 +1595,14 @@ genOrderedRegionClauses(lower::AbstractConverter &converter, static void genParallelClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::ParallelOperands &clauseOps, + mlir::Location loc, mlir::omp::ParallelOperands &clauseOps, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - // Evaluate NUM_THREADS on the host device, if the operation is defining part - // of a target SPMD kernel. - if (evalOutsideTarget) { - HostClausesInsertionGuard guard(converter.getFirOpBuilder()); - if (cp.processNumThreads(stmtCtx, clauseOps)) - clauseOps.numThreads = - addHostEvalVar(guard.getTargetOp(), clauseOps.numThreads); - } else { + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) cp.processNumThreads(stmtCtx, clauseOps); - } cp.processProcBind(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); @@ -1702,8 +1649,8 @@ static void genSingleClauses(lower::AbstractConverter &converter, static void genTargetClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool processHostOnlyClauses, + lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval, + const List &clauses, mlir::Location loc, mlir::omp::TargetOperands &clauseOps, llvm::SmallVectorImpl &hasDeviceAddrSyms, llvm::SmallVectorImpl &isDevicePtrSyms, @@ -1712,13 +1659,15 @@ static void genTargetClauses( cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); + if (!hostEvalInfo.empty()) { + // Only process host_eval if compiling for the host device. + processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); + hostEvalInfo.back().collectValues(clauseOps.hostEvalVars); + } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); - - if (processHostOnlyClauses) - cp.processNowait(clauseOps); - + cp.processNowait(clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); cp.processTODO &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::TeamsOperands &clauseOps, + mlir::Location loc, mlir::omp::TeamsOperands &clauseOps, llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - // Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside - // of an omp.target operation. - if (evalOutsideTarget) { - HostClausesInsertionGuard guard(converter.getFirOpBuilder()); - if (cp.processNumTeams(stmtCtx, clauseOps)) - clauseOps.numTeamsUpper = - addHostEvalVar(guard.getTargetOp(), clauseOps.numTeamsUpper); - - if (cp.processThreadLimit(stmtCtx, clauseOps)) - clauseOps.threadLimit = - addHostEvalVar(guard.getTargetOp(), clauseOps.threadLimit); - } else { + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) { cp.processNumTeams(stmtCtx, clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); } + cp.processReduction(loc, clauseOps, reductionSyms); } @@ -1994,13 +1932,14 @@ genOrderedRegionOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } -static mlir::omp::ParallelOp genParallelOp( - lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - mlir::Location loc, const ConstructQueue &queue, - ConstructQueue::const_iterator item, mlir::omp::ParallelOperands &clauseOps, - const EntryBlockArgs &args, DataSharingProcessor *dsp, - bool isComposite = false, mlir::omp::TargetOp parentTarget = nullptr) { +static mlir::omp::ParallelOp +genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, ConstructQueue::const_iterator item, + mlir::omp::ParallelOperands &clauseOps, + const EntryBlockArgs &args, DataSharingProcessor *dsp, + bool isComposite = false) { auto genRegionEntryCB = [&](mlir::Operation *op) { genEntryBlock(converter, args, op->getRegion(0)); bindEntryBlockArgs( @@ -2181,17 +2120,19 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); lower::StatementContext stmtCtx; + bool isTargetDevice = + llvm::cast(*converter.getModuleOp()) + .getIsTargetDevice(); - bool processHostOnlyClauses = - !llvm::cast(*converter.getModuleOp()) - .getIsTargetDevice(); + // Introduce a new host_eval information structure for this target region. + if (!isTargetDevice) + hostEvalInfo.emplace_back(); mlir::omp::TargetOperands clauseOps; llvm::SmallVector mapSyms, isDevicePtrSyms, hasDeviceAddrSyms; - genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, - isDevicePtrSyms, mapSyms); + genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc, + clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ @@ -2309,7 +2250,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, extractMappedBaseValues(clauseOps.mapVars, mapBaseValues); EntryBlockArgs args; - // TODO: Fill hostEval in advance rather than adding to it later on. + args.hostEvalVars = clauseOps.hostEvalVars; // TODO: Add in_reduction syms and vars. args.map.syms = mapSyms; args.map.vars = mapBaseValues; @@ -2318,6 +2259,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, queue, item, dsp); + + // Remove the host_eval information structure created for this target region. + if (!isTargetDevice) + hostEvalInfo.pop_back(); return targetOp; } @@ -2468,14 +2413,10 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; - mlir::omp::TargetOp targetOp = - findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = mustEvalTeamsOutsideTarget(eval, targetOp); - mlir::omp::TeamsOperands clauseOps; llvm::SmallVector reductionSyms; - genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - evalOutsideTarget, clauseOps, reductionSyms); + genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps, + reductionSyms); EntryBlockArgs args; // TODO: Add private syms and vars. @@ -2527,7 +2468,7 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); EntryBlockArgs distributeArgs; distributeArgs.priv.syms = dsp.getDelayedPrivSymbols(); @@ -2562,7 +2503,7 @@ static void genStandaloneDo(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); EntryBlockArgs wsloopArgs; // TODO: Add private syms and vars. @@ -2588,8 +2529,7 @@ static void genStandaloneParallel(lower::AbstractConverter &converter, mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - /*evalOutsideTarget=*/false, parallelClauseOps, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); std::optional dsp; if (enableDelayedPrivatization) { @@ -2633,7 +2573,7 @@ static void genStandaloneSimd(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); EntryBlockArgs simdArgs; // TODO: Add private syms and vars. @@ -2673,16 +2613,11 @@ static void genCompositeDistributeParallelDo( ConstructQueue::const_iterator parallelItem = std::next(distributeItem); ConstructQueue::const_iterator doItem = std::next(parallelItem); - mlir::omp::TargetOp targetOp = - findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = mustEvalTargetSPMDOutsideTarget(eval, targetOp); - // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - evalOutsideTarget, parallelClauseOps, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, @@ -2696,8 +2631,7 @@ static void genCompositeDistributeParallelDo( parallelArgs.reduction.syms = parallelReductionSyms; parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelArgs, &dsp, - /*isComposite=*/true, evalOutsideTarget ? targetOp : nullptr); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2712,7 +2646,7 @@ static void genCompositeDistributeParallelDo( mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, doItem->clauses, loc, - evalOutsideTarget, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs distributeArgs; @@ -2748,16 +2682,11 @@ static void genCompositeDistributeParallelDoSimd( ConstructQueue::const_iterator doItem = std::next(parallelItem); ConstructQueue::const_iterator simdItem = std::next(doItem); - mlir::omp::TargetOp targetOp = - findParentTargetOp(converter.getFirOpBuilder()); - bool evalOutsideTarget = mustEvalTargetSPMDOutsideTarget(eval, targetOp); - // Create parent omp.parallel first. mlir::omp::ParallelOperands parallelClauseOps; llvm::SmallVector parallelReductionSyms; genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc, - evalOutsideTarget, parallelClauseOps, - parallelReductionSyms); + parallelClauseOps, parallelReductionSyms); DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, @@ -2771,8 +2700,7 @@ static void genCompositeDistributeParallelDoSimd( parallelArgs.reduction.syms = parallelReductionSyms; parallelArgs.reduction.vars = parallelClauseOps.reductionVars; genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem, - parallelClauseOps, parallelArgs, &dsp, - /*isComposite=*/true, evalOutsideTarget ? targetOp : nullptr); + parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true); // Clause processing. mlir::omp::DistributeOperands distributeClauseOps; @@ -2792,7 +2720,7 @@ static void genCompositeDistributeParallelDoSimd( mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc, - evalOutsideTarget, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs distributeArgs; @@ -2860,7 +2788,7 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs distributeArgs; @@ -2918,7 +2846,7 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter, mlir::omp::LoopNestOperands loopNestClauseOps; llvm::SmallVector iv; genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc, - /*evalOutsideTarget=*/false, loopNestClauseOps, iv); + loopNestClauseOps, iv); // Operation creation. EntryBlockArgs wsloopArgs; diff --git a/flang/test/Lower/OpenMP/eval-outside-target.f90 b/flang/test/Lower/OpenMP/eval-outside-target.f90 index d0925971e4b2bc..32c52462b86a76 100644 --- a/flang/test/Lower/OpenMP/eval-outside-target.f90 +++ b/flang/test/Lower/OpenMP/eval-outside-target.f90 @@ -33,7 +33,7 @@ end subroutine teams subroutine distribute_parallel_do() ! BOTH: omp.target - ! HOST-SAME: host_eval(%{{.*}} -> %[[NUM_THREADS:.*]], %{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]] : i32, i32, i32, i32) + ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32) ! DEVICE-NOT: host_eval({{.*}}) ! DEVICE-SAME: { @@ -95,7 +95,7 @@ end subroutine distribute_parallel_do subroutine distribute_parallel_do_simd() ! BOTH: omp.target - ! HOST-SAME: host_eval(%{{.*}} -> %[[NUM_THREADS:.*]], %{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]] : i32, i32, i32, i32) + ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32) ! DEVICE-NOT: host_eval({{.*}}) ! DEVICE-SAME: { diff --git a/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 b/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 deleted file mode 100644 index b4d5cffffac1d6..00000000000000 --- a/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 +++ /dev/null @@ -1,95 +0,0 @@ -! Verifies that if expressions are used to compute a target parallel loop, that -! no values escape the target region when flang emits the ops corresponding to -! these expressions (for example the compute the trip count for the target region). - -! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s - -subroutine foo(upper_bound) - implicit none - integer :: upper_bound - integer :: nodes(1 : upper_bound) - integer :: i - - !$omp target teams distribute parallel do - do i = 1, ubound(nodes,1) - nodes(i) = i - end do - !$omp end target teams distribute parallel do -end subroutine - -! CHECK: func.func @_QPfoo(%[[FUNC_ARG:.*]]: !fir.ref {fir.bindc_name = "upper_bound"}) { -! CHECK: %[[UB_ALLOC:.*]] = fir.alloca i32 -! CHECK: fir.dummy_scope : !fir.dscope -! CHECK: %[[UB_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] {{.*}} {uniq_name = "_QFfooEupper_bound"} - -! CHECK: omp.map.info -! CHECK: omp.map.info -! CHECK: omp.map.info - -! Verify that we load from the original/host allocation of the `upper_bound` -! variable rather than the corresponding target region arg. - -! CHECK: fir.load %[[UB_ALLOC]] : !fir.ref -! CHECK: omp.target - -! CHECK: } - -subroutine foo_with_dummy_arg(nodes) - implicit none - integer, intent(inout) :: nodes( : ) - integer :: i - - !$omp target teams distribute parallel do - do i = 1, ubound(nodes, 1) - nodes(i) = i - end do - !$omp end target teams distribute parallel do -end subroutine - -! CHECK: func.func @_QPfoo_with_dummy_arg(%[[FUNC_ARG:.*]]: !fir.box> {fir.bindc_name = "nodes"}) { - -! CHECK: %[[ARR_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] dummy_scope - -! CHECK: omp.map.info -! CHECK: omp.map.info -! CHECK: omp.map.info - -! Verify that we get the box dims of the host array declaration not the target -! one. - -! CHECK: fir.box_dims %[[ARR_DECL]] - -! CHECK: omp.target - -! CHECK: } - - -subroutine bounds_expr_in_loop_control(array) - real, intent(out) :: array(:,:) - integer :: bounds(2), i, j - bounds = shape(array) - - !$omp target teams distribute parallel do simd collapse(2) - do j = 1,bounds(2) - do i = 1,bounds(1) - array(i,j) = 0. - enddo - enddo -end subroutine bounds_expr_in_loop_control - - -! CHECK: func.func @_QPbounds_expr_in_loop_control(%[[FUNC_ARG:.*]]: {{.*}}) { - -! CHECK: %[[BOUNDS_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "{{.*}}Ebounds"} : (!fir.ref>, !fir.shape<1>) -> ({{.*}}) - -! Verify that the host declaration of `bounds` (i.e. not the target/mapped one) -! is used for the trip count calculation. Trip count is calculation ops are emitted -! directly before the `omp.target` op and after all `omp.map.info` op; hence the -! `CHECK-NOT: ...` line. - -! CHECK: hlfir.designate %[[BOUNDS_DECL:.*]]#0 (%c2{{.*}}) -! CHECK: hlfir.designate %[[BOUNDS_DECL:.*]]#0 (%c1{{.*}}) -! CHECK-NOT: omp.map.info -! CHECK: omp.target - -! CHECK: } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index 1247a871f93c6d..f9a85626a3f149 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -41,6 +41,12 @@ struct DeviceTypeClauseOps { // Extra operation operand structures. //===----------------------------------------------------------------------===// +/// Clauses that correspond to operations other than omp.target, but might have +/// to be evaluated outside of a parent target region. +using HostEvaluatedOperands = + detail::Clauses; + // TODO: Add `indirect` clause. using DeclareTargetOperands = detail::Clauses; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index c142747ed2a7a9..e11e8c8f8dd587 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4210,31 +4210,54 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, /// corresponding global `ConfigurationEnvironmentTy` structure. static void initTargetDefaultBounds( omp::TargetOp targetOp, - llvm::OpenMPIRBuilder::TargetKernelDefaultBounds &bounds, bool isGPU) { - Value hostNumThreads, hostNumTeamsLower, hostNumTeamsUpper, hostThreadLimit; - extractHostEvalClauses(targetOp, hostNumThreads, hostNumTeamsLower, - hostNumTeamsUpper, hostThreadLimit); + llvm::OpenMPIRBuilder::TargetKernelDefaultBounds &bounds, + bool isTargetDevice, bool isGPU) { + // TODO: Handle constant 'if' clauses. + Operation *capturedOp = targetOp.getInnermostCapturedOmpOp(); + + // Extract values for host-evaluated clauses. + Value numThreads, numTeamsLower, numTeamsUpper, threadLimit; + if (!isTargetDevice) { + extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, + threadLimit); + } else { + // In the target device, values for these clauses are not passed as + // host_eval, but instead evaluated prior to entry to the region. This + // ensures values are mapped and available inside of the target region. + if (auto teamsOp = castOrGetParentOfType(capturedOp)) { + numTeamsLower = teamsOp.getNumTeamsLower(); + numTeamsUpper = teamsOp.getNumTeamsUpper(); + threadLimit = teamsOp.getThreadLimit(); + } + + if (auto parallelOp = castOrGetParentOfType(capturedOp)) + numThreads = parallelOp.getNumThreads(); + } - // TODO Handle constant IF clauses - Operation *innermostCapturedOmpOp = targetOp.getInnermostCapturedOmpOp(); + auto extractConstInteger = [](Value value) -> std::optional { + if (auto constOp = + dyn_cast_if_present(value.getDefiningOp())) + if (auto constAttr = dyn_cast(constOp.getValue())) + return constAttr.getInt(); + + return std::nullopt; + }; // Handle clauses impacting the number of teams. + int32_t minTeamsVal = 1, maxTeamsVal = -1; - if (castOrGetParentOfType(innermostCapturedOmpOp)) { + if (castOrGetParentOfType(capturedOp)) { // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match // clang and set min and max to the same value. - if (hostNumTeamsUpper) { - if (auto constOp = dyn_cast_if_present( - hostNumTeamsUpper.getDefiningOp())) { - if (auto constAttr = dyn_cast(constOp.getValue())) - minTeamsVal = maxTeamsVal = constAttr.getInt(); - } + if (numTeamsUpper) { + if (auto val = extractConstInteger(numTeamsUpper)) + minTeamsVal = maxTeamsVal = *val; } else { minTeamsVal = maxTeamsVal = 0; } - } else if (castOrGetParentOfType(innermostCapturedOmpOp, + } else if (castOrGetParentOfType(capturedOp, /*immediateParent=*/true) || - castOrGetParentOfType(innermostCapturedOmpOp, + castOrGetParentOfType(capturedOp, /*immediateParent=*/true)) { minTeamsVal = maxTeamsVal = 1; } else { @@ -4242,32 +4265,31 @@ static void initTargetDefaultBounds( } // Handle clauses impacting the number of threads. - int32_t targetThreadLimitVal = -1; - int32_t teamsThreadLimitVal = -1; - int32_t maxThreadsVal = -1; - auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) { - if (clauseValue) { - if (auto constOp = dyn_cast_if_present( - clauseValue.getDefiningOp())) { - if (auto constAttr = dyn_cast(constOp.getValue())) - result = constAttr.getInt(); - } - // Found an applicable clause, so it's not undefined. Mark as unknown - // because it's not constant. - if (result < 0) - result = 0; - } + auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue, + int32_t &result) { + if (!clauseValue) + return; + + if (auto val = extractConstInteger(clauseValue)) + result = *val; + + // Found an applicable clause, so it's not undefined. Mark as unknown + // because it's not constant. + if (result < 0) + result = 0; }; // Extract THREAD_LIMIT clause from TARGET and TEAMS directives. + int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1; setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal); - setMaxValueFromClause(hostThreadLimit, teamsThreadLimitVal); + setMaxValueFromClause(threadLimit, teamsThreadLimitVal); // Extract MAX_THREADS clause from PARALLEL or set to 1 if it's SIMD. - if (castOrGetParentOfType(innermostCapturedOmpOp)) - setMaxValueFromClause(hostNumThreads, maxThreadsVal); - else if (castOrGetParentOfType(innermostCapturedOmpOp, + int32_t maxThreadsVal = -1; + if (castOrGetParentOfType(capturedOp)) + setMaxValueFromClause(numThreads, maxThreadsVal); + else if (castOrGetParentOfType(capturedOp, /*immediateParent=*/true)) maxThreadsVal = 1; @@ -4285,9 +4307,8 @@ static void initTargetDefaultBounds( // Calculate reduction data size, limited to single reduction variable // for now. int32_t reductionDataSize = 0; - if (isGPU && innermostCapturedOmpOp) { - if (auto teamsOp = - castOrGetParentOfType(innermostCapturedOmpOp)) { + if (isGPU && capturedOp) { + if (auto teamsOp = castOrGetParentOfType(capturedOp)) { reductionDataSize = getTeamsReductionDataSize(teamsOp); } } @@ -4514,7 +4535,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, llvm::SmallVector kernelInput; llvm::OpenMPIRBuilder::TargetKernelDefaultBounds defaultBounds; - initTargetDefaultBounds(targetOp, defaultBounds, isGPU); + initTargetDefaultBounds(targetOp, defaultBounds, isTargetDevice, isGPU); // Collect host-evaluated values needed to properly launch the kernel from the // host.