Skip to content

Commit

Permalink
Fuse batch normalization into convolution weights
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpant committed Nov 18, 2024
1 parent b1c1115 commit da6cff5
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 0 deletions.
60 changes: 60 additions & 0 deletions stablehlo/testdata/bn_conv_fuse_float32.mlir

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1467,6 +1468,154 @@ struct ReorderElementwiseAndShapeOp final
}
};

// Fuses batch normalization operation with convolution weight:
// X = conv(input, weight)
// Y = batch_norm_inference(X, ...)
// into ->
// X = conv(input, weight(new))
// Y = add(X, broadcast_in_dim(Bias(new)))
//
struct FuseConvolutionBatchNormalization final
: OpRewritePattern<BatchNormInferenceOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(BatchNormInferenceOp op,
PatternRewriter &rewriter) const override {
auto bnOperandType = op.getOperand().getType();
auto bnOperandShape = bnOperandType.getShape();
auto bnResultType = op.getResult().getType();
uint64_t bnFeatureIndex = op.getFeatureIndex();

auto convOp = op.getOperand().getDefiningOp<ConvolutionOp>();
if (!convOp) return failure();

auto convWeight = convOp.getRhs();
auto convWeightType = convWeight.getType();
auto convWeightShape = convWeightType.getShape();

auto dimNumbers = convOp.getDimensionNumbers();
if (dimNumbers.getInputBatchDimension() != 0 ||
dimNumbers.getInputFeatureDimension() != 1 ||
dimNumbers.getOutputBatchDimension() != 0 ||
dimNumbers.getOutputFeatureDimension() != 1 ||
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
dimNumbers.getKernelInputFeatureDimension() != 1)
return rewriter.notifyMatchFailure(convOp,
"Only [b, f, ...]x[o, i, ...]->[b, f, "
"...] configuration is supported");

if (convOp.getFeatureGroupCount() > 1 || convOp.getBatchGroupCount() > 1)
return rewriter.notifyMatchFailure(
convOp, "feature or batch grouping is not supported");

if (bnOperandShape[bnFeatureIndex] != convWeightShape.front())
return failure();

DenseFPElementsAttr convWeightElems;
DenseFPElementsAttr scaleElems;
DenseFPElementsAttr offsetElems;
DenseFPElementsAttr meanElems;
DenseFPElementsAttr varianceElems;

auto epsilon = op.getEpsilon();

if (!matchPattern(convWeight, m_Constant(&convWeightElems)))
return rewriter.notifyMatchFailure(
op, "expected constant convolution weight");

if (!matchPattern(op.getScale(), m_Constant(&scaleElems)) ||
!matchPattern(op.getOffset(), m_Constant(&offsetElems)) ||
!matchPattern(op.getMean(), m_Constant(&meanElems)) ||
!matchPattern(op.getVariance(), m_Constant(&varianceElems)))
return failure();

const auto &convWeightSemantics =
cast<FloatType>(convWeightType.getElementType()).getFloatSemantics();

// W(new) = W(old) * gamma * rsqrt(variance + epsilon)
// B(new) = (B(old) - mean) * rsqrt(variance + epsilon) * gamma + betta
// where: gamma - scaling factor
// betta - shifting factor
// rsqrt - reciprocal square root function
// W - weight
// B - bias
//
const SmallVector<double> multipliers = llvm::map_to_vector(
llvm::zip_equal(varianceElems, scaleElems),
[&epsilon](const std::tuple<APFloat, APFloat> &pack) -> double {
const auto &[variance, scale] = pack;
auto varEps = (variance + epsilon).convertToDouble();
auto rsqrt = 1.0 / std::sqrt(varEps);
return rsqrt * scale.convertToDouble();
});

SmallVector<APFloat> newWeight;
newWeight.reserve(convWeightType.getNumElements());

const size_t outFeatureTileSize =
computeProduct(convWeightShape.drop_front());
auto it = convWeightElems.begin();
for (const auto &multiplier : multipliers) {
for (size_t i = 0; i < outFeatureTileSize; ++i) {
double v = (*it).convertToDouble() * multiplier;
APFloat result(v);
bool losesInfo;
if (APFloat::opStatus::opInvalidOp ==
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
&losesInfo))
return failure();
newWeight.push_back(result);
++it;
}
}

SmallVector<APFloat> biasValues;
biasValues.reserve(multipliers.size());

for (const auto &[off, multiplier, mean] :
llvm::zip_equal(offsetElems, multipliers, meanElems)) {
// stablehlo convolution operation doesn't have a builtin bias
double convBias = 0;
double v = (convBias - mean.convertToDouble()) * multiplier +
off.convertToDouble();
APFloat result(v);

bool losesInfo;
if (APFloat::opStatus::opInvalidOp ==
result.convert(convWeightSemantics, APFloat::rmNearestTiesToEven,
&losesInfo))
return failure();

biasValues.push_back(result);
}

rewriter.setInsertionPoint(op);
auto newConvWeight = rewriter.create<ConstantOp>(
convWeight.getLoc(), convWeightType,
DenseFPElementsAttr::get(convWeightType, newWeight));

// Keep old convolution as it might have other users
auto newConvOp = rewriter.create<ConvolutionOp>(
convOp.getLoc(), convOp->getResultTypes(),
ValueRange{convOp.getLhs(), newConvWeight}, convOp->getAttrs());

SmallVector<int64_t> biasShape{static_cast<int64_t>(biasValues.size())};
auto biasType =
convWeightType.cloneWith(biasShape, convWeightType.getElementType());
auto bias = rewriter.create<ConstantOp>(
op.getLoc(), biasType, DenseFPElementsAttr::get(biasType, biasValues));

auto indices =
rewriter.getDenseI64ArrayAttr({static_cast<int64_t>(bnFeatureIndex)});
auto bcast = rewriter.create<BroadcastInDimOp>(op.getLoc(), bnResultType,
bias, indices);
auto add = rewriter.create<AddOp>(op.getLoc(), newConvOp, bcast);

rewriter.replaceOp(op, add);
return success();
}
};

struct StablehloAggressiveSimplificationPass final
: impl::StablehloAggressiveSimplificationPassBase<
StablehloAggressiveSimplificationPass> {
Expand Down Expand Up @@ -1513,6 +1662,8 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
patterns
->add<GetDimensionSizeOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);

patterns->add<FuseConvolutionBatchNormalization>(context);
}

} // namespace stablehlo
Expand Down

0 comments on commit da6cff5

Please sign in to comment.