diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index 784e4ada..1c3e573f 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -1,5 +1,6 @@ import { FunctionComponent, useMemo } from "react" import { computeMean, computePercentile, computeStdDev } from "./util" +import { compute_effective_sample_size, compute_split_potential_scale_reduction } from "./stan_stats/stan_stats" type SummaryViewProps = { width: number @@ -16,11 +17,11 @@ const columns = [ label: 'Mean', title: 'Mean value of the parameter' }, - /*future: { + { key: 'mcse', label: 'MCSE', title: 'Monte Carlo Standard Error: Standard deviation of the parameter divided by the square root of the effective sample size' - },*/ + }, { key: 'stdDev', label: 'StdDev', @@ -41,21 +42,21 @@ const columns = [ label: '95%', title: '95th percentile of the parameter' }, - /*future: { + { key: 'nEff', label: 'N_Eff', - title: 'Effective sample size: A crude measure of the effective sample size (uses ess_imse)' - },*/ - /*future: { + title: 'Effective sample size: A crude measure of the effective sample size' + }, + { key: 'nEff/s', label: 'N_Eff/s', title: 'Effective sample size per second of compute time' - },*/ - /*future: { + }, + { key: 'rHat', label: 'R_hat', title: 'Potential scale reduction factor on split chains (at convergence, R_hat=1)' - }*/ + } ] type TableRow = { @@ -63,24 +64,21 @@ type TableRow = { values: number[] } -const SummaryView: FunctionComponent = ({ width, height, draws, paramNames }) => { - // will be used in the future: - // const uniqueChainIds = useMemo(() => (Array.from(new Set(drawChainIds)).sort()), [drawChainIds]); - // note: computeTimeSec will be used in the future - +const SummaryView: FunctionComponent = ({ width, height, draws, paramNames, drawChainIds, computeTimeSec }) => { const rows = useMemo(() => { const rows: TableRow[] = []; for (const pname of paramNames) { const pDraws = draws[paramNames.indexOf(pname)]; const pDrawsSorted = [...pDraws].sort((a, b) => a - b); + const ess = computeEss(pDraws, drawChainIds); + const rhat = computeRhat(pDraws, drawChainIds); const stdDev = computeStdDev(pDraws); const values = columns.map((column) => { if (column.key === 'mean') { return computeMean(pDraws); } else if (column.key === 'mcse') { - // placeholder for mcse - throw new Error('Not implemented'); + return stdDev / Math.sqrt(ess); } else if (column.key === 'stdDev') { return stdDev; @@ -95,16 +93,13 @@ const SummaryView: FunctionComponent = ({ width, height, draws return computePercentile(pDrawsSorted, 0.95); } else if (column.key === 'nEff') { - // placeholder for nEff - throw new Error('Not implemented'); + return ess; } else if (column.key === 'nEff/s') { - // placeholder for nEff/s - throw new Error('Not implemented'); + return computeTimeSec ? ess / computeTimeSec : NaN; } else if (column.key === 'rHat') { - // placeholder for rHat - throw new Error('Not implemented'); + return rhat; } else { return NaN; @@ -116,7 +111,7 @@ const SummaryView: FunctionComponent = ({ width, height, draws }) } return rows; - }, [paramNames, draws]); + }, [draws, paramNames, drawChainIds, computeTimeSec]); return (
@@ -157,6 +152,30 @@ const SummaryView: FunctionComponent = ({ width, height, draws ) } +const drawsByChain = (draws: number[], chainIds: number[]): number[][] => { + // Group draws by chain for use in computing ESS and Rhat + const uniqueChainIds = Array.from(new Set(chainIds)).sort(); + const drawsByChain: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); + for (let i = 0; i < draws.length; i++) { + const chainId = chainIds[i]; + const chainIndex = uniqueChainIds.indexOf(chainId); + drawsByChain[chainIndex].push(draws[i]); + } + return drawsByChain; +} + +const computeEss = (x: number[], chainIds: number[]) => { + const draws = drawsByChain(x, chainIds); + const ess = compute_effective_sample_size(draws); + return ess; +} + +const computeRhat = (x: number[], chainIds: number[]) => { + const draws = drawsByChain(x, chainIds); + const rhat = compute_split_potential_scale_reduction(draws); + return rhat; +} + // Example of Stan output... // Inference for Stan model: bernoulli_model // 1 chains: each with iter=(1000); warmup=(0); thin=(1); 1000 iterations saved. diff --git a/gui/src/app/SamplerOutputView/stan_stats/fft.ts b/gui/src/app/SamplerOutputView/stan_stats/fft.ts new file mode 100644 index 00000000..822e0db0 --- /dev/null +++ b/gui/src/app/SamplerOutputView/stan_stats/fft.ts @@ -0,0 +1,225 @@ +/* eslint-disable @typescript-eslint/no-inferrable-types */ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable prefer-const */ +/* + * Free FFT and convolution (TypeScript) + * + * Copyright (c) 2022 Project Nayuki. (MIT License) + * https://www.nayuki.io/page/free-small-fft-in-multiple-languages + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * - The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * - The Software is provided "as is", without warranty of any kind, express or + * implied, including but not limited to the warranties of merchantability, + * fitness for a particular purpose and noninfringement. In no event shall the + * authors or copyright holders be liable for any claim, damages or other + * liability, whether in an action of contract, tort or otherwise, arising from, + * out of or in connection with the Software or the use or other dealings in the + * Software. + */ + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This is a wrapper function. + */ +export function transform(real: Array|Float64Array, imag: Array|Float64Array): void { + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + if (n == 0) + return; + else if ((n & (n - 1)) == 0) // Is power of 2 + transformRadix2(real, imag); + else // More complicated algorithm for arbitrary sizes + transformBluestein(real, imag); +} + + +/* + * Computes the inverse discrete Fourier transform (IDFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This is a wrapper function. This transform does not perform scaling, so the inverse is not a true inverse. + */ +export function inverseTransform(real: Array|Float64Array, imag: Array|Float64Array): void { + transform(imag, real); +} + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector's length must be a power of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm. + */ +function transformRadix2(real: Array|Float64Array, imag: Array|Float64Array): void { + // Length variables + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + if (n == 1) // Trivial transform + return; + let levels: number = -1; + for (let i = 0; i < 32; i++) { + if (1 << i == n) + levels = i; // Equal to log2(n) + } + if (levels == -1) + throw new RangeError("Length is not a power of 2"); + + // Trigonometric tables + let cosTable = new Array(n / 2); + let sinTable = new Array(n / 2); + for (let i = 0; i < n / 2; i++) { + cosTable[i] = Math.cos(2 * Math.PI * i / n); + sinTable[i] = Math.sin(2 * Math.PI * i / n); + } + + // Bit-reversed addressing permutation + for (let i = 0; i < n; i++) { + const j: number = reverseBits(i, levels); + if (j > i) { + let temp: number = real[i]; + real[i] = real[j]; + real[j] = temp; + temp = imag[i]; + imag[i] = imag[j]; + imag[j] = temp; + } + } + + // Cooley-Tukey decimation-in-time radix-2 FFT + for (let size = 2; size <= n; size *= 2) { + const halfsize: number = size / 2; + const tablestep: number = n / size; + for (let i = 0; i < n; i += size) { + for (let j = i, k = 0; j < i + halfsize; j++, k += tablestep) { + const l: number = j + halfsize; + const tpre: number = real[l] * cosTable[k] + imag[l] * sinTable[k]; + const tpim: number = -real[l] * sinTable[k] + imag[l] * cosTable[k]; + real[l] = real[j] - tpre; + imag[l] = imag[j] - tpim; + real[j] += tpre; + imag[j] += tpim; + } + } + } + + // Returns the integer whose value is the reverse of the lowest 'width' bits of the integer 'val'. + function reverseBits(val: number, width: number): number { + let result: number = 0; + for (let i = 0; i < width; i++) { + result = (result << 1) | (val & 1); + val >>>= 1; + } + return result; + } +} + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This requires the convolution function, which in turn requires the radix-2 FFT function. + * Uses Bluestein's chirp z-transform algorithm. + */ +function transformBluestein(real: Array|Float64Array, imag: Array|Float64Array): void { + // Find a power-of-2 convolution length m such that m >= n * 2 + 1 + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + let m: number = 1; + while (m < n * 2 + 1) + m *= 2; + + // Trigonometric tables + let cosTable = new Array(n); + let sinTable = new Array(n); + for (let i = 0; i < n; i++) { + const j: number = i * i % (n * 2); // This is more accurate than j = i * i + cosTable[i] = Math.cos(Math.PI * j / n); + sinTable[i] = Math.sin(Math.PI * j / n); + } + + // Temporary vectors and preprocessing + let areal: Array = newArrayOfZeros(m); + let aimag: Array = newArrayOfZeros(m); + for (let i = 0; i < n; i++) { + areal[i] = real[i] * cosTable[i] + imag[i] * sinTable[i]; + aimag[i] = -real[i] * sinTable[i] + imag[i] * cosTable[i]; + } + let breal: Array = newArrayOfZeros(m); + let bimag: Array = newArrayOfZeros(m); + breal[0] = cosTable[0]; + bimag[0] = sinTable[0]; + for (let i = 1; i < n; i++) { + breal[i] = breal[m - i] = cosTable[i]; + bimag[i] = bimag[m - i] = sinTable[i]; + } + + // Convolution + let creal = new Array(m); + let cimag = new Array(m); + convolveComplex(areal, aimag, breal, bimag, creal, cimag); + + // Postprocessing + for (let i = 0; i < n; i++) { + real[i] = creal[i] * cosTable[i] + cimag[i] * sinTable[i]; + imag[i] = -creal[i] * sinTable[i] + cimag[i] * cosTable[i]; + } +} + + +/* + * Computes the circular convolution of the given real vectors. Each vector's length must be the same. + */ +// function convolveReal(xvec: Array|Float64Array, yvec: Array|Float64Array, outvec: Array|Float64Array): void { +// const n: number = xvec.length; +// if (n != yvec.length || n != outvec.length) +// throw new RangeError("Mismatched lengths"); +// convolveComplex(xvec, newArrayOfZeros(n), yvec, newArrayOfZeros(n), outvec, newArrayOfZeros(n)); +// } + + +/* + * Computes the circular convolution of the given complex vectors. Each vector's length must be the same. + */ +function convolveComplex( + xreal: Array|Float64Array, ximag: Array|Float64Array, + yreal: Array|Float64Array, yimag: Array|Float64Array, + outreal: Array|Float64Array, outimag: Array|Float64Array): void { + + const n: number = xreal.length; + if (n != ximag.length || n != yreal.length || n != yimag.length + || n != outreal.length || n != outimag.length) + throw new RangeError("Mismatched lengths"); + + xreal = xreal.slice(); + ximag = ximag.slice(); + yreal = yreal.slice(); + yimag = yimag.slice(); + transform(xreal, ximag); + transform(yreal, yimag); + + for (let i = 0; i < n; i++) { + const temp: number = xreal[i] * yreal[i] - ximag[i] * yimag[i]; + ximag[i] = ximag[i] * yreal[i] + xreal[i] * yimag[i]; + xreal[i] = temp; + } + inverseTransform(xreal, ximag); + + for (let i = 0; i < n; i++) { // Scaling (because this FFT implementation omits it) + outreal[i] = xreal[i] / n; + outimag[i] = ximag[i] / n; + } +} + + +function newArrayOfZeros(n: number): Array { + let result: Array = []; + for (let i = 0; i < n; i++) + result.push(0); + return result; +} \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts b/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts new file mode 100644 index 00000000..cba20558 --- /dev/null +++ b/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts @@ -0,0 +1,318 @@ +/* +Translated from +https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/compute_effective_sample_size.hpp +and +https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/autocovariance.hpp +and +https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp +*/ + +import { transform as inPlaceFftTransform, inverseTransform as inPlaceInverseFftTransform } from "./fft"; + +/** + * Computes the effective sample size (ESS) for the specified + * parameter across all kept samples. The value returned is the + * minimum of ESS and the number_total_draws * + * log10(number_total_draws). + * + * See more details in Stan reference manual section "Effective + * Sample Size". http://mc-stan.org/users/documentation + * + * Current implementation assumes draws are stored in contiguous + * blocks of memory. Chains are trimmed from the back to match the + * length of the shortest chain. Note that the effective sample size + * can not be estimated with less than four draws. + * + * draws: arrays of draws for each chain + * returns: effective sample size for the specified parameter + */ +export function compute_effective_sample_size(draws: number[][]): number { + const num_chains = draws.length; + + // use the minimum number of draws across all chains + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + if (num_draws < 4) { + // we don't have enough draws to compute ESS + return NaN; + } + + // check if chains are constant; all equal to first draw's value + let are_all_const = false; + const init_draw = new Array(num_chains).fill(0); + for (let chain_idx = 0; chain_idx < num_chains; chain_idx++) { + const draw = draws[chain_idx]; + for (let n = 0; n < num_draws; n++) { + if (!isFinite(draw[n])) { + // we can't compute ESS if there are non-finite values + return NaN; + } + } + + init_draw[chain_idx] = draw[0]; + + const precision = 1e-12; + if (draw.every(d => Math.abs(d - draw[0]) < precision)) { + are_all_const = true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + const precision = 1e-12; + if (init_draw.every(d => Math.abs(d - init_draw[0]) < precision)) { + return NaN; + } + } + + // acov: autocovariance for each chain + const acov = new Array(num_chains).fill(0).map(() => new Array(num_draws).fill(0)); + // chain_mean: mean of each chain + const chain_mean = new Array(num_chains).fill(0); + // chain_var: sample variance of each chain + const chain_var = new Array(num_chains).fill(0); + for (let chain = 0; chain < num_chains; ++chain) { + const draw = draws[chain]; + acov[chain] = autocovariance(draw); + chain_mean[chain] = compute_mean(draw); + chain_var[chain] = acov[chain][0] * num_draws / (num_draws - 1); + } + + // mean_var: mean of the chain variances + const mean_var = compute_mean(chain_var); + + let var_plus = mean_var * (num_draws - 1) / num_draws; + if (num_chains > 1) { + var_plus += compute_sample_variance(chain_mean); + } + + const rho_hat_s = new Array(num_draws).fill(0); + const acov_s = new Array(num_chains).fill(0); + for (let chain = 0; chain < num_chains; ++chain) { + acov_s[chain] = acov[chain][1]; + } + let rho_hat_even = 1.0; + rho_hat_s[0] = rho_hat_even; + let rho_hat_odd = 1 - (mean_var - compute_mean(acov_s)) / var_plus; + rho_hat_s[1] = rho_hat_odd; + + // Convert raw autocovariance estimators into Geyer's initial + // positive sequence. Loop only until num_draws - 4 to + // leave the last pair of autocorrelations as a bias term that + // reduces variance in the case of antithetical chains. + let s = 1; + while (s < (num_draws - 4) && (rho_hat_even + rho_hat_odd) > 0) { + for (let chain = 0; chain < num_chains; ++chain) { + acov_s[chain] = acov[chain][s + 1]; + } + rho_hat_even = 1 - (mean_var - compute_mean(acov_s)) / var_plus; + for (let chain = 0; chain < num_chains; ++chain) { + acov_s[chain] = acov[chain][s + 2]; + } + rho_hat_odd = 1 - (mean_var - compute_mean(acov_s)) / var_plus; + if ((rho_hat_even + rho_hat_odd) >= 0) { + rho_hat_s[s + 1] = rho_hat_even; + rho_hat_s[s + 2] = rho_hat_odd; + } + s += 2; + } + + const max_s = s; + // this is used in the improved estimate, which reduces variance + // in antithetic case -- see tau_hat below + if (rho_hat_even > 0) { + rho_hat_s[max_s + 1] = rho_hat_even; + } + + // Convert Geyer's initial positive sequence into an initial + // monotone sequence + for (let s = 1; s <= max_s - 3; s += 2) { + if (rho_hat_s[s + 1] + rho_hat_s[s + 2] > rho_hat_s[s - 1] + rho_hat_s[s]) { + rho_hat_s[s + 1] = (rho_hat_s[s - 1] + rho_hat_s[s]) / 2; + rho_hat_s[s + 2] = rho_hat_s[s + 1]; + } + } + + const num_total_draws = num_chains * num_draws; + // Geyer's truncated estimator for the asymptotic variance + // Improved estimate reduces variance in antithetic case + const tau_hat = -1 + 2 * compute_sum(rho_hat_s.slice(0, max_s)) + rho_hat_s[max_s + 1]; + return Math.min(num_total_draws / tau_hat, num_total_draws * Math.log10(num_total_draws)); +} + +function compute_sum(arr: number[]): number { + return arr.reduce((a, b) => a + b, 0); +} + +function compute_mean(arr: number[]): number { + return compute_sum(arr) / arr.length; +} + +function compute_population_variance(arr: number[]): number { + const mean = compute_mean(arr); + return compute_mean(arr.map(d => (d - mean) ** 2)); +} + +function compute_sample_variance(arr: number[]): number { + const mean = compute_mean(arr); + return compute_sum(arr.map(d => (d - mean) ** 2)) / (arr.length - 1); +} + +function autocorrelation(y: number[]): number[] { + const N = y.length; + const M = fftNextGoodSize(N); + const Mt2 = 2 * M; + + // centered_signal = y-mean(y) followed by N zeros + const centered_signal = new Array(Mt2).fill(0); + const y_mean = compute_mean(y); + for (let n = 0; n < N; n++) { + centered_signal[n] = y[n] - y_mean; + } + + const freqvec: [number, number][] = forwardFFT(centered_signal); + for (let i = 0; i < freqvec.length; i++) { + freqvec[i] = [freqvec[i][0] ** 2 + freqvec[i][1] ** 2, 0]; + } + + const ac_tmp = inverseFFT(freqvec); + + // use "biased" estimate as recommended by Geyer (1992) + const ac = new Array(N).fill(0); + for (let n = 0; n < N; n++) { + ac[n] = ac_tmp[n][0] / (N * N * 2); + } + const ac0 = ac[0]; + for (let n = 0; n < N; n++) { + ac[n] /= ac0; + } + + return ac; +} + +function autocovariance(y: number[]): number[] { + const acov = autocorrelation(y); + const variance = compute_population_variance(y); + return acov.map(v => v * variance); +} + +function fftNextGoodSize(n: number): number { + const isGoodSize = (n: number) => { + while (n % 2 === 0) { + n /= 2; + } + while (n % 3 === 0) { + n /= 3; + } + while (n % 5 === 0) { + n /= 5; + } + return n === 1; + } + while (!isGoodSize(n)) { + n++; + } + return n; +} + +function forwardFFT(signal: number[]): [number, number][] { + const realPart = [...signal]; + const imagPart = new Array(signal.length).fill(0); + inPlaceInverseFftTransform(realPart, imagPart); + return realPart.map((v, i) => [v, imagPart[i]]); +} + +function inverseFFT(freqvec: [number, number][]): [number, number][] { + const realPart = freqvec.map(v => v[0]); + const imagPart = freqvec.map(v => v[1]); + inPlaceFftTransform(realPart, imagPart); + return realPart.map((v, i) => [v, imagPart[i]]); +} + +const split_chains = (draws: number[][]) => { + const num_chains = draws.length; + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + const half = num_draws / 2.0; + // when N is odd, the (N+1)/2th draw is ignored + const end_first_half = Math.floor(half); + const start_second_half = Math.ceil(half); + const split_draws = new Array(2 * num_chains); + for (let n = 0; n < num_chains; ++n) { + split_draws[2 * n] = draws[n].slice(0, end_first_half); + split_draws[2 * n + 1] = draws[n].slice(start_second_half); + } + + return split_draws; +} + +export const compute_split_effective_sample_size = (draws: number[][]) => { + const split_draws = split_chains(draws); + + return compute_effective_sample_size(split_draws); +} + +export function compute_potential_scale_reduction(draws: number[][]): number { + const num_chains = draws.length; + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + // check if chains are constant; all equal to first draw's value + let are_all_const = false; + const init_draw = new Array(num_chains).fill(0); + for (let chain_idx = 0; chain_idx < num_chains; chain_idx++) { + const draw = draws[chain_idx]; + for (let n = 0; n < num_draws; n++) { + if (!isFinite(draw[n])) { + return NaN; + } + } + + init_draw[chain_idx] = draw[0]; + + const precision = 1e-12; + if (draw.every(d => Math.abs(d - draw[0]) < precision)) { + are_all_const = true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + const precision = 1e-12; + if (init_draw.every(d => Math.abs(d - init_draw[0]) < precision)) { + return NaN; + } + } + + // chain_mean: mean of each chain + const chain_mean = new Array(num_chains).fill(0); + // chain_var: sample variance of each chain + const chain_var = new Array(num_chains).fill(0); + + for (let chain = 0; chain < num_chains; ++chain) { + const draw = draws[chain]; + chain_mean[chain] = compute_mean(draw); + chain_var[chain] = compute_sample_variance(draw); + } + + const var_between = num_draws * compute_sample_variance(chain_mean); + const var_within = compute_mean(chain_var); + + return Math.sqrt((var_between / var_within + num_draws - 1) / num_draws); +} + +export function compute_split_potential_scale_reduction(draws: number[][]): number { + const split_draws = split_chains(draws); + + return compute_potential_scale_reduction(split_draws); +}