Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/maxpool2d #861

Merged
merged 7 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions xformer/IR/XCoreOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,24 @@ def XC_Beta_FcF32Op : XC_Op<"beta_fcf32", [Pure]> {
let results = (outs TensorOf<[F32]> : $output);
}

def XC_MaxPool2DOp : XC_Op<"maxpool2d", [Pure]> {
let summary = "MaxPool2D op";

let description = [{MaxPool2D op.}];

let arguments = (ins
TensorOf<[QI8]>:$input,
StrAttr:$memcpy_fn_param,
StrAttr:$aggregate_fn_param,
StrAttr:$output_transform_fn_param,
I32Attr:$scratch_bytes,
I32Attr:$thread_count,
StrArrayAttr:$abstract_kernel_params
);

let results = (outs TensorOf<[QI8]> : $output);
}

def XC_Conv2DV2Op : XC_Op<"conv2d_v2", [Pure]> {
let summary = "Conv2D V2 op";

Expand Down
1 change: 1 addition & 0 deletions xformer/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void buildXCorePassPipeline(OpPassManager &pm) {

// XC passes
pm.addPass(createReplaceAddPass());
pm.addPass(createReplaceMaxPoolPass());
pm.addPass(createReplaceMulPass());
pm.addPass(createReplaceStridedSlicePass());
pm.addPass(createReplaceConv2DPass());
Expand Down
1 change: 1 addition & 0 deletions xformer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createOpSplitPass();
std::unique_ptr<OperationPass<func::FuncOp>> createApplyTFLPatternsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceAddPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceMulPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceMaxPoolPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceStridedSlicePass();
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceConv2DPass();
std::unique_ptr<OperationPass<func::FuncOp>> createApplyXCPatternsPass();
Expand Down
3 changes: 1 addition & 2 deletions xformer/Transforms/ReplaceConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@ ReplaceWithXCConv2DBase<ConcreteType, ConvOpType, ArgsType>::matchAndRewrite(
std::vector<int16_t> mulsBiasesOrThresholdsData;

// Obtain thread count from command-line option
const int threadCount = threadCountOption;
llvm::SmallVector<std::string> strParams;
int scratchBytes = 0;
// Get image region splits for multiple threads
args.imageRegionSplits = utils::getImageRegionThreadSplits(
threadCount, args.Y.height, args.Y.width);
threadCountOption, args.Y.height, args.Y.width);

// Obtain serialized params and calculated tensors from lib_nn for the
// conv2d kernel type
Expand Down
113 changes: 113 additions & 0 deletions xformer/Transforms/ReplaceMaxPool2D.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include "IR/XCoreOps.h"
#include "Transforms/Options.h"

#include "Utils/ThreadSupport.h"
#include "lib_nn/api/AbstractKernel.hpp"
#include "lib_nn/api/AggregateFn.hpp"
#include "lib_nn/api/MemCpyFn.hpp"
#include "lib_nn/api/OutputTransformFn.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"

namespace mlir {
namespace xcore {

namespace {
struct ReplaceMaxPool2D
: public PassWrapper<ReplaceMaxPool2D, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceMaxPool2D)
void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<TFL::TensorFlowLiteDialect>();
}
StringRef getArgument() const final { return "xcore-replace-maxpool2d"; }
StringRef getDescription() const final {
return "Replace TFL MaxPool2D with MaxPool2D for XCore.";
}
void runOnOperation() override;
};

struct ReplaceMaxPool2DPattern : public OpRewritePattern<TFL::MaxPool2DOp> {
using OpRewritePattern<TFL::MaxPool2DOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TFL::MaxPool2DOp mPoolOp,
PatternRewriter &rewriter) const override {
auto inputType =
mPoolOp.getInput().getType().template dyn_cast<RankedTensorType>();
auto outputType =
mPoolOp.getOutput().getType().template dyn_cast<RankedTensorType>();
auto inputHeight = inputType.getDimSize(1);
auto inputWidth = inputType.getDimSize(2);
auto inputDepth = inputType.getDimSize(3);
auto outputHeight = outputType.getDimSize(1);
auto outputWidth = outputType.getDimSize(2);
auto outputDepth = outputType.getDimSize(3);
auto splits = utils::getImageRegionThreadSplits(threadCountOption,
outputHeight, outputWidth);

auto actualThreadCount = splits.size();
// Create a string array attr from a vector of strings
auto getStringArrayAttr = [&](llvm::SmallVector<std::string> value) {
auto attrs = llvm::to_vector<8>(
llvm::map_range(value, [&](std::string v) -> Attribute {
return rewriter.getStringAttr(v);
}));
return rewriter.getArrayAttr(attrs);
};
int32_t scratchByteParam =
nn::MatMulInt8::get_scratch_mem_bytes(mPoolOp.getFilterWidth() *
mPoolOp.getFilterHeight()) +
32; //[asj] FIXME
nn::ImageGeometry X(inputHeight, inputWidth, inputDepth);
nn::ImageGeometry Y(outputHeight, outputWidth, outputDepth);
llvm::SmallVector<std::string> akp;
for (auto &region : splits) {
nn::ImageRegion ir(region[0], region[1], 0, region[2], region[3],
outputDepth);
nn::AbstractKernel ak(Y, ir, VPU_INT8_ACC_PERIOD);
auto akParams = ak.getParams();
auto akpStr = std::string((char *)&akParams, sizeof(akParams));
akp.push_back(akpStr);
}
nn::ImageRegion ir(0, 0, 0, outputHeight, outputWidth, outputDepth);
nn::WindowGeometry window(
mPoolOp.getFilterHeight(), mPoolOp.getFilterWidth(), 1, 0, 0,
mPoolOp.getStrideH(), mPoolOp.getStrideW(), 1, 1, 1);
nn::DerefInputFn mf(X, window);
nn::MatMulDirectFn_DW af(X, window);
// TODO
nn::OT_int8_channelwise ot(outputDepth, 0);
auto mfParams = mf.getParams();
auto afParams = af.getParams();
auto otParams = ot.getParams();
auto mfStr = std::string((char *)&mfParams, sizeof(mfParams));
auto afStr = std::string((char *)&afParams, sizeof(afParams));
auto otStr = std::string((char *)&otParams, sizeof(otParams));

auto xcMaxPool2DOp = rewriter.create<MaxPool2DOp>(
mPoolOp.getLoc(), mPoolOp.getType(), mPoolOp.getInput(),
rewriter.getStringAttr(mfStr), rewriter.getStringAttr(afStr),
rewriter.getStringAttr(otStr),
rewriter.getI32IntegerAttr(scratchByteParam),
rewriter.getI32IntegerAttr(actualThreadCount), getStringArrayAttr(akp));
rewriter.replaceOp(mPoolOp, xcMaxPool2DOp.getOutput());
return success();
}
};

void ReplaceMaxPool2D::runOnOperation() {
auto *ctx = &getContext();
func::FuncOp func = getOperation();
RewritePatternSet patterns(ctx);
patterns.insert<ReplaceMaxPool2DPattern>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createReplaceMaxPoolPass() {
return std::make_unique<ReplaceMaxPool2D>();
}

static PassRegistration<ReplaceMaxPool2D> pass;

} // namespace xcore
} // namespace mlir
26 changes: 26 additions & 0 deletions xformer/Transforms/TranslateToCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,31 @@ std::vector<uint8_t> Conv2DV2Op::buildCustomOptions() {
return fbb.GetBuffer();
}

std::vector<uint8_t> MaxPool2DOp::buildCustomOptions() {
// TODO: Is the alignement messed up?
flexbuffers::Builder fbb;
auto rootMap = fbb.StartMap();
fbb.String("mp", getMemcpyFnParam().str());
fbb.String("a", getAggregateFnParam().str());
fbb.String("o", getOutputTransformFnParam().str());
int threadCount = (int)getThreadCount();
auto akpVec = fbb.StartVector("p");
for (int i = 0; i < threadCount; ++i) {
fbb.String(getAbstractKernelParams()
.cast<ArrayAttr>()[i]
.cast<StringAttr>()
.getValue()
.str() +
"00");
}
fbb.EndVector(akpVec, false, false);
fbb.Int("s", (int32_t)getScratchBytes());

fbb.EndMap(rootMap);
fbb.Finish();
return fbb.GetBuffer();
}

namespace {
/// This pass translates XCore ops to TFLite custom ops.
struct TranslateToCustomOp
Expand Down Expand Up @@ -172,6 +197,7 @@ void TranslateToCustomOp::runOnOperation() {
patterns.insert<RewriteToCustomOp<AddOp>>(ctx);
patterns.insert<RewriteToCustomOp<Bsign8Op>>(ctx);
patterns.insert<RewriteToCustomOp<Conv2DV2Op>>(ctx);
patterns.insert<RewriteToCustomOp<MaxPool2DOp>>(ctx);
patterns.insert<RewriteToCustomOp<LoadFlashOp>>(ctx);
patterns.insert<RewriteToCustomOp<LookupOp>>(ctx);
patterns.insert<RewriteToCustomOp<MulOp>>(ctx);
Expand Down
2 changes: 1 addition & 1 deletion xformer/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths")
############################### Compile Commands ###############################
# Hedron's Compile Commands Extractor for Bazel, used to get clangd to work
# Replace commit hash with latest version, later setup automatic update tool maybe?
BCCE_HASH = "e16062717d9b098c3c2ac95717d2b3e661c50608"
BCCE_HASH = "eca42c63700fccdc49cf58177e0a96f0f6075a68"
http_archive(
name = "hedron_compile_commands",
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/{hash}.tar.gz".format(hash = BCCE_HASH),
Expand Down
1 change: 1 addition & 0 deletions xformer/lib_tflite_micro.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ filegroup(
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_custom_options.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_bsign.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_conv2d_v2.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_maxpool2d.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_detection_post.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_load_from_flash.cc",
"lib_tflite_micro/src/tflite-xcore-kernels/xcore_lookup.cc",
Expand Down
72 changes: 0 additions & 72 deletions xformer/toolchain/BUILD

This file was deleted.

Loading
Loading