diff --git a/src/schroedinger.jl b/src/schroedinger.jl index ede39534..63978d09 100644 --- a/src/schroedinger.jl +++ b/src/schroedinger.jl @@ -56,7 +56,13 @@ end function schroedinger_dynamic(tspan, psi0::T, H::AbstractTimeDependentOperator; kwargs...) where {B,Bp,T<:Union{AbstractOperator{B,Bp},StateVector{B}}} - schroedinger_dynamic(tspan, psi0, schroedinger_dynamic_function(H); kwargs...) + promoted_tspan, psi0 = _promote_time_and_state(psi0, H, tspan) + if promoted_tspan !== tspan # promote H + promoted_H = TimeDependentSum(H.coefficients, H.static_op.operators; init_time=first(promoted_tspan)) + return schroedinger_dynamic(promoted_tspan, psi0, schroedinger_dynamic_function(promoted_H); kwargs...) + else + return schroedinger_dynamic(promoted_tspan, psi0, schroedinger_dynamic_function(H); kwargs...) + end end """ diff --git a/test/test_ForwardDiff.jl b/test/test_ForwardDiff.jl index 01d8135c..91b889fc 100644 --- a/test/test_ForwardDiff.jl +++ b/test/test_ForwardDiff.jl @@ -50,3 +50,26 @@ for u0 = (psi, psi', psi⊗psi') # test all methods of `rebuild` end end # testset + +@testset "ForwardDiff with schroedinger using TimeDependentSum" begin + +base=SpinBasis(1/2) +ψi = spinup(base) +ψt = spindown(base) +function Ftdop(q) + H = TimeDependentSum([q, abs2∘sinpi], [sigmaz(base), sigmax(base)]) + _, ψf = timeevolution.schroedinger_dynamic(range(0,1,2), ψi, H) + abs2(ψt'last(ψf)) +end +Ftdop(1.0) +@test ForwardDiff.derivative(Ftdop, 1.0) isa Any + +function Ftdop(q) + H = TimeDependentSum([1, abs2∘sinpi], [sigmaz(base), q*sigmax(base)]) + _, ψf = timeevolution.schroedinger_dynamic(range(0,1,2), ψi, H) + abs2(ψt'last(ψf)) +end +Ftdop(1.0) +@test ForwardDiff.derivative(Ftdop, 1.0) isa Any + +end # testset