Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SWP] Attempt to move all scheduling logic to a scheduling pass #4618

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

manman-ren
Copy link
Collaborator

The main purpose is to move all scheduling related logic including scheduleDeps/scheduleDistOne/scheduleRemaining to a new loopScheduling pass. The new pass will generate (stage, cluster) attributes for each operation inside a loop.

For operations that are created during createAsyncOps, we manually add the (stage, cluster) based on the attributes prior to lowering. In most cases, the lowered operations should have (stage, cluster) of the original loadOp, or should have (stage, cluster) of the first use of the loadOp.

This patch also gets rid of prefetchCluster, instead it prefetches one stage before the actual use
setStageCluster(forOp, wait, stageForFirstUse - 1, clusterForFirstUse + 1);
setStageCluster(forOp, wait, stageForFirstUse - 1, clusterForFirstUse + 1);
comparing to
schedule.insert(wait, numStages - 2, prefetchCluster);
schedule.insert(viewLoad, numStages - 2, prefetchCluster);

At end of createAsyncOps, we make sure all operations inside the loop have (stage, cluster) attributes.

There is no need to maintain CoarseSchedule in SWP, instead we will just use (stage, cluster) attributes.

@manman-ren manman-ren marked this pull request as draft August 30, 2024 18:39
@@ -402,7 +429,6 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
&loadOpToIndLevelAndUse,
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
llvm::MapVector<Operation *, LoadInfo> loadToInfo;

for (auto &[op, dist, use] : loadOpToIndLevelAndUse) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am keeping the logic for loadOpToIndLvelAndUse inside assignMemoryLayout. It is not clear to me how to split the logic between loopScheduling and SWP. This function sets fields such as

  ttg::SharedEncodingAttr sharedEncoding = nullptr;
  ttg::BlockedEncodingAttr blockedEncoding = nullptr;
  bool loadIsMMAV3 = false;
  bool usedByDot = false;

The decision of whether or not to pipeline a loadOp depends on earlier decisions based on encoding etc. Maybe we can move this whole function to loopScheduling? But we will need to duplicate some of the logic to set the fields inside SWP.
CC @pawelszczerbuk @ThomasRaoux

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is not ready for detailed review. But I would appreciate looking at changes to MatmulLoopPipeline.cpp to see if the overall direction looks fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assignMemoryLayouts still uses loadOpToIndLevelAndUse in order to match functionality prior to the refactoring. But we can simplify it to find the first Dot use.

schedule, prefetchCluster, loadToInfo, numStages);
int retCode =
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
loadToInfo, numStages, maxClusterId);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have set (stage, cluster) for extractIdx to the first Use of all loadOps (see helper getExtractIdxStageCluster). If wait/viewLoad are put in the prefetch stage, cluster, we need to reset the attributes for extractIdx here.
It is not clear to me if we need multiple copies of extractIdx here.

// The assignMemoryLayouts helper is split into two parts, one is to detect
// if a load should be pipelined. The rest is still in SWP.
static llvm::DenseSet<Operation *>
filterPipelinedLoad(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pawelszczerbuk @ThomasRaoux Part of assignMemoryLayouts ends up in LoopScheduling. It calls getSharedEncIfAllUsersAreDotEnc at line 136.

@@ -402,7 +429,6 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
&loadOpToIndLevelAndUse,
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
llvm::MapVector<Operation *, LoadInfo> loadToInfo;

for (auto &[op, dist, use] : loadOpToIndLevelAndUse) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assignMemoryLayouts still uses loadOpToIndLevelAndUse in order to match functionality prior to the refactoring. But we can simplify it to find the first Dot use.

// Convert from attributes to CoarseSchedule
tt::CoarseSchedule coarseSchedule(numStages);
getCoarseSchedule(forOp, coarseSchedule);
scheduleDependencies(forOp, coarseSchedule, numStages);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of calling scheduleDependencies on CoarseSchedule, we can have a different helper scheduleDependenciesWithAttributes(forOp, numStages) to apply (stage, cluster) attributes on the dependent ops.

for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) {
if (loadsToPipeline.count(loadOp) == 0)
continue;
loadOp->setAttr("loop.load",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have these strings defined in some shared file?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of loop.load? Ideally we should only annotate with stage and schedule information otherwise we are adding ad hoc conventions between the different passes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I haven't paid much attention to this yet. We should define them in some header.

The plan is to only annotate with (stage, cluster), but the last problem of separating out assignMemoryLayouts caused some complications. We are deciding which loads to pipeline in part1 of assignMemoryLayouts, which is in LoopScheduling, and the second part of assignMemoryLayouts in SWP needs to know which loads are going to be pipelined.

@manman-ren manman-ren force-pushed the pr-comp-swp-move-lowering branch 2 times, most recently from d7f6a65 to c8e9647 Compare September 27, 2024 04:16
Summary: add scheduleDeps based on attributes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants