You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞
Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU.
None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.
I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
W.r.t. 'TrackerAdjoint()' on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.
MWE
using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers
dev =gpu_device()
sensealg =TrackerAdjoint() #This works only on cpu
data =rand32(32,100,512) |> dev
x₀ =rand32(32,512) |> dev
ts =range(0.0f0, 1.0f0, length=100)
drift =Dense(32, 32, tanh)
diffusion =Scale(32, sigmoid)
basic_tgrad(u, p, t) =zero(u)
struct NeuralSDE{D, F} <:Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
drift::D
diffusion::F
solver
tspan
sensealg
endfunction (model::NeuralSDE)(x₀, ts, p, st)
μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
func =SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
prob =SDEProblem{false}(func, x₀, model.tspan, p)
sol =solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
returnpermutedims(cat(sol.u..., dims=3), (1,3,2))
endfunctionloss!(p, data)
pred =model(x₀, ts, p, st)
l =sum(abs2, data .- pred)
return l, st, pred
end
rng = Random.default_rng()
model =NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev
adtype =AutoZygote()
optf =OptimizationFunction((p, _ ) ->loss!(p, data), adtype)
optproblem =OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)
I cannot debug this until later this month since I don't have an NVIDIA GPU on me. But likely the issue is that a discrete adjoint is just better here since a continuous adjoint requires nested AD in order to handle the Ito condition. This would be much simpler in the Stratanovich sense because the adjoint rule needs to compute the derivative natively in the Stratanovich form and then convert it, which requires AD of the AD. For this reason, a lot of adjoints can have some issue with the nesting.
But there have been many advancements here. First of all, continuous nesting using Enzyme VJPs should now work for this? So that's worth a try. Secondly, discrete adjoints of non-stiff ODE solvers just landed with Enzyme directly on the solver SciML/OrdinaryDiffEq.jl#2282, and so we could do a similar thing for SDEs to support discrete Enzyme adjoints and add an EnzymeAdjoint option. That would likely do better than Tracker, though it's a bit dependent on EnzymeAD/Enzyme.jl#2077 getting solved in order for it to work on v1.11.
Or any leads that I can follow to attempt to fix it myself?
The most helpful thing would likely be to get discrete adjoints with Enzyme working in StochasticDiffEq.jl. I just opened #1148 with details of the steps there. See "Supporting EnzymeAdjoint for SDEs". If you can help on that part you'll get the gold you seek.
Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞
Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU.
None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.
MWE
Error & Stacktrace
I am using the latest releases for the packages and Julia 1.10.4.
The text was updated successfully, but these errors were encountered: