diff --git a/lib/Transforms/AnyWidthInteger.cpp b/lib/Transforms/AnyWidthInteger.cpp index 6b3a810b..c8a0323f 100644 --- a/lib/Transforms/AnyWidthInteger.cpp +++ b/lib/Transforms/AnyWidthInteger.cpp @@ -54,6 +54,8 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) { } else { new_result_types.push_back(memrefType); } + } else { + new_result_types.push_back(t); } } @@ -69,6 +71,8 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) { } else { new_arg_types.push_back(memrefType); } + } else { + new_arg_types.push_back(t); } } @@ -89,25 +93,35 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) { OpBuilder builder(funcOp->getRegion(0)); for (Block &block : funcOp.getBlocks()) { for (unsigned i = 0; i < block.getNumArguments(); i++) { - Type argType = block.getArgument(i).getType(); - if (MemRefType memrefType = argType.cast()) { + for (unsigned i = 0; i < block.getNumArguments(); ++i) { + MemRefType memrefType = + block.getArgument(i).getType().dyn_cast(); + if (!memrefType) { + blockArgs.push_back(block.getArgument(i)); + continue; + } + Type et = memrefType.getElementType(); - if (et.isa()) { - size_t width = 64; - Type newType = IntegerType::get(funcOp.getContext(), width); - Type newMemRefType = memrefType.clone(newType); - size_t oldWidth = et.cast().getWidth(); - block.getArgument(i).setType(newMemRefType); - bool is_unsigned = false; - if (i < itypes.length()) { - is_unsigned = itypes[i] == 'u'; - } - Value newMemRef = - castIntMemRef(builder, funcOp->getLoc(), block.getArgument(i), - oldWidth, is_unsigned); - newMemRefs.push_back(newMemRef); + if (!et.isa()) { blockArgs.push_back(block.getArgument(i)); + continue; } + + size_t width = 64; + Type newType = IntegerType::get(funcOp.getContext(), width); + Type newMemRefType = memrefType.clone(newType); + block.getArgument(i).setType(newMemRefType); + + bool is_unsigned = false; + if (i < itypes.length()) { + is_unsigned = itypes[i] == 'u'; + } + + Value newMemRef = + castIntMemRef(builder, funcOp->getLoc(), block.getArgument(i), + et.cast().getWidth(), is_unsigned); + newMemRefs.push_back(newMemRef); + blockArgs.push_back(block.getArgument(i)); } } } @@ -124,19 +138,20 @@ void updateTopFunctionSignature(func::FuncOp &funcOp) { // Cast the return values for (unsigned i = 0; i < op->getNumOperands(); i++) { Value arg = op->getOperand(i); - MemRefType type = arg.getType().cast(); - Type etype = type.getElementType(); - if (etype.isa()) { - if (auto allocOp = dyn_cast(arg.getDefiningOp())) { - bool is_unsigned = false; - if (i < otypes.length()) { - is_unsigned = otypes[i] == 'u'; + if (MemRefType type = arg.getType().dyn_cast()) { + Type etype = type.getElementType(); + if (etype.isa()) { + if (auto allocOp = dyn_cast(arg.getDefiningOp())) { + bool is_unsigned = false; + if (i < otypes.length()) { + is_unsigned = otypes[i] == 'u'; + } + Value newMemRef = + castIntMemRef(returnRewriter, op->getLoc(), allocOp.getResult(), + 64, is_unsigned, false); + // Only replace the single use of oldMemRef: returnOp + op->setOperand(i, newMemRef); } - Value newMemRef = - castIntMemRef(returnRewriter, op->getLoc(), allocOp.getResult(), - 64, is_unsigned, false); - // Only replace the single use of oldMemRef: returnOp - op->setOperand(i, newMemRef); } } } diff --git a/test/Transforms/datatype/anywidth-skip.mlir b/test/Transforms/datatype/anywidth-skip.mlir new file mode 100644 index 00000000..c74558e2 --- /dev/null +++ b/test/Transforms/datatype/anywidth-skip.mlir @@ -0,0 +1,9 @@ +// Copyright HeteroCL authors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// RUN: hcl-opt %s --lower-anywidth-integer +module { + func.func @kernel(%arg0: memref<4x4xf32>, %arg1: f32, %arg2: f32) -> memref<4x4xf32> attributes {"top"} { + return %arg0 : memref<4x4xf32> + } +} \ No newline at end of file