Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse-mode AD for SDEs #1084

Open
elgazzarr opened this issue Jul 22, 2024 · 3 comments
Open

Reverse-mode AD for SDEs #1084

elgazzarr opened this issue Jul 22, 2024 · 3 comments
Assignees
Labels

Comments

@elgazzarr
Copy link

Reverse adjoints for SDEs only works with 'TrackerAdjoint()' and only on CPU. 🐞

Training Large (e.g, Neural) SDEs on GPUs fails. The only working solution is 'TrackerAdjoint()' and this only currently works on CPU.
None of the continuous adjoints methods, e.g. 'InterpolatingAdjoint()' or 'BackwardsolveAdjoint()' work either on cpu or gpu.

  • I suspect the problem with the continuous methods is the shape of the noise during the backwards solve.
  • W.r.t. 'TrackerAdjoint()' on gpus, something is transferred to the CPU during the backwards pass. This also happens for ODEs btw.

MWE

using DifferentialEquations, Lux, ComponentArrays, Random, SciMLSensitivity, Zygote, BenchmarkTools, LuxCUDA, CUDA,
OptimizationOptimisers



dev = gpu_device()
sensealg = TrackerAdjoint()  #This works only on cpu

data = rand32(32,100,512) |> dev
x₀ = rand32(32,512) |> dev
ts = range(0.0f0, 1.0f0, length=100)
drift = Dense(32, 32, tanh)
diffusion = Scale(32, sigmoid)

basic_tgrad(u, p, t) = zero(u)

struct NeuralSDE{D, F} <: Lux.AbstractExplicitContainerLayer{(:drift, :diffusion)}
    drift::D
    diffusion::F
    solver
    tspan
    sensealg
end

function (model::NeuralSDE)(x₀, ts, p, st)
    μ(u, p, t) = model.drift(u, p.drift, st.drift)[1]
    σ(u, p, t) = model.diffusion(u, p.diffusion, st.diffusion)[1]
    func = SDEFunction{false}(μ, σ; tgrad=basic_tgrad)
    prob = SDEProblem{false}(func, x₀, model.tspan, p)
    sol = solve(prob, model.solver; saveat=ts, dt=0.01f0, sensealg = model.sensealg)
    return permutedims(cat(sol.u..., dims=3), (1,3,2))
end

function loss!(p, data)
    pred = model(x₀, ts, p, st)
    l = sum(abs2, data .- pred)
    return l, st, pred
end

rng = Random.default_rng()
model = NeuralSDE(drift, diffusion, EM(), (0.0f0, 1.0f0), sensealg)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32} |> dev


adtype = AutoZygote()
optf = OptimizationFunction((p, _ ) -> loss!(p, data), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), maxiters=10)

Error & Stacktrace

ERROR: LoadError: GPU compilation of MethodInstance for (::GPUArrays.var"#35#37")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
      .x is of type Matrix{Float32} which is not isbits.


Stacktrace:
    [1] check_invocation(job::GPUCompiler.CompilerJob)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/validation.jl:92
    [2] macro expansion
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:128 [inlined]
    [3] macro expansion
      @ ~/.julia/packages/TimerOutputs/Lw5SP/src/TimerOutput.jl:253 [inlined]
    [4] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:126
    [5] 
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:111
    [6] compile
      @ ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:103 [inlined]
    [7] #1145
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:254 [inlined]
    [8] JuliaContext(f::CUDA.var"#1145#1148"{GPUCompiler.CompilerJob{}}; kwargs::@Kwargs{})
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:52
    [9] JuliaContext(f::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/driver.jl:42
   [10] compile(job::GPUCompiler.CompilerJob)
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/compilation.jl:253
   [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:237
   [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/Y4hSX/src/execution.jl:151
   [13] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:369 [inlined]
   [14] macro expansion
      @ ./lock.jl:267 [inlined]
   [15] cufunction(f::GPUArrays.var"#35#37", tt::Type{Tuple{…}}; kwargs::@Kwargs{})
      @ CUDA ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:364
   [16] cufunction
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:361 [inlined]
   [17] macro expansion
      @ ~/.julia/packages/CUDA/Tl08O/src/compiler/execution.jl:112 [inlined]
   [18] #launch_heuristic#1204
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:17 [inlined]
   [19] launch_heuristic
      @ ~/.julia/packages/CUDA/Tl08O/src/gpuarrays.jl:15 [inlined]
   [20] _copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:78 [inlined]
   [21] copyto!
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:44 [inlined]
   [22] copy
      @ ~/.julia/packages/GPUArrays/8Y80U/src/host/broadcast.jl:29 [inlined]
   [23] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, typeof(+), Tuple{…}})
      @ Base.Broadcast ./broadcast.jl:903
   [24] accum!(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/params.jl:46
   [25] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:134
   [26] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [27] #64
      @ ./tuple.jl:628 [inlined]
   [28] BottomRF
      @ ./reduce.jl:86 [inlined]
   [29] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
   [30] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [31] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
   [32] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [33] foldl
      @ ./reduce.jl:198 [inlined]
   [34] foreach
      @ ./tuple.jl:628 [inlined]
   [35] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#583#584"{…}, Tuple{…}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [36] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 1, CUDA.DeviceMemory}}, Δ::Vector{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [37] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [38] #64
      @ ./tuple.jl:628 [inlined]
   [39] BottomRF
      @ ./reduce.jl:86 [inlined]
   [40] _foldl_impl
      @ ./reduce.jl:58 [inlined]
   [41] foldl_impl
      @ ./reduce.jl:48 [inlined]
   [42] mapfoldl_impl(f::typeof(identity), op::Base.var"#64#65"{}, nt::Nothing, itr::Base.Iterators.Zip{…})
      @ Base ./reduce.jl:44
   [43] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{…}, Tuple{…}}}; init::Nothing)
      @ Base ./reduce.jl:175
   [44] mapfoldl
      @ ./reduce.jl:175 [inlined]
   [45] foldl
      @ ./reduce.jl:198 [inlined]
   [46] foreach(::Function, ::Tuple{Tracker.Tracked{…}, Tracker.Tracked{…}}, ::Tuple{Vector{…}, Vector{…}})
      @ Base ./tuple.jl:628
   [47] back_(g::Tracker.Grads, c::Tracker.Call{Tracker.var"#552#555"{…}, Tuple{…}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
   [48] back(g::Tracker.Grads, x::Tracker.Tracked{CuArray{Float32, 2, CUDA.DeviceMemory}}, Δ::Matrix{Float32})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
   [49] #710
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:128 [inlined]
   [50] #64
      @ ./tuple.jl:628 [inlined]
   [51] BottomRF
      @ ./reduce.jl:86 [inlined]
   [52] _foldl_impl(op::Base.BottomRF{Base.var"#64#65"{…}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{…}})
      @ Base ./reduce.jl:58
--- the last 12 lines are repeated 98 more times ---
 [1229] foldl_impl
      @ ./reduce.jl:48 [inlined]
 [1230] mapfoldl_impl
      @ ./reduce.jl:44 [inlined]
 [1231] mapfoldl
      @ ./reduce.jl:175 [inlined]
 [1232] foldl
      @ ./reduce.jl:198 [inlined]
 [1233] foreach
      @ ./tuple.jl:628 [inlined]
 [1234] back_(g::Tracker.Grads, c::Tracker.Call{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:128
 [1235] back(g::Tracker.Grads, x::Tracker.Tracked{…}, Δ::RODESolution{…})
      @ Tracker ~/.julia/packages/Tracker/NYUWw/src/back.jl:140
 [1236] #712
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:155 [inlined]
 [1237] #715
      @ ~/.julia/packages/Tracker/NYUWw/src/back.jl:164 [inlined]
 [1238] (::SciMLSensitivity.var"#tracker_adjoint_backpass#368"{})(ybar::RODESolution{…})
      @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4hOeN/src/concrete_solve.jl:1319
 [1239] ZBack
      @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [1240] (::Zygote.var"#kw_zpullback#53"{})(dy::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [1241] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1242] (::Zygote.var"#2169#back#293"{})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1243] #solve#51
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
 [1244] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1245] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1246] (::Zygote.var"#2169#back#293"{})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [1247] solve
      @ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
 [1248] (::Zygote.Pullback{…})(Δ::RODESolution{…})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1249] NeuralSDE
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:31 [inlined]
 [1250] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::CuArray{Float32, 3, CUDA.DeviceMemory})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1251] loss!
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:36 [inlined]
 [1252] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1253] #39
      @ ~/code/NeuroDynamics.jl/examples/mwe.jl:48 [inlined]
 [1254] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1255] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1256] OptimizationFunction
      @ ~/.julia/packages/SciMLBase/rR75x/src/scimlfunctions.jl:3763 [inlined]
 [1257] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1258] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1259] #37
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:94 [inlined]
 [1260] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1261] #291
      @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [1262] #2169#back
      @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [1263] #39
      @ ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97 [inlined]
 [1264] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [1265] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [1266] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{…}}})
      @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [1267] (::OptimizationZygoteExt.var"#38#56"{})(::ComponentVector{…}, ::ComponentVector{…})
      @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/mGHPN/ext/OptimizationZygoteExt.jl:97
 [1268] macro expansion
      @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [1269] macro expansion
      @ ~/.julia/packages/Optimization/fPKIF/src/utils.jl:32 [inlined]
 [1270] __solve(cache::OptimizationCache{…})
      @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [1271] solve!(cache::OptimizationCache{…})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:188
 [1272] solve(::OptimizationProblem{…}, ::OptimiserChain{…}; kwargs::@Kwargs{})
      @ SciMLBase ~/.julia/packages/SciMLBase/rR75x/src/solve.jl:96
in expression starting at /home/artiintel/ahmelg/code/NeuroDynamics.jl/examples/mwe.jl:50
Some type information was truncated. Use `show(err)` to see complete types.

I am using the latest releases for the packages and Julia 1.10.4.

@ChrisRackauckas
Copy link
Member

I think this is related to the ComponentArrays thing we just found. @avik-pal is looking into it.

@elgazzarr
Copy link
Author

Hi, are there any updates on this?
Or any leads that I can follow to attempt to fix it myself?

@ChrisRackauckas
Copy link
Member

I cannot debug this until later this month since I don't have an NVIDIA GPU on me. But likely the issue is that a discrete adjoint is just better here since a continuous adjoint requires nested AD in order to handle the Ito condition. This would be much simpler in the Stratanovich sense because the adjoint rule needs to compute the derivative natively in the Stratanovich form and then convert it, which requires AD of the AD. For this reason, a lot of adjoints can have some issue with the nesting.

But there have been many advancements here. First of all, continuous nesting using Enzyme VJPs should now work for this? So that's worth a try. Secondly, discrete adjoints of non-stiff ODE solvers just landed with Enzyme directly on the solver SciML/OrdinaryDiffEq.jl#2282, and so we could do a similar thing for SDEs to support discrete Enzyme adjoints and add an EnzymeAdjoint option. That would likely do better than Tracker, though it's a bit dependent on EnzymeAD/Enzyme.jl#2077 getting solved in order for it to work on v1.11.

Or any leads that I can follow to attempt to fix it myself?

The most helpful thing would likely be to get discrete adjoints with Enzyme working in StochasticDiffEq.jl. I just opened #1148 with details of the steps there. See "Supporting EnzymeAdjoint for SDEs". If you can help on that part you'll get the gold you seek.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants