Skip to content

Commit

Permalink
update representation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ziofil committed Nov 15, 2024
1 parent 5632af6 commit 5d84f9f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 118 deletions.
92 changes: 6 additions & 86 deletions mrmustard/physics/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from __future__ import annotations
from typing import Sequence
from enum import Enum

import numpy as np

Expand All @@ -32,41 +31,11 @@

from .ansatz import Ansatz, PolyExpAnsatz, ArrayAnsatz
from .triples import identity_Abc
from .wires import Wires
from .wires import Wires, ReprEnum

__all__ = ["Representation"]


class RepEnum(Enum):
r"""
An enum to represent what representation a wire is in.
"""

NONETYPE = 0
BARGMANN = 1
FOCK = 2
QUADRATURE = 3
PHASESPACE = 4

@classmethod
def from_ansatz(cls, ansatz: Ansatz):
r"""
Returns a ``RepEnum`` from an ``Ansatz``.
Args:
ansatz: The ansatz.
"""
if isinstance(ansatz, PolyExpAnsatz):
return cls(1)
elif isinstance(ansatz, ArrayAnsatz):
return cls(2)
else:
return cls(0)

def __repr__(self) -> str:
return self.name


class Representation:
r"""
A class for representations.
Expand All @@ -87,10 +56,7 @@ class Representation:
"""

def __init__(
self,
ansatz: Ansatz | None = None,
wires: Wires | Sequence[tuple[int]] | None = None,
idx_reps: dict | None = None,
self, ansatz: Ansatz | None = None, wires: Wires | Sequence[tuple[int]] | None = None
) -> None:
self._ansatz = ansatz

Expand Down Expand Up @@ -128,9 +94,6 @@ def __init__(
self._ansatz = ansatz.reorder(tuple(perm))

self._wires = wires
self._idx_reps = idx_reps or dict.fromkeys(
wires.indices, (RepEnum.from_ansatz(ansatz), None)
)

@property
def adjoint(self) -> Representation:
Expand All @@ -142,12 +105,7 @@ def adjoint(self) -> Representation:
kets = self.wires.ket.indices
ansatz = self.ansatz.reorder(kets + bras).conj if self.ansatz else None
wires = self.wires.adjoint
idx_reps = {}
for i, j in enumerate(kets):
idx_reps[i] = self._idx_reps[j]
for i, j in enumerate(bras):
idx_reps[i + len(kets)] = self._idx_reps[j]
return Representation(ansatz, wires, idx_reps)
return Representation(ansatz, wires)

@property
def ansatz(self) -> Ansatz | None:
Expand All @@ -168,16 +126,7 @@ def dual(self) -> Representation:
ob = self.wires.bra.output.indices
ansatz = self.ansatz.reorder(ib + ob + ik + ok).conj if self.ansatz else None
wires = self.wires.dual
idx_reps = {}
for i, j in enumerate(ib):
idx_reps[i] = self._idx_reps[j]
for i, j in enumerate(ob):
idx_reps[i + len(ib)] = self._idx_reps[j]
for i, j in enumerate(ik):
idx_reps[i + len(ib + ob)] = self._idx_reps[j]
for i, j in enumerate(ok):
idx_reps[i + len(ib + ob + ik)] = self._idx_reps[j]
return Representation(ansatz, wires, idx_reps)
return Representation(ansatz, wires)

@property
def wires(self) -> Wires | None:
Expand Down Expand Up @@ -301,37 +250,9 @@ def _matmul_indices(self, other: Representation) -> tuple[tuple[int, ...], tuple
idx_zconj += other.wires.ket.input[ket_modes].indices
return idx_z, idx_zconj

def _matmul_idx_reps(self, wires_result: Wires, other: Representation):
r"""
Returns the new representation mappings when contracting ``self`` and ``other``.
Args:
wires_result: The resulting wires after contraction.
other: The representation contracting with.
"""
idx_reps = {}
for id in wires_result.ids:
if id in other.wires.ids:
temp_rep = other
else:
temp_rep = self
for t in (0, 1, 2, 3, 4, 5):
try:
idx = temp_rep.wires.ids_index_dicts[t][id]
n_idx = wires_result.ids_index_dicts[t][id]
idx_reps[n_idx] = temp_rep._idx_reps[idx]
break
except KeyError:
continue
return idx_reps

def __eq__(self, other):
if isinstance(other, Representation):
return (
self.ansatz == other.ansatz
and self.wires == other.wires
and self._idx_reps == other._idx_reps
)
return self.ansatz == other.ansatz and self.wires == other.wires
return False

def __matmul__(self, other: Representation):
Expand All @@ -347,5 +268,4 @@ def __matmul__(self, other: Representation):

rep = self_ansatz[idx_z] @ other_ansatz[idx_zconj]
rep = rep.reorder(perm) if perm else rep
idx_reps = self._matmul_idx_reps(wires_result, other)
return Representation(rep, wires_result, idx_reps)
return Representation(rep, wires_result)
41 changes: 9 additions & 32 deletions tests/test_physics/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,56 +58,33 @@ def test_init(self, triple):
empty_rep = Representation()
assert empty_rep.ansatz is None
assert empty_rep.wires == Wires()
assert empty_rep._idx_reps == {}

ansatz = PolyExpAnsatz(*triple)
wires = Wires(set([0, 1]))
rep = Representation(ansatz, wires)
assert rep.ansatz == ansatz
assert rep.wires == wires
assert rep._idx_reps == dict.fromkeys(wires.indices, (RepEnum.from_ansatz(ansatz), None))

@pytest.mark.parametrize("triple", [Abc_n2])
def test_adjoint_idx_reps(self, triple):
ansatz = PolyExpAnsatz(*triple)
wires = Wires(modes_out_bra=set([0]), modes_out_ket=set([0]))
idx_reps = {0: (RepEnum.BARGMANN, None), 1: (RepEnum.QUADRATURE, 0.1)}
rep = Representation(ansatz, wires, idx_reps)
adj_rep = rep.adjoint
assert adj_rep._idx_reps == {
1: (RepEnum.BARGMANN, None),
0: (RepEnum.QUADRATURE, 0.1),
}

@pytest.mark.parametrize("triple", [Abc_n2])
def test_dual_idx_reps(self, triple):
ansatz = PolyExpAnsatz(*triple)
wires = Wires(modes_out_bra=set([0]), modes_in_bra=set([0]))
idx_reps = {0: (RepEnum.BARGMANN, None), 1: (RepEnum.QUADRATURE, 0.1)}
rep = Representation(ansatz, wires, idx_reps)
adj_rep = rep.dual
assert adj_rep._idx_reps == {
1: (RepEnum.BARGMANN, None),
0: (RepEnum.QUADRATURE, 0.1),
}

def test_matmul_btoq(self, d_gate_rep, btoq_rep):
q_dgate = d_gate_rep @ btoq_rep
assert q_dgate._idx_reps == {
0: (RepEnum.QUADRATURE, 0.2),
1: (RepEnum.BARGMANN, None),
}
for w in q_dgate.wires.input.wires:
assert w.repr == RepEnum.BARGMANN
for w in q_dgate.wires.output.wires:
assert w.repr == RepEnum.QUADRATURE
assert w.param == [0.2]

def test_to_bargmann(self, d_gate_rep):
d_fock = d_gate_rep.to_fock(shape=(4, 6))
d_barg = d_fock.to_bargmann()
assert d_fock.ansatz._original_abc_data == d_gate_rep.ansatz.triple
assert d_barg == d_gate_rep
assert all((k[0] == RepEnum.BARGMANN for k in d_barg._idx_reps.values()))
for w in d_barg.wires.wires:
assert w.repr == RepEnum.BARGMANN

def test_to_fock(self, d_gate_rep):
d_fock = d_gate_rep.to_fock(shape=(4, 6))
assert d_fock.ansatz == ArrayAnsatz(
math.hermite_renormalized(*displacement_gate_Abc(x=0.1, y=0.1), shape=(4, 6))
)
assert all((k[0] == RepEnum.FOCK for k in d_fock._idx_reps.values()))
for w in d_fock.wires.wires:
assert w.repr == RepEnum.FOCK

0 comments on commit 5d84f9f

Please sign in to comment.