From 8bceb5a426ac1df3696e37ce9841b055fda6dfd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 19 Apr 2024 17:50:16 +0200 Subject: [PATCH 01/14] Improve single-cell plot --- workflows/single_cell.smk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workflows/single_cell.smk b/workflows/single_cell.smk index cca2ac3..9bd019c 100644 --- a/workflows/single_cell.smk +++ b/workflows/single_cell.smk @@ -250,7 +250,7 @@ rule manuscript_plot: def plot(ax, props): x_axis = np.arange(len(props["cell_types"])) - ax.set_xticks(x_axis, labels=props["cell_types"], rotation=70) + ax.set_xticks(x_axis, labels=props["cell_types"], rotation=63) ax.set_ylabel("Proportion") # Ground-truth From 88603eb551c35e157175bdbbccaf7e937cbc9e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 19 Apr 2024 17:50:32 +0200 Subject: [PATCH 02/14] Add labels to benchmark subpanels. --- workflows/benchmark.smk | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/workflows/benchmark.smk b/workflows/benchmark.smk index f42faac..391b5e1 100644 --- a/workflows/benchmark.smk +++ b/workflows/benchmark.smk @@ -6,6 +6,7 @@ import joblib import pandas as pd import matplotlib import matplotlib.pyplot as plt +import matplotlib.transforms as mtransforms matplotlib.use("Agg") import numpy as np @@ -349,6 +350,10 @@ rule plot_results_rule: fig.tight_layout() fig.savefig(str(output)) +def label_ax(fig, ax, label): + trans = mtransforms.ScaledTranslation(11/72, -1/72, fig.dpi_scale_trans) + ax.text(0.0, 1.0, f"{label}.", transform=ax.transAxes + trans, fontsize='medium', verticalalignment='top') + rule plot_large_plot: output: "plots/summary-{metric}.pdf" @@ -374,6 +379,7 @@ rule plot_large_plot: data = pd.read_csv(input.prevalence, index_col=False) plot_results(ax, data) ax.set_xlabel("Prevalence $\\pi'_1$") + label_ax(fig, ax, "a") # Unlabeled data set size ax = axs[0, 1] @@ -381,12 +387,15 @@ rule plot_large_plot: plot_results(ax, data) ax.set_xscale("log", base=10) ax.set_xlabel("Unlabeled sample size $N'$") + label_ax(fig, ax, "b") + # Classifier quality ax = axs[0, 2] data = pd.read_csv(input.quality, index_col=False) plot_results(ax, data) ax.set_xlabel("Classifier quality $q$") + label_ax(fig, ax, "c") ax.legend(frameon=False, bbox_to_anchor=(1.05, 1), loc="upper left") @@ -396,6 +405,7 @@ rule plot_large_plot: plot_results(ax, data) ax.set_xlabel("Classifier outputs $K$") ax.set_xticks([3, 5, 7, 9]) + label_ax(fig, ax, "d") # Change L = K ax = axs[1, 1] @@ -403,6 +413,7 @@ rule plot_large_plot: plot_results(ax, data) ax.set_xlabel("Number of labels $L=K$") ax.set_xticks([3, 5, 7, 9]) + label_ax(fig, ax, "e") # Change misspecification ax = axs[1, 2] @@ -410,6 +421,7 @@ rule plot_large_plot: plot_results(ax, data) ax.set_xlabel("Misspecified quality $q'$") ax.axvline(0.85, color="black", linestyle="--") + label_ax(fig, ax, "f") fig.tight_layout() fig.savefig(str(output)) From 542551075670755266cdca4136d7780eba3ea448 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 19 Apr 2024 19:58:45 +0200 Subject: [PATCH 03/14] Report Rhat for the single-cell experiment --- labelshift/algorithms/bayesian_discrete.py | 13 ++++++++++++- workflows/single_cell.smk | 12 +++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/labelshift/algorithms/bayesian_discrete.py b/labelshift/algorithms/bayesian_discrete.py index 22a1add..ed82fa1 100644 --- a/labelshift/algorithms/bayesian_discrete.py +++ b/labelshift/algorithms/bayesian_discrete.py @@ -17,6 +17,7 @@ class SamplingParams(pydantic.BaseModel): warmup: pydantic.PositiveInt = pydantic.Field(default=500) samples: pydantic.PositiveInt = pydantic.Field(default=1000) + chains: pydantic.PositiveInt = pydantic.Field(default=1) P_TRAIN_Y: str = "P_train(Y)" @@ -60,15 +61,19 @@ def __init__(self, params: Optional[SamplingParams] = None, seed: int = 42) -> N params = SamplingParams() self._params = params self._seed = seed + self._mcmc = None def sample_posterior(self, /, statistic: pe.SummaryStatistic): """Returns the samples from the MCMC sampler.""" mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(model), num_warmup=self._params.warmup, - num_samples=self._params.samples) + num_samples=self._params.samples, + num_chains=self._params.chains, + ) rng_key = jax.random.PRNGKey(self._seed) mcmc.run(rng_key, summary_statistic=statistic) + self._mcmc = mcmc return mcmc.get_samples() def estimate_from_summary_statistic( @@ -77,3 +82,9 @@ def estimate_from_summary_statistic( """Returns the mean prediction.""" samples = self.sample_posterior(statistic)[P_TEST_Y] return np.array(samples.mean(axis=0)) + + def get_mcmc(self): + """Returns the MCMC object.""" + if self._mcmc is None: + raise ValueError("Run `sample_posterior` to obtain MCMC samples first.") + return self._mcmc diff --git a/workflows/single_cell.smk b/workflows/single_cell.smk index 9bd019c..1cb3462 100644 --- a/workflows/single_cell.smk +++ b/workflows/single_cell.smk @@ -38,7 +38,7 @@ # ISSN 2211-1247, https://doi.org/10.1016/j.celrep.2017.10.030. # # ------------------------------------------------------------------------------------------------------- - +from contextlib import redirect_stdout import numpy as np import pandas as pd import scanpy as sc @@ -292,7 +292,8 @@ rule estimate_proportions: input: "data/{name}_data.h5ad" output: estimates = "proportions/{name}.npz", - error_log = "proportions/{name}.log" + error_log = "proportions/{name}.log", + convergence = "convergence/{name}.txt" run: # Load the data adata = sc.read_h5ad(input[0]) @@ -332,7 +333,12 @@ rule estimate_proportions: n_c_unlabeled=n_c_unlabeled, ) - posterior = algo.DiscreteCategoricalMeanEstimator().sample_posterior(statistic)[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y] + mcmc_algo = algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)) + posterior = mcmc_algo.sample_posterior(statistic)[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y] + with open(output.convergence, "w") as fh: + with redirect_stdout(fh): + mcmc_algo.get_mcmc().print_summary() + failed_counter = 0 From 7a1e61631d9287dae7e9ca21e87f52da3a2610e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 19 Apr 2024 20:12:09 +0200 Subject: [PATCH 04/14] Report Rhat for the nearly-nonidentifiable experiment. --- workflows/nearly_nonidentifiable.smk | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/workflows/nearly_nonidentifiable.smk b/workflows/nearly_nonidentifiable.smk index ded9b73..5545cda 100644 --- a/workflows/nearly_nonidentifiable.smk +++ b/workflows/nearly_nonidentifiable.smk @@ -1,6 +1,7 @@ # --------------------------------------------------- # - Experiment with a nearly non-identifiable model - # --------------------------------------------------- +from contextlib import redirect_stdout from dataclasses import dataclass import numpy as np import matplotlib.pyplot as plt @@ -72,12 +73,19 @@ rule generate_data: rule run_mcmc: input: "data/{setting}-{seed}.joblib" - output: "samples/MCMC/{setting}-{seed}.npy" + output: + array = "samples/MCMC/{setting}-{seed}.npy", + convergence = "samples/MCMC/convergence/{setting}-{seed}.txt" run: data = joblib.load(str(input)) - estimator = algo.DiscreteCategoricalMeanEstimator() + + estimator = algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)) samples = np.asarray(estimator.sample_posterior(data)[estimator.P_TEST_Y]) - np.save(str(output), samples) + with open(output.convergence, "w") as fh: + with redirect_stdout(fh): + estimator.get_mcmc().print_summary() + + np.save(output.array, samples) def _bootstrap(rng, stat: dc.SummaryStatistic) -> dc.SummaryStatistic: From 92cc65d9964b96b1adc7ec3ed3ed6fc22004cc50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 20 Apr 2024 00:10:54 +0200 Subject: [PATCH 05/14] Add prior sensitivity experiment --- labelshift/algorithms/bayesian_discrete.py | 16 +- workflows/prior_sensitivity.py | 258 +++++++++++++++++++++ 2 files changed, 268 insertions(+), 6 deletions(-) create mode 100644 workflows/prior_sensitivity.py diff --git a/labelshift/algorithms/bayesian_discrete.py b/labelshift/algorithms/bayesian_discrete.py index ed82fa1..13df6be 100644 --- a/labelshift/algorithms/bayesian_discrete.py +++ b/labelshift/algorithms/bayesian_discrete.py @@ -26,16 +26,16 @@ class SamplingParams(pydantic.BaseModel): P_C_COND_Y: str = "P(C|Y)" -def model(summary_statistic): +def model(summary_statistic, alpha: float = 1.0): n_y_labeled = summary_statistic.n_y_labeled n_y_and_c_labeled = summary_statistic.n_y_and_c_labeled n_c_unlabeled = summary_statistic.n_c_unlabeled K = len(n_c_unlabeled) L = len(n_y_labeled) - pi = numpyro.sample(P_TRAIN_Y, dist.Dirichlet(jnp.ones(L))) - pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L))) - p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K))) + pi = numpyro.sample(P_TRAIN_Y, dist.Dirichlet(alpha * jnp.ones(L))) + pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(alpha * jnp.ones(L))) + p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(alpha * jnp.ones(K).repeat(L).reshape(L, K))) N_y = numpyro.sample('N_y', dist.Multinomial(jnp.sum(n_y_labeled), pi), obs=n_y_labeled) @@ -56,13 +56,17 @@ class DiscreteCategoricalMeanEstimator(pe.SummaryStatisticPrevalenceEstimator): P_TEST_C = P_TEST_C P_C_COND_Y = P_C_COND_Y - def __init__(self, params: Optional[SamplingParams] = None, seed: int = 42) -> None: + def __init__(self, params: Optional[SamplingParams] = None, seed: int = 42, alpha: float = 1.0) -> None: if params is None: params = SamplingParams() self._params = params self._seed = seed self._mcmc = None + if alpha <= 0: + raise ValueError("Concentration parameter alpha has to be positive.") + self._alpha = alpha + def sample_posterior(self, /, statistic: pe.SummaryStatistic): """Returns the samples from the MCMC sampler.""" mcmc = numpyro.infer.MCMC( @@ -72,7 +76,7 @@ def sample_posterior(self, /, statistic: pe.SummaryStatistic): num_chains=self._params.chains, ) rng_key = jax.random.PRNGKey(self._seed) - mcmc.run(rng_key, summary_statistic=statistic) + mcmc.run(rng_key, summary_statistic=statistic, alpha=self._alpha) self._mcmc = mcmc return mcmc.get_samples() diff --git a/workflows/prior_sensitivity.py b/workflows/prior_sensitivity.py new file mode 100644 index 0000000..7ae90fe --- /dev/null +++ b/workflows/prior_sensitivity.py @@ -0,0 +1,258 @@ +# -------------------------------------------------------------------- +# --- Prior sensitivity check for binary quantification problems --- +# -------------------------------------------------------------------- +import matplotlib +import matplotlib.pyplot as plt +matplotlib.use("Agg") + +import json +from contextlib import redirect_stdout +from dataclasses import dataclass +import joblib +import numpy as np +import pandas as pd + + +import labelshift.algorithms.api as algo +import labelshift.experiments.api as exp +import labelshift.datasets.discrete_categorical as dc + +workdir: "generated/prior_sensitivity" + + + +@dataclass +class DataSetting: + scalar_p_y_labeled: float + scalar_p_y_unlabeled: float + + quality_labeled: float + quality_unlabeled: float + + n_y: int + n_c: int + + n_labeled: int + n_unlabeled: int + + @property + def p_y_labeled(self) -> np.ndarray: + return dc.almost_eye(self.n_y, self.n_y, diagonal=self.scalar_p_y_labeled)[0, :] + + @property + def p_y_unlabeled(self) -> np.ndarray: + return dc.almost_eye(self.n_y, self.n_y, diagonal=self.scalar_p_y_unlabeled)[0, :] + + @property + def p_c_cond_y_labeled(self) -> np.ndarray: + return dc.almost_eye( + y=self.n_y, + c=self.n_c, + diagonal=self.quality_labeled, + ) + + @property + def p_c_cond_y_unlabeled(self) -> np.ndarray: + return dc.almost_eye( + y=self.n_y, + c=self.n_c, + diagonal=self.quality_unlabeled, + ) + + def create_sampler(self) -> dc.DiscreteSampler: + return dc.discrete_sampler_factory( + p_y_labeled=self.p_y_labeled, + p_y_unlabeled=self.p_y_unlabeled, + p_c_cond_y_labeled=self.p_c_cond_y_labeled, + p_c_cond_y_unlabeled=self.p_c_cond_y_unlabeled, + ) + + +def generate_data_setting( + n_labeled: int = 1000, + n_unlabeled: int = 500, + quality: float = 0.85, + quality_unlabeled: float | None = None, + L: int = 5, + K: int | None = None, + prevalence_labeled: float | None = None, + prevalence_unlabeled: float | None = 0.7, +) -> DataSetting: + n_y = L + n_c = exp.calculate_value(overwrite=K, default=n_y) + + quality_unlabeled = exp.calculate_value( + overwrite=quality_unlabeled, default=quality + ) + + p_y_labeled = exp.calculate_value( + overwrite=prevalence_labeled, default=1 / n_y + ) + p_y_unlabeled = exp.calculate_value( + overwrite=prevalence_unlabeled, default=1 / n_y + ) + + return DataSetting( + scalar_p_y_labeled=p_y_labeled, + scalar_p_y_unlabeled=p_y_unlabeled, + quality_labeled=quality, + quality_unlabeled=quality_unlabeled, + n_y=n_y, + n_c=n_c, + n_labeled=n_labeled, + n_unlabeled=n_unlabeled, + ) + +ALPHA_SMALL = 0.1 +ALPHA_MEDIUM = 1.0 +ALPHA_LARGE = 10.0 + + +MODELS = { + str(ALPHA_SMALL): algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4), alpha=ALPHA_SMALL), + str(ALPHA_MEDIUM): algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4), alpha=ALPHA_MEDIUM), + str(ALPHA_LARGE): algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4), alpha=ALPHA_LARGE), +} +COLORS = { + str(ALPHA_SMALL): "darkblue", + str(ALPHA_MEDIUM): "purple", + str(ALPHA_LARGE): "goldenrod", +} + +N_SMALL = 50 +N_MEDIUM = 500 +N_LARGE = 5_000 + +DATA_SETTINGS = { + str(N_SMALL): generate_data_setting(n_labeled=N_SMALL, n_unlabeled=N_SMALL, L=2, K=2), + str(N_MEDIUM): generate_data_setting(n_labeled=N_MEDIUM, n_unlabeled=N_MEDIUM, L=2, K=2), + str(N_LARGE): generate_data_setting(n_labeled=N_LARGE, n_unlabeled=N_LARGE, L=2, K=2), +} + + +def get_data_setting(data_setting: str) -> DataSetting: + return DATA_SETTINGS[data_setting] + + +rule all: + input: "prior_sensitivity.pdf", "convergence_stats.json" + + +rule plot: + input: + expand("posterior_samples/{data_setting}/model-{model}/1.joblib", data_setting=DATA_SETTINGS.keys(), model=MODELS.keys()) + output: "prior_sensitivity.pdf" + run: + data_sets = {} + for path in input: + samples = joblib.load(path) + setting = samples["data_setting"] + model = samples["model"] + data_sets[(setting, model)] = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 0] + + fig, axs = plt.subplots(1, 3, sharex=True, sharey=False, figsize=(6, 1.8), dpi=350) + + for ax in axs: + ax.spines[["top", "left", "right"]].set_visible(False) + ax.set_yticks([]) + ax.set_xlabel("$\\pi_1'$") + ax.axvline(0.7, linestyle="--", linewidth=1, c="k", label="$\\pi_1^*$") + + bins = np.linspace(0, 1, 30) + + def plot_posterior(ax, samples, color, label=None): + ax.hist(samples, bins=bins, histtype="step", color=color, alpha=0.8) + ax.axvline(np.mean(samples), color=color, linewidth=1, label=label) + + for data_setting, ax in zip(DATA_SETTINGS.keys(), axs): + ax.set_title(f"$N=N'={data_setting}$") + + for model in MODELS.keys(): + samples = data_sets[(data_setting, model)] + plot_posterior(ax, samples, color=COLORS[model], label=f"$\\alpha={model}$") + + ax = axs.ravel()[-1] + ax.legend(frameon=False, bbox_to_anchor=(1.05, 1.)) + + fig.tight_layout() + fig.savefig(str(output)) + +rule generate_data: + output: "data/{data_setting}/{seed}.joblib" + run: + data_setting = get_data_setting(wildcards.data_setting) + sampler = data_setting.create_sampler() + + summary_statistic = sampler.sample_summary_statistic( + n_labeled=data_setting.n_labeled, + n_unlabeled=data_setting.n_unlabeled, + seed=int(wildcards.seed), + ) + joblib.dump(summary_statistic, str(output)) + + +rule apply_estimator: + input: "data/{data_setting}/{seed}.joblib" + output: + posterior_samples = "posterior_samples/{data_setting}/model-{model}/{seed}.joblib", + convergence = "convergence/{data_setting}/model-{model}/{seed}.txt" + run: + data = joblib.load(str(input)) + estimator = MODELS[wildcards.model] + + posterior_samples = estimator.sample_posterior(data) + posterior_samples["data_setting"] = wildcards.data_setting + posterior_samples["model"] = wildcards.model + + joblib.dump(posterior_samples, filename=output.posterior_samples) + + with open(output.convergence, "w") as fh: + with redirect_stdout(fh): + estimator.get_mcmc().print_summary() + + +def parse_text_to_dataframe(file_path): + # Read the entire file into a list of lines + with open(file_path) as file: + lines = file.readlines() + + # Find the start of the actual data (ignoring initial empty lines and headers) + start_index = 0 + while not lines[start_index].strip(): # This finds the first non-empty line + start_index += 1 + + # We assume the table ends where non-table data starts again, typically after an empty line + end_index = start_index + while end_index < len(lines) and lines[end_index].strip(): + end_index += 1 + + # Now extract only the relevant lines + data_lines = lines[start_index:end_index] + + # Use pandas to read these lines, considering whitespace as a separator + from io import StringIO + data_str = '\n'.join(data_lines) + dataframe = pd.read_csv(StringIO(data_str), sep=r'\s+', engine='python') + + return dataframe + +rule parse_convergence_txt_to_csv: + input: "convergence/{data_setting}/model-{model}/{seed}.txt" + output: "convergence-csv/{data_setting}/model-{model}/{seed}.csv" + run: + parse_text_to_dataframe(str(input)).to_csv(str(output), index=False) + +rule get_convergence_stats: + input: expand("convergence-csv/{data_setting}/model-{model}/1.csv", data_setting=DATA_SETTINGS, model=MODELS) + output: "convergence_stats.json" + run: + min_n_eff = 1e12 + max_r_hat = -100 + + for pth in input: + df = pd.read_csv(pth) + min_n_eff = min(min_n_eff, df["n_eff"].values.min()) + max_r_hat = max(max_r_hat, df["r_hat"].values.max()) + + with open(str(output), "w") as fp: + json.dump(obj={"r_hat": max_r_hat, "n_eff": min_n_eff}, fp=fp) From 7e327a6df0fe3cd60f05d7b032c7140006cd6edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 10:33:42 +0200 Subject: [PATCH 06/14] Improve DPI --- workflows/nearly_nonidentifiable.smk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workflows/nearly_nonidentifiable.smk b/workflows/nearly_nonidentifiable.smk index 5545cda..e8b70c7 100644 --- a/workflows/nearly_nonidentifiable.smk +++ b/workflows/nearly_nonidentifiable.smk @@ -147,7 +147,7 @@ rule plot: bbs = "samples/bootstrap-BBS/{setting}-{seed}.npy" output: "figures/{setting}-{seed}.pdf" run: - fig, axs = plt.subplots(1, 4, figsize=(8, 2), dpi=130, sharex=True, sharey=True) + fig, axs = plt.subplots(1, 4, figsize=(8, 2), dpi=400, sharex=True, sharey=True) ax = axs[0] ax.set_ylabel("Prevalence") From 3121707f2f66d80578a9dd769c3b076dd434905a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 10:43:27 +0200 Subject: [PATCH 07/14] Adjust build and README --- .github/workflows/build.yml | 5 +---- README.md | 40 ++++++++++++++++++------------------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b605277..84f2398 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v2 @@ -30,9 +30,6 @@ jobs: - name: Check code with flake8 run: | flake8 - - name: Check docstring coverage with interrogate - run: | - interrogate - name: Check whether black has been used run: | black --check tests diff --git a/README.md b/README.md index 6a3a324..2c35280 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,30 @@ -# Label Shift +[![Project Status: Concept – Minimal or no implementation has been done yet, or the repository is only intended to be a limited example, demo, or proof-of-concept.](https://www.repostatus.org/badges/latest/concept.svg)](https://www.repostatus.org/#concept) +[![Venue](https://img.shields.io/badge/venue-TMLR_2024-darkblue)](https://openreview.net/forum?id=Ft4kHrOawZ) +[![build](https://github.com/pawel-czyz/labelshift/actions/workflows/build.yml/badge.svg)](https://github.com/pawel-czyz/labelshift/actions/workflows/build.yml) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -Python library for *quantification* (estimating the class prevalence in an unlabeled data set) under the prior probability shift assumption. +# Bayesian quantification with black-box estimators -This module is created with two purposes in mind: - - easily apply state-of-the-art quantification algorithms to the real problems, - - benchmark novel quantification algorithms against others. +*Quantification* is the problem of estimating the label prevalence from an unlabeled data set. In this repository we provide the code associated with our manuscript, which can be used to reproduce the experiments. -It is compatible with any classifier using any machine learning framework. +## Installation -The code inside was used to run the experiments in [our preprint](https://arxiv.org/abs/2302.09159), which can be cited as: -``` -@misc{https://doi.org/10.48550/arxiv.2302.09159, - doi = {10.48550/ARXIV.2302.09159}, - url = {https://arxiv.org/abs/2302.09159}, - author = {Ziegler, Albert and Czyż, Paweł}, - title = {Bayesian Quantification with Black-Box Estimators}, - publisher = {arXiv}, - year = {2023} -} +We recommend using [Micromamba](https://mamba.readthedocs.io/en/latest/user_guide/micromamba.html) to set a new Python 3.11 environment. +Then, the package can be installed with: + +```bash +$ pip install -e . ``` -## Installation -Currently the module is in early development stage and is not ready to be installed. It does not have proper documentation either. We hope to change it soon – thank you for your patience! +To reproduce the experiments, install [Snakemake](https://snakemake.readthedocs.io/en/stable/) using the instructions provided. Then, install additional dependencies: -## Contributions -Contributions are very welcome! Please, check our [Contribution guide](CONTRIBUTING.md). +```bash +$ pip install -r requirements.txt +``` +The experiments can be reproduced by running: +```bash +$ snakemake -c4 -s workflows/WORKFLOW_NAME.smk +``` From 515bf3e5d1c76b471d07e5b4a45952f369d1d69a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 10:47:10 +0200 Subject: [PATCH 08/14] Check code --- labelshift/adjustments.py | 1 + labelshift/algorithms/api.py | 6 +++- labelshift/algorithms/bayesian_discrete.py | 29 +++++++++++++------ labelshift/algorithms/bbse.py | 1 + labelshift/algorithms/classify_and_count.py | 1 + .../algorithms/expectation_maximization.py | 1 + labelshift/algorithms/ratio_estimator.py | 1 + labelshift/algorithms/validate.py | 1 + labelshift/datasets/discrete_categorical.py | 4 ++- labelshift/datasets/gaussian_mixture.py | 1 + labelshift/datasets/split.py | 1 + labelshift/experiments/api.py | 1 + labelshift/experiments/names.py | 1 + labelshift/experiments/timer.py | 1 + labelshift/interfaces/point_estimators.py | 1 + labelshift/partition.py | 1 + labelshift/probability.py | 1 + labelshift/recalibrate.py | 1 + labelshift/scoring.py | 1 + labelshift/summary_statistic.py | 1 + ...r_sensitivity.py => prior_sensitivity.smk} | 0 21 files changed, 45 insertions(+), 11 deletions(-) rename workflows/{prior_sensitivity.py => prior_sensitivity.smk} (100%) diff --git a/labelshift/adjustments.py b/labelshift/adjustments.py index 2c3922f..5693d9e 100755 --- a/labelshift/adjustments.py +++ b/labelshift/adjustments.py @@ -1,4 +1,5 @@ """Predictions adjustments.""" + import numpy as np from numpy.typing import ArrayLike diff --git a/labelshift/algorithms/api.py b/labelshift/algorithms/api.py index c55e6ae..11ee80d 100644 --- a/labelshift/algorithms/api.py +++ b/labelshift/algorithms/api.py @@ -4,7 +4,11 @@ >>> import labelshift.algorithms.api as algo """ -from labelshift.algorithms.bayesian_discrete import DiscreteCategoricalMeanEstimator, SamplingParams + +from labelshift.algorithms.bayesian_discrete import ( + DiscreteCategoricalMeanEstimator, + SamplingParams, +) from labelshift.algorithms.bbse import BlackBoxShiftEstimator from labelshift.algorithms.classify_and_count import ClassifyAndCount from labelshift.algorithms.ratio_estimator import InvariantRatioEstimator diff --git a/labelshift/algorithms/bayesian_discrete.py b/labelshift/algorithms/bayesian_discrete.py index 13df6be..6d5ca6c 100644 --- a/labelshift/algorithms/bayesian_discrete.py +++ b/labelshift/algorithms/bayesian_discrete.py @@ -1,4 +1,5 @@ """Categorical discrete Bayesian model for quantification.""" + import numpy as np import jax.numpy as jnp import numpyro @@ -7,8 +8,6 @@ import pydantic from typing import Optional -from numpy.typing import ArrayLike - import labelshift.interfaces.point_estimators as pe @@ -35,15 +34,21 @@ def model(summary_statistic, alpha: float = 1.0): pi = numpyro.sample(P_TRAIN_Y, dist.Dirichlet(alpha * jnp.ones(L))) pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(alpha * jnp.ones(L))) - p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(alpha * jnp.ones(K).repeat(L).reshape(L, K))) + p_c_cond_y = numpyro.sample( + P_C_COND_Y, dist.Dirichlet(alpha * jnp.ones(K).repeat(L).reshape(L, K)) + ) - N_y = numpyro.sample('N_y', dist.Multinomial(jnp.sum(n_y_labeled), pi), obs=n_y_labeled) - - with numpyro.plate('plate', L): - numpyro.sample('F_yc', dist.Multinomial(N_y, p_c_cond_y), obs=n_y_and_c_labeled) + N_y = numpyro.sample( + "N_y", dist.Multinomial(jnp.sum(n_y_labeled), pi), obs=n_y_labeled + ) + + with numpyro.plate("plate", L): + numpyro.sample("F_yc", dist.Multinomial(N_y, p_c_cond_y), obs=n_y_and_c_labeled) p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_)) - numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled) + numpyro.sample( + "N_c", dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled + ) class DiscreteCategoricalMeanEstimator(pe.SummaryStatisticPrevalenceEstimator): @@ -51,12 +56,18 @@ class DiscreteCategoricalMeanEstimator(pe.SummaryStatisticPrevalenceEstimator): Note that it runs the MCMC sampler in the backend. """ + P_TRAIN_Y = P_TRAIN_Y P_TEST_Y = P_TEST_Y P_TEST_C = P_TEST_C P_C_COND_Y = P_C_COND_Y - def __init__(self, params: Optional[SamplingParams] = None, seed: int = 42, alpha: float = 1.0) -> None: + def __init__( + self, + params: Optional[SamplingParams] = None, + seed: int = 42, + alpha: float = 1.0, + ) -> None: if params is None: params = SamplingParams() self._params = params diff --git a/labelshift/algorithms/bbse.py b/labelshift/algorithms/bbse.py index 7a04068..a0c57d7 100644 --- a/labelshift/algorithms/bbse.py +++ b/labelshift/algorithms/bbse.py @@ -5,6 +5,7 @@ Detecting and Correcting for Label Shift with Black Box Predictors https://arxiv.org/pdf/1802.03916.pdf """ + from typing import Optional import numpy as np diff --git a/labelshift/algorithms/classify_and_count.py b/labelshift/algorithms/classify_and_count.py index d4de302..a3fa600 100644 --- a/labelshift/algorithms/classify_and_count.py +++ b/labelshift/algorithms/classify_and_count.py @@ -1,4 +1,5 @@ """Classify and Count algorithm.""" + import numpy as np from numpy.typing import ArrayLike diff --git a/labelshift/algorithms/expectation_maximization.py b/labelshift/algorithms/expectation_maximization.py index 6286b05..1295d59 100644 --- a/labelshift/algorithms/expectation_maximization.py +++ b/labelshift/algorithms/expectation_maximization.py @@ -1,4 +1,5 @@ """Expectation Maximization algorithm.""" + import warnings from typing import Optional import numpy as np diff --git a/labelshift/algorithms/ratio_estimator.py b/labelshift/algorithms/ratio_estimator.py index b77d6dc..2c3200a 100644 --- a/labelshift/algorithms/ratio_estimator.py +++ b/labelshift/algorithms/ratio_estimator.py @@ -41,6 +41,7 @@ ``H_hat[l, k] = G_hat[l, k] = E_labeled[ g(X)[k] | Y = l] \\in R^{L x (K-1)}.`` """ + from typing import Optional, Tuple import numpy as np diff --git a/labelshift/algorithms/validate.py b/labelshift/algorithms/validate.py index bf5ecce..fc0fb27 100755 --- a/labelshift/algorithms/validate.py +++ b/labelshift/algorithms/validate.py @@ -1,4 +1,5 @@ """Preprocessing and validation methods.""" + from typing import Tuple import numpy as np from numpy.typing import ArrayLike diff --git a/labelshift/datasets/discrete_categorical.py b/labelshift/datasets/discrete_categorical.py index 30ce5f5..d1917b8 100644 --- a/labelshift/datasets/discrete_categorical.py +++ b/labelshift/datasets/discrete_categorical.py @@ -1,4 +1,5 @@ """Discrete categorical sampler.""" + import dataclasses import math from typing import Tuple, Any, Union, Optional @@ -14,12 +15,13 @@ @dataclasses.dataclass class SummaryMultinomialStatistic: """ - + Attributes: n_y: shape (L,) n_c: shape (K,) n_y_and_c: shape (L, K) """ + n_y: np.ndarray n_c: np.ndarray n_y_and_c: np.ndarray diff --git a/labelshift/datasets/gaussian_mixture.py b/labelshift/datasets/gaussian_mixture.py index da9fadf..7278fdb 100644 --- a/labelshift/datasets/gaussian_mixture.py +++ b/labelshift/datasets/gaussian_mixture.py @@ -1,5 +1,6 @@ """Model used for working with exact probabilities in the Gaussian mixture model.""" + from typing import Protocol import numpy as np diff --git a/labelshift/datasets/split.py b/labelshift/datasets/split.py index ed61519..3bc33d9 100644 --- a/labelshift/datasets/split.py +++ b/labelshift/datasets/split.py @@ -1,4 +1,5 @@ """Utilities for working with NumPy datasets.""" + import dataclasses from typing import List, Protocol diff --git a/labelshift/experiments/api.py b/labelshift/experiments/api.py index 4dab1ad..b6e6fb1 100644 --- a/labelshift/experiments/api.py +++ b/labelshift/experiments/api.py @@ -1,4 +1,5 @@ """The experimental utilities.""" + from typing import TypeVar, Optional from labelshift.experiments.timer import Timer diff --git a/labelshift/experiments/names.py b/labelshift/experiments/names.py index 2d24ae6..432c0c9 100644 --- a/labelshift/experiments/names.py +++ b/labelshift/experiments/names.py @@ -1,4 +1,5 @@ """Utilities for dealing with filesystem IO.""" + import petname from datetime import datetime diff --git a/labelshift/experiments/timer.py b/labelshift/experiments/timer.py index 3cfe69f..e86f68c 100644 --- a/labelshift/experiments/timer.py +++ b/labelshift/experiments/timer.py @@ -1,4 +1,5 @@ """Creates a Timer class, a convenient thing to measure the elapsed time.""" + import time diff --git a/labelshift/interfaces/point_estimators.py b/labelshift/interfaces/point_estimators.py index c6ca08d..65d31ea 100644 --- a/labelshift/interfaces/point_estimators.py +++ b/labelshift/interfaces/point_estimators.py @@ -1,5 +1,6 @@ """Protocols for point estimators for P_test(Y), which may have access to different data modalities.""" + import dataclasses from typing import Protocol diff --git a/labelshift/partition.py b/labelshift/partition.py index 38cb61a..438f2c9 100644 --- a/labelshift/partition.py +++ b/labelshift/partition.py @@ -1,4 +1,5 @@ """Partition of the real line into intervals.""" + from typing import List, Sequence, Tuple import numpy as np diff --git a/labelshift/probability.py b/labelshift/probability.py index 014e660..5e7c4f5 100644 --- a/labelshift/probability.py +++ b/labelshift/probability.py @@ -1,4 +1,5 @@ """Common NumPy utilities for dealing with probabilities.""" + import numpy as np from numpy.typing import ArrayLike diff --git a/labelshift/recalibrate.py b/labelshift/recalibrate.py index d057691..c47446e 100644 --- a/labelshift/recalibrate.py +++ b/labelshift/recalibrate.py @@ -1,4 +1,5 @@ """Recalibration utilities under the prior probability shift assumption.""" + import numpy as np from numpy.typing import ArrayLike diff --git a/labelshift/scoring.py b/labelshift/scoring.py index 6129120..f43a625 100644 --- a/labelshift/scoring.py +++ b/labelshift/scoring.py @@ -5,6 +5,7 @@ A Review on Quantification Learning, ACM Computing Surveys, Vol. 50, No. 5. DOI: https://dl.acm.org/doi/10.1145/3117807 """ + from typing import cast, Protocol import numpy as np diff --git a/labelshift/summary_statistic.py b/labelshift/summary_statistic.py index 1613d59..2fe391d 100644 --- a/labelshift/summary_statistic.py +++ b/labelshift/summary_statistic.py @@ -1,4 +1,5 @@ """Used to calculate summary statistic of discrete data.""" + from typing import Sequence import numpy as np diff --git a/workflows/prior_sensitivity.py b/workflows/prior_sensitivity.smk similarity index 100% rename from workflows/prior_sensitivity.py rename to workflows/prior_sensitivity.smk From e5df39b25afef019e7ba3f31802173201e96b7da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 10:49:38 +0200 Subject: [PATCH 09/14] Run code formatter or --- tests/algorithms/test_bayesian_discrete.py | 1 + tests/algorithms/test_bbse.py | 1 + tests/algorithms/test_classify_and_count.py | 1 + tests/algorithms/test_expectation_maximization.py | 1 + tests/algorithms/test_ratio_estimator.py | 1 + tests/conftest.py | 1 + tests/datasets/test_discrete_categorical.py | 1 + tests/datasets/test_gaussian_mixture.py | 1 + tests/datasets/test_split.py | 1 + tests/test_adjustments.py | 1 + tests/test_partition.py | 1 + tests/test_probability.py | 1 + tests/test_recalibrate.py | 1 + tests/test_scoring.py | 1 + tests/test_summary_statistic.py | 1 + tests/test_timer.py | 1 + 16 files changed, 16 insertions(+) diff --git a/tests/algorithms/test_bayesian_discrete.py b/tests/algorithms/test_bayesian_discrete.py index e9e23ac..ab94e01 100644 --- a/tests/algorithms/test_bayesian_discrete.py +++ b/tests/algorithms/test_bayesian_discrete.py @@ -1,4 +1,5 @@ """Tests for the labelshift/algorithms/bayesian_discrete.py""" + import numpy as np import pytest diff --git a/tests/algorithms/test_bbse.py b/tests/algorithms/test_bbse.py index 0c9e507..654d128 100644 --- a/tests/algorithms/test_bbse.py +++ b/tests/algorithms/test_bbse.py @@ -1,4 +1,5 @@ """Tests for the BBSE submodule.""" + import pytest import numpy as np diff --git a/tests/algorithms/test_classify_and_count.py b/tests/algorithms/test_classify_and_count.py index c8e77f8..a7b63bd 100644 --- a/tests/algorithms/test_classify_and_count.py +++ b/tests/algorithms/test_classify_and_count.py @@ -1,4 +1,5 @@ """Classify and count algorithm.""" + import numpy as np import numpy.testing as nptest diff --git a/tests/algorithms/test_expectation_maximization.py b/tests/algorithms/test_expectation_maximization.py index e973d14..e274e7f 100644 --- a/tests/algorithms/test_expectation_maximization.py +++ b/tests/algorithms/test_expectation_maximization.py @@ -1,4 +1,5 @@ """Tests for Expectation Maximization.""" + import numpy as np import numpy.testing as nptest import pytest diff --git a/tests/algorithms/test_ratio_estimator.py b/tests/algorithms/test_ratio_estimator.py index a13a67a..4e7a90f 100644 --- a/tests/algorithms/test_ratio_estimator.py +++ b/tests/algorithms/test_ratio_estimator.py @@ -1,4 +1,5 @@ """Tests of the Invariant Ratio Estimator algorithm.""" + import numpy as np import pytest diff --git a/tests/conftest.py b/tests/conftest.py index 8a21395..badc486 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ """Fixtures for pytests.""" + import random from typing import Sequence diff --git a/tests/datasets/test_discrete_categorical.py b/tests/datasets/test_discrete_categorical.py index eb1f5d9..69d5cb8 100644 --- a/tests/datasets/test_discrete_categorical.py +++ b/tests/datasets/test_discrete_categorical.py @@ -1,4 +1,5 @@ """Tests for the discrete categorical sampler.""" + from typing import Sequence import numpy as np diff --git a/tests/datasets/test_gaussian_mixture.py b/tests/datasets/test_gaussian_mixture.py index 676a9af..a90f754 100644 --- a/tests/datasets/test_gaussian_mixture.py +++ b/tests/datasets/test_gaussian_mixture.py @@ -1,4 +1,5 @@ """Tests of the Gaussian mixture module.""" + import numpy as np import pytest diff --git a/tests/datasets/test_split.py b/tests/datasets/test_split.py index 18ac1fb..f5e6279 100644 --- a/tests/datasets/test_split.py +++ b/tests/datasets/test_split.py @@ -1,4 +1,5 @@ """Tests of `labelshift.datasets.split`.""" + import numpy as np import pytest from sklearn import datasets diff --git a/tests/test_adjustments.py b/tests/test_adjustments.py index 0446fa4..5585f6a 100644 --- a/tests/test_adjustments.py +++ b/tests/test_adjustments.py @@ -1,4 +1,5 @@ """Tests for adjustment submodule.""" + import numpy as np import numpy.testing as nptest import pytest diff --git a/tests/test_partition.py b/tests/test_partition.py index 6d44f44..8d9b62c 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -1,4 +1,5 @@ """Tests of the `partition` submodule.""" + import numpy as np import pytest from scipy import stats diff --git a/tests/test_probability.py b/tests/test_probability.py index a0f5d21..5875920 100644 --- a/tests/test_probability.py +++ b/tests/test_probability.py @@ -1,4 +1,5 @@ """Tests for the auxilary probability submodule.""" + import numpy.testing as nptest import pytest diff --git a/tests/test_recalibrate.py b/tests/test_recalibrate.py index 20c5bbe..b5c7239 100644 --- a/tests/test_recalibrate.py +++ b/tests/test_recalibrate.py @@ -1,4 +1,5 @@ """Tests for recalibration.""" + from typing import Tuple import numpy as np import numpy.testing as nptest diff --git a/tests/test_scoring.py b/tests/test_scoring.py index 0bc50b4..14a3c64 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -1,4 +1,5 @@ """Tests of the scoring submodule.""" + from typing import List import numpy as np diff --git a/tests/test_summary_statistic.py b/tests/test_summary_statistic.py index d7fb6e9..8f57c7d 100644 --- a/tests/test_summary_statistic.py +++ b/tests/test_summary_statistic.py @@ -1,4 +1,5 @@ """Tests for the module calculating the summary statistic in the discrete case.""" + import numpy as np import pytest diff --git a/tests/test_timer.py b/tests/test_timer.py index b572143..6c29d74 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -1,4 +1,5 @@ """Tests of the `timer` submodule.""" + import time import pytest From 4e68c447e65aae7038ecbf0adc8acb1b41533a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 10:54:25 +0200 Subject: [PATCH 10/14] Fix type error --- labelshift/datasets/split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labelshift/datasets/split.py b/labelshift/datasets/split.py index 3bc33d9..f2f64cc 100644 --- a/labelshift/datasets/split.py +++ b/labelshift/datasets/split.py @@ -73,7 +73,7 @@ def split_dataset( if set(np.unique(dataset.target)) != set(range(n_labels)): raise ValueError( - f"Labels must be 0-indexed integers: {dataset.target_names} != " + f"Labels must be 0-indexed integers: {dataset.target} != " f"{set(range(n_labels))}." ) if { From 2d41ed535202492fff9a4c44596fff734c6a7500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 10:56:05 +0200 Subject: [PATCH 11/14] Removing contribution guide. --- CONTRIBUTING.md | 44 -------------------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 947c94c..0000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,44 +0,0 @@ -# Contribution guide -Thank you for your time! - -## Reporting a bug - -If you find a bug, please [submit a new issue](https://github.com/labelshift/labelshift/issues). - -To be able to reproduce a bug, we will usually need the following information: - - - Versions of Python packages used (in particular version of this library). - - A minimal code snippet allowing us to reproduce the bug. - - What is the desired behaviour in the reported case? - - What is the actual behaviour? - - -## Submitting a pull request - -**Do:** - - - Do use [Google Style Guide](https://google.github.io/styleguide/pyguide.html). We use [black](https://github.com/psf/black) for code formatting. - - Do write unit tests – 100% code coverage is a necessity. We use [pytest](https://docs.pytest.org/). - - Do write docstrings – 100% coverage is a necessity. We use [interrogate](https://pypi.org/project/interrogate/). - - Do write high-level documentation as examples and tutorials, illustrating introduced features. - - Do consider submitting a *draft* pull request with a description of proposed changes. - - Do check the [Development section](#development). - -**Don't:** - - - Don't include license information. This project is BSD-3 licensed and by submitting your pull request you implicitly and irrevocably agree to use this. - - Don't implement too many ideas in a single pull request. Multiple features should be implemented in separate pull requests. - -## Development -To install the repository in editable mode use: -``` -pip install -r requirements.txt # Install dev requirements -pip install -e . # Install the module in editable mode -pre-commit install # Install pre-commit hooks -``` -We suggest using a virtual environment for this. - -You can use `make` to run the required code quality checks. - - -Thank you a lot! From f6970d3fa398d319c09fc688db7fb4a44420df86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 11:01:12 +0200 Subject: [PATCH 12/14] Fix unit test --- tests/algorithms/test_bayesian_discrete.py | 23 +++++++--------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/algorithms/test_bayesian_discrete.py b/tests/algorithms/test_bayesian_discrete.py index ab94e01..76e02a4 100644 --- a/tests/algorithms/test_bayesian_discrete.py +++ b/tests/algorithms/test_bayesian_discrete.py @@ -29,22 +29,13 @@ def test_right_values(n_labeled: int = 10_000, n_unlabeled: int = 10_000) -> Non n_labeled=n_labeled, n_unlabeled=n_unlabeled, seed=111 ) - params = bd.SamplingParams(chains=1, draws=100) + params = bd.SamplingParams(chains=1, warmup=100, draws=100) + estimator = bd.DiscreteCategoricalMeanEstimator(params) - model = bd.build_model( - n_y_and_c_labeled=statistic.n_y_and_c_labeled, - n_c_unlabeled=statistic.n_c_unlabeled, - ) - - inference_result = bd.sample_from_bayesian_discrete_model_posterior( - model=model, - sampling_params=params, - ) + estimator.estimate_from_summary_statistic(statistic) - def get_mean(key) -> np.ndarray: - """Returns the mean over all samples for variable `key`.""" - return np.asarray(inference_result.posterior.data_vars[key].mean(axis=(0, 1))) + samples = estimator.get_mcmc().get_samples() - assert get_mean(bd.P_TEST_Y) == pytest.approx(p_y_unlabeled, abs=0.02) - assert get_mean(bd.P_TRAIN_Y) == pytest.approx(p_y_labeled, abs=0.02) - assert get_mean(bd.P_C_COND_Y) == pytest.approx(p_c_cond_y, abs=0.02) + assert samples[bd.P_TEST_Y].mean(axis=0) == pytest.approx(p_y_unlabeled, abs=0.02) + assert samples[bd.P_TRAIN_Y].mean(axis=0) == pytest.approx(p_y_labeled, abs=0.02) + assert samples[bd.P_C_COND_Y].mean(axis=0) == pytest.approx(p_c_cond_y, abs=0.02) From bde18d8ada3a60fac75754fbbe9295bd6801afd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 11:15:52 +0200 Subject: [PATCH 13/14] Remove redundant utility. --- labelshift/experiments/api.py | 2 -- labelshift/experiments/names.py | 13 ------------- 2 files changed, 15 deletions(-) delete mode 100644 labelshift/experiments/names.py diff --git a/labelshift/experiments/api.py b/labelshift/experiments/api.py index b6e6fb1..f9441f5 100644 --- a/labelshift/experiments/api.py +++ b/labelshift/experiments/api.py @@ -3,7 +3,6 @@ from typing import TypeVar, Optional from labelshift.experiments.timer import Timer -from labelshift.experiments.names import generate_name _T = TypeVar("_T") @@ -18,6 +17,5 @@ def calculate_value(*, overwrite: Optional[_T], default: _T) -> _T: __all__ = [ "Timer", "calculate_value", - "generate_name", "calculate_value", ] diff --git a/labelshift/experiments/names.py b/labelshift/experiments/names.py deleted file mode 100644 index 432c0c9..0000000 --- a/labelshift/experiments/names.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Utilities for dealing with filesystem IO.""" - -import petname -from datetime import datetime - - -def generate_name() -> str: - """Generates a name with timestamp and a random part.""" - - now = datetime.now() # current date and time - date_time = now.strftime("%Y%m%d-%H%M%S") - suffix = petname.generate(separator="-", words=3) - return f"{date_time}-{suffix}" From 2f398dcf0be290ce8c3c7bb02ddd1a98453e17a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 11:18:22 +0200 Subject: [PATCH 14/14] Remove interrogate check --- .pre-commit-config.yaml | 4 ---- Makefile | 1 - pyproject.toml | 18 ------------------ requirements.txt | 1 - 4 files changed, 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c45f976..82172f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,8 +7,4 @@ repos: rev: 5.0.4 hooks: - id: flake8 -- repo: https://github.com/econchick/interrogate - rev: 1.5.0 - hooks: - - id: interrogate diff --git a/Makefile b/Makefile index 3650a14..58d66b4 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,6 @@ test: flake8 pytype labelshift pytype tests - interrogate pytest install: diff --git a/pyproject.toml b/pyproject.toml index 7588a42..63161a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,24 +2,6 @@ line-length = 88 target-version = ["py38"] - -[tool.interrogate] -ignore-init-method = true -ignore-init-module = false -ignore-magic = false -ignore-semiprivate = true -ignore-private = true -ignore-property-decorators = false -ignore-module = false -fail-under = 90 -exclude = ["setup.py", "docs", "build"] -ignore-regex = ["^get$", "^mock_.*", ".*BaseClass.*"] -verbose = 2 -quiet = false -whitelist-regex = [] -color = true - - [tool.pytest.ini_options] minversion = "6.0" addopts = "-ra -q --cov=labelshift -n auto" diff --git a/requirements.txt b/requirements.txt index 4446c8f..2d15ecd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ subplots_from_axsize # Code quality tools black flake8 -interrogate pre-commit pytest pytest-cov