Skip to content

Commit

Permalink
Revert revert of 32-to-64-bit update (#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqc-alec authored Jun 18, 2024
1 parent 36a4e5c commit d545141
Show file tree
Hide file tree
Showing 19 changed files with 99 additions and 66 deletions.
4 changes: 2 additions & 2 deletions pytket/binders/include/UnitRegister.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ conventions defined here:
registers are up to _TKET_REG_WIDTH wide in bits and are interpreted as
equivalent to the C++ type _tket_uint_t
*/
#define _TKET_REG_WIDTH 32
typedef uint32_t _tket_uint_t;
#define _TKET_REG_WIDTH 64
typedef uint64_t _tket_uint_t;

template <typename T>
class UnitRegister {
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def package(self):
cmake.install()

def requirements(self):
self.requires("tket/1.3.10@tket/stable")
self.requires("tket/1.3.11@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tkassert/0.3.4@tket/stable")
Expand Down
6 changes: 6 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

Unreleased
----------

* Support classical transforms and predicates, and QASM registers, with up to 64
bits. Add an attribute to the pytket module to assert this.

1.29.2 (June 2024)
------------------

Expand Down
6 changes: 6 additions & 0 deletions pytket/pytket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@
config.write_file(pytket_config_file)

__path__ = __import__("pkgutil").extend_path(__path__, __name__)

"""Flag indicating 64-bit support.
If True, classical transforms and predicates, and QASM registers, with up to 64
bits are supported."""
bit_width_64 = True
2 changes: 1 addition & 1 deletion pytket/pytket/_tket/unit_id.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,4 @@ _DEBUG_ONE_REG_PREFIX: str = 'tk_DEBUG_ONE_REG'
_DEBUG_ZERO_REG_PREFIX: str = 'tk_DEBUG_ZERO_REG'
_TEMP_BIT_NAME: str = 'tk_SCRATCH_BIT'
_TEMP_BIT_REG_BASE: str = 'tk_SCRATCH_BITREG'
_TEMP_REG_SIZE: int = 32
_TEMP_REG_SIZE: int = 64
2 changes: 1 addition & 1 deletion pytket/pytket/circuit/add_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _add_condition(
target_bits = pred_exp.to_list()

minval = 0
maxval = (1 << 32) - 1
maxval = (1 << 64) - 1
if isinstance(condition, RegLt):
maxval = pred_val - 1
elif isinstance(condition, RegGt):
Expand Down
11 changes: 11 additions & 0 deletions pytket/pytket/qasm/qasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,18 @@ def _retrieve_registers(


def _parse_range(minval: int, maxval: int, maxwidth: int) -> Tuple[str, int]:
if maxwidth > 64:
raise NotImplementedError("Register width exceeds maximum of 64.")

REGMAX = (1 << maxwidth) - 1

if minval > REGMAX:
raise NotImplementedError("Range's lower bound exceeds register capacity.")
elif minval > maxval:
raise NotImplementedError("Range's lower bound exceeds upper bound.")
elif maxval > REGMAX:
maxval = REGMAX

if minval == maxval:
return ("==", minval)
elif minval == 0:
Expand Down
26 changes: 13 additions & 13 deletions pytket/tests/classical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

from pytket.passes import DecomposeClassicalExp, FlattenRegisters

from strategies import reg_name_regex, binary_digits, uint32 # type: ignore
from strategies import reg_name_regex, binary_digits, uint32, uint64 # type: ignore

curr_file_path = Path(__file__).resolve().parent

Expand Down Expand Up @@ -840,7 +840,7 @@ def primitive_reg_logic_exps(
RegGeq,
),
):
const_compare = draw(uint32)
const_compare = draw(uint64)
args.append(const_compare)
else:
args.append(draw(bit_regs))
Expand All @@ -854,8 +854,8 @@ def primitive_reg_logic_exps(
@given(
reg_exp=primitive_reg_logic_exps(),
constants=strategies.tuples(
uint32,
uint32,
uint64,
uint64,
),
)
def test_reg_exp(reg_exp: RegLogicExp, constants: Tuple[int, int]) -> None:
Expand Down Expand Up @@ -929,7 +929,7 @@ def composite_bit_logic_exps(
def composite_reg_logic_exps(
draw: DrawType,
regs: SearchStrategy[BitRegister] = bit_register(),
constants: SearchStrategy[int] = uint32,
constants: SearchStrategy[int] = uint64,
operators: SearchStrategy[Callable] = strategies.sampled_from(
[
operator.and_,
Expand Down Expand Up @@ -979,7 +979,7 @@ def reg_const_predicates(
operators: SearchStrategy[
Callable[[Union[RegLogicExp, BitRegister], int], PredicateExp]
] = strategies.sampled_from([reg_eq, reg_neq, reg_lt, reg_gt, reg_leq, reg_geq]),
constants: SearchStrategy[int] = uint32,
constants: SearchStrategy[int] = uint64,
) -> PredicateExp:
return draw(operators)(draw(exp), draw(constants)) # type: ignore

Expand Down Expand Up @@ -1131,10 +1131,10 @@ def test_decomposition_known() -> None:
)
check_serialization_roundtrip(circ)

temp_bits = BitRegister(_TEMP_BIT_NAME, 32)
temp_bits = BitRegister(_TEMP_BIT_NAME, 64)

def temp_reg(i: int) -> BitRegister:
return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 32)
return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 64)

for b in (temp_bits[i] for i in range(0, 10)):
conditioned_circ.add_bit(b)
Expand Down Expand Up @@ -1170,13 +1170,13 @@ def temp_reg(i: int) -> BitRegister:
conditioned_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
conditioned_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0)
conditioned_circ.add_c_range_predicate(
4, 4294967295, registers_lists[3], temp_bits[6]
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
conditioned_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1)
conditioned_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
conditioned_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1)
conditioned_circ.add_c_range_predicate(
3, 4294967295, registers_lists[5], temp_bits[8]
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)
conditioned_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1)

Expand All @@ -1196,7 +1196,7 @@ def temp_reg(i: int) -> BitRegister:
decomposed_circ.add_bit(b)

decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_0", 3))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 32))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 64))

decomposed_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1)
decomposed_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1)
Expand All @@ -1211,11 +1211,11 @@ def temp_reg(i: int) -> BitRegister:
decomposed_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4])
decomposed_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
decomposed_circ.add_c_range_predicate(
4, 4294967295, registers_lists[3], temp_bits[6]
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
decomposed_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
decomposed_circ.add_c_range_predicate(
3, 4294967295, registers_lists[5], temp_bits[8]
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)

decomposed_circ.add_c_xor(bits[5], bits[6], temp_bits[2])
Expand Down
8 changes: 8 additions & 0 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,14 @@ def test_const_condition() -> None:
)


def test_range_with_maxwidth() -> None:
c = Circuit(1)
a = c.add_c_register("a", 8)
c.X(0, condition=reg_geq(a, 1))
qasm = circuit_to_qasm_str(c, header="hqslib1", maxwidth=63)
assert "if(a>=1) x q[0];" in qasm


if __name__ == "__main__":
test_qasm_correct()
test_qasm_qubit()
Expand Down
1 change: 1 addition & 0 deletions pytket/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

binary_digits = st.sampled_from((0, 1))
uint32 = st.integers(min_value=1, max_value=1 << 32 - 1)
uint64 = st.integers(min_value=1, max_value=1 << 64 - 1)
reg_name_regex = re.compile("[a-z][a-zA-Z0-9_]*")


Expand Down
10 changes: 5 additions & 5 deletions schemas/circuit_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@
},
"lower": {
"type": "integer",
"maximum": 4294967295,
"description": "The inclusive minimum of the RangePredicate as a uint32."
"maximum": 18446744073709551615,
"description": "The inclusive minimum of the RangePredicate as a uint64."
},
"upper": {
"type": "integer",
"maximum": 4294967295,
"description": "The inclusive maximum of the RangePredicate as a uint32."
"maximum": 18446744073709551615,
"description": "The inclusive maximum of the RangePredicate as a uint64."
}
},
"required": [
Expand Down Expand Up @@ -1179,4 +1179,4 @@
"additionalProperties": false
}
}
}
}
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.10"
version = "1.3.11"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
22 changes: 11 additions & 11 deletions tket/include/tket/Ops/ClassicalOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,18 @@ class ClassicalTransformOp : public ClassicalEvalOp {
* @param values table of binary-encoded values
* @param name name of operation
*
* @pre n <= 32
* @pre n <= 64
*/
ClassicalTransformOp(
unsigned n, const std::vector<uint32_t> &values,
unsigned n, const std::vector<uint64_t> &values,
const std::string &name = "ClassicalTransform");

std::vector<bool> eval(const std::vector<bool> &x) const override;

std::vector<uint32_t> get_values() const { return values_; }
std::vector<uint64_t> get_values() const { return values_; }

private:
const std::vector<uint32_t> values_;
const std::vector<uint64_t> values_;
};

/**
Expand Down Expand Up @@ -341,15 +341,15 @@ class RangePredicateOp : public PredicateOp {
* @param b upper bound in little-endian encoding
*/
RangePredicateOp(
unsigned n, uint32_t a = 0,
uint32_t b = std::numeric_limits<uint32_t>::max())
unsigned n, uint64_t a = 0,
uint64_t b = std::numeric_limits<uint64_t>::max())
: PredicateOp(OpType::RangePredicate, n, "RangePredicate"), a(a), b(b) {}

std::string get_name(bool latex) const override;

uint32_t upper() const { return b; }
uint64_t upper() const { return b; }

uint32_t lower() const { return a; }
uint64_t lower() const { return a; }

std::vector<bool> eval(const std::vector<bool> &x) const override;

Expand All @@ -359,8 +359,8 @@ class RangePredicateOp : public PredicateOp {
bool is_equal(const Op &other) const override;

private:
uint32_t a;
uint32_t b;
uint64_t a;
uint64_t b;
};

/**
Expand All @@ -378,7 +378,7 @@ class ExplicitPredicateOp : public PredicateOp {
* @param values table of values
* @param name name of operation
*
* @pre n <= 32
* @pre n <= 64
*/
ExplicitPredicateOp(
unsigned n, const std::vector<bool> &values,
Expand Down
4 changes: 2 additions & 2 deletions tket/include/tket/Utils/HelperFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ bimap_to_map(MapT& bm) {
}

/**
* Reverse bits 0,1,...,w-1 of the number v, assuming v < 2^w and w <= 32.
* Reverse bits 0,1,...,w-1 of the number v, assuming v < 2^w and w <= 64.
*/
uint32_t reverse_bits(uint32_t v, unsigned w);
uint64_t reverse_bits(uint64_t v, unsigned w);

/**
* @brief
Expand Down
Loading

0 comments on commit d545141

Please sign in to comment.