Skip to content

Commit

Permalink
Add r-hat calculation for expensive experiments (#26)
Browse files Browse the repository at this point in the history
* Add r-hat diagonostic for the benchmark runs

* Add convergence checks to experiment with misspecified models

* Use four chains

* Refactor code.

* Improve DPI

* Filter out nonconverged runs
  • Loading branch information
pawel-czyz authored May 28, 2024
1 parent 6517db6 commit 97adcce
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 30 deletions.
27 changes: 24 additions & 3 deletions workflows/benchmark.smk
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import matplotlib.transforms as mtransforms
matplotlib.use("Agg")

import numpy as np

from numpyro.diagnostics import summary

import labelshift.algorithms.api as algo
import labelshift.experiments.api as exp
Expand All @@ -22,7 +22,7 @@ ESTIMATORS = {
"BBS": algo.BlackBoxShiftEstimator(),
"CC": algo.ClassifyAndCount(),
"RIR": algo.InvariantRatioEstimator(restricted=True),
"BAY": algo.DiscreteCategoricalMeanEstimator(),
"BAY": algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)),
}
ESTIMATOR_COLORS = {
"BBS": "orangered",
Expand Down Expand Up @@ -146,6 +146,7 @@ _k_vals = [2, 3, 5, 7, 9]
_quality = [0.55, 0.65, 0.75, 0.85, 0.95]
_quality_prime = [0.45, 0.55, 0.65, 0.75, 0.80, 0.85, 0.90, 0.95]


BENCHMARKS = {
"change_prevalence": BenchmarkSettings(
param_name="Prevalence $\\pi'_1$",
Expand Down Expand Up @@ -179,6 +180,7 @@ BENCHMARKS = {
),
}


def get_data_setting(benchmark: str, param: int | str) -> DataSetting:
return BENCHMARKS[str(benchmark)].settings[int(param)]

Expand Down Expand Up @@ -234,6 +236,14 @@ rule apply_estimator:
elapsed_time = timer.check()
run_ok = True
additional_info = {}

if hasattr(estimator, "get_mcmc"):
samples = estimator.get_mcmc().get_samples(group_by_chain=True)
summ = summary(samples)
n_eff_list = [np.min(d["n_eff"]) for d in summ.values()]
r_hat_list = [np.max(d["r_hat"]) for d in summ.values()]
additional_info = additional_info | {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)}

except Exception as e:
elapsed_time = float("nan")
estimate = np.full_like(data.n_y_labeled, fill_value=float("nan"))
Expand Down Expand Up @@ -267,9 +277,13 @@ def _get_paths_to_be_assembled(wildcards):
rule assemble_results:
output:
csv = "results/benchmark-{benchmark}-metric-{metric}.csv",
err = "results/status/benchmark-{benchmark}-metric-{metric}.txt"
err = "results/status/benchmark-{benchmark}-metric-{metric}.txt",
convergence = "results/convergence/benchmark-{benchmark}-metric-{metric}.txt",
input: _get_paths_to_be_assembled
run:
max_r_hat = -1e9
min_n_eff = 1e9

results = []
for pth in input:
res = joblib.load(pth)
Expand All @@ -285,6 +299,10 @@ rule assemble_results:
}
results.append(nice)

if "max_r_hat" in res.additional_info:
max_r_hat = max(max_r_hat, res.additional_info["max_r_hat"])
min_n_eff = min(min_n_eff, res.additional_info["min_n_eff"])

results = pd.DataFrame(results)

df_ok = results[results["run_ok"]]
Expand All @@ -298,6 +316,9 @@ rule assemble_results:
df_ok = df_ok.drop(columns=["run_ok", "additional_info"])
df_ok.to_csv(str(output.csv), index=False)

with open(output.convergence, "w") as f:
f.write(f"Max r_hat: {max_r_hat}\n")
f.write(f"Min n_eff: {min_n_eff}\n")


def plot_results(ax, df, plot_std: bool = True, alpha: float = 0.5):
Expand Down
134 changes: 107 additions & 27 deletions workflows/misspecified.smk
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ----------------------------------------------------------------------------------
from dataclasses import dataclass
import numpy as np
import pandas as pd
import json
import joblib

import matplotlib
Expand All @@ -14,7 +14,7 @@ matplotlib.use("agg")
import jax
import numpyro
import numpyro.distributions as dist

from numpyro.diagnostics import summary

import labelshift.algorithms.api as algo
from labelshift.datasets.discrete_categorical import SummaryStatistic
Expand Down Expand Up @@ -57,18 +57,18 @@ N_POINTS = [100, 1000, 10_000]
PI_LABELED = 0.5
PI_UNLABELED = 0.2

N_MCMC_WARMUP = 500
N_MCMC_SAMPLES = 1000
N_MCMC_WARMUP = 1500
N_MCMC_SAMPLES = 2000
N_MCMC_CHAINS = 4


COVERAGES = np.arange(0.05, 0.96, 0.05)


rule all:
input: expand("plots/{n_points}.pdf", n_points=N_POINTS)

# rule all:
# input: expand("figures/{setting}-{seed}.pdf", setting=SETTINGS.keys(), seed=SEEDS)
input:
plots = expand("plots/{n_points}.pdf", n_points=N_POINTS),
convergence = "convergence_overall.json",


rule generate_data:
Expand All @@ -82,7 +82,7 @@ rule generate_data:

def gaussian_model(observed: Data, unobserved: np.ndarray):
sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2)))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3))

pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2)))

Expand All @@ -101,7 +101,7 @@ def gaussian_model(observed: Data, unobserved: np.ndarray):
def student_model(observed: Data, unobserved: np.ndarray):
df = numpyro.sample('df', dist.Gamma(np.ones(2), np.ones(2)))
sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2)))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3))

pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2)))

Expand All @@ -117,36 +117,55 @@ def student_model(observed: Data, unobserved: np.ndarray):
numpyro.sample('x', mixture, obs=unobserved)


def generate_summary(samples):
summ = summary(samples)
n_eff_list = [float(np.min(d["n_eff"])) for d in summ.values()]
r_hat_list = [float(np.max(d["r_hat"])) for d in summ.values()]
return {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)}

rule run_gaussian_mcmc:
input: "data/{n_points}/{seed}.npy"
output: "samples/{n_points}/Gaussian/{seed}.npy"
output:
samples = "samples/{n_points}/Gaussian/{seed}.npy",
convergence = "convergence/{n_points}/Gaussian/{seed}.joblib",
run:
data_labeled, data_unlabeled = joblib.load(str(input))
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(gaussian_model),
num_warmup=N_MCMC_WARMUP,
num_samples=N_MCMC_SAMPLES,
num_chains=N_MCMC_CHAINS,
)
rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101)
mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs)
samples = mcmc.get_samples()
joblib.dump(samples, str(output))
joblib.dump(samples, output.samples)

summ = generate_summary(mcmc.get_samples(group_by_chain=True))
joblib.dump(summ, output.convergence)


rule run_student_mcmc:
input: "data/{n_points}/{seed}.npy"
output: "samples/{n_points}/Student/{seed}.npy"
run:
output:
samples = "samples/{n_points}/Student/{seed}.npy",
convergence = "convergence/{n_points}/Student/{seed}.joblib",
run:
data_labeled, data_unlabeled = joblib.load(str(input))
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(student_model),
num_warmup=N_MCMC_WARMUP,
num_samples=N_MCMC_SAMPLES,
num_chains=N_MCMC_CHAINS,
)
rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101)
mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs)
samples = mcmc.get_samples()
joblib.dump(samples, str(output))
joblib.dump(samples, output.samples)

summ = generate_summary(mcmc.get_samples(group_by_chain=True))
joblib.dump(summ, output.convergence)



def _calculate_bins(n: int):
Expand All @@ -169,15 +188,24 @@ def generate_summary_statistic(

rule run_discrete_mcmc:
input: "data/{n_points}/{seed}.npy"
output: "samples/{n_points}/Discrete-{n_bins}/{seed}.npy"
output:
samples = "samples/{n_points}/Discrete-{n_bins}/{seed}.npy",
convergence = "convergence/{n_points}/Discrete-{n_bins}/{seed}.joblib",
run:
data_labeled, data_unlabeled = joblib.load(str(input))
estimator = algo.DiscreteCategoricalMeanEstimator(
seed=int(wildcards.seed) + 101,
params=algo.SamplingParams(warmup=N_MCMC_WARMUP, samples=N_MCMC_SAMPLES),
params=algo.SamplingParams(
warmup=N_MCMC_WARMUP,
samples=N_MCMC_SAMPLES,
chains=N_MCMC_CHAINS,
),
)
samples = estimator.sample_posterior(generate_summary_statistic(data_labeled, data_unlabeled.xs, int(wildcards.n_bins)))
joblib.dump(samples, str(output))
joblib.dump(samples, output.samples)

summ = generate_summary(estimator.get_mcmc().get_samples(group_by_chain=True))
joblib.dump(summ, output.convergence)


def calculate_hdi(arr, prob: float) -> tuple[float, float]:
Expand All @@ -196,12 +224,17 @@ def calculate_hdi(arr, prob: float) -> tuple[float, float]:


rule contains_ground_truth:
input: "samples/{n_points}/{algorithm}/{seed}.npy"
input:
samples = "samples/{n_points}/{algorithm}/{seed}.npy",
convergence = "convergence/{n_points}/{algorithm}/{seed}.joblib",
output: "contains/{n_points}/{algorithm}/{seed}.joblib"
run:
samples = joblib.load(str(input))
samples = joblib.load(input.samples)
convergence = joblib.load(input.convergence)
run_ok = True if convergence["max_r_hat"] < 1.02 else False

pi_samples = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 1]

results = []
intervals = []
for coverage in COVERAGES:
Expand All @@ -212,7 +245,7 @@ rule contains_ground_truth:

results = np.asarray(results, dtype=float)
intervals = np.asarray(intervals, dtype=float)
joblib.dump((results, intervals), str(output))
joblib.dump((results, intervals, run_ok), str(output))


def _input_paths_calculate_coverages(wildcards):
Expand All @@ -221,15 +254,62 @@ def _input_paths_calculate_coverages(wildcards):

rule calculate_coverages:
input: _input_paths_calculate_coverages
output: "coverages/{n_points}/{algorithm}.npy"
output:
coverages = "coverages/{n_points}/{algorithm}.npy",
excluded_runs = "excluded/{n_points}-{algorithm}.json"
run:
results = []

ok_runs = 0
excluded_runs = 0
for pth in input:
res, _ = joblib.load(pth)
results.append(res)
res, _, run_ok = joblib.load(pth)
if run_ok:
results.append(res)
ok_runs += 1
else:
excluded_runs += 1

results = np.asarray(results)
coverages = results.mean(axis=0)
np.save(str(output), coverages)
np.save(output.coverages, coverages)

with open(output.excluded_runs, "w") as fh:
json.dump({"excluded_runs": excluded_runs, "ok_runs": ok_runs}, fh)

def _input_paths_summarize_convergence(wildcards):
return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS]


rule summarize_convergence:
input: _input_paths_summarize_convergence
output: "convergence/{n_points}/{algorithm}.json"
run:
min_n_effs = []
max_r_hats = []
for pth in input:
res = joblib.load(pth)
min_n_effs.append(res["min_n_eff"])
max_r_hats.append(res["max_r_hat"])

with open(str(output), "w") as fh:
json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh)


rule summarize_convergence_overall:
input: expand("convergence/{n_points}/{algorithm}.json", n_points=N_POINTS, algorithm=["Gaussian", "Student", "Discrete-5", "Discrete-10"])
output: "convergence_overall.json"
run:
min_n_effs = []
max_r_hats = []
for pth in input:
with open(pth) as fh:
res = json.load(fh)
min_n_effs.append(res["min_n_eff"])
max_r_hats.append(res["max_r_hat"])

with open(str(output), "w") as fh:
json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh)

rule plot_coverage:
input:
Expand All @@ -243,7 +323,7 @@ rule plot_coverage:
sample_discrete10 = "samples/{n_points}/Discrete-10/1.npy",
output: "plots/{n_points}.pdf"
run:
fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=150, left=0.2, top=0.3, right=1.8)
fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=400, left=0.2, top=0.3, right=1.8)
axs = axs.ravel()

# Conditional distributions P(X|Y)
Expand Down

0 comments on commit 97adcce

Please sign in to comment.