From 4e69b7e9a25adfed38b5571b2861d53f893c9ded Mon Sep 17 00:00:00 2001 From: Phionx Date: Mon, 26 Feb 2024 22:41:57 -0500 Subject: [PATCH] fixed dag , batch_dag --- jaxquantum/quantum/base.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/jaxquantum/quantum/base.py b/jaxquantum/quantum/base.py index d89ebec..a0254c6 100644 --- a/jaxquantum/quantum/base.py +++ b/jaxquantum/quantum/base.py @@ -78,7 +78,17 @@ def dag(op: jnp.ndarray) -> jnp.ndarray: Returns: conjugate transpose of op """ - # return jnp.conj(op).T + return jnp.conj(op).T + +def batch_dag(op: jnp.ndarray) -> jnp.ndarray: + """Conjugate transpose. + + Args: + op: operator + + Returns: + conjugate of op, and transposes last two axes + """ return jnp.moveaxis(jnp.conj(op), -1, -2) # transposes last two axes, good for batching