diff --git a/xformer/Transforms/WriteFlashImage.cpp b/xformer/Transforms/WriteFlashImage.cpp deleted file mode 100644 index 6c87b66f8..000000000 --- a/xformer/Transforms/WriteFlashImage.cpp +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2021 XMOS LIMITED. This Software is subject to the terms of the -// XMOS Public License: Version 1 - -#include "IR/XCoreOps.h" -#include "Transforms/Options.h" -#include "Utils/FileIO.h" -#include "Utils/TileRamSupport.h" - -#include "mlir/Pass/Pass.h" -#include "mlir/Support/FileUtilities.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "llvm/Support/ToolOutputFile.h" - -namespace mlir { -namespace xcore { - -namespace { -// Write flash image -struct WriteWeights - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WriteWeights) - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - } - StringRef getArgument() const final { return "xcore-write-flash-image"; } - StringRef getDescription() const final { return "Write flash image"; } - void runOnOperation() override; -}; - -struct WriteWeightsPattern : public OpRewritePattern { - WriteWeightsPattern(std::vector> *tensorsVec, - MLIRContext *context) - : OpRewritePattern(context), tensorsVec_(tensorsVec) {} - - std::vector getTensorData(LoadConstantOp loadOp) const { - DenseElementsAttr attr; - if (loadOp.getInput() - .getType() - .cast() - .getElementType() - .isa()) { - auto qConstOp = - dyn_cast(loadOp.getInput().getDefiningOp()); - attr = qConstOp.getValue().template cast(); - } else { - matchPattern(loadOp.getInput(), m_Constant(&attr)); - } - - std::vector tensorData; - int n = attr.isSplat() ? attr.getNumElements() : 1; - for (int i = 0; i < n; ++i) { - tensorData.insert(tensorData.end(), attr.getRawData().begin(), - attr.getRawData().end()); - } - return tensorData; - } - - LogicalResult matchAndRewrite(LoadConstantOp loadOp, - PatternRewriter &rewriter) const override { - std::vector tensorData; - SmallVector dataSizes; - - int address = 0; - for (auto const &t : *tensorsVec_) { - address += t.size(); - } - - if (loadOp.getResult().hasOneUse()) { - auto use = loadOp->use_begin(); - Operation *ownerOp = use->getOwner(); - - SmallVector outputTypes; - SmallVector opNums; - - for (int i = 0; i < ownerOp->getNumOperands(); i++) { - auto loadOpForOwnerOp = dyn_cast_or_null( - ownerOp->getOperand(i).getDefiningOp()); - - if (loadOpForOwnerOp) { - std::vector loadOpData = getTensorData(loadOpForOwnerOp); - dataSizes.push_back(rewriter.getI32IntegerAttr(loadOpData.size())); - tensorData.insert(tensorData.end(), loadOpData.begin(), - loadOpData.end()); - outputTypes.push_back(loadOpForOwnerOp.getType()); - opNums.push_back(i); - } - } - - auto loadFlashOp = - rewriter.create(loadOp.getLoc(), outputTypes, address, - rewriter.getArrayAttr(dataSizes)); - - for (int i = 0; i < opNums.size(); i++) { - ownerOp->setOperand(opNums[i], loadFlashOp.getResult(i)); - } - - loadFlashOp->moveBefore(ownerOp); - loadOp.erase(); - } else { - std::vector loadOpData = getTensorData(loadOp); - dataSizes.push_back(rewriter.getI32IntegerAttr(loadOpData.size())); - tensorData.insert(tensorData.end(), loadOpData.begin(), loadOpData.end()); - auto loadFlashOp = rewriter.create( - loadOp.getLoc(), loadOp.getType(), address, - rewriter.getArrayAttr(dataSizes)); - rewriter.replaceOp(loadOp, loadFlashOp.getOutput()); - - // Find all uses of loadFlashOp and find the first Owner op - // so that we can move the loading to just before that op. - mlir::Operation *firstOwnerOp = - loadFlashOp->getResult(0).getUses().begin()->getOwner(); - for (const mlir::OpOperand &use : loadFlashOp->getResult(0).getUses()) { - mlir::Operation *op = use.getOwner(); - if (op->isBeforeInBlock(firstOwnerOp)) { - firstOwnerOp = op; - } - } - loadFlashOp->moveBefore(firstOwnerOp); - } - - tensorsVec_->push_back(tensorData); - - return success(); - } - -private: - std::vector> *tensorsVec_; -}; - -void WriteWeights::runOnOperation() { - func::FuncOp f = getOperation(); - if (weightsFilenameOption.empty()) { - f.emitError("Flash image file option should be provided to run this pass!"); - signalPassFailure(); - return; - } - - auto *ctx = &getContext(); - func::FuncOp func = getOperation(); - // For each LoadOp in the graph, save the tensor data, and replace the LoadOp - // with a LoadFlashOp - std::vector> tensorsVec; - RewritePatternSet patterns(ctx); - patterns.insert(&tensorsVec, ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - - if (tileLoadOption) { - if (failed(utils::writeTileServerDataToFile(weightsFilenameOption, - tensorsVec))) { - f.emitError("Failed to write tile data!"); - signalPassFailure(); - return; - } - } - // Write tensor data to flash image file - else if (failed( - utils::writeWeightsToFile(weightsFilenameOption, tensorsVec))) { - f.emitError("Failed to write flash image!"); - signalPassFailure(); - return; - } -} -} // namespace - -// Creates an instance of the WriteWeights pass. -std::unique_ptr> createWriteWeightsPass() { - return std::make_unique(); -} - -static PassRegistration pass; - -} // namespace xcore -} // namespace mlir