Skip to content

Commit

Permalink
Merge pull request #5 from EQuS/qarray
Browse files Browse the repository at this point in the history
Adopting Qarray's!
  • Loading branch information
Phionx authored Mar 4, 2024
2 parents 8911b98 + aa4b2d9 commit 288e133
Show file tree
Hide file tree
Showing 14 changed files with 1,338 additions and 450 deletions.
2 changes: 1 addition & 1 deletion jaxquantum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os

from .utils import *
from .quantum import *
from .core import *


with open(
Expand Down
8 changes: 8 additions & 0 deletions jaxquantum/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Quantum Tooling"""

from .operators import *
from .conversions import *
from .visualization import *
from .solvers import *
from .qarray import *
from .settings import SETTINGS
81 changes: 81 additions & 0 deletions jaxquantum/core/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Converting between different object types.
"""

from jax import config, Array
from qutip import Qobj
from typing import Optional
import jax.numpy as jnp
import numpy as np


from jaxquantum.core.qarray import Qarray, DIMS_TYPE


config.update("jax_enable_x64", True)

# Convert between QuTiP and JAX
# ===============================================================
def qt2jqt(qt_obj, dtype=jnp.complex128):
"""QuTiP state -> Qarray.
Args:
qt_obj: QuTiP state.
dtype: JAX dtype.
Returns:
Qarray.
"""
if isinstance(qt_obj, Qarray) or qt_obj is None:
return qt_obj
return Qarray.create(jnp.array(qt_obj, dtype=dtype), dims=qt_obj.dims)


def jqt2qt(jqt_obj):
"""Qarray -> QuTiP state.
Args:
jqt_obj: Qarray.
dims: QuTiP dims.
Returns:
QuTiP state.
"""
if isinstance(jqt_obj, Qobj) or jqt_obj is None:
return jqt_obj

return Qobj(np.array(jqt_obj.data), dims=jqt_obj.dims)

def jnp2jqt(arr: Array, dims: Optional[DIMS_TYPE] = None):
"""JAX array -> QuTiP state.
Args:
jnp_obj: JAX array.
Returns:
QuTiP state.
"""
return Qarray.create(arr, dims=dims)


def jnps2jqts(arrs: Array, dims: Optional[DIMS_TYPE] = None):
"""JAX array -> QuTiP state.
Args:
jnp_obj: JAX array.
Returns:
QuTiP state.
"""
return [Qarray.create(arr, dims=dims) for arr in arrs]

def jqts2jnps(qarrs: Qarray):
"""QuTiP state -> JAX array.
Args:
qt_obj: QuTiP state.
Returns:
JAX array.
"""
return jnp.array([qarr.data for qarr in qarrs])
151 changes: 151 additions & 0 deletions jaxquantum/core/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
""" States. """

from jax import config

import jax.numpy as jnp
from jax.nn import one_hot

from jaxquantum.core.qarray import Qarray

config.update("jax_enable_x64", True)



def sigmax() -> Qarray:
"""σx
Returns:
σx Pauli Operator
"""
return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]]))


def sigmay() -> Qarray:
"""σy
Returns:
σy Pauli Operator
"""
return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]]))


def sigmaz() -> Qarray:
"""σz
Returns:
σz Pauli Operator
"""
return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]]))


def hadamard() -> Qarray:
"""H
Returns:
H: Hadamard gate
"""
return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2))

def sigmam() -> Qarray:
"""σ-
Returns:
σ- Pauli Operator
"""
return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]]))


def sigmap() -> Qarray:
"""σ+
Returns:
σ+ Pauli Operator
"""
return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]]))


def destroy(N) -> Qarray:
"""annihilation operator
Args:
N: Hilbert space size
Returns:
annilation operator in Hilber Space of size N
"""
return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1))


def create(N) -> Qarray:
"""creation operator
Args:
N: Hilbert space size
Returns:
creation operator in Hilber Space of size N
"""
return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1))


def num(N) -> Qarray:
"""Number operator
Args:
N: Hilbert Space size
Returns:
number operator in Hilber Space of size N
"""
return Qarray.create(jnp.diag(jnp.arange(N)))


def identity(*args, **kwargs) -> Qarray:
"""Identity matrix.
Returns:
Identity matrix.
"""
return Qarray.create(jnp.eye(*args, **kwargs))


def displace(N, α) -> Qarray:
"""Displacement operator
Args:
N: Hilbert Space Size
α: Phase space displacement
Returns:
Displace operator D(α)
"""
a = destroy(N)
return (α * a.dag() - jnp.conj(α) * a).expm()


# States ---------------------------------------------------------------------

def basis(N, k):
"""Creates a |k> (i.e. fock state) ket in a specified Hilbert Space.
Args:
N: Hilbert space dimension
k: fock number
Returns:
Fock State |k>
"""
return Qarray.create(one_hot(k, N).reshape(N, 1))


def coherent(N, α) -> Qarray:
"""Coherent state.
Args:
N: Hilbert Space Size.
α: coherent state amplitude.
Return:
Coherent state |α⟩.
"""
return displace(N, α) @ basis(N, 0)
Loading

0 comments on commit 288e133

Please sign in to comment.