Skip to content

Commit

Permalink
added propagator support
Browse files Browse the repository at this point in the history
  • Loading branch information
Phionx committed Jul 30, 2024
1 parent 831d97e commit f274a23
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 33 deletions.
22 changes: 21 additions & 1 deletion jaxquantum/core/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np


from jaxquantum.core.qarray import Qarray, DIMS_TYPE
from jaxquantum.core.qarray import Qarray, DIMS_TYPE, Qtypes


config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -48,6 +48,26 @@ def jqt2qt(jqt_obj):
return Qobj(np.array(jqt_obj.data), dims=jqt_obj.dims)


def op2jqts(op: Qarray, cols=True):
"""QuTiP operator -> JAX array.
Args:
op: QuTiP operator.
Returns:
JAX array.
"""
if op.qtype != Qtypes.oper:
raise ValueError("Input must be a QuTiP operator.")

space_dims = op.space_dims
ones = [1] * len(space_dims)

if cols:
return [Qarray.create(op.data[i,:], dims=[space_dims, ones]) for i in range(op.data.shape[0])]
else:
return [Qarray.create(op.data[:,i][jnp.newaxis,...], dims=[ones, space_dims]) for i in range(op.data.shape[1])]

def extract_dims(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
"""Extract dims from a JAX array or Qarray.
Expand Down
14 changes: 14 additions & 0 deletions jaxquantum/core/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" States. """

from jax import config
from math import prod

import jax.numpy as jnp
from jax.nn import one_hot
Expand Down Expand Up @@ -108,6 +109,19 @@ def identity(*args, **kwargs) -> Qarray:
"""
return Qarray.create(jnp.eye(*args, **kwargs))

def identity_like(A) -> Qarray:
"""Identity matrix with the same shape as A.
Args:
A: Matrix.
Returns:
Identity matrix with the same shape as A.
"""
space_dims = A.space_dims
total_dim = prod(space_dims)
return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims])


def displace(N, α) -> Qarray:
"""Displacement operator
Expand Down
10 changes: 10 additions & 0 deletions jaxquantum/core/qarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ def dtype(self):
@property
def dims(self):
return self._qdims.dims

@property
def space_dims(self):
if self.qtype in [Qtypes.oper, Qtypes.ket]:
return self.dims[0]
elif self.qtype == Qtypes.bra:
return self.dims[1]
else:
raise ValueError("Unsupported qtype.")


@property
def data(self):
Expand Down
193 changes: 169 additions & 24 deletions jaxquantum/core/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, TqdmProgressMeter, NoProgressMeter
from functools import partial
from flax import struct
from jax import jit, vmap, Array
from typing import Callable, List, Optional, Dict
from jax import vmap, Array
from typing import Callable, List, Optional, Dict, Union
import diffrax
import jax.numpy as jnp
import jax.scipy as jsp
import warnings
import tqdm
import logging



from jaxquantum.core.qarray import Qarray, Qtypes
from jaxquantum.core.conversions import jnps2jqts, jqts2jnps

from jaxquantum.core.conversions import jnps2jqts, jqts2jnps, jnp2jqt
from jaxquantum.utils.utils import robust_isscalar


# ----
Expand All @@ -33,7 +34,6 @@ def create(cls, progress_meter: bool = True, solver: str = "Tsit5", max_steps: i
return cls(progress_meter, solver, max_steps)


@jit
def calc_expect(op: Qarray, states: List[Qarray]) -> Array:
"""Calculate expectation value of an operator given a list of states.
Expand Down Expand Up @@ -126,10 +126,6 @@ def solve(ρ0, f, t_list, args, solver_options: Optional[SolverOptions] = None):

return sol

@partial(
jit,
static_argnums=(4,5),
)
def mesolve(
ρ0: Qarray,
t_list: Array,
Expand Down Expand Up @@ -161,10 +157,50 @@ def mesolve(

ρ0 = ρ0.to_dm()
dims = ρ0.dims
ρ0 = jnp.asarray(ρ0.data) + 0.0j
ρ0 = ρ0.data

c_ops = [c_op.data for c_op in c_ops]
H0 = jnp.asarray(H0.data) if H0 is not None else None
Ht_data = lambda t: Ht(t).data if Ht is not None else None

ys = mesolve_data(ρ0, t_list, c_ops, H0, Ht_data, solver_options)

c_ops = jnp.asarray([c_op.data for c_op in c_ops]) + 0.0j
H0 = jnp.asarray(H0.data) + 0.0j if H0 is not None else None
return jnps2jqts(ys, dims=dims)


def mesolve_data(
ρ0: Array,
t_list: Array,
c_ops: Optional[List[Array]] = None,
H0: Optional[Array] = None,
Ht: Optional[Callable[[float], Array]] = None,
solver_options: Optional[SolverOptions] = None
):
"""Quantum Master Equation solver.
Args:
ρ0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
t_list: time list
c_ops: list of collapse operators
H0: time independent Hamiltonian. If H0 is not None, it will override Ht.
Ht: time dependent Hamiltonian function.
solver_options: SolverOptions with solver options
Returns:
list of states
"""

c_ops = c_ops or []

if len(c_ops) == 0:
logging.warning(
"Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
)

ρ0 = ρ0 + 0.0j

c_ops = jnp.asarray([c_op for c_op in c_ops]) + 0.0j
H0 = H0 + 0.0j if H0 is not None else None

def f(
t: float,
Expand All @@ -177,7 +213,7 @@ def f(
if H0_val is not None:
H = H0_val # use H0 if given
else:
H = Ht(t).data # type: ignore
H = Ht(t) # type: ignore
H = H + 0.0j

rho_dot = -1j * (H @ rho - rho @ H)
Expand All @@ -190,12 +226,8 @@ def f(

sol = solve(ρ0, f, t_list, [H0, c_ops], solver_options=solver_options)

return jnps2jqts(sol.ys, dims=dims)
return sol.ys

@partial(
jit,
static_argnums=(3,4),
)
def sesolve(
ψ: Qarray,
t_list: Array,
Expand Down Expand Up @@ -225,10 +257,37 @@ def sesolve(

dims = ψ.dims

ψ = jnp.asarray(ψ.data) + 0.0j
H0 = jnp.asarray(H0.data) + 0.0j if H0 is not None else None
solver_options = solver_options or {}
ψ = ψ.data
H0 = H0.data if H0 is not None else None
Ht_data = lambda t: Ht(t).data if Ht is not None else None

ys = sesolve_data(ψ, t_list, H0, Ht_data, solver_options)

return jnps2jqts(ys, dims=dims)

def sesolve_data(
ψ: Array,
t_list: Array,
H0: Optional[Array] = None,
Ht: Optional[Callable[[float], Array]] = None,
solver_options: Optional[SolverOptions] = None,
):
"""Schrödinger Equation solver.
Args:
ψ: initial statevector
t_list: time list
H0: time independent Hamiltonian. If H0 is not None, it will override Ht.
Ht: time dependent Hamiltonian function.
solver_options: SolverOptions with solver options
Returns:
list of states
"""

ψ = ψ + 0.0j
H0 = H0 + 0.0j if H0 is not None else None
solver_options = solver_options or {}

def f(
t: float,
Expand All @@ -240,7 +299,7 @@ def f(
if H0_val is not None:
H = H0_val # use H0 if given
else:
H = Ht(t).data # type: ignore
H = Ht(t) # type: ignore
# print("H", H.shape)
# print("psit", ψₜ.shape)
ψₜ_dot = -1j * (H @ ψₜ)
Expand All @@ -249,8 +308,94 @@ def f(


sol = solve(ψ, f, t_list, [H0], solver_options=solver_options)
return sol.ys

# ----


# propagators
# ----


def propagator(
H: Union[Qarray, Callable[[float], Qarray]],
t: Union[float, Array],
solver_options=None
):
""" Generate the propagator for a time dependent Hamiltonian.
Args:
H (Qarray or callable):
A Qarray static Hamiltonian OR
a function that takes a time argument and returns a Hamiltonian.
ts (float or Array):
A single time point or
an Array of time points.
Returns:
Qarray or List[Qarray]:
The propagator for the Hamiltonian at time t.
OR a list of propagators for the Hamiltonian at each time in t.
"""

t_is_scalar = robust_isscalar(t)

if isinstance(H, Qarray):
dims = H.dims
if t_is_scalar:
return jnp2jqt(propagator_0_data(H.data,t), dims=dims)
else:
f = lambda t: propagator_0_data(H.data,t)
return jnps2jqts(vmap(f)(t), dims)
else:
dims = H(0.0).dims
H_data = lambda t: H(t).data
if t_is_scalar:
return jnp2jqt(
propagator_t_data(H_data, t, solver_options=solver_options),
dims=dims
)
else:
f = lambda t: propagator_t_data(H_data, t, solver_options=solver_options)
return jnps2jqts(vmap(f)(t), dims)

def propagator_0_data(
H0: Array,
t: float
):
""" Generate the propagator for a time independent Hamiltonian.
Args:
H0 (Qarray): The Hamiltonian.
return jnps2jqts(sol.ys, dims=dims)
Returns:
Qarray: The propagator for the time independent Hamiltonian.
"""
return jsp.linalg.expm(-1j * H0 * t)

def propagator_t_data(
Ht: Callable[[float], Array],
t: float,
solver_options=None
):
""" Generate the propagator for a time dependent Hamiltonian.
Args:
t (float): The final time of the propagator.
Warning: Do not send in t. In this case, just do exp(-1j*Ht(0.0)).
Ht (callable): A function that takes a time argument and returns a Hamiltonian.
solver_options (dict): Options to pass to the solver.
# ----
Returns:
Qarray: The propagator for the time dependent Hamiltonian for the time range [0, t_final].
"""
ts = jnp.linspace(0,t,2)
N = Ht(0).shape[0]
basis_states = jnp.eye(N)

def propogate_state(initial_state):
return sesolve_data(initial_state, ts, Ht=Ht, solver_options=solver_options)[1]

U_prop = vmap(propogate_state)(basis_states)
return U_prop
6 changes: 6 additions & 0 deletions jaxquantum/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,9 @@ def conj_transpose_iso_matrix(A):
Ar = A[:N//2,:N//2].T
Ai = A[N//2:,:N//2].T
return jnp.block([[Ar, Ai],[-Ai,Ar]])

def robust_isscalar(val):
is_scalar = isinstance(val, Number) or jnp.isscalar(val)
if isinstance(val, jnp.ndarray):
is_scalar = len(val.shape) == 0
return is_scalar
14 changes: 6 additions & 8 deletions tutorials/1-single-qubit-rabi-qarray.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit f274a23

Please sign in to comment.