Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prior sensitivity check #25

Merged
merged 14 commits into from
May 27, 2024
5 changes: 1 addition & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -30,9 +30,6 @@ jobs:
- name: Check code with flake8
run: |
flake8
- name: Check docstring coverage with interrogate
run: |
interrogate
- name: Check whether black has been used
run: |
black --check tests
Expand Down
4 changes: 0 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,4 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
- repo: https://github.com/econchick/interrogate
rev: 1.5.0
hooks:
- id: interrogate

44 changes: 0 additions & 44 deletions CONTRIBUTING.md

This file was deleted.

1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ test:
flake8
pytype labelshift
pytype tests
interrogate
pytest

install:
Expand Down
40 changes: 20 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
# Label Shift
[![Project Status: Concept – Minimal or no implementation has been done yet, or the repository is only intended to be a limited example, demo, or proof-of-concept.](https://www.repostatus.org/badges/latest/concept.svg)](https://www.repostatus.org/#concept)
[![Venue](https://img.shields.io/badge/venue-TMLR_2024-darkblue)](https://openreview.net/forum?id=Ft4kHrOawZ)
[![build](https://github.com/pawel-czyz/labelshift/actions/workflows/build.yml/badge.svg)](https://github.com/pawel-czyz/labelshift/actions/workflows/build.yml)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

Python library for *quantification* (estimating the class prevalence in an unlabeled data set) under the prior probability shift assumption.
# Bayesian quantification with black-box estimators

This module is created with two purposes in mind:
- easily apply state-of-the-art quantification algorithms to the real problems,
- benchmark novel quantification algorithms against others.
*Quantification* is the problem of estimating the label prevalence from an unlabeled data set. In this repository we provide the code associated with our manuscript, which can be used to reproduce the experiments.

It is compatible with any classifier using any machine learning framework.
## Installation

The code inside was used to run the experiments in [our preprint](https://arxiv.org/abs/2302.09159), which can be cited as:
```
@misc{https://doi.org/10.48550/arxiv.2302.09159,
doi = {10.48550/ARXIV.2302.09159},
url = {https://arxiv.org/abs/2302.09159},
author = {Ziegler, Albert and Czyż, Paweł},
title = {Bayesian Quantification with Black-Box Estimators},
publisher = {arXiv},
year = {2023}
}
We recommend using [Micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) to set a new Python 3.11 environment.
Then, the package can be installed with:

```bash
$ pip install -e .
```

## Installation
Currently the module is in early development stage and is not ready to be installed. It does not have proper documentation either. We hope to change it soon – thank you for your patience!
To reproduce the experiments, install [Snakemake](https://snakemake.readthedocs.io/en/stable/) using the instructions provided. Then, install additional dependencies:

## Contributions
Contributions are very welcome! Please, check our [Contribution guide](CONTRIBUTING.md).
```bash
$ pip install -r requirements.txt
```

The experiments can be reproduced by running:

```bash
$ snakemake -c4 -s workflows/WORKFLOW_NAME.smk
```

1 change: 1 addition & 0 deletions labelshift/adjustments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Predictions adjustments."""

import numpy as np
from numpy.typing import ArrayLike

Expand Down
6 changes: 5 additions & 1 deletion labelshift/algorithms/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

>>> import labelshift.algorithms.api as algo
"""
from labelshift.algorithms.bayesian_discrete import DiscreteCategoricalMeanEstimator, SamplingParams

from labelshift.algorithms.bayesian_discrete import (
DiscreteCategoricalMeanEstimator,
SamplingParams,
)
from labelshift.algorithms.bbse import BlackBoxShiftEstimator
from labelshift.algorithms.classify_and_count import ClassifyAndCount
from labelshift.algorithms.ratio_estimator import InvariantRatioEstimator
Expand Down
54 changes: 40 additions & 14 deletions labelshift/algorithms/bayesian_discrete.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Categorical discrete Bayesian model for quantification."""

import numpy as np
import jax.numpy as jnp
import numpyro
Expand All @@ -7,8 +8,6 @@
import pydantic
from typing import Optional

from numpy.typing import ArrayLike

import labelshift.interfaces.point_estimators as pe


Expand All @@ -17,6 +16,7 @@ class SamplingParams(pydantic.BaseModel):

warmup: pydantic.PositiveInt = pydantic.Field(default=500)
samples: pydantic.PositiveInt = pydantic.Field(default=1000)
chains: pydantic.PositiveInt = pydantic.Field(default=1)


P_TRAIN_Y: str = "P_train(Y)"
Expand All @@ -25,50 +25,70 @@ class SamplingParams(pydantic.BaseModel):
P_C_COND_Y: str = "P(C|Y)"


def model(summary_statistic):
def model(summary_statistic, alpha: float = 1.0):
n_y_labeled = summary_statistic.n_y_labeled
n_y_and_c_labeled = summary_statistic.n_y_and_c_labeled
n_c_unlabeled = summary_statistic.n_c_unlabeled
K = len(n_c_unlabeled)
L = len(n_y_labeled)

pi = numpyro.sample(P_TRAIN_Y, dist.Dirichlet(jnp.ones(L)))
pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L)))
p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K)))
pi = numpyro.sample(P_TRAIN_Y, dist.Dirichlet(alpha * jnp.ones(L)))
pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(alpha * jnp.ones(L)))
p_c_cond_y = numpyro.sample(
P_C_COND_Y, dist.Dirichlet(alpha * jnp.ones(K).repeat(L).reshape(L, K))
)

N_y = numpyro.sample(
"N_y", dist.Multinomial(jnp.sum(n_y_labeled), pi), obs=n_y_labeled
)

N_y = numpyro.sample('N_y', dist.Multinomial(jnp.sum(n_y_labeled), pi), obs=n_y_labeled)

with numpyro.plate('plate', L):
numpyro.sample('F_yc', dist.Multinomial(N_y, p_c_cond_y), obs=n_y_and_c_labeled)
with numpyro.plate("plate", L):
numpyro.sample("F_yc", dist.Multinomial(N_y, p_c_cond_y), obs=n_y_and_c_labeled)

p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_))
numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)
numpyro.sample(
"N_c", dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled
)


class DiscreteCategoricalMeanEstimator(pe.SummaryStatisticPrevalenceEstimator):
"""A version of Bayesian quantification which finds the mean solution.

Note that it runs the MCMC sampler in the backend.
"""

P_TRAIN_Y = P_TRAIN_Y
P_TEST_Y = P_TEST_Y
P_TEST_C = P_TEST_C
P_C_COND_Y = P_C_COND_Y

def __init__(self, params: Optional[SamplingParams] = None, seed: int = 42) -> None:
def __init__(
self,
params: Optional[SamplingParams] = None,
seed: int = 42,
alpha: float = 1.0,
) -> None:
if params is None:
params = SamplingParams()
self._params = params
self._seed = seed
self._mcmc = None

if alpha <= 0:
raise ValueError("Concentration parameter alpha has to be positive.")
self._alpha = alpha

def sample_posterior(self, /, statistic: pe.SummaryStatistic):
"""Returns the samples from the MCMC sampler."""
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(model),
num_warmup=self._params.warmup,
num_samples=self._params.samples)
num_samples=self._params.samples,
num_chains=self._params.chains,
)
rng_key = jax.random.PRNGKey(self._seed)
mcmc.run(rng_key, summary_statistic=statistic)
mcmc.run(rng_key, summary_statistic=statistic, alpha=self._alpha)
self._mcmc = mcmc
return mcmc.get_samples()

def estimate_from_summary_statistic(
Expand All @@ -77,3 +97,9 @@ def estimate_from_summary_statistic(
"""Returns the mean prediction."""
samples = self.sample_posterior(statistic)[P_TEST_Y]
return np.array(samples.mean(axis=0))

def get_mcmc(self):
"""Returns the MCMC object."""
if self._mcmc is None:
raise ValueError("Run `sample_posterior` to obtain MCMC samples first.")
return self._mcmc
1 change: 1 addition & 0 deletions labelshift/algorithms/bbse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Detecting and Correcting for Label Shift with Black Box Predictors
https://arxiv.org/pdf/1802.03916.pdf
"""

from typing import Optional

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions labelshift/algorithms/classify_and_count.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Classify and Count algorithm."""

import numpy as np
from numpy.typing import ArrayLike

Expand Down
1 change: 1 addition & 0 deletions labelshift/algorithms/expectation_maximization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Expectation Maximization algorithm."""

import warnings
from typing import Optional
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions labelshift/algorithms/ratio_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
``H_hat[l, k] = G_hat[l, k] = E_labeled[ g(X)[k] | Y = l] \\in R^{L x (K-1)}.``

"""

from typing import Optional, Tuple

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions labelshift/algorithms/validate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Preprocessing and validation methods."""

from typing import Tuple
import numpy as np
from numpy.typing import ArrayLike
Expand Down
4 changes: 3 additions & 1 deletion labelshift/datasets/discrete_categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Discrete categorical sampler."""

import dataclasses
import math
from typing import Tuple, Any, Union, Optional
Expand All @@ -14,12 +15,13 @@
@dataclasses.dataclass
class SummaryMultinomialStatistic:
"""

Attributes:
n_y: shape (L,)
n_c: shape (K,)
n_y_and_c: shape (L, K)
"""

n_y: np.ndarray
n_c: np.ndarray
n_y_and_c: np.ndarray
Expand Down
1 change: 1 addition & 0 deletions labelshift/datasets/gaussian_mixture.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Model used for working with exact probabilities
in the Gaussian mixture model."""

from typing import Protocol

import numpy as np
Expand Down
3 changes: 2 additions & 1 deletion labelshift/datasets/split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for working with NumPy datasets."""

import dataclasses
from typing import List, Protocol

Expand Down Expand Up @@ -72,7 +73,7 @@ def split_dataset(

if set(np.unique(dataset.target)) != set(range(n_labels)):
raise ValueError(
f"Labels must be 0-indexed integers: {dataset.target_names} != "
f"Labels must be 0-indexed integers: {dataset.target} != "
f"{set(range(n_labels))}."
)
if {
Expand Down
3 changes: 1 addition & 2 deletions labelshift/experiments/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""The experimental utilities."""

from typing import TypeVar, Optional

from labelshift.experiments.timer import Timer
from labelshift.experiments.names import generate_name

_T = TypeVar("_T")

Expand All @@ -17,6 +17,5 @@ def calculate_value(*, overwrite: Optional[_T], default: _T) -> _T:
__all__ = [
"Timer",
"calculate_value",
"generate_name",
"calculate_value",
]
12 changes: 0 additions & 12 deletions labelshift/experiments/names.py

This file was deleted.

1 change: 1 addition & 0 deletions labelshift/experiments/timer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Creates a Timer class, a convenient thing to measure the elapsed time."""

import time


Expand Down
1 change: 1 addition & 0 deletions labelshift/interfaces/point_estimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Protocols for point estimators for P_test(Y),
which may have access to different data modalities."""

import dataclasses
from typing import Protocol

Expand Down
1 change: 1 addition & 0 deletions labelshift/partition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Partition of the real line into intervals."""

from typing import List, Sequence, Tuple

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions labelshift/probability.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common NumPy utilities for dealing with probabilities."""

import numpy as np
from numpy.typing import ArrayLike

Expand Down
Loading
Loading