Skip to content

Commit

Permalink
broadcast scalar fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelpoluektov committed May 17, 2024
1 parent 7ee609b commit dadedcb
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
func.func @main(%arg0: tensor<!quant.uniform<i8:f32, 6.5359479049220681E-4:-128>>) -> tensor<1x1x13x64x!quant.uniform<i8:f32, 6.5359479049220681E-4:-128>> attributes {tf.entry_function = {inputs = "arg0", outputs = "0"}} {
%cst = arith.constant dense<[1, 1, 13, 64]> : tensor<4xi32>
%0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<!quant.uniform<i8:f32, 6.5359479049220681E-4:-128>>, tensor<4xi32>) -> tensor<1x1x13x64x!quant.uniform<i8:f32, 6.5359479049220681E-4:-128>>
return %0 : tensor<1x1x13x64x!quant.uniform<i8:f32, 6.5359479049220681E-4:-128>>
}
Binary file not shown.
3 changes: 2 additions & 1 deletion xformer/Transforms/ReplaceBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct ReplaceBroadcastPattern : public OpRewritePattern<TFL::BroadcastToOp> {
auto inputType = broadcastOp.getInput().getType().cast<RankedTensorType>();
auto outputType =
broadcastOp.getOutput().getType().cast<RankedTensorType>();

llvm::outs() << "inputType: " << inputType << "\n";
llvm::outs() << "outputType: " << outputType << "\n";
if (!inputType.hasStaticShape())
return failure();

Expand Down
5 changes: 5 additions & 0 deletions xformer/Utils/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// XMOS Public License: Version 1

#include "Utils/Util.h"
#include <iostream>

#include "mlir/Dialect/Quant/QuantTypes.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -87,6 +88,10 @@ ArrayRef<int64_t> getValShape(Value tensor) {

bool checkSliceNoOp(RankedTensorType inputType, RankedTensorType outputType) {
const int rank = inputType.getRank();
if (rank != outputType.getRank()) {
return false;
}
std::cout << "Rank: " << rank << std::endl;
bool isNoOp = true;
for (int i = 0; i < rank; i++) {
if (inputType.getDimSize(i) != outputType.getDimSize(i)) {
Expand Down

0 comments on commit dadedcb

Please sign in to comment.