diff --git a/target_simulator/Conversion/QUIRToAer.cpp b/target_simulator/Conversion/QUIRToAer.cpp index 433eb4178..71043425a 100644 --- a/target_simulator/Conversion/QUIRToAer.cpp +++ b/target_simulator/Conversion/QUIRToAer.cpp @@ -50,10 +50,53 @@ namespace qssc::targets::simulator::conversion { namespace { std::map aerFuncTable; -std::map qubitTable; + +// TODO: Take care of errors, nan and inf. +using AngleTable = std::map; +using QubitTable = std::map; + +// TODO: We should map some IDs (attributes?) to values +class ValueTable { +public: + static void registerAngle(double angle, Value v) { + assert(!std::isnan(angle) && !std::isinf(angle)); + angleTable[angle] = v; + } + + static void registerQubit(int qid, Value v) { + qubitTable[qid] = v; + } + + static Value lookupAngle(double angle) { + return angleTable.at(angle); + } + + static Value lookupQubit(int qid) { + return qubitTable.at(qid); + } + + static QubitTable getQubits() { + return qubitTable; + } + + static void clear() { + angleTable.clear(); + qubitTable.clear(); + } + + static void clearQubitTable() { + qubitTable.clear(); + } + +private: + static AngleTable angleTable; + static QubitTable qubitTable; +}; +AngleTable ValueTable::angleTable; +QubitTable ValueTable::qubitTable; void buildQubitTable(ModuleOp moduleOp, Value aerState) { - qubitTable.clear(); + ValueTable::clearQubitTable(); OpBuilder builder(moduleOp); moduleOp.walk([&](quir::DeclareQubitOp declOp){ @@ -71,8 +114,18 @@ void buildQubitTable(ModuleOp moduleOp, Value aerState) { ValueRange{aerState, constOp}); const int id = *declOp.id(); - qubitTable[id] = alloc.getResult(0); + ValueTable::registerQubit(id, alloc.getResult(0)); }); + + // TODO + // Assume that qubit declaration with the biggest id is called after any other declaration + auto qubitTable = ValueTable::getQubits(); + if(!qubitTable.empty()) { + auto declOp = qubitTable.rbegin()->second.getDefiningOp(); + builder.setInsertionPointAfter(declOp); + builder.create( + declOp->getLoc(), aerFuncTable.at("aer_state_initialize"), aerState); + } } } @@ -98,7 +151,7 @@ struct QCSInitConversionPat : public OpConversionPattern { const auto with_null = config_str + std::string("\0", 1); globals[config_str] = LLVM::createGlobalString( initOp->getLoc(), rewriter, var_name, - with_null, LLVM::Linkage::Internal); + with_null, LLVM::Linkage::Private); } // configure // aer_state_configure(state, "method", "statevector") @@ -166,7 +219,23 @@ struct BuiltinUopConversionPat : public OpConversionPattern { quir::Builtin_UOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - //TODO + // qubit, angle{1, 2, 3} + assert(op.getOperands().size() == 4); + + std::vector args; + args.emplace_back(aerState); + { // qubit + const auto qID = quir::lookupQubitId(op.target()); + assert(qID); + args.emplace_back(ValueTable::lookupQubit(*qID)); + } + for(auto val : {op.theta(), op.phi(), op.lambda()}) { + auto angleOp = dyn_cast(val.getDefiningOp()); + const double angle = angleOp.value().dyn_cast().getValue().convertToDouble(); + args.emplace_back(ValueTable::lookupAngle(angle)); + } + rewriter.create( + op.getLoc(), aerFuncTable.at("aer_apply_u3"), args); rewriter.eraseOp(op); return success(); } @@ -188,8 +257,8 @@ struct BuiltinCXConversionPat : public OpConversionPattern { auto qID1 = quir::lookupQubitId(op->getOperand(0)); auto qID2 = quir::lookupQubitId(op->getOperand(1)); if(qID1 && qID2) { - auto q1 = qubitTable.at(*qID1); - auto q2 = qubitTable.at(*qID2); + auto q1 = ValueTable::lookupQubit(*qID1); + auto q2 = ValueTable::lookupQubit(*qID2); rewriter.create( op->getLoc(), aerFuncTable.at("aer_apply_cx"), @@ -221,7 +290,7 @@ struct MeasureOpConversionPat : public OpConversionPattern { const unsigned arrSize = 1; // TODO const IntegerAttr arraySizeAttr = rewriter.getIntegerAttr(i64Type, arrSize); const unsigned int align = 8; // TODO - const auto qubit = qubitTable.at(*qID); + const auto qubit = ValueTable::lookupQubit(*qID); auto arrSizeOp = rewriter.create( op->getLoc(), i64Type, arraySizeAttr); auto qubitArr = rewriter.create( @@ -259,6 +328,8 @@ struct AngleConversionPat : public OpConversionPattern { FloatAttr fAttr = rewriter.getFloatAttr(fType, angle); auto constOp = rewriter.create(op->getLoc(), fType, fAttr); rewriter.replaceOp(op, {constOp}); + // TODO: We must build angle table before applying QUIRToAerPass + ValueTable::registerAngle(angle, constOp); } return success(); } @@ -400,6 +471,7 @@ void QUIRToAERPass::declareAerFunctions(ModuleOp moduleOp) { const auto voidType = LLVM::LLVMVoidType::get(context); const auto i8Type = IntegerType::get(context, 8); const auto i64Type = IntegerType::get(context, 64); + const auto f64Type = Float64Type::get(context); const auto aerStateType = LLVM::LLVMPointerType::get(i8Type); const auto strType = LLVM::LLVMPointerType::get(i8Type); // @aer_state(...) -> i8* @@ -416,6 +488,10 @@ void QUIRToAERPass::declareAerFunctions(ModuleOp moduleOp) { // @aer_state_initialize(...) -> i8* const auto aerStateInitType = LLVMFunctionType::get(aerStateType, {}, true); registerFunc("aer_state_initialize", aerStateInitType); + // @aer_apply_u3(i8* noundef, i64 noundef, i64 noundef, i64 noundef) -> void + const auto aerApplyU3Type = LLVMFunctionType::get( + voidType, {aerStateType, i64Type, f64Type, f64Type, f64Type}); + registerFunc("aer_apply_u3", aerApplyU3Type); // @aer_apply_cx(i8* noundef, i64 noundef, i64 noundef) -> void const auto aerApplyCXType = LLVMFunctionType::get( voidType, {aerStateType, i64Type, i64Type});