diff --git a/jaxquantum/core/conversions.py b/jaxquantum/core/conversions.py index 131ef99..ac7eb2c 100644 --- a/jaxquantum/core/conversions.py +++ b/jaxquantum/core/conversions.py @@ -2,9 +2,10 @@ Converting between different object types. """ +from numbers import Number from jax import config, Array from qutip import Qobj -from typing import Optional +from typing import Optional, Union, List import jax.numpy as jnp import numpy as np @@ -46,15 +47,37 @@ def jqt2qt(jqt_obj): return Qobj(np.array(jqt_obj.data), dims=jqt_obj.dims) -def jnp2jqt(arr: Array, dims: Optional[DIMS_TYPE] = None): + +def extract_dims(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None): + """Extract dims from a JAX array or Qarray. + + Args: + arr: JAX array or Qarray. + dims: Qarray dims. + + Returns: + Qarray dims. + """ + if isinstance(dims[0], Number): + is_op = len(arr.shape) == 2 and arr.shape[0] == arr.shape[1] + if is_op: + dims = [dims, dims] + else: + dims = [dims, [1] * len(dims)] # defaults to ket + return dims + + +def jnp2jqt(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None): """JAX array -> QuTiP state. Args: jnp_obj: JAX array. + dims: Qarray dims. Returns: QuTiP state. """ + dims = extract_dims(arr, dims) return Qarray.create(arr, dims=dims) @@ -67,6 +90,8 @@ def jnps2jqts(arrs: Array, dims: Optional[DIMS_TYPE] = None): Returns: QuTiP state. """ + + dims = extract_dims(arrs[0], dims) return [Qarray.create(arr, dims=dims) for arr in arrs] def jqts2jnps(qarrs: Qarray): diff --git a/jaxquantum/core/qarray.py b/jaxquantum/core/qarray.py index 96f119d..536b496 100644 --- a/jaxquantum/core/qarray.py +++ b/jaxquantum/core/qarray.py @@ -239,6 +239,9 @@ def dims(self): def data(self): return self._data + @property + def shaped_data(self): + return self._data.reshape(self.dims[0] + self.dims[1]) def _str_header(self): out = ", ".join([ @@ -334,7 +337,7 @@ def tensor(*args, **kwargs) -> Qarray: dims[1] += arg.dims[1] return Qarray.create(data, dims=dims) -def tr(qarr: Qarray, **kwargs) -> Qarray: +def tr(qarr: Qarray, **kwargs) -> jnp.complex128: """Full trace. Args: @@ -343,7 +346,7 @@ def tr(qarr: Qarray, **kwargs) -> Qarray: Returns: Full trace. """ - return jnp.trace(qarr.data, **kwargs) + return trace(qarr, **kwargs) def expm_data(data: Array, **kwargs) -> Array: @@ -425,13 +428,10 @@ def ptrace(qarr: Qarray, indx) -> Qarray: """ qarr = ket2dm(qarr) - rho = qarr.data + rho = qarr.shaped_data dims = qarr.dims Nq = len(dims[0]) - dims2 = jnp.concatenate(jnp.array(dims)) - - rho = rho.reshape(dims2) indxs = [indx, indx + Nq] for j in range(Nq): @@ -446,6 +446,9 @@ def ptrace(qarr: Qarray, indx) -> Qarray: return Qarray.create(rho) +def trace(qarr: Qarray, **kwargs) -> Qarray: + return jnp.trace(qarr.data, **kwargs) + def dag(qarr: Qarray) -> Qarray: """Conjugate transpose.