Skip to content

Commit

Permalink
Merge pull request #933 from xmos/memory-fix
Browse files Browse the repository at this point in the history
Channelwise conv for i16 and memory offset fix
  • Loading branch information
panickal-xmos authored Sep 24, 2024
2 parents 7d753aa + 7e44391 commit 2c72c42
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 38 deletions.
65 changes: 39 additions & 26 deletions xformer/Analysis/MemoryPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,34 +188,47 @@ std::vector<int> MemoryPlan::getAllocatedOffsets(const bool overlapOps,
inputVals.push_back(inVal);

auto outVal = o->getResult(0);
auto nextOp = *outVal.getUsers().begin();
// Identify chain of overlappable Ops
while (outVal.hasOneUse() && !alreadyVisited.contains(nextOp) &&
nextOp->hasTrait<OpTrait::xcore::MemoryOverlappable>()) {
inVal = outVal;
inputVals.push_back(inVal);
alreadyVisited.insert(nextOp);
outVal = nextOp->getResult(0);
nextOp = *outVal.getUsers().begin();
}

// Set first Used of output Val to the first input Val
vInfo[outVal].firstUsed = vInfo[inputVals[0]].firstUsed;
auto unalignedSizeOutVal =
utils::getShapedTypeSize(outVal.getType().dyn_cast<ShapedType>());
size_t maxSizeNeeded = 0;
for (auto inV : inputVals) {
auto unalignedSizeInV =
utils::getShapedTypeSize(inV.getType().dyn_cast<ShapedType>());
auto unalignedOffset = unalignedSizeOutVal - unalignedSizeInV;
// Align offset up to double word = 8 bytes
auto offset = ((unalignedOffset + 7) / 8) * 8;
maxSizeNeeded = std::max(vInfo[inV].size + offset, maxSizeNeeded);
inOutMap[inV] = {outVal, offset};
// Only overlap if the output value size is equal or larger than the
// input value size We use the allocated space for the output value to
// store the input value
if ((utils::getShapedTypeSize(
outVal.getType().dyn_cast<ShapedType>()) >=
utils::getShapedTypeSize(
inVal.getType().dyn_cast<ShapedType>()))) {
auto nextOp = *outVal.getUsers().begin();
// Identify chain of overlappable Ops
while (outVal.hasOneUse() && !alreadyVisited.contains(nextOp) &&
nextOp->hasTrait<OpTrait::xcore::MemoryOverlappable>() &&
(utils::getShapedTypeSize(
outVal.getType().dyn_cast<ShapedType>()) >=
utils::getShapedTypeSize(
inVal.getType().dyn_cast<ShapedType>()))) {
inVal = outVal;
inputVals.push_back(inVal);
alreadyVisited.insert(nextOp);
outVal = nextOp->getResult(0);
nextOp = *outVal.getUsers().begin();
}

// Set first Used of output Val to the first input Val
vInfo[outVal].firstUsed = vInfo[inputVals[0]].firstUsed;
auto unalignedSizeOutVal = utils::getShapedTypeSize(
outVal.getType().dyn_cast<ShapedType>());
size_t maxSizeNeeded = 0;
for (auto inV : inputVals) {
auto unalignedSizeInV = utils::getShapedTypeSize(
inV.getType().dyn_cast<ShapedType>());
auto unalignedOffset = unalignedSizeOutVal - unalignedSizeInV;
// Align offset up to double word = 8 bytes
auto offset = ((unalignedOffset + 7) / 8) * 8;
maxSizeNeeded = std::max(vInfo[inV].size + offset, maxSizeNeeded);
inOutMap[inV] = {outVal, offset};
}
// The aligned input val size plus aligned offset might be larger
// than aligned output val size
vInfo[outVal].size = std::max(vInfo[outVal].size, maxSizeNeeded);
}
// The aligned input val size plus aligned offset might be larger than
// aligned output val size
vInfo[outVal].size = std::max(vInfo[outVal].size, maxSizeNeeded);
}
}
}
Expand Down
38 changes: 26 additions & 12 deletions xformer/Transforms/OptimizeConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ struct ChannelwiseSplitConv2DOutputPattern
LogicalResult matchAndRewrite(TFL::Conv2DOp op,
PatternRewriter &rewriter) const override {
// Check for invalid types and return
// Input type must be QI8
// Input type must be QI8 or QI16
auto inputElementType =
op.getInput().getType().cast<ShapedType>().getElementType();
op.getInput().getType().template cast<ShapedType>().getElementType();
if (!utils::isNBitSignedQType<8>(inputElementType) &&
!utils::isNBitSignedQType<16>(inputElementType)) {
return failure();
}

auto filterElementType =
op.getFilter().getType().cast<ShapedType>().getElementType();

if (!utils::isNBitSignedQType<8>(inputElementType) ||
!utils::isNBitSignedQType<8>(filterElementType)) {
if (!utils::isNBitSignedQType<8>(filterElementType)) {
return failure();
}

Expand Down Expand Up @@ -91,13 +94,16 @@ struct ChannelwiseSplitConv2DOutputPattern

// We want to try to keep the split filtersize less than specified size
int numSplits = ceil(filterSize / convChannelwiseSplitSizeOption);
// Only try to split if at least two splits are possible
if (numSplits < 2) {
return failure();
}

// Let's split the filter batch size as that's the same as bias size and
// output channel size
// We want splits to be multiples of four
auto filterBatchSize = filterType.getShape()[0];
// Only try to split if at least two splits are possible
if (numSplits < 2 || filterBatchSize / 4 < 2) {
return failure();
}

// We want splits to be multiples of four, so we divide here and multiply
// after calculating the split sizes
int tmp = filterBatchSize / 4;
Expand Down Expand Up @@ -790,7 +796,6 @@ void OptimizeConv2D::runOnOperation() {
auto *ctx = &getContext();
func::FuncOp func = getOperation();
RewritePatternSet patterns(ctx);

// Convert TransposeConv2D with SAME padding to VALID padding + slice
patterns.insert<SameToValidTransposeConvPattern>(ctx);

Expand All @@ -813,12 +818,21 @@ void OptimizeConv2D::runOnOperation() {
// conv2d output channels, and add a slice to remove the padded
// section
patterns.insert<PadTo4DepthwiseConv2DPattern>(ctx);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
(void)applyPatternsAndFoldGreedily(func, frozenPatterns);

// We apply channelwise splitting after padding.
RewritePatternSet patterns2(ctx);
// When the filter is too large, we channelwise split the conv2d output to
// make multiple conv2ds so that the filter for each can be loaded separately.
// This means the filter batch gets split. We also have to split the
// quantization params.
patterns.insert<ChannelwiseSplitConv2DOutputPattern>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
patterns2.insert<ChannelwiseSplitConv2DOutputPattern>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns2));

// We apply padding once after channelwise splitting if there are some convs
// that are left unoptimized
(void)applyPatternsAndFoldGreedily(func, frozenPatterns);
}
} // namespace

Expand Down

0 comments on commit 2c72c42

Please sign in to comment.