Skip to content

Commit

Permalink
Save aer state to stack
Browse files Browse the repository at this point in the history
  • Loading branch information
ibm-wakizaka committed Aug 2, 2023
1 parent f092545 commit 8460cfd
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 30 deletions.
81 changes: 56 additions & 25 deletions target_simulator/Conversion/QUIRToAer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ namespace qssc::targets::simulator::conversion {

namespace {

class AerStateWrapper {
public:
AerStateWrapper(Value mem) : mem(mem) {}

Value access(OpBuilder& builder) const {
return builder.create<LLVM::LoadOp>(
builder.getUnknownLoc(), mem, /*alignment=*/8);
}

private:
Value mem;
};

std::map<std::string, LLVM::LLVMFuncOp> aerFuncTable;

// TODO: Take care of errors, nan and inf.
Expand Down Expand Up @@ -95,10 +108,11 @@ class ValueTable {
AngleTable ValueTable::angleTable;
QubitTable ValueTable::qubitTable;

void buildQubitTable(ModuleOp moduleOp, Value aerState) {
void buildQubitTable(ModuleOp moduleOp, AerStateWrapper wrapper) {
ValueTable::clearQubitTable();
OpBuilder builder(moduleOp);


moduleOp.walk([&](quir::DeclareQubitOp declOp){
const int width = declOp.getType().dyn_cast<quir::QubitType>().getWidth();
if(width != 1) throw std::runtime_error(""); // TODO
Expand All @@ -108,6 +122,7 @@ void buildQubitTable(ModuleOp moduleOp, Value aerState) {
builder.getI64Type(), width);
auto constOp = builder.create<arith::ConstantOp>(
declOp->getLoc(), builder.getI64Type(), widthAttr);
auto aerState = wrapper.access(builder);
auto alloc = builder.create<LLVM::CallOp>(
declOp->getLoc(),
aerFuncTable.at("aer_allocate_qubits"),
Expand All @@ -123,6 +138,7 @@ void buildQubitTable(ModuleOp moduleOp, Value aerState) {
if(!qubitTable.empty()) {
auto declOp = qubitTable.rbegin()->second.getDefiningOp();
builder.setInsertionPointAfter(declOp);
auto aerState = wrapper.access(builder);
builder.create<LLVM::CallOp>(
declOp->getLoc(), aerFuncTable.at("aer_state_initialize"), aerState);
}
Expand All @@ -132,7 +148,7 @@ void buildQubitTable(ModuleOp moduleOp, Value aerState) {

// Assume qcs.init is called before all quir.declare_qubit operations
struct QCSInitConversionPat : public OpConversionPattern<qcs::SystemInitOp> {
explicit QCSInitConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, Value aerState)
explicit QCSInitConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, AerStateWrapper aerState)
: OpConversionPattern(typeConverter, ctx, /* benefit= */1), aerState(aerState)
{}

Expand All @@ -157,21 +173,22 @@ struct QCSInitConversionPat : public OpConversionPattern<qcs::SystemInitOp> {
// aer_state_configure(state, "method", "statevector")
// aer_state_configure(state, "device", "CPU")
// aer_state_configure(state, "precision", "double")
auto state = aerState.access(rewriter);
rewriter.create<LLVM::CallOp>(
initOp->getLoc(), aerFuncTable.at("aer_state_configure"),
ValueRange{aerState, globals["method"], globals["statevector"]});
ValueRange{state, globals["method"], globals["statevector"]});
rewriter.create<LLVM::CallOp>(
initOp->getLoc(), aerFuncTable.at("aer_state_configure"),
ValueRange{aerState, globals["device"], globals["CPU"]});
ValueRange{state, globals["device"], globals["CPU"]});
rewriter.create<LLVM::CallOp>(
initOp->getLoc(), aerFuncTable.at("aer_state_configure"),
ValueRange{aerState, globals["precision"], globals["double"]});
ValueRange{state, globals["precision"], globals["double"]});
rewriter.eraseOp(initOp);
return success();
}

private:
Value aerState;
AerStateWrapper aerState;
};

struct RemoveQCSShotInitConversionPat : public OpConversionPattern<qcs::ShotInitOp> {
Expand All @@ -188,7 +205,7 @@ struct RemoveQCSShotInitConversionPat : public OpConversionPattern<qcs::ShotInit
};

struct FinalizeConversionPat : public OpConversionPattern<qcs::SystemFinalizeOp> {
explicit FinalizeConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, Value aerState)
explicit FinalizeConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, AerStateWrapper aerState)
: OpConversionPattern(typeConverter, ctx, /* benefit= */1),
aerState(aerState)
{}
Expand All @@ -200,17 +217,17 @@ struct FinalizeConversionPat : public OpConversionPattern<qcs::SystemFinalizeOp>
rewriter.setInsertionPointAfter(finOp);
rewriter.create<LLVM::CallOp>(rewriter.getUnknownLoc(),
aerFuncTable.at("aer_state_finalize"),
aerState);
aerState.access(rewriter));
rewriter.eraseOp(finOp);
return success();
}

private:
Value aerState;
AerStateWrapper aerState;
};

struct BuiltinUopConversionPat : public OpConversionPattern<quir::Builtin_UOp> {
explicit BuiltinUopConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, Value aerState)
explicit BuiltinUopConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, AerStateWrapper aerState)
: OpConversionPattern(typeConverter, ctx, /*benefit=*/1),
aerState(aerState)
{}
Expand All @@ -223,7 +240,7 @@ struct BuiltinUopConversionPat : public OpConversionPattern<quir::Builtin_UOp> {
assert(op.getOperands().size() == 4);

std::vector<Value> args;
args.emplace_back(aerState);
args.emplace_back(aerState.access(rewriter));
{ // qubit
const auto qID = quir::lookupQubitId(op.target());
assert(qID);
Expand All @@ -241,11 +258,11 @@ struct BuiltinUopConversionPat : public OpConversionPattern<quir::Builtin_UOp> {
}

private:
Value aerState;
AerStateWrapper aerState;
};

struct BuiltinCXConversionPat : public OpConversionPattern<quir::BuiltinCXOp> {
explicit BuiltinCXConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, Value aerState)
explicit BuiltinCXConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, AerStateWrapper aerState)
: OpConversionPattern(typeConverter, ctx, /*benefit=*/1),
aerState(aerState)
{}
Expand All @@ -259,22 +276,24 @@ struct BuiltinCXConversionPat : public OpConversionPattern<quir::BuiltinCXOp> {
if(qID1 && qID2) {
auto q1 = ValueTable::lookupQubit(*qID1);
auto q2 = ValueTable::lookupQubit(*qID2);
auto state = aerState.access(rewriter);
rewriter.create<LLVM::CallOp>(
op->getLoc(),
aerFuncTable.at("aer_apply_cx"),
ValueRange{aerState, q1, q2});
ValueRange{state, q1, q2});
rewriter.eraseOp(op);
return success();
}
return failure();
}

private:
Value aerState;
AerStateWrapper aerState;
};

struct MeasureOpConversionPat : public OpConversionPattern<quir::MeasureOp> {
explicit MeasureOpConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, Value aerState)
explicit MeasureOpConversionPat(
MLIRContext *ctx, TypeConverter &typeConverter, AerStateWrapper aerState)
: OpConversionPattern(typeConverter, ctx, /*benefit=*/1),
aerState(aerState)
{}
Expand All @@ -298,18 +317,19 @@ struct MeasureOpConversionPat : public OpConversionPattern<quir::MeasureOp> {
LLVM::LLVMPointerType::get(i64Type),
arrSizeOp, align);
rewriter.create<LLVM::StoreOp>(op->getLoc(), qubit, qubitArr);
auto state = aerState.access(rewriter);
auto meas = rewriter.create<LLVM::CallOp>(
op->getLoc(),
aerFuncTable.at("aer_apply_measure"),
ValueRange{aerState, qubitArr.getResult(), arrSizeOp});
ValueRange{state, qubitArr.getResult(), arrSizeOp});
rewriter.replaceOp(op, meas.getResult(0));
return success();
}
return failure();
}

private:
Value aerState;
AerStateWrapper aerState;
};

struct AngleConversionPat : public OpConversionPattern<quir::ConstantOp> {
Expand Down Expand Up @@ -353,7 +373,7 @@ struct RemoveConversionPat : public OpConversionPattern<Op> {
// TDOO
// Probably I should implement this pattern as another pass
struct FunctionConversionPat : public OpConversionPattern<mlir::FuncOp> {
explicit FunctionConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, Value aerState)
explicit FunctionConversionPat(MLIRContext *ctx, TypeConverter &typeConverter, AerStateWrapper aerState)
: OpConversionPattern(typeConverter, ctx, 1), aerState(aerState)
{}

Expand All @@ -365,7 +385,7 @@ struct FunctionConversionPat : public OpConversionPattern<mlir::FuncOp> {
}

private:
Value aerState;
AerStateWrapper aerState;
};


Expand Down Expand Up @@ -414,16 +434,27 @@ void QUIRToAERPass::runOnOperation(SimulatorSystem &system) {

// Aer initialization
declareAerFunctions(moduleOp);
auto aerState = [&]() -> Value {
auto aerState = [&]() -> AerStateWrapper {
auto mainFunc = mlir::quir::getMainFunction(moduleOp);
OpBuilder builder(mainFunc);

auto mainBody = &mainFunc->getRegion(0).getBlocks().front();
builder.setInsertionPointToStart(mainBody);
return builder.create<LLVM::CallOp>(
builder.getUnknownLoc(),
aerFuncTable.at("aer_state"),
ValueRange{}).getResult(0);
const auto i8Type = IntegerType::get(context, 8);
const auto i64Type = IntegerType::get(context, 64);
auto arrSizeOp = builder.create<arith::ConstantOp>(
builder.getUnknownLoc(), i64Type,
builder.getIntegerAttr(i64Type, 1));
auto alloca = builder.create<LLVM::AllocaOp>(
builder.getUnknownLoc(),
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(i8Type)),
arrSizeOp, /*alignment=*/8); // alignment must be 8
auto call = builder.create<LLVM::CallOp>(
builder.getUnknownLoc(),
aerFuncTable.at("aer_state"),
ValueRange{}).getResult(0);
builder.create<LLVM::StoreOp>(builder.getUnknownLoc(), call, alloca);
return AerStateWrapper(alloca);
}();
// Must build qubit table before applying any conversion
buildQubitTable(moduleOp, aerState);
Expand Down
17 changes: 12 additions & 5 deletions target_simulator/test/python_lib/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
InputType,
OutputType,
CompileOptions,
QSSCompilationFailure,
)
from qss_compiler.exceptions import QSSCompilationFailure

compiler_extra_args = ["--enable-circuits=false"]


def check_mlir_string(mlir):
assert isinstance(mlir, str)
assert "module" in mlir
Expand Down Expand Up @@ -81,7 +82,9 @@ def test_compile_str_to_qem(mock_config_file, example_qasm3_str, check_payload):
check_payload(payload_filelike)


def test_compile_file_to_qem_file(example_qasm3_tmpfile, mock_config_file, tmp_path, check_payload):
def test_compile_file_to_qem_file(
example_qasm3_tmpfile, mock_config_file, tmp_path, check_payload
):
"""Test that we can compile a file input via the interface compile_file
to a QEM payload into a file"""
tmpfile = tmp_path / "payload.qem"
Expand All @@ -105,7 +108,9 @@ def test_compile_file_to_qem_file(example_qasm3_tmpfile, mock_config_file, tmp_p
check_payload(payload)


def test_compile_str_to_qem_file(mock_config_file, tmp_path, example_qasm3_str, check_payload):
def test_compile_str_to_qem_file(
mock_config_file, tmp_path, example_qasm3_str, check_payload
):
"""Test that we can compile an OpenQASM3 string via the interface
compile_file to a QEM payload in an output file"""
tmpfile = tmp_path / "payload.qem"
Expand Down Expand Up @@ -175,7 +180,7 @@ def test_compile_options(mock_config_file, example_qasm3_str):
config_path=mock_config_file,
shot_delay=100,
num_shots=10000,
extra_args= compiler_extra_args + ["--pass-statistics"],
extra_args=compiler_extra_args + ["--pass-statistics"],
)

mlir = compile_str(example_qasm3_str, compile_options=compile_options)
Expand Down Expand Up @@ -223,7 +228,9 @@ async def test_async_compile_str(mock_config_file, example_qasm3_str, check_payl


@pytest.mark.asyncio
async def test_async_compile_file(example_qasm3_tmpfile, mock_config_file, check_payload):
async def test_async_compile_file(
example_qasm3_tmpfile, mock_config_file, check_payload
):
"""Test that async wrapper produces correct output and does not block the even loop."""
async_compile = compile_file_async(
example_qasm3_tmpfile,
Expand Down

0 comments on commit 8460cfd

Please sign in to comment.