Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
uenoku committed Oct 25, 2024
1 parent 1d15913 commit 4f70efe
Show file tree
Hide file tree
Showing 25 changed files with 162 additions and 195 deletions.
4 changes: 2 additions & 2 deletions include/circt/Conversion/CombToAIG.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- CombToAIG.h - Comb to AIG dialect conversion ---------*- C++ -*-===//
//===- CombToAIG.h - Comb to AIG dialect conversion --------------- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -22,4 +22,4 @@ class HWModuleOp;

} // namespace circt

#endif // CIRCT_CONVERSION_COMBTOARITH_H
#endif // CIRCT_CONVERSION_COMBTOAIG_H
5 changes: 1 addition & 4 deletions include/circt/Dialect/AIG/AIG.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AIG.td - AIG Definitions ----------*- tablegen -*-===//
//===- AIG.td - AIG Definitions ----------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -10,14 +10,11 @@
#define CIRCT_AIG_DIALECT_TD

include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"

def AIG_Dialect : Dialect {
let name = "aig";
let cppNamespace = "::circt::aig";
let summary = "Representation of AIGs";

let usePropertiesForAttributes = 0;
}

include "circt/Dialect/AIG/AIGOps.td"
Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/AIG/AIGDialect.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AIGDialect.h - AIG Definitions --------------------------*- C++ -*-===//
//===- AIGDialect.h - AIG Definitions ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/AIG/AIGOps.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AIGOps.h - AIG Op Definitions ---------------------------*- C++ -*-===//
//===- AIGOps.h - AIG Op Definitions ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
17 changes: 10 additions & 7 deletions include/circt/Dialect/AIG/AIGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def AndInverterOp : AIGOp<"and_inv", [SameOperandsAndResultType, Pure]> {
let summary = "AIG dialect AND operation";
let description = [{
The `aig.and_inv` operation represents an And-Inverter in the AIG dialect.
Unlike comb.and, operands can be inverted respectively.
Unlike `comb.and`, operands can be inverted respectively.

Example:
```mlir
Expand Down Expand Up @@ -81,27 +81,30 @@ def CutOp : AIGOp<"cut", [IsolatedFromAbove, SingleBlock]> {
^bb0(%arg0: i1, %arg1: i1, %arg2: i1, %arg3: i1):
%0 = aig.and_inv not %arg0, %arg1 : i1
%1 = aig.and_inv %arg1, %arg3 : i1
aig.output %0, %1 : i1
}
aig.output %0, %1 : i1 }
```

}];

// TODO: Restrict to HWIntegerType.
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$body);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat = [{
$inputs attr-dict `:` functional-type($inputs, $results) $body
$inputs attr-dict `:` functional-type($inputs, $results) $bodyRegion
}];

let builders = [
OpBuilder<(ins
CArg<"TypeRange", "{}">:$resultTypes,
CArg<"ValueRange", "{}">:$inputs,
CArg<"std::function<void()>", "{}">:$ctor)>
CArg<"std::function<void(mlir::Block::BlockArgListType)>", "{}">:$ctor)>
];

let hasVerifier = 1;

let extraClassDeclaration = [{
Block *getBodyBlock() { return &getBody().front(); }
Block *getBodyBlock() { return getBody(); }
}];
}

Expand Down
6 changes: 0 additions & 6 deletions include/circt/Dialect/AIG/AIGPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ namespace aig {
#define GEN_PASS_DECL
#include "circt/Dialect/AIG/AIGPasses.h.inc"

std::unique_ptr<mlir::Pass> createLowerCutToLUTPass();
std::unique_ptr<mlir::Pass> createLowerVariadicPass();
std::unique_ptr<mlir::Pass> createLowerWordToBitsPass();
std::unique_ptr<mlir::Pass>
createGreedyCutDecompPass(const GreedyCutDecompOptions &options = {});

#define GEN_PASS_REGISTRATION
#include "circt/Dialect/AIG/AIGPasses.h.inc"

Expand Down
12 changes: 4 additions & 8 deletions include/circt/Dialect/AIG/AIGPasses.td
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
//===- ArcPasses.td - Arc dialect passes -------------------*- tablegen -*-===//
//===- AIGPasses.td - AIG dialect passes -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_DIALECT_ARC_ARCPASSES_TD
#define CIRCT_DIALECT_ARC_ARCPASSES_TD
#ifndef CIRCT_DIALECT_AIG_AIGPASSES_TD
#define CIRCT_DIALECT_AIG_AIGPASSES_TD

include "mlir/Pass/PassBase.td"

def LowerCutToLUT : Pass<"aig-lower-cut-to-lut", "hw::HWModuleOp"> {
let summary = "Lower a cut to a LUT";
let dependentDialects = ["comb::CombDialect"];
let constructor = "circt::aig::createLowerCutToLUTPass()";
}

def LowerVariadic : Pass<"aig-lower-variadic", "hw::HWModuleOp"> {
let summary = "Lower variadic AndInverter operations to binary AndInverter";
let constructor = "circt::aig::createLowerVariadicPass()";
}

def LowerWordToBits : Pass<"aig-lower-word-to-bits", "hw::HWModuleOp"> {
let summary = "Lower multi-bit AIG operations to single-bit ones";
let dependentDialects = ["comb::CombDialect"];
let constructor = "circt::aig::createLowerWordToBitsPass()";
}

def GreedyCutDecomp : Pass<"aig-greedy-cut-decomp", "hw::HWModuleOp"> {
let summary = "Decompose AIGs into k-feasible Cuts using a greedy algorithm";
let dependentDialects = ["comb::CombDialect"];
let constructor = "circt::aig::createGreedyCutDecompPass()";
let options = [
Option<"cutSizes", "cut-sizes", "uint32_t", "6",
"The sizes of the cuts to decompose">,
];
}

#endif // CIRCT_DIALECT_ARC_ARCPASSES_TD
#endif // CIRCT_DIALECT_AIG_AIGPASSES_TD
3 changes: 3 additions & 0 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ struct CombXorOpConversion : OpConversionPattern<XorOp> {
LogicalResult
matchAndRewrite(XorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getNumOperands() != 2)
return failure();
// Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)

// (a | b) = ~(~a & ~b)
// (~a | ~b) = ~(a & b)
auto inputs = adaptor.getInputs();
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/AIG/AIGDialect.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AIGDialect.cpp - Implement the AIG dialect -----------------------===//
//===- AIGDialect.cpp - Implement the AIG dialect -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -12,7 +12,6 @@

#include "circt/Dialect/AIG/AIGDialect.h"
#include "circt/Dialect/AIG/AIGOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down
66 changes: 28 additions & 38 deletions lib/Dialect/AIG/AIGOps.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- LoopScheduleOps.cpp - LoopSchedule CIRCT Operations ------*- C++ -*-===//
//===- AIGOps.cpp - AIG Dialect Operations ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -101,11 +101,7 @@ mlir::ParseResult AndInverterOp::parse(mlir::OpAsmParser &parser,
auto loc = parser.getCurrentLocation();

while (true) {
if (succeeded(parser.parseOptionalKeyword("not"))) {
inverts.push_back(true);
} else {
inverts.push_back(false);
}
inverts.push_back(succeeded(parser.parseOptionalKeyword("not")));
operands.push_back(OpAsmParser::UnresolvedOperand());

if (parser.parseOperand(operands.back()))
Expand All @@ -114,28 +110,17 @@ mlir::ParseResult AndInverterOp::parse(mlir::OpAsmParser &parser,
break;
}

if (parser.parseOptionalAttrDict(result.attributes))
return mlir::failure();

if (parser.parseColon())
return mlir::failure();

mlir::Type resultRawType{};
llvm::ArrayRef<mlir::Type> resultTypes(&resultRawType, 1);

{
mlir::Type type;
if (parser.parseCustomTypeWithFallback(type))
return mlir::failure();
resultRawType = type;
}
mlir::Type type;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseCustomTypeWithFallback(type))
return failure();

result.addTypes(resultTypes);
result.addTypes({type});
result.addAttribute("inverted",
parser.getBuilder().getDenseBoolArrayAttr(inverts));
if (parser.resolveOperands(operands, resultTypes[0], loc, result.operands))
return mlir::failure();
return mlir::success();
if (parser.resolveOperands(operands, type, loc, result.operands))
return failure();
return success();
}

void AndInverterOp::print(mlir::OpAsmPrinter &odsPrinter) {
Expand All @@ -148,18 +133,9 @@ void AndInverterOp::print(mlir::OpAsmPrinter &odsPrinter) {
}
odsPrinter << input;
});
llvm::SmallVector<llvm::StringRef, 2> elidedAttrs;
elidedAttrs.push_back("inverted");
llvm::SmallVector<llvm::StringRef, 2> elidedAttrs{"inverted"};
odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
odsPrinter << ' ' << ":";
odsPrinter << ' ';
{
auto type = getResult().getType();
if (auto validType = llvm::dyn_cast<mlir::Type>(type))
odsPrinter.printStrippedAttrOrType(validType);
else
odsPrinter << type;
}
odsPrinter << " : " << getResult().getType();
}

APInt AndInverterOp::evaluate(ArrayRef<APInt> inputs) {
Expand All @@ -178,7 +154,7 @@ APInt AndInverterOp::evaluate(ArrayRef<APInt> inputs) {

void CutOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, ValueRange inputs,
std::function<void()> ctor) {
std::function<void(mlir::Block::BlockArgListType)> ctor) {
OpBuilder::InsertionGuard guard(builder);

auto *block = builder.createBlock(result.addRegion());
Expand All @@ -188,5 +164,19 @@ void CutOp::build(OpBuilder &builder, OperationState &result,
block->addArgument(input.getType(), input.getLoc());

if (ctor)
ctor();
ctor(block->getArguments());
}

LogicalResult CutOp::verify() {
if (getInputs().size() != getBodyBlock()->getNumArguments())
return emitOpError("the number of inputs and the number of block arguments "
"do not match. Expected ")
<< getInputs().size() << " but got "
<< getBodyBlock()->getNumArguments();
if (getNumResults() != getBodyBlock()->getTerminator()->getNumOperands())
return emitOpError("the number of results and the number of terminator "
"operands do not match. Expected ")
<< getNumResults() << " but got "
<< getBodyBlock()->getTerminator()->getNumOperands();
return success();
}
3 changes: 1 addition & 2 deletions lib/Dialect/AIG/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ add_circt_dialect_library(CIRCTAIG
CIRCTHW

DEPENDS
CIRCTHW
MLIRAIGIncGen
)

add_subdirectory(Transforms)
add_subdirectory(Transforms)
41 changes: 14 additions & 27 deletions lib/Dialect/AIG/Transforms/GreedyCutDecomp.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- GreedyCutDecomp.cpp ---------------------------------------------===//
//===- GreedyCutDecomp.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This pass performs cut decomposition on AIGs based on a naive greedy
// algorithm. We first convert all `aig.and_inv` to `aig.cut` that has a single
// algorithm. We first convert all `aig.and_inv` to `aig.cut` that have a single
// operation and then try to merge cut operations on inputs.
//
//===----------------------------------------------------------------------===//
Expand All @@ -19,6 +19,7 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "aig-greedy-cut-decomp"

namespace circt {
Expand All @@ -42,14 +43,15 @@ struct AndInverterOpToCutPattern : public OpRewritePattern<aig::AndInverterOp> {

LogicalResult matchAndRewrite(aig::AndInverterOp op,
PatternRewriter &rewriter) const override {
if (isa<aig::CutOp>(op->getParentOp()))
if (op->getParentOfType<aig::CutOp>())
return failure();

auto cutOp = rewriter.create<aig::CutOp>(
op.getLoc(), op.getResult().getType(), op.getInputs(), [&]() {
op.getLoc(), op.getResult().getType(), op.getInputs(),
[&](Block::BlockArgListType args) {
auto result = rewriter.create<aig::AndInverterOp>(
op.getLoc(), op.getResult().getType(),
rewriter.getBlock()->getArguments(), op.getInvertedAttr());
op.getLoc(), op.getResult().getType(), args,
op.getInvertedAttr());
rewriter.create<aig::OutputOp>(op.getLoc(), ValueRange{result});
});

Expand All @@ -67,10 +69,10 @@ static aig::CutOp mergeCuts(Location loc, MutableArrayRef<Operation *> cuts,
assert(cuts.size() >= 2);

DenseMap<Value, Value> valueToNewValue, inputsToBlockArg;
auto cutOp =
rewriter.create<aig::CutOp>(loc, output.getType(), inputs, [&]() {
auto cutOp = rewriter.create<aig::CutOp>(
loc, output.getType(), inputs, [&](Block::BlockArgListType args) {
for (auto [i, input] : llvm::enumerate(inputs))
inputsToBlockArg[input] = rewriter.getBlock()->getArgument(i);
inputsToBlockArg[input] = args[i];

for (auto [i, cut] : llvm::enumerate(cuts)) {
auto cutOp = cast<aig::CutOp>(cut);
Expand Down Expand Up @@ -103,7 +105,7 @@ static aig::CutOp mergeCuts(Location loc, MutableArrayRef<Operation *> cuts,
for (auto oldCut : llvm::reverse(cuts)) {
auto *oldCutBlock = cast<aig::CutOp>(oldCut).getBodyBlock();
auto oldCutOutput = oldCutBlock->getTerminator();
oldCutOutput->erase();
rewriter.eraseOp(oldCutOutput);
// Erase arguments before inlining. Arguments are already replaced.
oldCutBlock->eraseArguments([](BlockArgument block) { return true; });
rewriter.inlineBlockBefore(oldCutBlock, cutOp.getBodyBlock(),
Expand Down Expand Up @@ -241,18 +243,8 @@ struct SinkConstantPattern : public mlir::OpRewritePattern<aig::CutOp> {
block->eraseArguments(eraseArgs);
}

auto newCut = rewriter.create<aig::CutOp>(op.getLoc(), op.getResultTypes(),
oldInputs, [&]() {});

for (auto [newArg, oldArg] :
llvm::zip(newCut.getBodyBlock()->getArguments(), oldArgs))
rewriter.replaceAllUsesWith(oldArg, newArg);

// Erase arguments before inlining. Arguments are already replaced.
block->eraseArguments([](BlockArgument arg) { return true; });
rewriter.inlineBlockBefore(block, newCut.getBodyBlock(),
newCut.getBodyBlock()->begin());
rewriter.replaceOp(op, newCut);
rewriter.modifyOpInPlace(
op, [&]() { op.getInputsMutable().assign(oldInputs); });
return success();
}
};
Expand Down Expand Up @@ -285,8 +277,3 @@ void GreedyCutDecompPass::runOnOperation() {
mlir::applyPatternsAndFoldGreedily(getOperation(), frozen, config)))
return signalPassFailure();
}

std::unique_ptr<mlir::Pass>
aig::createGreedyCutDecompPass(const GreedyCutDecompOptions &options) {
return std::make_unique<GreedyCutDecompPass>(options);
}
Loading

0 comments on commit 4f70efe

Please sign in to comment.