diff --git a/jaxquantum/core/solvers.py b/jaxquantum/core/solvers.py index 424138f..2cdb69c 100644 --- a/jaxquantum/core/solvers.py +++ b/jaxquantum/core/solvers.py @@ -1,13 +1,14 @@ """Solvers""" +from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController from functools import partial +from jax import jit, vmap, Array from typing import Callable, List, Optional - -from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController import diffrax - -from jax import jit, vmap, Array import jax.numpy as jnp +import warnings + + from jaxquantum.core.qarray import Qarray @@ -84,18 +85,20 @@ def solve(ρ0, f, t_list, args, solver_options): stepsize_controller = PIDController(rtol=1e-6, atol=1e-6) # solve! - sol = diffeqsolve( - term, - solver, - t0=t_list[0], - t1=t_list[-1], - dt0=t_list[1] - t_list[0], - y0=ρ0, - saveat=saveat, - stepsize_controller=stepsize_controller, - args=args, - max_steps=solver_options.get("max_steps", 100_000), - ) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) # NOTE: suppresses complex dtype warning in diffrax + sol = diffeqsolve( + term, + solver, + t0=t_list[0], + t1=t_list[-1], + dt0=t_list[1] - t_list[0], + y0=ρ0, + saveat=saveat, + stepsize_controller=stepsize_controller, + args=args, + max_steps=solver_options.get("max_steps", 100_000), + ) return sol