-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Free CuArrays in the reverse pass #1340
Conversation
Could we combine this with JuliaDiff/ChainRulesCore.jl#592 ? |
It's possible. I think that means having two distinct structs, ZygoteRuleConfig and ZygoteOnceRuleConfig or something. At present, BTW, most of these |
Or introduce another type parameter like
Do you mean the |
With something like this
a big ResNet gradient [used to]
But I don't think that fits CR's mechanism; the current struct is We could also think about changing it to |
Couldn't it be done like |
Oh right, that ought to work. Current status is that some arrays are freed too early (e.g. with Metalhead's ResNet, at addact(relu)) but it's hard to isolate. Still happens if I disable all thunks. In Zygote's tests, some failures due to too-early |
This adds:
Context
to indicate that the pullback will never be called twice -- set to true forgradient
, false forjacobian
y=f(x)
in the forward pass hasfinalize(y)
in the reverse. This increases the largest size of Flux model which can run on a given GPU.Applying such modifications everywhere led to many errors, some from rules like
y = x .+ false
which returny === x
under Zygote. So they now require a separate macro@adjoint_final
.At present this modification is applied to all CR
rrule
s. This is probably unsafe and we should revert 2524163 . Unclear how best to opt-in within ChainRules. Xref JuliaDiff/ChainRulesCore.jl#592 about the idea of a flag, but not entirely sure that's the right approach.Explicit finalising won't work well with thunks. Which doesn't matter at all yet, but might after #966.
It also does not work with second derivatives, hence is disabled. Other uses of the context flag (like testing
only_once(cfg)
& then over-writing some array) probably also need to be disabled.Needs FluxML/ZygoteRules.jl#23 so CI will fail. Locally, one failure, one failure to fail: