From 9e1624c23e5c46ab3326eaeab31ee5e1c1dbfa26 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Fri, 4 Oct 2024 16:07:32 +0200 Subject: [PATCH] [HWToSMT] ArrayCreateOp and ArrayGetOp support --- integration_test/circt-lec/hw.mlir | 24 +++++++++++ lib/Conversion/HWToSMT/HWToSMT.cpp | 67 +++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/integration_test/circt-lec/hw.mlir b/integration_test/circt-lec/hw.mlir index f264398b3040..a4c728cbbd3b 100644 --- a/integration_test/circt-lec/hw.mlir +++ b/integration_test/circt-lec/hw.mlir @@ -76,3 +76,27 @@ hw.module @onePlusTwoNonSSA(out out: i2) { // RUN: circt-lec %s -c1=onePlusTwo -c2=onePlusTwoNonSSA --shared-libs=%libz3 | FileCheck %s --check-prefix=HW_MODULE_GRAPH // HW_MODULE_GRAPH: c1 == c2 +// array_create + array_get test +// RUN: circt-lec %s -c1=MultibitMux -c2=MultibitMux2 --shared-libs=%libz3 | FileCheck %s --check-prefix=ARRAY_GET +// ARRAY_GET: c1 == c2 + +hw.module @MultibitMux(in %a_0 : i1, in %a_1 : i1, in %sel : i1, out b : i1) { + %0 = hw.array_create %a_1, %a_0 : i1 + %1 = hw.array_get %0[%sel] : !hw.array<2xi1>, i1 + hw.output %1 : i1 +} + +hw.module @MultibitMux2(in %a_0 : i1, in %a_1 : i1, in %sel : i1, out b : i1) { + %0 = comb.mux bin %sel, %a_1, %a_0 : i1 + hw.output %0 : i1 +} + +// array_get out-of-bounds must not be equivalent +// RUN: circt-lec %s -c1=ArrayOOB -c2=ArrayOOB --shared-libs=%libz3 | FileCheck %s --check-prefix=ARRAY_OOB +// ARRAY_OOB: c1 != c2 + +hw.module @ArrayOOB(in %a : !hw.array<3xi1>, out b : i1) { + %0 = hw.constant 3 : i2 + %1 = hw.array_get %a[%0] : !hw.array<3xi1>, i2 + hw.output %1 : i1 +} diff --git a/lib/Conversion/HWToSMT/HWToSMT.cpp b/lib/Conversion/HWToSMT/HWToSMT.cpp index 04f18493985b..121e55e01c05 100644 --- a/lib/Conversion/HWToSMT/HWToSMT.cpp +++ b/lib/Conversion/HWToSMT/HWToSMT.cpp @@ -93,6 +93,61 @@ struct InstanceOpConversion : OpConversionPattern { } }; +/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an +/// smt::ArrayStoreOp for each operand. +struct ArrayCreateOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type arrTy = typeConverter->convertType(op.getType()); + if (!arrTy) + return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type"); + + unsigned width = adaptor.getInputs().size(); + + Value arr = rewriter.create(loc, arrTy); + for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) { + Value idx = rewriter.create(loc, width - i - 1, + llvm::Log2_64_Ceil(width)); + arr = rewriter.create(loc, arr, idx, el); + } + + rewriter.replaceOp(op, arr); + return success(); + } +}; + +/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp +struct ArrayGetOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + unsigned numElements = + cast(op.getInput().getType()).getNumElements(); + + Type type = typeConverter->convertType(op.getType()); + if (!type) + return rewriter.notifyMatchFailure(op.getLoc(), + "unsupported array element type"); + + Value oobVal = rewriter.create(loc, type); + Value numElementsVal = rewriter.create( + loc, numElements - 1, llvm::Log2_64_Ceil(numElements)); + Value inBounds = rewriter.create( + loc, smt::BVCmpPredicate::ule, adaptor.getIndex(), numElementsVal); + Value indexed = rewriter.create(loc, adaptor.getInput(), + adaptor.getIndex()); + rewriter.replaceOpWithNewOp(op, inBounds, indexed, oobVal); + return success(); + } +}; + /// Remove redundant (seq::FromClock and seq::ToClock) ops. template struct ReplaceWithInput : OpConversionPattern { @@ -136,6 +191,14 @@ void circt::populateHWToSMTTypeConverter(TypeConverter &converter) { converter.addConversion([](seq::ClockType type) -> std::optional { return smt::BitVectorType::get(type.getContext(), 1); }); + converter.addConversion([&](ArrayType type) -> std::optional { + auto rangeType = converter.convertType(type.getElementType()); + if (!rangeType) + return {}; + auto domainType = smt::BitVectorType::get( + type.getContext(), llvm::Log2_64_Ceil(type.getNumElements())); + return smt::ArrayType::get(type.getContext(), domainType, rangeType); + }); // Default target materialization to convert from illegal types to legal // types, e.g., at the boundary of an inlined child block. @@ -219,8 +282,8 @@ void circt::populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns) { patterns.add, - ReplaceWithInput>(converter, - patterns.getContext()); + ReplaceWithInput, ArrayCreateOpConversion, + ArrayGetOpConversion>(converter, patterns.getContext()); } void ConvertHWToSMTPass::runOnOperation() {