From e2c71999e1d859f1a0f3e97dfe1f8712aec53bd9 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Tue, 24 Sep 2024 14:14:15 +0100 Subject: [PATCH 1/2] Memory overlap should be done only if output space is larger than input --- xformer/Analysis/MemoryPlan.cpp | 65 ++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/xformer/Analysis/MemoryPlan.cpp b/xformer/Analysis/MemoryPlan.cpp index d08ccd8f1..d7618b7f7 100644 --- a/xformer/Analysis/MemoryPlan.cpp +++ b/xformer/Analysis/MemoryPlan.cpp @@ -188,34 +188,47 @@ std::vector MemoryPlan::getAllocatedOffsets(const bool overlapOps, inputVals.push_back(inVal); auto outVal = o->getResult(0); - auto nextOp = *outVal.getUsers().begin(); - // Identify chain of overlappable Ops - while (outVal.hasOneUse() && !alreadyVisited.contains(nextOp) && - nextOp->hasTrait()) { - inVal = outVal; - inputVals.push_back(inVal); - alreadyVisited.insert(nextOp); - outVal = nextOp->getResult(0); - nextOp = *outVal.getUsers().begin(); - } - // Set first Used of output Val to the first input Val - vInfo[outVal].firstUsed = vInfo[inputVals[0]].firstUsed; - auto unalignedSizeOutVal = - utils::getShapedTypeSize(outVal.getType().dyn_cast()); - size_t maxSizeNeeded = 0; - for (auto inV : inputVals) { - auto unalignedSizeInV = - utils::getShapedTypeSize(inV.getType().dyn_cast()); - auto unalignedOffset = unalignedSizeOutVal - unalignedSizeInV; - // Align offset up to double word = 8 bytes - auto offset = ((unalignedOffset + 7) / 8) * 8; - maxSizeNeeded = std::max(vInfo[inV].size + offset, maxSizeNeeded); - inOutMap[inV] = {outVal, offset}; + // Only overlap if the output value size is equal or larger than the + // input value size We use the allocated space for the output value to + // store the input value + if ((utils::getShapedTypeSize( + outVal.getType().dyn_cast()) >= + utils::getShapedTypeSize( + inVal.getType().dyn_cast()))) { + auto nextOp = *outVal.getUsers().begin(); + // Identify chain of overlappable Ops + while (outVal.hasOneUse() && !alreadyVisited.contains(nextOp) && + nextOp->hasTrait() && + (utils::getShapedTypeSize( + outVal.getType().dyn_cast()) >= + utils::getShapedTypeSize( + inVal.getType().dyn_cast()))) { + inVal = outVal; + inputVals.push_back(inVal); + alreadyVisited.insert(nextOp); + outVal = nextOp->getResult(0); + nextOp = *outVal.getUsers().begin(); + } + + // Set first Used of output Val to the first input Val + vInfo[outVal].firstUsed = vInfo[inputVals[0]].firstUsed; + auto unalignedSizeOutVal = utils::getShapedTypeSize( + outVal.getType().dyn_cast()); + size_t maxSizeNeeded = 0; + for (auto inV : inputVals) { + auto unalignedSizeInV = utils::getShapedTypeSize( + inV.getType().dyn_cast()); + auto unalignedOffset = unalignedSizeOutVal - unalignedSizeInV; + // Align offset up to double word = 8 bytes + auto offset = ((unalignedOffset + 7) / 8) * 8; + maxSizeNeeded = std::max(vInfo[inV].size + offset, maxSizeNeeded); + inOutMap[inV] = {outVal, offset}; + } + // The aligned input val size plus aligned offset might be larger + // than aligned output val size + vInfo[outVal].size = std::max(vInfo[outVal].size, maxSizeNeeded); } - // The aligned input val size plus aligned offset might be larger than - // aligned output val size - vInfo[outVal].size = std::max(vInfo[outVal].size, maxSizeNeeded); } } } From 7e44391c15cb1c32a8d18ccf34ea88226b56d965 Mon Sep 17 00:00:00 2001 From: panickal-xmos Date: Tue, 24 Sep 2024 14:15:05 +0100 Subject: [PATCH 2/2] Enable channelwise split for i16 and tidy up pattern matching --- xformer/Transforms/OptimizeConv2D.cpp | 38 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/xformer/Transforms/OptimizeConv2D.cpp b/xformer/Transforms/OptimizeConv2D.cpp index b6642df89..62d637106 100644 --- a/xformer/Transforms/OptimizeConv2D.cpp +++ b/xformer/Transforms/OptimizeConv2D.cpp @@ -31,14 +31,17 @@ struct ChannelwiseSplitConv2DOutputPattern LogicalResult matchAndRewrite(TFL::Conv2DOp op, PatternRewriter &rewriter) const override { // Check for invalid types and return - // Input type must be QI8 + // Input type must be QI8 or QI16 auto inputElementType = - op.getInput().getType().cast().getElementType(); + op.getInput().getType().template cast().getElementType(); + if (!utils::isNBitSignedQType<8>(inputElementType) && + !utils::isNBitSignedQType<16>(inputElementType)) { + return failure(); + } + auto filterElementType = op.getFilter().getType().cast().getElementType(); - - if (!utils::isNBitSignedQType<8>(inputElementType) || - !utils::isNBitSignedQType<8>(filterElementType)) { + if (!utils::isNBitSignedQType<8>(filterElementType)) { return failure(); } @@ -91,13 +94,16 @@ struct ChannelwiseSplitConv2DOutputPattern // We want to try to keep the split filtersize less than specified size int numSplits = ceil(filterSize / convChannelwiseSplitSizeOption); - // Only try to split if at least two splits are possible - if (numSplits < 2) { - return failure(); - } + // Let's split the filter batch size as that's the same as bias size and // output channel size + // We want splits to be multiples of four auto filterBatchSize = filterType.getShape()[0]; + // Only try to split if at least two splits are possible + if (numSplits < 2 || filterBatchSize / 4 < 2) { + return failure(); + } + // We want splits to be multiples of four, so we divide here and multiply // after calculating the split sizes int tmp = filterBatchSize / 4; @@ -790,7 +796,6 @@ void OptimizeConv2D::runOnOperation() { auto *ctx = &getContext(); func::FuncOp func = getOperation(); RewritePatternSet patterns(ctx); - // Convert TransposeConv2D with SAME padding to VALID padding + slice patterns.insert(ctx); @@ -813,12 +818,21 @@ void OptimizeConv2D::runOnOperation() { // conv2d output channels, and add a slice to remove the padded // section patterns.insert(ctx); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + (void)applyPatternsAndFoldGreedily(func, frozenPatterns); + + // We apply channelwise splitting after padding. + RewritePatternSet patterns2(ctx); // When the filter is too large, we channelwise split the conv2d output to // make multiple conv2ds so that the filter for each can be loaded separately. // This means the filter batch gets split. We also have to split the // quantization params. - patterns.insert(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + patterns2.insert(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns2)); + + // We apply padding once after channelwise splitting if there are some convs + // that are left unoptimized + (void)applyPatternsAndFoldGreedily(func, frozenPatterns); } } // namespace