Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Go-Nogo and deadlines with missing data #358

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
95ba1ce
added procedures to process data according to missing_data and deadli…
digicosmos86 Feb 20, 2024
3dcd497
update github actions, add support for Python 3.12
digicosmos86 Feb 20, 2024
422dc88
revert back python 3.12 support
digicosmos86 Feb 20, 2024
7bdb137
revert back python 3.12 support
digicosmos86 Feb 20, 2024
b359241
limit version of jax
digicosmos86 Feb 20, 2024
cbe62a5
limit version of jax
digicosmos86 Feb 20, 2024
3895c6b
list onnx-runtime as a dev dependency
digicosmos86 Feb 20, 2024
c69a5a0
Update github action
digicosmos86 Feb 21, 2024
54a99e9
update config and default to make room for custom response variables
digicosmos86 Feb 21, 2024
4165eac
update tests for configs
digicosmos86 Feb 21, 2024
d0cae6f
update files to support custom response variables
digicosmos86 Feb 21, 2024
b63c14d
finalize support for custom response variables
digicosmos86 Feb 21, 2024
9007918
update dependencies
digicosmos86 Feb 21, 2024
4e1246e
refactor make_model_distribution
digicosmos86 Feb 21, 2024
1a6a17d
add back in accidentally commented out code
digicosmos86 Feb 22, 2024
80a202d
fixed errors in graphing
digicosmos86 Feb 22, 2024
854c1dd
correct Enum assignment
digicosmos86 Feb 26, 2024
427c519
better handling of data passed in
digicosmos86 Feb 29, 2024
1cf5a60
add more supported onnx ops
digicosmos86 Feb 29, 2024
a4979b9
better support for corner cases in likelihood makers
digicosmos86 Feb 29, 2024
377f60b
add test onnx files
digicosmos86 Feb 29, 2024
0e960b2
assemble functions for dual networks
digicosmos86 Feb 29, 2024
4ff3ad3
tests for dual networks
digicosmos86 Feb 29, 2024
70b1aca
set decimal to 4 in assert_array_almost_equal
digicosmos86 Feb 29, 2024
accc57b
Add missing data networks to hssm
digicosmos86 Feb 29, 2024
0b97363
Update tests so they pass
digicosmos86 Feb 29, 2024
bd6bc02
fixed one bug
digicosmos86 Feb 29, 2024
ec54df0
remove one debug line
digicosmos86 Mar 4, 2024
a7305e5
Remove two debug prints
digicosmos86 Mar 4, 2024
b4db5f6
Simplify function and class creation to make it DRY
digicosmos86 Mar 4, 2024
8dcb771
Update tests
digicosmos86 Mar 4, 2024
7528bae
add pytest-xdist to enable testing cases in parallel
digicosmos86 Mar 5, 2024
f7e5634
update action configs to enable parallel testing
digicosmos86 Mar 5, 2024
d158733
Rearrange datasets with missing to enable computation
digicosmos86 Mar 5, 2024
d53c2e6
Fix bugs in distribution_utils
digicosmos86 Mar 5, 2024
dbdb21c
Update tests
digicosmos86 Mar 5, 2024
c6f4bc0
Clean up the code to make parameters more consistent
digicosmos86 Mar 6, 2024
07988a6
Update tests
digicosmos86 Mar 6, 2024
cb47100
update documentation for using HSSM with PyMC
digicosmos86 Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/hssm/distribution_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

from ..utils import download_hf
from .dist import (
assemble_callables,
make_blackbox_op,
make_distribution,
make_family,
make_likelihood_callable,
make_missing_data_callable,
make_ssm_rv,
)

__all__ = [
"assemble_callables",
"download_hf",
"make_blackbox_op",
"make_distribution",
"make_likelihood_callable",
"make_missing_data_callable",
"make_family",
"make_ssm_rv",
]
82 changes: 76 additions & 6 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,6 @@ def make_likelihood_callable(
+ "but did not provide `pytensor` or `jax` as backend."
)

if params_is_reg is None:
raise ValueError(
"You set `loglik_kind` to `approx_differentiable` "
+ "but did not provide `params_is_reg`."
)

if isinstance(loglik, (str, PathLike)):
if not Path(loglik).exists():
loglik = download_hf(str(loglik))
Expand All @@ -582,6 +576,11 @@ def make_likelihood_callable(
lan_logp_pt = make_pytensor_logp(onnx_model, data_dim)
return lan_logp_pt
if backend == "jax":
if params_is_reg is None:
raise ValueError(
"You set `loglik_kind` to `approx_differentiable` "
+ "and `backend` to `jax` but did not provide `params_is_reg`."
)
logp, logp_grad, logp_nojit = make_jax_logp_funcs_from_onnx(
onnx_model,
params_is_reg,
Expand All @@ -594,3 +593,74 @@ def make_likelihood_callable(
return lan_logp_jax

raise ValueError("Incorrect likelihood specification.")


def make_missing_data_callable(
loglik: pytensor.graph.Op | Callable | PathLike | str,
is_cpn_only: bool,
backend: Literal["pytensor", "jax", "cython", "other"] | None = "jax",
params_is_reg: list[bool] | None = None,
) -> pytensor.graph.Op | Callable:
"""Make a secondary network for the likelihood function.

Please refer to the documentation of `make_likelihood_callable` for more.
"""
return make_likelihood_callable(
loglik, "approx_differentiable", backend, params_is_reg, 0 if is_cpn_only else 1
) # Just assume that the missing data network is always approx_differentiable


def assemble_callables(
callable: pytensor.graph.Op | Callable,
missing_data_callable: pytensor.graph.Op | Callable,
is_cpn_only: bool,
) -> Callable:
"""Assemble the likelihood callables into a single callable.
digicosmos86 marked this conversation as resolved.
Show resolved Hide resolved

Assembles the likelihood callables into a single callable.

Parameters
----------
callable
The callable for the likelihood function.
missing_data_callable
The callable for the secondary network for the likelihood function.
is_cpn_only
Whether the missing data model is a CPN only model, in which case we do not
apply any data to the missing data model.
"""

def likelihood_callable(data, *dist_params):
"""Compute the log-likelihoood of the model."""
# Assuming the first column of the data is always rt
data = pt.as_tensor_variable(data)
dist_params = [pt.as_tensor_variable(param) for param in dist_params]

missing_mask = pt.eq(data[:, 0], -999.0)
observed_mask = pt.bitwise_not(missing_mask)

observed_data = data[observed_mask, :]

dist_params_observed = [
param if param.ndim == 0 else param[observed_mask] for param in dist_params
]

logp_observed = callable(observed_data[:, :-1], *dist_params_observed)

dist_params_missing = [
param if param.ndim == 0 else param[missing_mask] for param in dist_params
]

if is_cpn_only:
logp_missing = missing_data_callable(*dist_params_missing)
else:
missing_data = data[missing_mask, -1:]
logp_missing = missing_data_callable(missing_data, *dist_params_missing)

logp = pt.empty_like(missing_mask, dtype=pytensor.config.floatX)
logp = pt.set_subtensor(logp[observed_mask], logp_observed)
logp = pt.set_subtensor(logp[missing_mask], logp_missing)

return logp

return likelihood_callable
96 changes: 63 additions & 33 deletions src/hssm/distribution_utils/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import onnx
import pytensor
import pytensor.tensor as pt
from jax import jit, vjp, vmap
from jax import grad, jit, vjp, vmap
from numpy.typing import ArrayLike
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify
Expand Down Expand Up @@ -76,38 +76,44 @@ def logp_no_data(*dist_params: float) -> float:
# Makes a matrix to feed to the LAN model
input_vector = jnp.array(dist_params)

return interpret_onnx(loaded_model.graph, input_vector)[0].squeeze()
result = interpret_onnx(loaded_model.graph, input_vector)[0]
return result.squeeze() if any(params_is_reg) else result

# The vectorization of the logp function
vmap_logp_no_data = vmap(
logp_no_data,
in_axes=[0 if is_regression else None for is_regression in params_is_reg],
)

def vjp_vmap_logp_no_data(
*dist_params: list[float | ArrayLike], gz: ArrayLike
) -> list[ArrayLike]:
"""Compute the VJP of the log-likelihood function.

Parameters
----------
data
A two-column numpy array with response time and response.
dist_params
A list of parameters used in the likelihood computation.
gz
The value of vmap_logp at which the VJP is evaluated, typically is just
vmap_logp(data, *dist_params)

Returns
-------
list[ArrayLike]
The VJP of the log-likelihood function computed at gz.
"""
_, vjp_fn = vjp(vmap_logp_no_data, *dist_params)
return vjp_fn(gz)[1:]
if any(params_is_reg):
vmap_logp_no_data = vmap(
logp_no_data,
in_axes=[
0 if is_regression else None for is_regression in params_is_reg
],
)

return jit(vmap_logp_no_data), jit(vjp_vmap_logp_no_data), vmap_logp_no_data
def vjp_vmap_logp_no_data(
*dist_params: list[float | ArrayLike], gz: ArrayLike
) -> list[ArrayLike]:
"""Compute the VJP of the log-likelihood function.

Parameters
----------
data
A two-column numpy array with response time and response.
dist_params
A list of parameters used in the likelihood computation.
gz
The value of vmap_logp at which the VJP is evaluated, typically is
just vmap_logp(data, *dist_params)

Returns
-------
list[ArrayLike]
The VJP of the log-likelihood function computed at gz.
"""
_, vjp_fn = vjp(vmap_logp_no_data, *dist_params)
return vjp_fn(gz)[1:]

return jit(vmap_logp_no_data), jit(vjp_vmap_logp_no_data), vmap_logp_no_data

return jit(logp_no_data), jit(grad(logp_no_data)), logp_no_data

def logp(data: np.ndarray, *dist_params: float) -> float:
"""Compute the log-likelihood.
Expand Down Expand Up @@ -383,7 +389,7 @@ def grad(self, inputs, output_gradients):
inputs
The same as the inputs produced in `make_node`.
output_gradients
Holds the results of the perform `perform` method.
Holds the results of the `perform` method.

Notes
-----
Expand Down Expand Up @@ -475,10 +481,34 @@ def make_pytensor_logp(
onnx.load(str(model)) if isinstance(model, (str, PathLike)) else model
)

if data_dim == 0:

def logp_no_data(*dist_params: list[float | ArrayLike]) -> ArrayLike:
# Specify input layer of MLP
dist_params_tensors = [
pt.as_tensor_variable(param) for param in dist_params # type: ignore
]
n_rows = pt.max(
[
1 if param.ndim == 0 else param.shape[0] # type: ignore
for param in dist_params_tensors
]
)
inputs = pt.empty((n_rows, len(dist_params)))
for i, dist_param in enumerate(dist_params):
inputs = pt.set_subtensor(
inputs[:, i],
dist_param,
)

# Returns elementwise log-likelihoods
return pt.squeeze(pt_interpret_onnx(loaded_model.graph, inputs)[0])

return logp_no_data

def logp(data: np.ndarray, *dist_params: list[float | ArrayLike]) -> ArrayLike:
# Specify input layer of MLP
data = data.reshape((-1, data_dim)) if data_dim > 1 else data
inputs = pt.zeros((data.shape[0], len(dist_params) + data_dim))
inputs = pt.empty((data.shape[0], (len(dist_params) + data_dim)))
for i, dist_param in enumerate(dist_params):
inputs = pt.set_subtensor(
inputs[:, i],
Expand Down
10 changes: 9 additions & 1 deletion src/hssm/distribution_utils/onnx/onnx2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from .onnx2xla import _asarray, attribute_handlers


def onnx_add(a, b, axis=None, broadcast=True):
"""Numpy-backed implementation of ONNX Add op."""
return [pt.add(a, b)]


def pytensor_gemm(
a, b, c=0.0, alpha=1.0, beta=1.0, transA=0, transB=0
): # pylint: disable=C0103
Expand All @@ -26,13 +31,16 @@ def pytensor_gemm(


pt_onnx_ops = {
"Add": pt.add,
"Add": lambda a, b: onnx_add(a, b),
"Constant": lambda value: [value],
"MatMul": lambda x, y: [pt.dot(x, y)],
"Relu": lambda x: [pt.math.max(x, 0)],
"Reshape": lambda x, shape: [pt.reshape(x, shape)],
"Tanh": lambda x: [pt.tanh(x)],
"Gemm": pytensor_gemm,
"Neg": lambda x: [-x],
"Exp": lambda x: [pt.exp(x)],
"Log": lambda x: [pt.log(x)],
}


Expand Down
10 changes: 3 additions & 7 deletions src/hssm/distribution_utils/onnx/onnx2xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"""

import jax.numpy as jnp
import numpy as np
import onnx
from jax import lax
from onnx import numpy_helper
Expand Down Expand Up @@ -99,12 +98,6 @@ def onnx_conv(

def onnx_add(a, b, axis=None, broadcast=True):
"""Numpy-backed implementation of ONNX Add op."""
if broadcast:
axis = (a.dim - b.ndim) if axis is None else axis % a.ndim
assert a.shape[axis:][: b.ndim] == b.shape
b_shape = np.ones(a.ndim, dtype="int64")
b_shape[axis : axis + b.ndim] = b.shape
b = jnp.reshape(b, b_shape)
return [a + b]


Expand Down Expand Up @@ -132,6 +125,9 @@ def onnx_gemm(
# Added by HSSM developers
"Tanh": lambda x: [jnp.tanh(x)],
"Gemm": onnx_gemm,
"Neg": lambda x: [-x],
"Exp": lambda x: [jnp.exp(x)],
"Log": lambda x: [jnp.log(x)],
}


Expand Down
Loading
Loading