Skip to content

Commit

Permalink
Support u3 gate
Browse files Browse the repository at this point in the history
  • Loading branch information
ibm-wakizaka committed Jul 31, 2023
1 parent 003e1af commit 3b3e861
Showing 1 changed file with 84 additions and 8 deletions.
92 changes: 84 additions & 8 deletions target_simulator/Conversion/QUIRToAer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,53 @@ namespace qssc::targets::simulator::conversion {
namespace {

std::map<std::string, LLVM::LLVMFuncOp> aerFuncTable;
std::map<int, Value> qubitTable;

// TODO: Take care of errors, nan and inf.
using AngleTable = std::map<double, Value>;
using QubitTable = std::map<int, Value>;

// TODO: We should map some IDs (attributes?) to values
class ValueTable {
public:
static void registerAngle(double angle, Value v) {
assert(!std::isnan(angle) && !std::isinf(angle));
angleTable[angle] = v;
}

static void registerQubit(int qid, Value v) {
qubitTable[qid] = v;
}

static Value lookupAngle(double angle) {
return angleTable.at(angle);
}

static Value lookupQubit(int qid) {
return qubitTable.at(qid);
}

static QubitTable getQubits() {
return qubitTable;
}

static void clear() {
angleTable.clear();
qubitTable.clear();
}

static void clearQubitTable() {
qubitTable.clear();
}

private:
static AngleTable angleTable;
static QubitTable qubitTable;
};
AngleTable ValueTable::angleTable;
QubitTable ValueTable::qubitTable;

void buildQubitTable(ModuleOp moduleOp, Value aerState) {
qubitTable.clear();
ValueTable::clearQubitTable();
OpBuilder builder(moduleOp);

moduleOp.walk([&](quir::DeclareQubitOp declOp){
Expand All @@ -71,8 +114,18 @@ void buildQubitTable(ModuleOp moduleOp, Value aerState) {
ValueRange{aerState, constOp});

const int id = *declOp.id();
qubitTable[id] = alloc.getResult(0);
ValueTable::registerQubit(id, alloc.getResult(0));
});

// TODO
// Assume that qubit declaration with the biggest id is called after any other declaration
auto qubitTable = ValueTable::getQubits();
if(!qubitTable.empty()) {
auto declOp = qubitTable.rbegin()->second.getDefiningOp();
builder.setInsertionPointAfter(declOp);
builder.create<LLVM::CallOp>(
declOp->getLoc(), aerFuncTable.at("aer_state_initialize"), aerState);
}
}

}
Expand All @@ -98,7 +151,7 @@ struct QCSInitConversionPat : public OpConversionPattern<qcs::SystemInitOp> {
const auto with_null = config_str + std::string("\0", 1);
globals[config_str] = LLVM::createGlobalString(
initOp->getLoc(), rewriter, var_name,
with_null, LLVM::Linkage::Internal);
with_null, LLVM::Linkage::Private);
}
// configure
// aer_state_configure(state, "method", "statevector")
Expand Down Expand Up @@ -166,7 +219,23 @@ struct BuiltinUopConversionPat : public OpConversionPattern<quir::Builtin_UOp> {
quir::Builtin_UOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override
{
//TODO
// qubit, angle{1, 2, 3}
assert(op.getOperands().size() == 4);

std::vector<Value> args;
args.emplace_back(aerState);
{ // qubit
const auto qID = quir::lookupQubitId(op.target());
assert(qID);
args.emplace_back(ValueTable::lookupQubit(*qID));
}
for(auto val : {op.theta(), op.phi(), op.lambda()}) {
auto angleOp = dyn_cast<quir::ConstantOp>(val.getDefiningOp());
const double angle = angleOp.value().dyn_cast<quir::AngleAttr>().getValue().convertToDouble();
args.emplace_back(ValueTable::lookupAngle(angle));
}
rewriter.create<LLVM::CallOp>(
op.getLoc(), aerFuncTable.at("aer_apply_u3"), args);
rewriter.eraseOp(op);
return success();
}
Expand All @@ -188,8 +257,8 @@ struct BuiltinCXConversionPat : public OpConversionPattern<quir::BuiltinCXOp> {
auto qID1 = quir::lookupQubitId(op->getOperand(0));
auto qID2 = quir::lookupQubitId(op->getOperand(1));
if(qID1 && qID2) {
auto q1 = qubitTable.at(*qID1);
auto q2 = qubitTable.at(*qID2);
auto q1 = ValueTable::lookupQubit(*qID1);
auto q2 = ValueTable::lookupQubit(*qID2);
rewriter.create<LLVM::CallOp>(
op->getLoc(),
aerFuncTable.at("aer_apply_cx"),
Expand Down Expand Up @@ -221,7 +290,7 @@ struct MeasureOpConversionPat : public OpConversionPattern<quir::MeasureOp> {
const unsigned arrSize = 1; // TODO
const IntegerAttr arraySizeAttr = rewriter.getIntegerAttr(i64Type, arrSize);
const unsigned int align = 8; // TODO
const auto qubit = qubitTable.at(*qID);
const auto qubit = ValueTable::lookupQubit(*qID);
auto arrSizeOp = rewriter.create<arith::ConstantOp>(
op->getLoc(), i64Type, arraySizeAttr);
auto qubitArr = rewriter.create<LLVM::AllocaOp>(
Expand Down Expand Up @@ -259,6 +328,8 @@ struct AngleConversionPat : public OpConversionPattern<quir::ConstantOp> {
FloatAttr fAttr = rewriter.getFloatAttr(fType, angle);
auto constOp = rewriter.create<arith::ConstantOp>(op->getLoc(), fType, fAttr);
rewriter.replaceOp(op, {constOp});
// TODO: We must build angle table before applying QUIRToAerPass
ValueTable::registerAngle(angle, constOp);
}
return success();
}
Expand Down Expand Up @@ -400,6 +471,7 @@ void QUIRToAERPass::declareAerFunctions(ModuleOp moduleOp) {
const auto voidType = LLVM::LLVMVoidType::get(context);
const auto i8Type = IntegerType::get(context, 8);
const auto i64Type = IntegerType::get(context, 64);
const auto f64Type = Float64Type::get(context);
const auto aerStateType = LLVM::LLVMPointerType::get(i8Type);
const auto strType = LLVM::LLVMPointerType::get(i8Type);
// @aer_state(...) -> i8*
Expand All @@ -416,6 +488,10 @@ void QUIRToAERPass::declareAerFunctions(ModuleOp moduleOp) {
// @aer_state_initialize(...) -> i8*
const auto aerStateInitType = LLVMFunctionType::get(aerStateType, {}, true);
registerFunc("aer_state_initialize", aerStateInitType);
// @aer_apply_u3(i8* noundef, i64 noundef, i64 noundef, i64 noundef) -> void
const auto aerApplyU3Type = LLVMFunctionType::get(
voidType, {aerStateType, i64Type, f64Type, f64Type, f64Type});
registerFunc("aer_apply_u3", aerApplyU3Type);
// @aer_apply_cx(i8* noundef, i64 noundef, i64 noundef) -> void
const auto aerApplyCXType = LLVMFunctionType::get(
voidType, {aerStateType, i64Type, i64Type});
Expand Down

0 comments on commit 3b3e861

Please sign in to comment.