Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does ZygoteVJP() support training Neural ODE with Discretecallback on GPU? #1093

Open
yunan-l opened this issue Aug 23, 2024 · 3 comments
Open
Labels

Comments

@yunan-l
Copy link

yunan-l commented Aug 23, 2024

Hi, I tried to train a Neural ODE with Discretecallback with sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true) on GPU, but said:

Only `ReverseDiffVJP` and `EnzymeVJP` are currently compatible with continuous adjoint sensitivity methods for hybrid DEs. Please select `ReverseDiffVJP` or `EnzymeVJP` as `autojacvec`.

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _setup_reverse_callbacks(cb::DiscreteCallback{DiffEqCallbacks.var"#109#113"{Vector{Float32}}, SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#111#115"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#112#116"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}, affect::SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#111#115"{typeof(affect!)}, Nothing, Int64}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, dgdu::Function, dgdp::Nothing, loss_ref::Base.RefValue{Int64}, terminated::Bool)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/PstNN/src/callback_tracking.jl:244
  [3] _setup_reverse_callbacks(cb::DiscreteCallback{DiffEqCallbacks.var"#109#113"{Vector{Float32}}, SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#111#115"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#112#116"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, dgdu::Function, dgdp::Nothing, loss_ref::Base.RefValue{Int64}, terminated::Bool)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/PstNN/src/callback_tracking.jl:219

so, the ZygoteVJP() dosen't support Discretecallback on GPU, right?

@yunan-l yunan-l changed the title Does ZygoteVJP() support training Neural ODE with Discretecallback? Does ZygoteVJP() support training Neural ODE with Discretecallback on GPU? Aug 23, 2024
@ChrisRackauckas
Copy link
Member

Yes it needs to be ReverseDiffVJP or EnzymeVJP. For GPU then, it would need to be EnzymeVJP. I don't quite know how much coverage Enzyme has on GPU now but it should be getting close @wsmoses @avik-pal ? It would be good to have an example to work through with this.

@wsmoses
Copy link

wsmoses commented Sep 2, 2024

There are some minor things like LuxDL/LuxLib.jl#148. And JuliaGPU/CUDA.jl#2471 needing to get implemented and merged, respectively — but otherwise for cuda things should generally work

@avik-pal
Copy link
Member

avik-pal commented Sep 2, 2024

Just as a sidenote, even once LuxDL/LuxLib.jl#148 is merged, I need LuxDL/Lux.jl#744 (at least 2-3 weeks) to be merged before the LuxLib fixes are available to end-users

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants