Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Go-Nogo and deadlines with missing data #358

Merged

Conversation

digicosmos86
Copy link
Collaborator

No description provided.

@digicosmos86 digicosmos86 self-assigned this Feb 20, 2024
@digicosmos86 digicosmos86 linked an issue Feb 20, 2024 that may be closed by this pull request
Copy link
Collaborator

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely taking shape! :)

hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
seaborn = "^0.13.2"
pytensor = "<2.17.4"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we still held back by that? just as a sidenote

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, especially in non-conda environments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's follow gameplan (this is the wrong place to put it but nevertheless :)):

  1. add the new network types
  2. fix initialization issues
  3. get to conda
  4. finish RL piece
  5. web interface

src/hssm/config.py Show resolved Hide resolved
src/hssm/defaults.py Outdated Show resolved Hide resolved
src/hssm/distribution_utils/blackbox.py Outdated Show resolved Hide resolved
def make_likelihood_callable(
loglik: pytensor.graph.Op | Callable | PathLike | str,
loglik_kind: Literal["analytical", "approx_differentiable", "blackbox"],
backend: Literal["pytensor", "jax", "cython", "other"] | None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there need to make 'cython' backend explicit at this level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want this. It is possible to implement analytical likelihoods in cython right? Our likelihood kinds might need some rethinking, since they are not mutually exclusive...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that we may want to change the naming conventions entirely eventually.
For now, cython stuff falls under blackbox (since no gradients) and I am just questioning if it really falls under 'backend' category in any relevant sense.
It's just some pyhton callable in the end.

The VJP of the log-likelihood function computed at gz.
"""
_, vjp_fn = vjp(vmap_logp_no_data, *dist_params)
return vjp_fn(gz)[1:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just as a question. Can't immediately remember why the return here skips 0 index.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the Jax Doc

If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fun(*primals). If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fun.

vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same number and shapes as primals, representing the vector-Jacobian product of fun evaluated at primals.

Looks like we don't need the primals

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotta love how complicated this reads...

src/hssm/distribution_utils/onnx/onnx.py Outdated Show resolved Hide resolved
src/hssm/distribution_utils/onnx/onnx.py Outdated Show resolved Hide resolved
src/hssm/hssm.py Outdated
@@ -219,9 +217,17 @@ def __init__(
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = None,
extra_namespace: dict[str, Any] | None = None,
missing_data: bool = False,
deadline: bool = False,
na_value: float = -999.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe missing_data_value here instead of na_value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with both. na_value is shorter and is also a convention in pandas. Let me know which one you prefer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have missing_data kwarg, then probably go with missing_data_value kwarg to make it consistent.

We can also allow both missing_data_value and na_value and pool internally, if you think it's worth the burden.

In general appreciate the convention argument, but locally I think going for missing_data AND na_value is confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Copy link
Collaborator

@AlexanderFengler AlexanderFengler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now.
Couldn't find any apparent issue.
Just left a comment on some doc-string nonsense :).

src/hssm/distribution_utils/dist.py Show resolved Hide resolved
@digicosmos86 digicosmos86 merged commit f2bfa62 into main Mar 7, 2024
0 of 2 checks passed
@digicosmos86 digicosmos86 deleted the 357-support-for-networks-built-in-jax-and-multiple-networks branch March 7, 2024 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for networks built in JAX and multiple networks
2 participants