diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index c8b2a773e0ca..ff677d14987a 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -2399,6 +2399,112 @@ struct GridwiseGemmAccelRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + // Generate only the compute loop, i.e., we assume here that all + // the data that we need is already in LDS + void generateComputeLoop( + Location loc, PatternRewriter &b, + const std::unique_ptr &accelEmitterPtr, + Value regsA, Value regsB, Value regsC, StringAttr arch, + GemmFeaturesAttr features, + const RockAccelTuningParamAttrInterface tuningParams) const { + + rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams(); + int64_t mRepeats = params.mRepeats; + int64_t nRepeats = params.nRepeats; + int64_t kBasePerThread = params.kBasePerThread; + + auto mLoop = b.create(loc, 0, mRepeats); + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(mLoop.getBody()); + Value i = mLoop.getInductionVar(); + + auto nLoop = b.create(loc, 0, nRepeats); + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(nLoop.getBody()); + Value j = nLoop.getInductionVar(); + + // regsC += regsA * regsB + auto kLoop = b.create(loc, 0, kBasePerThread); + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(kLoop.getBody()); + Value viewA = + accelEmitterPtr->generateThreadwiseViewBufferA(b, loc, regsA); + Value viewB = + accelEmitterPtr->generateThreadwiseViewBufferB(b, loc, regsB); + Value viewC = + accelEmitterPtr->generateThreadwiseViewBufferC(b, loc, regsC); + Value k = kLoop.getInductionVar(); + b.create(loc, viewA, viewB, viewC, + ValueRange{i, j, k}, arch, features, + tuningParams); + } + } + } + } + + // Generate the Read loop from LDS. So we read A[0:mRepeats, 0:kBasePerThread] + // and B[0:nRepeats, 0:kBasePerThread] before entering the MMA loop + void generateReadLoop( + Location loc, PatternRewriter &b, + const std::unique_ptr &accelEmitterPtr, + Value tid, Value ldsAView, Value ldsBView, Value regsA, Value regsB, + Value regsC, int64_t blockSize, int64_t inMPerThread, + int64_t inNPerThread, bool rotateMWithK, bool rotateNWithK) const { + + // wrapLDSBufferForLoad is reading a single set of Ks into private memory + // A/B[m/n, 0:kBasePerThread] + Value ldsA = accelEmitterPtr->wrapLDSBufferForLoad( + b, loc, ldsAView, blockSize, inMPerThread, "m", rotateMWithK); + + Value ldsB = accelEmitterPtr->wrapLDSBufferForLoad( + b, loc, ldsBView, blockSize, inNPerThread, "n", rotateNWithK); + + rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams(); + int64_t mRepeats = params.mRepeats; + int64_t nRepeats = params.nRepeats; + int64_t kBasePerThread = params.kBasePerThread; + + // We enhance the transformation from wrapLDSBufferForLoad using a builder + // that, given a single index, splits it into "m"("n") and "k" and lets + // tid pass through. We can give those indices to wrapLDSBufferForLoad which should + // compute the right transform + + // Read from LDS buffer for A + { + TopDownTMBuilder mkBuilder(b, {"tid", "mk"}, + {blockSize, mRepeats * kBasePerThread}, loc); + mkBuilder.passThrough("tid"); + mkBuilder.merge({"m", "k"}, {1, 2}, "mk", {mRepeats, kBasePerThread}); + + auto [ldsBufferA, ldsTransformsA, ignoreA] = rock::untransform(b, ldsA); + ldsTransformsA = rock::prependUpperViews( + b, b.getArrayAttr({mkBuilder.get()}), ldsTransformsA); + ldsA = rock::transform(b, ldsBufferA, ldsTransformsA); + b.create(loc, ldsA, regsA, b.getArrayAttr({}), + ValueRange{tid}, /*forceUnroll=*/true, + /*useIndexDiffs=*/true); + } + + // Read from LDS buffer for B + { + TopDownTMBuilder nkBuilder(b, {"tid", "nk"}, + {blockSize, nRepeats * kBasePerThread}, loc); + nkBuilder.passThrough("tid"); + nkBuilder.merge({"n", "k"}, {1, 2}, "nk", {nRepeats, kBasePerThread}); + + auto [ldsBufferB, ldsTransformsB, ignoreB] = rock::untransform(b, ldsB); + ldsTransformsB = rock::prependUpperViews( + b, b.getArrayAttr({nkBuilder.get()}), ldsTransformsB); + ldsB = rock::transform(b, ldsBufferB, ldsTransformsB); + b.create(loc, ldsB, regsB, b.getArrayAttr({}), + ValueRange{tid}, /*forceUnroll=*/true, + /*useIndexDiffs=*/true); + } + } + LogicalResult matchAndRewrite(GridwiseGemmAccelOp op, PatternRewriter &b) const override { Location loc = op.getLoc(); @@ -2692,11 +2798,21 @@ struct GridwiseGemmAccelRewritePattern Value ldsViewForGemmB = viewBufferAs(b, ldsByteBufferB, ldsReadTypeB); int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats; + // TODO: add an heuristic to decide if the ii should be 1 or 2. This is for + // now not worth it, since any form of double buffering results in poor + // assembly begin generated. So we need to stick with II=2 + int64_t initiationInterval = 2; + // Logic to setup buffers for blockwise_gemm_accel. - auto arrayA = - gpuAlloc(b, loc, kBasePerThread, argTypeA, AddressSpace::Private); - auto arrayB = - gpuAlloc(b, loc, kBasePerThread, argTypeB, AddressSpace::Private); + int64_t arrayALen = kBasePerThread; + int64_t arrayBLen = kBasePerThread; + if (initiationInterval == 1) { + arrayALen *= mRepeats; + arrayBLen *= nRepeats; + } + + auto arrayA = gpuAlloc(b, loc, arrayALen, argTypeA, AddressSpace::Private); + auto arrayB = gpuAlloc(b, loc, arrayBLen, argTypeB, AddressSpace::Private); auto regCAllocOp = gpuAlloc(b, loc, nOutputVectors, accVectorType, AddressSpace::Private); @@ -2709,8 +2825,9 @@ struct GridwiseGemmAccelRewritePattern BlockwiseGemmAccelOp blockwiseGemmAccelOp; auto loopOp = b.create(loc, zeroConstantOp, nIterations, step); - loopOp->setAttr(PipelineAttr::getMnemonic(), - rock::PipelineAttr::get(b.getContext(), 2)); + loopOp->setAttr( + PipelineAttr::getMnemonic(), + rock::PipelineAttr::get(b.getContext(), initiationInterval)); { PatternRewriter::InsertionGuard guard(b); b.setInsertionPointToStart(loopOp.getBody()); @@ -2772,20 +2889,48 @@ struct GridwiseGemmAccelRewritePattern b.create(loc); } - // Emit blockwise GEMM. - auto stage2 = b.create(loc, "MMA"); - { - PatternRewriter::InsertionGuard guard(b); - b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock()); - blockwiseGemmAccelOp = b.create( - loc, ldsViewForGemmA, ldsViewForGemmB, - b.getI32IntegerAttr(copyMPerThread), - b.getI32IntegerAttr(copyNPerThread), - (rotateMWithK ? b.getUnitAttr() : nullptr), - (rotateNWithK ? b.getUnitAttr() : nullptr), arrayA, arrayB, - regCAllocOp, op.getArchAttr(), op.getFeaturesAttr(), - op.getBlockSizeAttr(), op.getParamsAttr()); - b.create(loc); + if (initiationInterval > 1) { + // Emit blockwise GEMM. This will load data from LDS and + // compute the MMA at the same time + auto stage2 = b.create(loc, "MMA"); + { + PatternRewriter::InsertionGuard guard(b); + b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock()); + blockwiseGemmAccelOp = b.create( + loc, ldsViewForGemmA, ldsViewForGemmB, + b.getI32IntegerAttr(copyMPerThread), + b.getI32IntegerAttr(copyNPerThread), + (rotateMWithK ? b.getUnitAttr() : nullptr), + (rotateNWithK ? b.getUnitAttr() : nullptr), arrayA, arrayB, + regCAllocOp, op.getArchAttr(), op.getFeaturesAttr(), + op.getBlockSizeAttr(), op.getParamsAttr()); + b.create(loc); + } + } else { + // If we are running double-buffered pipeleines, it makes sense to also + // parellize The LDSRead/MMA stages. We do this here, by splitting the + // MMA loop in two separate stages + auto stage2 = b.create(loc, "LDSRead"); + { + // Read from LDS into registers + PatternRewriter::InsertionGuard guard(b); + b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock()); + generateReadLoop(loc, b, accelEmitterPtr, tid, ldsViewForGemmA, + ldsViewForGemmB, arrayA, arrayB, regCAllocOp, + blockSize, copyMPerThread, copyNPerThread, + rotateMWithK, rotateNWithK); + b.create(loc); + } + auto stage3 = b.create(loc, "MMA"); + { + // Compute the matrix-multiplication + PatternRewriter::InsertionGuard guard(b); + b.setInsertionPointToStart(&stage3.getRegion().emplaceBlock()); + generateComputeLoop(loc, b, accelEmitterPtr, arrayA, arrayB, + regCAllocOp, op.getArchAttr(), + op.getFeaturesAttr(), tuningParams); + b.create(loc); + } } } diff --git a/mlir/test/Dialect/Rock/test_rock_pipeline.mlir b/mlir/test/Dialect/Rock/test_rock_pipeline.mlir index 708d75b64298..bc5665fe0b99 100644 --- a/mlir/test/Dialect/Rock/test_rock_pipeline.mlir +++ b/mlir/test/Dialect/Rock/test_rock_pipeline.mlir @@ -246,3 +246,88 @@ func.func @rock_pipeline_4_stages_ii_2(%input : memref<16xi8, #gpu.address_space memref.store %out, %output[%c0] : memref<16xi8, #gpu.address_space> return } + +// CHECK-LABEL: rock_pipeline_4_stages_ii_1 +func.func @rock_pipeline_4_stages_ii_1(%input : memref<16xi8, #gpu.address_space>, %output : memref<16xi8, #gpu.address_space>){ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : i8 + %c16 = arith.constant 16 : index + + %rawLds = rock.alloc() : memref<16xi8, #gpu.address_space> + %rawReg0 = rock.alloc() : memref<16xi8, #gpu.address_space> + %rawReg1 = rock.alloc() : memref<16xi8, #gpu.address_space> + %rawReg2 = rock.alloc() : memref<16xi8, #gpu.address_space> + + %lds = memref.view %rawLds[%c0][] : memref<16xi8, #gpu.address_space> to memref<16xi8, #gpu.address_space> + %reg0 = memref.view %rawReg0[%c0][] : memref<16xi8, #gpu.address_space> to memref<16xi8, #gpu.address_space> + %reg1 = memref.view %rawReg1[%c0][] : memref<16xi8, #gpu.address_space> to memref<16xi8, #gpu.address_space> + %reg2 = memref.view %rawReg2[%c0][] : memref<16xi8, #gpu.address_space> to memref<16xi8, #gpu.address_space> + // CHECK: %[[rawLds0:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> + // CHECK: %[[rawLds1:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> + // CHECK: %[[rawReg0:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> + // CHECK: %[[rawReg1:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> + // CHECK: %[[rawReg2:.*]] = rock.alloc() : memref<16xi8, #gpu.address_space> + // CHECK: %[[ldsView0:.*]] = memref.view %[[rawLds0]] + // CHECK: %[[ldsView1:.*]] = memref.view %[[rawLds1]] + // CHECK: %[[regView0:.*]] = memref.view %[[rawReg0]] + // CHECK: %[[regView1:.*]] = memref.view %[[rawReg1]] + // CHECK: %[[regView2:.*]] = memref.view %[[rawReg2]] + + // Please note how we swap S0/S1 and S2/S3 to avoid private multi-buffers + // CHECK: name = "S0" + // CHECK: name = "S1" + // CHECK: name = "S0" + // CHECK: name = "__fwd_barrier__" + // CHECK: name = "S1" + // CHECK: name = "S0" + // CHECK: name = "S2" + // CHECK: scf.for + // CHECK: name = "__fwd_barrier__" + // CHECK: rock.extract_multibuffer(%[[regView0]]) + // CHECK: rock.extract_multibuffer(%[[ldsView0]], %[[ldsView1]]) + // CHECK: name = "S1" + // CHECK: rock.extract_multibuffer(%[[regView0]]) + // CHECK: name = "S0" + // CHECK: rock.extract_multibuffer(%[[regView1]]) + // CHECK: name = "S3" + // CHECK: rock.extract_multibuffer(%[[ldsView0]], %[[ldsView1]]) + // CHECK: rock.extract_multibuffer(%[[regView1]]) + // CHECK: name = "S2" + // CHECK: name = "__fwd_barrier__" + // CHECK: name = "S1" + // CHECK: name = "S3" + // CHECK: name = "S2" + // CHECK: name = "__fwd_barrier__" + // CHECK: name = "S3" + // CHECK: name = "S2" + // CHECK: name = "S3" + scf.for %arg3 = %c0 to %c16 step %c1 { + rock.stage { + %tmp = memref.load %input[%arg3] : memref<16xi8, #gpu.address_space> + memref.store %tmp, %reg0[%arg3] : memref<16xi8, #gpu.address_space> + rock.yield + }{name="S0"} + rock.stage { + %tmp = memref.load %reg0[%arg3] : memref<16xi8, #gpu.address_space> + memref.store %tmp, %lds[%arg3] : memref<16xi8, #gpu.address_space> + rock.yield + }{name="S1"} + rock.stage { + %tmp = memref.load %lds[%arg3] : memref<16xi8, #gpu.address_space> + %comp = arith.addi %tmp, %c2 : i8 + memref.store %tmp, %reg1[%arg3] : memref<16xi8, #gpu.address_space> + rock.yield + }{name="S2"} + rock.stage { + %tmp = memref.load %reg1[%arg3] : memref<16xi8, #gpu.address_space> + %comp = arith.addi %tmp, %c2 : i8 + memref.store %comp, %reg2[%arg3] : memref<16xi8, #gpu.address_space> + rock.yield + }{name="S3"} + }{pipeline = #rock.pipeline<1>} + + %out = memref.load %reg2[%c0] : memref<16xi8, #gpu.address_space> + memref.store %out, %output[%c0] : memref<16xi8, #gpu.address_space> + return +}