Skip to content

Commit

Permalink
added shaped_data and fixed errors in ptrace
Browse files Browse the repository at this point in the history
  • Loading branch information
Phionx committed Mar 18, 2024
1 parent 7f24daf commit 1aa5532
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
29 changes: 27 additions & 2 deletions jaxquantum/core/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Converting between different object types.
"""

from numbers import Number
from jax import config, Array
from qutip import Qobj
from typing import Optional
from typing import Optional, Union, List
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -46,15 +47,37 @@ def jqt2qt(jqt_obj):

return Qobj(np.array(jqt_obj.data), dims=jqt_obj.dims)

def jnp2jqt(arr: Array, dims: Optional[DIMS_TYPE] = None):

def extract_dims(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
"""Extract dims from a JAX array or Qarray.
Args:
arr: JAX array or Qarray.
dims: Qarray dims.
Returns:
Qarray dims.
"""
if isinstance(dims[0], Number):
is_op = len(arr.shape) == 2 and arr.shape[0] == arr.shape[1]
if is_op:
dims = [dims, dims]
else:
dims = [dims, [1] * len(dims)] # defaults to ket
return dims


def jnp2jqt(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
"""JAX array -> QuTiP state.
Args:
jnp_obj: JAX array.
dims: Qarray dims.
Returns:
QuTiP state.
"""
dims = extract_dims(arr, dims)
return Qarray.create(arr, dims=dims)


Expand All @@ -67,6 +90,8 @@ def jnps2jqts(arrs: Array, dims: Optional[DIMS_TYPE] = None):
Returns:
QuTiP state.
"""

dims = extract_dims(arrs[0], dims)
return [Qarray.create(arr, dims=dims) for arr in arrs]

def jqts2jnps(qarrs: Qarray):
Expand Down
15 changes: 9 additions & 6 deletions jaxquantum/core/qarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def dims(self):
def data(self):
return self._data

@property
def shaped_data(self):
return self._data.reshape(self.dims[0] + self.dims[1])

def _str_header(self):
out = ", ".join([
Expand Down Expand Up @@ -334,7 +337,7 @@ def tensor(*args, **kwargs) -> Qarray:
dims[1] += arg.dims[1]
return Qarray.create(data, dims=dims)

def tr(qarr: Qarray, **kwargs) -> Qarray:
def tr(qarr: Qarray, **kwargs) -> jnp.complex128:
"""Full trace.
Args:
Expand All @@ -343,7 +346,7 @@ def tr(qarr: Qarray, **kwargs) -> Qarray:
Returns:
Full trace.
"""
return jnp.trace(qarr.data, **kwargs)
return trace(qarr, **kwargs)


def expm_data(data: Array, **kwargs) -> Array:
Expand Down Expand Up @@ -425,13 +428,10 @@ def ptrace(qarr: Qarray, indx) -> Qarray:
"""

qarr = ket2dm(qarr)
rho = qarr.data
rho = qarr.shaped_data
dims = qarr.dims

Nq = len(dims[0])
dims2 = jnp.concatenate(jnp.array(dims))

rho = rho.reshape(dims2)

indxs = [indx, indx + Nq]
for j in range(Nq):
Expand All @@ -446,6 +446,9 @@ def ptrace(qarr: Qarray, indx) -> Qarray:

return Qarray.create(rho)

def trace(qarr: Qarray, **kwargs) -> Qarray:
return jnp.trace(qarr.data, **kwargs)

def dag(qarr: Qarray) -> Qarray:
"""Conjugate transpose.
Expand Down

0 comments on commit 1aa5532

Please sign in to comment.