-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerical improvements to correlation bijectors #313
Conversation
`@grad_from_chainrules` can't handle multi-output functions, see JuliaDiff/ReverseDiff.jl#221. In this case it can AD through the primal just fine.
Also use consistent notation with inverse transform
This reverts commit bd6ff3d.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be really nice @sethaxen :)
I've added a few comments. It also seems as if the chain rule is somehow not type stable? 😕
@@ -268,7 +268,6 @@ end | |||
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) | |||
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix) | |||
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix) | |||
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the problem is that neither @grad
nor @grad_from_chainrules
supports multi-output functions (JuliaDiff/ReverseDiff.jl#221), so we cannot use this macro. At the same time, nothing in the function should not be AD-able by ReverseDiff, so I just removed the rule.
However, we have the same problem with Tracker. Tracker.@grad
seems to not support multi-output functions, and I'm still working out how to AD through the primal (I have an idea for a fix).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have an idea how to get this working for ReverseDiff, let me know. It would be great to use the manual pullback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, my Tracker idea did not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind, it works!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, how did you achieve it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this commit: a2eac95 . In Tracker, the cotangent of multi-output functions ends up being a TrackedTuple
, which doesn't support iteration, so instead use indexing to split the tuple. And also make pd_from_lower
and pd_from_upper
use the same tricks as cholesky_upper
and cholesky_lower
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uuuuh I didn't know that that was the reason why it was an issue! Dopey:)
Co-authored-by: Tor Erlend Fjelde <[email protected]>
Remaining errors seem to be ones introduced in #304 and unrelated to this PR |
Uhmm that's strange o.O Don't understand why this wasn't failing in the original PR. Ooor it might be because it hit the cholesky error and thus didn't run the interface tests on 1.6.. Should be a quick Compat.jl inclusion though; lemme have a check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely stuff:) Happy with merging as soon as tests pass (which should be a quick merge with master after #314 )
@@ -268,7 +268,6 @@ end | |||
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) | |||
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix) | |||
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix) | |||
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely:)
@@ -268,7 +268,6 @@ end | |||
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) | |||
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix) | |||
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix) | |||
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, how did you achieve it?
It seems #314 does not address the issue on Julia 1.6. |
I don't think #314 was completed before merging. Its CI was still failing with a similar error. |
@sethaxen I just pushed the fix directly to this branch. Let's see if CI succeeds now:) |
Damn, even this doesn't work because Really sorry about this @sethaxen ; this bug was hidden behind an unrelated numerical issue that caused this particular test to never be run on 1.6. But the cause of this shouldn't have been merged. I'll just disable those tests on this PR and then we'll have to fix it in a separate PR. |
Seems like it worked! EDIT: and, no problem, @torfjelde ! |
Lovely! Feel free to hit the big green button:) |
Sadly, I am not an "authorized user," and the button is gray. |
Want me to do it then? Also happy to let give you authorization given your involvement in Bijectors.jl if you want to:) |
Sure to both! |
Done:) Wonderful stuff @sethaxen ; thanks! |
This PR implements the numerical suggestions in #301. It does not make any of the suggested renaming changes, which will be left for a future PR.