Skip to content

Commit

Permalink
added warning filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Phionx committed Jun 3, 2024
1 parent e89eb77 commit 6381fa0
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions jaxquantum/core/solvers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Solvers"""

from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController
from functools import partial
from jax import jit, vmap, Array
from typing import Callable, List, Optional

from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController
import diffrax

from jax import jit, vmap, Array
import jax.numpy as jnp
import warnings




from jaxquantum.core.qarray import Qarray
Expand Down Expand Up @@ -84,18 +85,20 @@ def solve(ρ0, f, t_list, args, solver_options):
stepsize_controller = PIDController(rtol=1e-6, atol=1e-6)

# solve!
sol = diffeqsolve(
term,
solver,
t0=t_list[0],
t1=t_list[-1],
dt0=t_list[1] - t_list[0],
y0=ρ0,
saveat=saveat,
stepsize_controller=stepsize_controller,
args=args,
max_steps=solver_options.get("max_steps", 100_000),
)
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning) # NOTE: suppresses complex dtype warning in diffrax
sol = diffeqsolve(
term,
solver,
t0=t_list[0],
t1=t_list[-1],
dt0=t_list[1] - t_list[0],
y0=ρ0,
saveat=saveat,
stepsize_controller=stepsize_controller,
args=args,
max_steps=solver_options.get("max_steps", 100_000),
)

return sol

Expand Down

0 comments on commit 6381fa0

Please sign in to comment.