-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Go-Nogo and deadlines with missing data #358
Support for Go-Nogo and deadlines with missing data #358
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely taking shape! :)
hddm-wfpt = "^0.1.1" | ||
seaborn = "^0.13.0" | ||
seaborn = "^0.13.2" | ||
pytensor = "<2.17.4" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we still held back by that? just as a sidenote
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, especially in non-conda environments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok let's follow gameplan (this is the wrong place to put it but nevertheless :)):
- add the new network types
- fix initialization issues
- get to conda
- finish RL piece
- web interface
src/hssm/distribution_utils/dist.py
Outdated
def make_likelihood_callable( | ||
loglik: pytensor.graph.Op | Callable | PathLike | str, | ||
loglik_kind: Literal["analytical", "approx_differentiable", "blackbox"], | ||
backend: Literal["pytensor", "jax", "cython", "other"] | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there need to make 'cython' backend explicit at this level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want this. It is possible to implement analytical likelihoods in cython right? Our likelihood kinds might need some rethinking, since they are not mutually exclusive...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree that we may want to change the naming conventions entirely eventually.
For now, cython stuff falls under blackbox (since no gradients) and I am just questioning if it really falls under 'backend' category in any relevant sense.
It's just some pyhton callable in the end.
The VJP of the log-likelihood function computed at gz. | ||
""" | ||
_, vjp_fn = vjp(vmap_logp_no_data, *dist_params) | ||
return vjp_fn(gz)[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just as a question. Can't immediately remember why the return here skips 0
index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the Jax Doc
If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fun(*primals). If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fun.
vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same number and shapes as primals, representing the vector-Jacobian product of fun evaluated at primals.
Looks like we don't need the primals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotta love how complicated this reads...
src/hssm/hssm.py
Outdated
@@ -219,9 +217,17 @@ def __init__( | |||
link_settings: Literal["log_logit"] | None = None, | |||
prior_settings: Literal["safe"] | None = None, | |||
extra_namespace: dict[str, Any] | None = None, | |||
missing_data: bool = False, | |||
deadline: bool = False, | |||
na_value: float = -999.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe missing_data_value
here instead of na_value
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine with both. na_value
is shorter and is also a convention in pandas
. Let me know which one you prefer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have missing_data
kwarg, then probably go with missing_data_value
kwarg to make it consistent.
We can also allow both missing_data_value
and na_value
and pool internally, if you think it's worth the burden.
In general appreciate the convention argument, but locally I think going for missing_data
AND na_value
is confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now.
Couldn't find any apparent issue.
Just left a comment on some doc-string nonsense :).
No description provided.