-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nan gradients in analytical likelihood #468
Fix nan gradients in analytical likelihood #468
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, this is mostly about iterating conceptually, not code quality.
LOGP_LB, | ||
tt = negative_rt * epsilon + (1 - negative_rt) * rt | ||
|
||
p = pt.maximum(ftt01w(tt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quick note,
it seems like we are only passing k_terms here, not actually computing k_terms.
I think we had agreed to do that way back on another iteration of trying to fix issues with this likelihood, and I think it's fine, but in this case we should make the default a bit higher than 7.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just playing around here. Not actually changing
- (v_flipped**2 * rt / 2.0) | ||
- 2.0 * pt.log(a), | ||
- (v_flipped**2 * tt / 2.0) | ||
- 2.0 * pt.log(pt.maximum(epsilon, a)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reflecting on this a bit,
I think this maximum business is actually corrupting the gradients, so we should just a priori restrict a > epsilon
(via prior essentially?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on the other hand, apart from initialization (which 1. our strategies should already avoid, 2. we generally can impact) a
should basically never come close to 0, so this should basically never be the culprit...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this did help a bit, for some reason...
- (v_flipped**2 * rt / 2.0) | ||
- 2.0 * pt.log(a), | ||
- (v_flipped**2 * tt / 2.0) | ||
- 2.0 * pt.log(pt.maximum(epsilon, a)) | ||
) | ||
|
||
checked_logp = check_parameters(logp, a >= 0, msg="a >= 0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the spirit of above, this check could be a>0
but honestly we shouldn't really ever get there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -220,7 +199,7 @@ def logp_ddm( | |||
z: float, | |||
t: float, | |||
err: float = 1e-15, | |||
k_terms: int = 20, | |||
k_terms: int = 7, | |||
epsilon: float = 1e-15, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what was used for testing / is used as actual value for inference, but I guess it is this default?
The epsilon for the rt
part should rather be on the order of 1e-3, or even 1e-2.
If we are reusing the same epsilon in multiple places, we should probably separate it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was playing around. It seems that changing k_terms
to 7 did not improve speed or computation
src/hssm/likelihoods/analytical.py
Outdated
@@ -262,15 +241,17 @@ def logp_ddm( | |||
z_flipped = pt.switch(flip, 1 - z, z) # transform z if x is upper-bound response | |||
rt = rt - t | |||
|
|||
p = pt.maximum(ftt01w(rt, a, z_flipped, err, k_terms), pt.exp(LOGP_LB)) | |||
negative_rt = rt <= epsilon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok reflecting on this a bit, the logic that we want should probably look something like:
- flag all rts lower than epsilon
- go through with ftt01w
- then set all flagged rts to
LOGB_LB
This should actually cut the gradient for problematic rts.
Potentially we put this as a logp_ddm_2
and compare results / gradients.
Alternatively, if any rt
breaches epsilon, directly send logp to -infty (this is probably not preferable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were doing this. I think the problem is that the gradient is computed anyway and the over/underflow was still happening
@digicosmos86 is this stale for now? |
There doesn't seem to be a solution for really small RTs in the denominator, which can blow up |
ah well maybe this is related to my "RT hack" I proposed as an interim
solution for cases where *t* is low. (Which I attributed to the possibility
that the sampler would end up proposing values of t that hit the lower
bound and lead to unstable gradients, but I know that pymc is supposed to
deal with such boundaries smoothly under the hood - so maybe the issue is
just the RTs in the denominator being small). In that case the RT hack
would still work (ie under the hood befor fitting just add a constant value
to all RTs (say 0.5), which should only shift the t parameter, and then
report t_new = t - 0.5).
…On Tue, Jul 9, 2024 at 8:34 AM Paul Xu ***@***.***> wrote:
@digicosmos86 <https://github.com/digicosmos86> is this stale for now?
There doesn't seem to be a solution for really small RTs in the
denominator, which can blow up
—
Reply to this email directly, view it on GitHub
<#468 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAG7TFGTTCFH2PBQAF525KLZLPKEDAVCNFSM6AAAAABJUTJUAKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMJXGU3DGMRWHE>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
@frankmj I ran a few more tests and the RT-hack did do the trick. It might be hard for us to implement this trick in our code though, mostly because people use |
Great. maybe one simple solution would be to simply add a link function
with t= t'-const ?
…On Tue, Jul 9, 2024 at 9:47 AM Paul Xu ***@***.***> wrote:
@frankmj <https://github.com/frankmj> I ran a few more tests and the
RT-hack did do the trick. It might be hard for us to implement this trick
in our code though, mostly because people use arviz functions instead of
the convenience functions that we provide, which could give us some control
over the output. We could note this in our documentation somewhere about
this trick so that the users can implement this themselves so that they
have full control
—
Reply to this email directly, view it on GitHub
<#468 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAG7TFDEGDMAENAMCYVOMCDZLPSY5AVCNFSM6AAAAABJUTJUAKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMJXG44TKNRVGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@frankmj That's a great idea! I also noticed that the RT-hack only worked when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things here I might end up picking up myself.
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 | ||
_b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err)) | ||
_c = pt.sqrt(rt) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fundamental operation is pt.sqrt(rt)
. It's better to do this first and reuse the result to avoid computing it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For numerical stability, it's better to group the constant factor C = 2 * pt.sqrt(2 * np.pi) * err
and compare each member of sqrt_rt = pt.sqrt(rt)
against 1/C
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Feel free to change this
ks = 2 + pt.sqrt(-2 * rt * pt.log(2 * np.sqrt(2 * np.pi * rt) * err)) | ||
ks = pt.max(pt.stack([ks, pt.sqrt(rt) + 1]), axis=0) | ||
ks = pt.switch(2 * pt.sqrt(2 * np.pi * rt) * err < 1, ks, 2) | ||
_a = 2 * pt.sqrt(2 * np.pi * rt) * err < 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would a better name for this boolean array be, maybe mask
or sieve
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should pt.lt
be used here as done elsewhere in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually equivalent but I was just playing around
_b = 2 + pt.sqrt(-2 * rt * pt.log(2 * pt.sqrt(2 * np.pi * rt) * err)) | ||
_c = pt.sqrt(rt) + 1 | ||
_d = pt.max(pt.stack([_b, _c]), axis=0) | ||
ks = _a * _d + (1 - _a) * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because _a is boolean, I think it's better to treat it as such and use pt.switch
.
ks = _a * _d + (1 - _a) * 2 | |
ks = pt.switch(mask, _d, 2) # having renamed `_a` to `mask`, for example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see comment below
_b = 1.0 / (np.pi * pt.sqrt(rt)) | ||
_c = pt.sqrt(-2 * pt.log(np.pi * rt * err) / (np.pi**2 * rt)) | ||
_d = pt.max(pt.stack([_b, _c]), axis=0) | ||
kl = _a * _b + (1 - _a) * _b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_c
and _d
are not used. Should _d
be used in the second term instead of _b
? Otherwise kl
will be _b
.
kl = _a * _b + (1 - _a) * _b | |
kl = pt.switch(mask, _b, _d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see comment below
logp = pt.where( | ||
rt <= epsilon, | ||
LOGP_LB, | ||
tt = negative_rt * epsilon + (1 - negative_rt) * rt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tt = negative_rt * epsilon + (1 - negative_rt) * rt | |
tt = pt.switch(negative_rt, epsilon, rt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually is done on purpose. pt.switch
can cause some weird errors
+ ( | ||
(a * z_flipped * sv) ** 2 | ||
- 2 * a * v_flipped * z_flipped | ||
- (v_flipped**2) * rt | ||
- (v_flipped**2) * tt | ||
) | ||
/ (2 * (sv**2) * rt + 2) | ||
- 0.5 * pt.log(sv**2 * rt + 1) | ||
- 2 * pt.log(a), | ||
/ (2 * (sv**2) * tt + 2) | ||
- 0.5 * pt.log(sv**2 * tt + 1) | ||
- 2 * pt.log(pt.maximum(epsilon, a)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Evaluate separately providing a meaningful name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are probably not going to keep this one. I just tried this to see if we keep the log positive we can get somewhere. It helps a bit it seems, but the culprit is not this one
Co-authored-by: Carlos Paniagua <[email protected]>
Co-authored-by: Carlos Paniagua <[email protected]>
Co-authored-by: Carlos Paniagua <[email protected]>
Co-authored-by: Carlos Paniagua <[email protected]>
@cpaniaguam Thanks for the suggestions! I committed all excluding those involving Please feel free to take this further. This PR wasn't final - was just a placeholder for some of my experiments |
@digicosmos86 let's use this PR to switch to float64 overall? Also, the latest state of affairs with all changes in this PR is that it's still breaking right? |
You are correct. It is still broken. This PR is kind of my mess though. I'd rather start a new one and just switch out all the |
@digicosmos86 I am good with that approach. |
Since this is still in the works, I am going to convert it to a draft PR |
@digicosmos86 to be closed now that the other PR is up? |
No description provided.