Skip to content

Commit

Permalink
added ket, bra, dm checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Phionx committed Jun 4, 2024
1 parent 3d35aff commit 16b900b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
13 changes: 13 additions & 0 deletions jaxquantum/core/qarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ def ptrace(self, indx):
def is_dm(self):
return self.qtype == Qtypes.oper

def to_ket(self):
return to_ket(self)


def keep_only_diag_elements(self):
return keep_only_diag_elements(self)

Expand Down Expand Up @@ -412,6 +416,15 @@ def keep_only_diag_elements(qarr: Qarray) -> Qarray:
data = jnp.diag(jnp.diag(qarr.data))
return Qarray.create(data, dims=dims)

def to_ket(qarr: Qarray) -> Qarray:
if qarr.qtype == Qtypes.ket:
return qarr
elif qarr.qtype == Qtypes.bra:
return qarr.dag()
else:
raise ValueError("Can only get ket from a ket or bra.")


# More quantum specific -----------------------------------------------------

def ptrace(qarr: Qarray, indx) -> Qarray:
Expand Down
21 changes: 19 additions & 2 deletions jaxquantum/core/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import jax.numpy as jnp
import warnings
import tqdm
import logging


from jaxquantum.core.qarray import Qarray
from jaxquantum.core.qarray import Qarray, Qtypes
from jaxquantum.core.conversions import jnps2jqts, jqts2jnps


Expand Down Expand Up @@ -134,9 +135,18 @@ def mesolve(
Returns:
list of states
"""

c_ops = c_ops or []

if len(c_ops) == 0 and ρ0.qtype != Qtypes.oper:
logging.warning(
"Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
)

ρ0 = ρ0.to_dm()
dims = ρ0.dims
ρ0 = jnp.asarray(ρ0.data) + 0.0j
c_ops = c_ops or []

c_ops = jnp.asarray([c_op.data for c_op in c_ops]) + 0.0j
H0 = jnp.asarray(H0.data) + 0.0j if H0 is not None else None
solver_options = solver_options or {}
Expand Down Expand Up @@ -191,6 +201,13 @@ def sesolve(
list of states
"""

if ψ.qtype == Qtypes.oper:
raise ValueError(
"Please use `jqt.mesolve` for initial state inputs in density matrix form."
)

ψ = ψ.to_ket()

dims = ψ.dims

ψ = jnp.asarray(ψ.data) + 0.0j
Expand Down
29 changes: 10 additions & 19 deletions tutorials/experimental/3-kerr-cat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,7 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shanj/miniconda3/envs/jax-new/lib/python3.12/site-packages/qutip/__init__.py:66: UserWarning: The new version of Cython, (>= 3.0.0) is not supported.\n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"from jax import jit, vmap, config, device_put\n",
"import jax\n",
Expand All @@ -33,13 +24,13 @@
"\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
"gpu_device = jax.devices('gpu')[0]\n",
"# gpu_device = jax.devices('gpu')[0]\n",
"cpu_device = jax.devices('cpu')[0]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -71,35 +62,35 @@
" # Simulation controls ----\n",
" ts = jnp.linspace(0,1e3,101) # [ns]\n",
" \n",
" states = jqt.sesolve(initial_state, ts, Ht=Ht) \n",
" states = jqt.mesolve(initial_state, ts, Ht=Ht, c_ops=[]) \n",
" # states = jqt.sesolve(initial_state, ts, Ht=Ht) \n",
"\n",
" return states"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/shanj/miniconda3/envs/jax-new/lib/python3.12/site-packages/equinox/_jit.py:49: UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.\n",
" out = fun(*args, **kwargs)\n"
"100% |\u001b[35m██████████\u001b[0m| [00:04<00:00, 23.98%/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.38 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
"5.09 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
]
}
],
"source": [
"with jax.default_device(cpu_device):\n",
" %timeit -n1 -r1 states = simulate(jqt.basis(20, 0))"
" %timeit -n1 -r1 states = simulate(jqt.basis(20, 0).to_dm())"
]
},
{
Expand Down Expand Up @@ -198,7 +189,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 16b900b

Please sign in to comment.