From f4f40b1f417df8ba81ad62c39e91660c81bb9aac Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 7 Nov 2023 00:04:26 +0000 Subject: [PATCH] add sequential testing --- examples/gibbs/main.jl | 5 ++++ src/MCMCTesting.jl | 7 +++++- src/seqtest.jl | 54 ++++++++++++++++++++++++++++++++++++++++++ src/twosample.jl | 4 ++-- 4 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 src/seqtest.jl diff --git a/examples/gibbs/main.jl b/examples/gibbs/main.jl index 83e214f..ce340a8 100644 --- a/examples/gibbs/main.jl +++ b/examples/gibbs/main.jl @@ -102,4 +102,9 @@ function main() mcmctest(test, TestSubject(model, GibbsRandScan())) |> display mcmctest(test, TestSubject(model, GibbsRandScanWrongMean())) |> display mcmctest(test, TestSubject(model, GibbsRandScanWrongVar())) |> display + + test = TwoSampleGibbsTest(100, 100, 100) + seqmcmctest(test, TestSubject(model, GibbsRandScan()), 0.001, 32) |> display + seqmcmctest(test, TestSubject(model, GibbsRandScanWrongMean()), 0.001, 32) |> display + seqmcmctest(test, TestSubject(model, GibbsRandScanWrongVar()), 0.001, 32) |> display end diff --git a/src/MCMCTesting.jl b/src/MCMCTesting.jl index 2ef029f..00f3ebd 100644 --- a/src/MCMCTesting.jl +++ b/src/MCMCTesting.jl @@ -9,11 +9,13 @@ export sample_predictive, sample_joint, sample_markov_chain, - mcmctest + mcmctest, + seqmcmctest using Random using HypothesisTests using ProgressMeter +using MultipleTesting function sample_joint end function sample_predictive end @@ -25,6 +27,9 @@ struct TestSubject{M, K} kernel::K end +abstract type AbstractMCMCTest end + include("twosample.jl") +include("seqtest.jl") end diff --git a/src/seqtest.jl b/src/seqtest.jl new file mode 100644 index 0000000..03f9d95 --- /dev/null +++ b/src/seqtest.jl @@ -0,0 +1,54 @@ + +function seqmcmctest( + test ::AbstractMCMCTest, + subject ::TestSubject, + false_rejection_rate::Real, + samplesize ::Int, + max_iter ::Real = 3, + samplesize_increase ::Real = 2.; + show_progress = true, + pvalue_adjustment::MultipleTesting.PValueAdjustment = MultipleTesting.Bonferroni(), + kwargs...) + α = false_rejection_rate + k = max_iter + β = α / k + γ = β^(1/k) + Δ = samplesize_increase + + for i = 1:k + prog = ProgressMeter.Progress( + samplesize; + barlen = 31, + showspeed = true, + enabled = show_progress + ) + pvals_all = mapreduce(hcat, 1:samplesize) do n + pval = mcmctest(test, subject; kwargs...) + next!(prog, + showvalues = [ + (:test_iteration, i), + (:pvalue_sampling, "$(n)/$(samplesize)") + ]) + pval + end + + pvals_adjusted = mapreduce(vcat, eachcol(pvals_all)) do pvals_paramwise + adjust(Vector(pvals_paramwise), pvalue_adjustment) + end + + q = minimum(pvals_adjusted)*length(pvals_adjusted) + + if q ≤ β + return false + elseif q > γ + β + break + end + + β /= γ + + if i == 1 + samplesize = ceil(Int, samplesize*Δ) + end + end + true +end diff --git a/src/twosample.jl b/src/twosample.jl index ae07544..d5953c6 100644 --- a/src/twosample.jl +++ b/src/twosample.jl @@ -8,13 +8,13 @@ function markovchain_multiple_transition( θ end -struct TwoSampleTest +struct TwoSampleTest <: AbstractMCMCTest n_control ::Int n_treatment ::Int n_mcmc_steps::Int end -struct TwoSampleGibbsTest +struct TwoSampleGibbsTest <: AbstractMCMCTest n_control ::Int n_treatment ::Int n_mcmc_steps::Int