Skip to content

Commit

Permalink
[Transform][AnyWidthInteger] Fix issues with float type arg (#189)
Browse files Browse the repository at this point in the history
* Support non-memref func args

* Add test case for skipping irrelevant args

* Fix format
  • Loading branch information
zzzDavid authored Jul 23, 2023
1 parent fc7f169 commit 2fdcd51
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
71 changes: 43 additions & 28 deletions lib/Transforms/AnyWidthInteger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) {
} else {
new_result_types.push_back(memrefType);
}
} else {
new_result_types.push_back(t);
}
}

Expand All @@ -69,6 +71,8 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) {
} else {
new_arg_types.push_back(memrefType);
}
} else {
new_arg_types.push_back(t);
}
}

Expand All @@ -89,25 +93,35 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) {
OpBuilder builder(funcOp->getRegion(0));
for (Block &block : funcOp.getBlocks()) {
for (unsigned i = 0; i < block.getNumArguments(); i++) {
Type argType = block.getArgument(i).getType();
if (MemRefType memrefType = argType.cast<MemRefType>()) {
for (unsigned i = 0; i < block.getNumArguments(); ++i) {
MemRefType memrefType =
block.getArgument(i).getType().dyn_cast<MemRefType>();
if (!memrefType) {
blockArgs.push_back(block.getArgument(i));
continue;
}

Type et = memrefType.getElementType();
if (et.isa<IntegerType>()) {
size_t width = 64;
Type newType = IntegerType::get(funcOp.getContext(), width);
Type newMemRefType = memrefType.clone(newType);
size_t oldWidth = et.cast<IntegerType>().getWidth();
block.getArgument(i).setType(newMemRefType);
bool is_unsigned = false;
if (i < itypes.length()) {
is_unsigned = itypes[i] == 'u';
}
Value newMemRef =
castIntMemRef(builder, funcOp->getLoc(), block.getArgument(i),
oldWidth, is_unsigned);
newMemRefs.push_back(newMemRef);
if (!et.isa<IntegerType>()) {
blockArgs.push_back(block.getArgument(i));
continue;
}

size_t width = 64;
Type newType = IntegerType::get(funcOp.getContext(), width);
Type newMemRefType = memrefType.clone(newType);
block.getArgument(i).setType(newMemRefType);

bool is_unsigned = false;
if (i < itypes.length()) {
is_unsigned = itypes[i] == 'u';
}

Value newMemRef =
castIntMemRef(builder, funcOp->getLoc(), block.getArgument(i),
et.cast<IntegerType>().getWidth(), is_unsigned);
newMemRefs.push_back(newMemRef);
blockArgs.push_back(block.getArgument(i));
}
}
}
Expand All @@ -124,19 +138,20 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) {
// Cast the return values
for (unsigned i = 0; i < op->getNumOperands(); i++) {
Value arg = op->getOperand(i);
MemRefType type = arg.getType().cast<MemRefType>();
Type etype = type.getElementType();
if (etype.isa<IntegerType>()) {
if (auto allocOp = dyn_cast<memref::AllocOp>(arg.getDefiningOp())) {
bool is_unsigned = false;
if (i < otypes.length()) {
is_unsigned = otypes[i] == 'u';
if (MemRefType type = arg.getType().dyn_cast<MemRefType>()) {
Type etype = type.getElementType();
if (etype.isa<IntegerType>()) {
if (auto allocOp = dyn_cast<memref::AllocOp>(arg.getDefiningOp())) {
bool is_unsigned = false;
if (i < otypes.length()) {
is_unsigned = otypes[i] == 'u';
}
Value newMemRef =
castIntMemRef(returnRewriter, op->getLoc(), allocOp.getResult(),
64, is_unsigned, false);
// Only replace the single use of oldMemRef: returnOp
op->setOperand(i, newMemRef);
}
Value newMemRef =
castIntMemRef(returnRewriter, op->getLoc(), allocOp.getResult(),
64, is_unsigned, false);
// Only replace the single use of oldMemRef: returnOp
op->setOperand(i, newMemRef);
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions test/Transforms/datatype/anywidth-skip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright HeteroCL authors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

// RUN: hcl-opt %s --lower-anywidth-integer
module {
func.func @kernel(%arg0: memref<4x4xf32>, %arg1: f32, %arg2: f32) -> memref<4x4xf32> attributes {"top"} {
return %arg0 : memref<4x4xf32>
}
}

0 comments on commit 2fdcd51

Please sign in to comment.