You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I tried to train a Neural ODE with Discretecallback with sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true) on GPU, but said:
yunan-l
changed the title
Does ZygoteVJP() support training Neural ODE with Discretecallback?
Does ZygoteVJP() support training Neural ODE with Discretecallback on GPU?
Aug 23, 2024
Yes it needs to be ReverseDiffVJP or EnzymeVJP. For GPU then, it would need to be EnzymeVJP. I don't quite know how much coverage Enzyme has on GPU now but it should be getting close @wsmoses@avik-pal ? It would be good to have an example to work through with this.
There are some minor things like LuxDL/LuxLib.jl#148. And JuliaGPU/CUDA.jl#2471 needing to get implemented and merged, respectively — but otherwise for cuda things should generally work
Just as a sidenote, even once LuxDL/LuxLib.jl#148 is merged, I need LuxDL/Lux.jl#744 (at least 2-3 weeks) to be merged before the LuxLib fixes are available to end-users
Hi, I tried to train a Neural ODE with Discretecallback with
sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
on GPU, but said:so, the ZygoteVJP() dosen't support Discretecallback on GPU, right?
The text was updated successfully, but these errors were encountered: