Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Add new input type and new operator support to quantum oracle generator #813

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 135 additions & 7 deletions samples/qir/oracle-generator/oracle-generator/read_qir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ namespace detail
}
}

// for now, the sample only supports Boolean functions (i.e., all
// input and output types are either of type `Bool` or a tuple of
// type `Bool`)
// for now, the sample only supports Boolean and Integer functions (i.e., all
// input and output types are either of type `Bool`, `Int` or a tuple of
// type `Bool` `Int`)
if (!analyze_function_signature(function))
{
fmt::print("[e] function signature not supported: inputs must be Bool and return type must be Bool or "
Expand Down Expand Up @@ -256,10 +256,138 @@ namespace detail
}
break;

// Signed remainder operation
case llvm::Instruction::SRem:
{
// Get the previous instruction
const llvm::Instruction* prevInst = inst.getPrevNode();
if (prevInst) {
// Check the opcode of the previous instruction
unsigned int prevOpcode = prevInst->getOpcode();

if (prevOpcode == llvm::Instruction::Add)
{
// Perform modular addition inplace
auto const* op0 = prevInst -> getOperand(0u);
auto const* op1 = prevInst -> getOperand(1u);
auto const* ty0 = op0->getType();
auto const* ty1 = op1->getType();
auto const* op2 = inst.getOperand(1u);

value_signals[&inst] = value_signals[op0];

// Check if the operand is a constant integer
if (const llvm::ConstantInt* constantInt = llvm::dyn_cast<llvm::ConstantInt>(op2)) {
// Get the unsigned value of the constant
llvm::APInt intValue = constantInt->getValue();
uint64_t value = intValue.getZExtValue();

// The op2 value is converted to uint64_t in the 'value' variable
modular_adder_inplace(ntk, value_signals[&inst], value_signals[op1], value);
} else {
// Handle the case when op2 is not a constant integer
fmt::print("op2 is not a constant integer\n");
std::abort();
}


}

else if (prevOpcode == llvm::Instruction::Mul) {
// Get the operands from the previous instruction
auto const* op0 = prevInst->getOperand(0u);
auto const* op1 = prevInst->getOperand(1u);

// Get the operands from the current instruction
auto const* op2 = inst.getOperand(1u);
auto const* ty0 = op0->getType();
auto const* ty1 = op1->getType();
auto const* ty2 = op2->getType();

if (ty0->isIntegerTy(64) && ty1->isIntegerTy(64) && ty2->isIntegerTy(64))
{
auto signal0 = get_signal(ntk, op0);
auto signal1 = get_signal(ntk, op1);
const auto size = std::max(signal0.size(), signal1.size());
value_signals[&inst] = value_signals[op0];

if (const llvm::ConstantInt* constantInt = llvm::dyn_cast<llvm::ConstantInt>(op2))
{
// Get the unsigned value of the constant
llvm::APInt intValue = constantInt->getValue();
uint64_t value = intValue.getZExtValue();

// Now you have the op2 value converted to uint64_t in the 'value' variable
// You can use it as needed in your code
modular_multiplication_inplace(ntk, value_signals[&inst], value_signals[op1], value);
}
else
{
// Handle the case when op2 is not a constant integer
fmt::print("op2 is not a constant integer\n");
std::abort();
}
}
}

else
{
fmt::print("Unsupported previous opcode: {}\n", prevOpcode);
std::abort();
}
}
else {
fmt::print("No previous instruction found\n");
std::abort();
}
}
break;

// Multiplication operation
case llvm::Instruction::Mul:
{
auto const* op0 = inst.getOperand(0u);
auto const* op1 = inst.getOperand(1u);
auto const* ty0 = op0->getType();
auto const* ty1 = op1->getType();

if (ty0->isIntegerTy(64) && ty1->isIntegerTy(64))
{
auto signal0 = get_signal(ntk, op0);
auto signal1 = get_signal(ntk, op1);
const auto size = std::max(signal0.size(), signal1.size());
value_signals[&inst] = value_signals[inst.getOperand(0u)];
modular_multiplication_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)], 11);
}
else
{
fmt::print("Not Implemented");
std::abort();
}

}
break;

// Addition operation
case llvm::Instruction::Add:
{
auto const* op0 = inst.getOperand(0u);
auto const* op1 = inst.getOperand(1u);
auto const* ty0 = op0->getType();
auto const* ty1 = op1->getType();
value_signals[&inst] = value_signals[inst.getOperand(0u)];
modular_adder_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)]);
break;

if (ty0->isIntegerTy(64) && ty1->isIntegerTy(64))
{
modular_adder_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)], 11);

}
else
{
modular_adder_inplace(ntk, value_signals[&inst], value_signals[inst.getOperand(1u)]);
}
}
break;

case llvm::Instruction::Br:
{
Expand Down Expand Up @@ -474,7 +602,7 @@ namespace detail
/* input type */
for (const auto& arg : f.args())
{
if (arg.getType()->isIntegerTy(1u))
if (arg.getType()->isIntegerTy(64) || arg.getType()->isIntegerTy(1u))
{
continue;
}
Expand All @@ -484,7 +612,7 @@ namespace detail
/* output type */
auto const* retTy = f.getReturnType();

return retTy->isIntegerTy(1u) || is_valid_tuple_pointer_type(retTy);
return retTy->isIntegerTy(64) || retTy->isIntegerTy(1u) || is_valid_tuple_pointer_type(retTy);
}

private:
Expand Down
Loading