Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First attempt at Quantile probability plot #319

Merged
merged 15 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
rev: v0.1.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.10.1
hooks:
- id: black-jupyter
args:
Expand All @@ -29,7 +29,7 @@ repos:
build|
dist"""
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1 # Use the sha / tag you want to point at
rev: v1.6.1 # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
488 changes: 243 additions & 245 deletions docs/tutorials/plotting.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
xhistogram = "^0.3.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
black = { extras = ["jupyter"], version = "^23.7.0" }
mypy = "^1.4.1"
black = { extras = ["jupyter"], version = "^23.10.1" }
mypy = "^1.6.1"
pre-commit = "^2.20.0"
jupyterlab = "^4.0.2"
ipykernel = "^6.16.0"
ipywidgets = "^8.0.3"
graphviz = "^0.20.1"
ruff = "^0.0.272"
ruff = "^0.1.3"
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.17"
mkdocstrings-python = "^1.1.2"
Expand Down Expand Up @@ -159,6 +160,8 @@ ignore = [
"PLR2004",
# Consider `elif` instead of `else` then `if` to remove indentation level
"PLR5501",
# Ignore "Use `float` instead of `int | float`."
"PYI041",
# Allow importing from parent modules
"TID252",
]
Expand Down
12 changes: 9 additions & 3 deletions src/hssm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
"""

import os
from collections import namedtuple
from typing import Optional, Union
from typing import NamedTuple, Optional, Union

import pandas as pd

base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

FileMetadata = namedtuple("FileMetadata", ["filename", "path", "description"])

class FileMetadata(NamedTuple):
"""Typing for dataset metadata."""

filename: str
path: str
description: str


DATASETS = {
"cavanagh_theta": FileMetadata(
Expand Down
8 changes: 7 additions & 1 deletion src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ class DefaultConfig(TypedDict):
"approx_differentiable": {
"loglik": "ddm_sdv.onnx",
"backend": "jax",
"default_priors": {},
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
},
},
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
Expand Down
19 changes: 12 additions & 7 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from os import PathLike
from typing import Any, Callable, Iterable, Type
from typing import Any, Callable, Type

import bambi as bmb
import numpy as np
Expand Down Expand Up @@ -279,19 +279,24 @@ def rng_fn(
)
out_shape = sims_out.shape[:-1]
replace = rng.binomial(n=1, p=p_outlier, size=out_shape).astype(bool)
replace = np.stack([replace, replace], axis=-1)
n_draws = np.prod(out_shape)
replace_n = int(np.sum(replace, axis=None))
if replace_n == 0:
return sims_out
replace_shape = (*out_shape[:-1], replace_n)
replace_mask = np.stack([replace, replace], axis=-1)
n_draws = np.prod(replace_shape)
lapse_rt = pm.draw(
get_distribution_from_prior(cls._lapse).dist(**cls._lapse.args),
n_draws,
random_seed=rng,
).reshape(out_shape)
lapse_response = rng.binomial(n=1, p=0.5, size=out_shape)
).reshape(replace_shape)
lapse_response = rng.binomial(n=1, p=0.5, size=replace_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here I would go with np.random.choice(n_choices, ...) to accomodate models that have more than 2 choices.

if n_choices == 2, we change 0 --> -1, otherwise keep choices as they are.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we tab this for now since it might take a bit more engineering to figure out where that n_choices comes from. ssms would need that to figure out how many responses to generate right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. That should be part of model configs actually so should be easily accessible. (can double check but I think it is for ssm-simulators?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also outputs from ssms.simulator() have the 'metadata' key, which also has that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we do have one step to figure out number of responses from the data, but we might want the user to be a bit more explicit in non-binary cases. In any case we might need a separate PR to address this

lapse_response = np.where(lapse_response == 1, 1, -1)
lapse_output = np.stack(
[lapse_rt, lapse_response],
axis=-1,
)
np.putmask(sims_out, replace, lapse_output)
np.putmask(sims_out, replace_mask, lapse_output)

return sims_out

Expand Down Expand Up @@ -379,7 +384,7 @@ def dist(cls, **kwargs): # pylint: disable=arguments-renamed

def logp(data, *dist_params): # pylint: disable=E0213
num_params = len(list_params)
extra_fields: Iterable[np.ndarray] = []
extra_fields = []

if num_params < len(dist_params):
extra_fields = dist_params[num_params:]
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/distribution_utils/onnx/onnx2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def pt_interpret_onnx(graph, *args):
"""
vals = dict(
{n.name: a for n, a in zip(graph.input, args)},
**{n.name: _asarray(n) for n in graph.initializer}
**{n.name: _asarray(n) for n in graph.initializer},
)
for node in graph.node:
args = (vals[name] for name in node.input)
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/distribution_utils/onnx/onnx2xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def interpret_onnx(graph, *args):
"""
vals = dict(
{n.name: a for n, a in zip(graph.input, args)},
**{n.name: _asarray(n) for n in graph.initializer}
**{n.name: _asarray(n) for n in graph.initializer},
)
for node in graph.node:
args = (vals[name] for name in node.input)
Expand Down
3 changes: 2 additions & 1 deletion src/hssm/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Plotting functionalities for HSSM."""

from .posterior_predictive import plot_posterior_predictive
from .quantile_probability import plot_quantile_probability

__all__ = ["plot_posterior_predictive"]
__all__ = ["plot_posterior_predictive", "plot_quantile_probability"]
Loading
Loading