Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double buffering improvements #1511

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 165 additions & 20 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,112 @@ struct GridwiseGemmAccelRewritePattern
: public OpRewritePattern<GridwiseGemmAccelOp> {
using OpRewritePattern<GridwiseGemmAccelOp>::OpRewritePattern;

// Generate only the compute loop, i.e., we assume here that all
// the data that we need is already in LDS
void generateComputeLoop(
Location loc, PatternRewriter &b,
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
Value regsA, Value regsB, Value regsC, StringAttr arch,
GemmFeaturesAttr features,
const RockAccelTuningParamAttrInterface tuningParams) const {

rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams();
int64_t mRepeats = params.mRepeats;
int64_t nRepeats = params.nRepeats;
int64_t kBasePerThread = params.kBasePerThread;

auto mLoop = b.create<affine::AffineForOp>(loc, 0, mRepeats);
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(mLoop.getBody());
Value i = mLoop.getInductionVar();

auto nLoop = b.create<affine::AffineForOp>(loc, 0, nRepeats);
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(nLoop.getBody());
Value j = nLoop.getInductionVar();

// regsC += regsA * regsB
auto kLoop = b.create<affine::AffineForOp>(loc, 0, kBasePerThread);
{
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(kLoop.getBody());
Value viewA =
accelEmitterPtr->generateThreadwiseViewBufferA(b, loc, regsA);
Value viewB =
accelEmitterPtr->generateThreadwiseViewBufferB(b, loc, regsB);
Value viewC =
accelEmitterPtr->generateThreadwiseViewBufferC(b, loc, regsC);
Value k = kLoop.getInductionVar();
b.create<ThreadwiseAccelGemmOp>(loc, viewA, viewB, viewC,
ValueRange{i, j, k}, arch, features,
tuningParams);
}
}
}
}

// Generate the Read loop from LDS. So we read A[0:mRepeats, 0:kBasePerThread]
// and B[0:nRepeats, 0:kBasePerThread] before entering the MMA loop
void generateReadLoop(
Location loc, PatternRewriter &b,
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
Value tid, Value ldsAView, Value ldsBView, Value regsA, Value regsB,
Value regsC, int64_t blockSize, int64_t inMPerThread,
int64_t inNPerThread, bool rotateMWithK, bool rotateNWithK) const {

// wrapLDSBufferForLoad is reading a single set of Ks into private memory
// A/B[m/n, 0:kBasePerThread]
Value ldsA = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, ldsAView, blockSize, inMPerThread, "m", rotateMWithK);

Value ldsB = accelEmitterPtr->wrapLDSBufferForLoad(
b, loc, ldsBView, blockSize, inNPerThread, "n", rotateNWithK);

rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams();
int64_t mRepeats = params.mRepeats;
int64_t nRepeats = params.nRepeats;
int64_t kBasePerThread = params.kBasePerThread;

// We enhance the transformation from wrapLDSBufferForLoad using a builder
// that, given a single index, splits it into "m"("n") and "k" and lets
// tid pass through. We can give those indices to wrapLDSBufferForLoad which should
// compute the right transform

// Read from LDS buffer for A
{
TopDownTMBuilder mkBuilder(b, {"tid", "mk"},
{blockSize, mRepeats * kBasePerThread}, loc);
mkBuilder.passThrough("tid");
mkBuilder.merge({"m", "k"}, {1, 2}, "mk", {mRepeats, kBasePerThread});

auto [ldsBufferA, ldsTransformsA, ignoreA] = rock::untransform(b, ldsA);
ldsTransformsA = rock::prependUpperViews(
b, b.getArrayAttr({mkBuilder.get()}), ldsTransformsA);
ldsA = rock::transform(b, ldsBufferA, ldsTransformsA);
b.create<ThreadwiseReadIntoOp>(loc, ldsA, regsA, b.getArrayAttr({}),
ValueRange{tid}, /*forceUnroll=*/true,
/*useIndexDiffs=*/true);
}

// Read from LDS buffer for B
{
TopDownTMBuilder nkBuilder(b, {"tid", "nk"},
{blockSize, nRepeats * kBasePerThread}, loc);
nkBuilder.passThrough("tid");
nkBuilder.merge({"n", "k"}, {1, 2}, "nk", {nRepeats, kBasePerThread});

auto [ldsBufferB, ldsTransformsB, ignoreB] = rock::untransform(b, ldsB);
ldsTransformsB = rock::prependUpperViews(
b, b.getArrayAttr({nkBuilder.get()}), ldsTransformsB);
ldsB = rock::transform(b, ldsBufferB, ldsTransformsB);
b.create<ThreadwiseReadIntoOp>(loc, ldsB, regsB, b.getArrayAttr({}),
ValueRange{tid}, /*forceUnroll=*/true,
/*useIndexDiffs=*/true);
}
}

LogicalResult matchAndRewrite(GridwiseGemmAccelOp op,
PatternRewriter &b) const override {
Location loc = op.getLoc();
Expand Down Expand Up @@ -2692,11 +2798,21 @@ struct GridwiseGemmAccelRewritePattern
Value ldsViewForGemmB = viewBufferAs(b, ldsByteBufferB, ldsReadTypeB);
int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats;

// TODO: add an heuristic to decide if the ii should be 1 or 2. This is for
// now not worth it, since any form of double buffering results in poor
// assembly begin generated. So we need to stick with II=2
int64_t initiationInterval = 2;

// Logic to setup buffers for blockwise_gemm_accel.
auto arrayA =
gpuAlloc(b, loc, kBasePerThread, argTypeA, AddressSpace::Private);
auto arrayB =
gpuAlloc(b, loc, kBasePerThread, argTypeB, AddressSpace::Private);
int64_t arrayALen = kBasePerThread;
int64_t arrayBLen = kBasePerThread;
if (initiationInterval == 1) {
arrayALen *= mRepeats;
arrayBLen *= nRepeats;
}

auto arrayA = gpuAlloc(b, loc, arrayALen, argTypeA, AddressSpace::Private);
auto arrayB = gpuAlloc(b, loc, arrayBLen, argTypeB, AddressSpace::Private);
auto regCAllocOp =
gpuAlloc(b, loc, nOutputVectors, accVectorType, AddressSpace::Private);

Expand All @@ -2709,8 +2825,9 @@ struct GridwiseGemmAccelRewritePattern
BlockwiseGemmAccelOp blockwiseGemmAccelOp;

auto loopOp = b.create<scf::ForOp>(loc, zeroConstantOp, nIterations, step);
loopOp->setAttr(PipelineAttr::getMnemonic(),
rock::PipelineAttr::get(b.getContext(), 2));
loopOp->setAttr(
PipelineAttr::getMnemonic(),
rock::PipelineAttr::get(b.getContext(), initiationInterval));
{
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(loopOp.getBody());
Expand Down Expand Up @@ -2772,20 +2889,48 @@ struct GridwiseGemmAccelRewritePattern
b.create<rock::YieldOp>(loc);
}

// Emit blockwise GEMM.
auto stage2 = b.create<StageOp>(loc, "MMA");
{
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
blockwiseGemmAccelOp = b.create<BlockwiseGemmAccelOp>(
loc, ldsViewForGemmA, ldsViewForGemmB,
b.getI32IntegerAttr(copyMPerThread),
b.getI32IntegerAttr(copyNPerThread),
(rotateMWithK ? b.getUnitAttr() : nullptr),
(rotateNWithK ? b.getUnitAttr() : nullptr), arrayA, arrayB,
regCAllocOp, op.getArchAttr(), op.getFeaturesAttr(),
op.getBlockSizeAttr(), op.getParamsAttr());
b.create<rock::YieldOp>(loc);
if (initiationInterval > 1) {
// Emit blockwise GEMM. This will load data from LDS and
// compute the MMA at the same time
auto stage2 = b.create<StageOp>(loc, "MMA");
{
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
blockwiseGemmAccelOp = b.create<BlockwiseGemmAccelOp>(
loc, ldsViewForGemmA, ldsViewForGemmB,
b.getI32IntegerAttr(copyMPerThread),
b.getI32IntegerAttr(copyNPerThread),
(rotateMWithK ? b.getUnitAttr() : nullptr),
(rotateNWithK ? b.getUnitAttr() : nullptr), arrayA, arrayB,
regCAllocOp, op.getArchAttr(), op.getFeaturesAttr(),
op.getBlockSizeAttr(), op.getParamsAttr());
b.create<rock::YieldOp>(loc);
}
} else {
// If we are running double-buffered pipeleines, it makes sense to also
// parellize The LDSRead/MMA stages. We do this here, by splitting the
// MMA loop in two separate stages
auto stage2 = b.create<StageOp>(loc, "LDSRead");
{
// Read from LDS into registers
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
generateReadLoop(loc, b, accelEmitterPtr, tid, ldsViewForGemmA,
ldsViewForGemmB, arrayA, arrayB, regCAllocOp,
blockSize, copyMPerThread, copyNPerThread,
rotateMWithK, rotateNWithK);
b.create<rock::YieldOp>(loc);
}
auto stage3 = b.create<StageOp>(loc, "MMA");
{
// Compute the matrix-multiplication
PatternRewriter::InsertionGuard guard(b);
b.setInsertionPointToStart(&stage3.getRegion().emplaceBlock());
generateComputeLoop(loc, b, accelEmitterPtr, arrayA, arrayB,
regCAllocOp, op.getArchAttr(),
op.getFeaturesAttr(), tuningParams);
b.create<rock::YieldOp>(loc);
}
}
}

Expand Down
85 changes: 85 additions & 0 deletions mlir/test/Dialect/Rock/test_rock_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,88 @@ func.func @rock_pipeline_4_stages_ii_2(%input : memref<16xi8, #gpu.address_space
memref.store %out, %output[%c0] : memref<16xi8, #gpu.address_space<global>>
return
}

// CHECK-LABEL: rock_pipeline_4_stages_ii_1
func.func @rock_pipeline_4_stages_ii_1(%input : memref<16xi8, #gpu.address_space<global>>, %output : memref<16xi8, #gpu.address_space<global>>){
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : i8
%c16 = arith.constant 16 : index

%rawLds = rock.alloc() : memref<16xi8, #gpu.address_space<workgroup>>
%rawReg0 = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
%rawReg1 = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
%rawReg2 = rock.alloc() : memref<16xi8, #gpu.address_space<private>>

%lds = memref.view %rawLds[%c0][] : memref<16xi8, #gpu.address_space<workgroup>> to memref<16xi8, #gpu.address_space<workgroup>>
%reg0 = memref.view %rawReg0[%c0][] : memref<16xi8, #gpu.address_space<private>> to memref<16xi8, #gpu.address_space<private>>
%reg1 = memref.view %rawReg1[%c0][] : memref<16xi8, #gpu.address_space<private>> to memref<16xi8, #gpu.address_space<private>>
%reg2 = memref.view %rawReg2[%c0][] : memref<16xi8, #gpu.address_space<private>> to memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[rawLds0:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<workgroup>>
// CHECK: %[[rawLds1:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<workgroup>>
// CHECK: %[[rawReg0:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[rawReg1:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[rawReg2:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space<private>>
// CHECK: %[[ldsView0:.*]] = memref.view %[[rawLds0]]
// CHECK: %[[ldsView1:.*]] = memref.view %[[rawLds1]]
// CHECK: %[[regView0:.*]] = memref.view %[[rawReg0]]
// CHECK: %[[regView1:.*]] = memref.view %[[rawReg1]]
// CHECK: %[[regView2:.*]] = memref.view %[[rawReg2]]

// Please note how we swap S0/S1 and S2/S3 to avoid private multi-buffers
// CHECK: name = "S0"
// CHECK: name = "S1"
// CHECK: name = "S0"
// CHECK: name = "__fwd_barrier__"
// CHECK: name = "S1"
// CHECK: name = "S0"
// CHECK: name = "S2"
// CHECK: scf.for
// CHECK: name = "__fwd_barrier__"
// CHECK: rock.extract_multibuffer(%[[regView0]])
// CHECK: rock.extract_multibuffer(%[[ldsView0]], %[[ldsView1]])
// CHECK: name = "S1"
// CHECK: rock.extract_multibuffer(%[[regView0]])
// CHECK: name = "S0"
// CHECK: rock.extract_multibuffer(%[[regView1]])
// CHECK: name = "S3"
// CHECK: rock.extract_multibuffer(%[[ldsView0]], %[[ldsView1]])
// CHECK: rock.extract_multibuffer(%[[regView1]])
// CHECK: name = "S2"
// CHECK: name = "__fwd_barrier__"
// CHECK: name = "S1"
// CHECK: name = "S3"
// CHECK: name = "S2"
// CHECK: name = "__fwd_barrier__"
// CHECK: name = "S3"
// CHECK: name = "S2"
// CHECK: name = "S3"
scf.for %arg3 = %c0 to %c16 step %c1 {
rock.stage {
%tmp = memref.load %input[%arg3] : memref<16xi8, #gpu.address_space<global>>
memref.store %tmp, %reg0[%arg3] : memref<16xi8, #gpu.address_space<private>>
rock.yield
}{name="S0"}
rock.stage {
%tmp = memref.load %reg0[%arg3] : memref<16xi8, #gpu.address_space<private>>
memref.store %tmp, %lds[%arg3] : memref<16xi8, #gpu.address_space<workgroup>>
rock.yield
}{name="S1"}
rock.stage {
%tmp = memref.load %lds[%arg3] : memref<16xi8, #gpu.address_space<workgroup>>
%comp = arith.addi %tmp, %c2 : i8
memref.store %tmp, %reg1[%arg3] : memref<16xi8, #gpu.address_space<private>>
rock.yield
}{name="S2"}
rock.stage {
%tmp = memref.load %reg1[%arg3] : memref<16xi8, #gpu.address_space<private>>
%comp = arith.addi %tmp, %c2 : i8
memref.store %comp, %reg2[%arg3] : memref<16xi8, #gpu.address_space<private>>
rock.yield
}{name="S3"}
}{pipeline = #rock.pipeline<1>}

%out = memref.load %reg2[%c0] : memref<16xi8, #gpu.address_space<private>>
memref.store %out, %output[%c0] : memref<16xi8, #gpu.address_space<global>>
return
}