From ec711167beef09cacf20c9d47da7edb660f24f0d Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 7 Jun 2024 12:16:37 -0400 Subject: [PATCH 1/7] implement ess stats --- gui/src/app/SamplerOutputView/SummaryView.tsx | 73 ++++-- gui/src/app/SamplerOutputView/advanced/ess.ts | 142 +++++++++++ gui/src/app/SamplerOutputView/advanced/fft.ts | 225 ++++++++++++++++++ .../app/SamplerOutputView/advanced/rhat.ts | 32 +++ 4 files changed, 451 insertions(+), 21 deletions(-) create mode 100644 gui/src/app/SamplerOutputView/advanced/ess.ts create mode 100644 gui/src/app/SamplerOutputView/advanced/fft.ts create mode 100644 gui/src/app/SamplerOutputView/advanced/rhat.ts diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index 784e4ada..b84a9a5c 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -1,5 +1,7 @@ import { FunctionComponent, useMemo } from "react" +import { ess } from "./advanced/ess" import { computeMean, computePercentile, computeStdDev } from "./util" +import rhat from "./advanced/rhat" type SummaryViewProps = { width: number @@ -16,11 +18,11 @@ const columns = [ label: 'Mean', title: 'Mean value of the parameter' }, - /*future: { + { key: 'mcse', label: 'MCSE', title: 'Monte Carlo Standard Error: Standard deviation of the parameter divided by the square root of the effective sample size' - },*/ + }, { key: 'stdDev', label: 'StdDev', @@ -41,21 +43,21 @@ const columns = [ label: '95%', title: '95th percentile of the parameter' }, - /*future: { + { key: 'nEff', label: 'N_Eff', title: 'Effective sample size: A crude measure of the effective sample size (uses ess_imse)' - },*/ - /*future: { + }, + { key: 'nEff/s', label: 'N_Eff/s', title: 'Effective sample size per second of compute time' - },*/ - /*future: { + }, + { key: 'rHat', label: 'R_hat', title: 'Potential scale reduction factor on split chains (at convergence, R_hat=1)' - }*/ + } ] type TableRow = { @@ -63,24 +65,22 @@ type TableRow = { values: number[] } -const SummaryView: FunctionComponent = ({ width, height, draws, paramNames }) => { - // will be used in the future: - // const uniqueChainIds = useMemo(() => (Array.from(new Set(drawChainIds)).sort()), [drawChainIds]); - // note: computeTimeSec will be used in the future +const SummaryView: FunctionComponent = ({ width, height, draws, paramNames, drawChainIds, computeTimeSec }) => { + const uniqueChainIds = useMemo(() => (Array.from(new Set(drawChainIds)).sort()), [drawChainIds]); const rows = useMemo(() => { const rows: TableRow[] = []; for (const pname of paramNames) { const pDraws = draws[paramNames.indexOf(pname)]; const pDrawsSorted = [...pDraws].sort((a, b) => a - b); + const ess = computeEss(pDraws, drawChainIds); const stdDev = computeStdDev(pDraws); const values = columns.map((column) => { if (column.key === 'mean') { return computeMean(pDraws); } else if (column.key === 'mcse') { - // placeholder for mcse - throw new Error('Not implemented'); + return stdDev / Math.sqrt(ess); } else if (column.key === 'stdDev') { return stdDev; @@ -95,16 +95,16 @@ const SummaryView: FunctionComponent = ({ width, height, draws return computePercentile(pDrawsSorted, 0.95); } else if (column.key === 'nEff') { - // placeholder for nEff - throw new Error('Not implemented'); + return ess; } else if (column.key === 'nEff/s') { - // placeholder for nEff/s - throw new Error('Not implemented'); + return computeTimeSec ? ess / computeTimeSec : NaN; } else if (column.key === 'rHat') { - // placeholder for rHat - throw new Error('Not implemented'); + const counts = computeChainCounts(drawChainIds, uniqueChainIds); + const means = computeChainMeans(pDraws, drawChainIds, uniqueChainIds); + const stdevs = computeChainStdDevs(pDraws, drawChainIds, uniqueChainIds); + return rhat({ counts, means, stdevs }); } else { return NaN; @@ -116,7 +116,7 @@ const SummaryView: FunctionComponent = ({ width, height, draws }) } return rows; - }, [paramNames, draws]); + }, [paramNames, draws, drawChainIds, uniqueChainIds, computeTimeSec]); return (
@@ -157,6 +157,37 @@ const SummaryView: FunctionComponent = ({ width, height, draws ) } +const computeEss = (x: number[], chainIds: number[]) => { + const uniqueChainIds = Array.from(new Set(chainIds)).sort(); + let sumEss = 0; + for (const chainId of uniqueChainIds) { + const chainX = x.filter((_, i) => chainIds[i] === chainId); + const {essValue} = ess(chainX); + sumEss += essValue; + } + return sumEss; +} + +const computeChainCounts = (chainIds: number[], uniqueChainIds: number[]) => { + return uniqueChainIds.map((chainId) => { + return chainIds.filter((id) => id === chainId).length; + }); +} + +const computeChainMeans = (x: number[], chainIds: number[], uniqueChainIds: number[]) => { + return uniqueChainIds.map((chainId) => { + const chainX = x.filter((_, i) => chainIds[i] === chainId); + return computeMean(chainX); + }); +} + +const computeChainStdDevs = (x: number[], chainIds: number[], uniqueChainIds: number[]) => { + return uniqueChainIds.map((chainId) => { + const chainX = x.filter((_, i) => chainIds[i] === chainId); + return computeStdDev(chainX); + }); +} + // Example of Stan output... // Inference for Stan model: bernoulli_model // 1 chains: each with iter=(1000); warmup=(0); thin=(1); 1000 iterations saved. diff --git a/gui/src/app/SamplerOutputView/advanced/ess.ts b/gui/src/app/SamplerOutputView/advanced/ess.ts new file mode 100644 index 00000000..1ebe1553 --- /dev/null +++ b/gui/src/app/SamplerOutputView/advanced/ess.ts @@ -0,0 +1,142 @@ +// See: https://github.com/flatironinstitute/bayes-kit/blob/main/bayes_kit/ess.py + +import { inverseTransform, transform as transformFft } from "./fft" + +// def autocorr_fft(chain: VectorType) -> VectorType: +// """ +// Return sample autocorrelations at all lags for the specified sequence. +// Algorithmically, this function calls a fast Fourier transform (FFT). +// Parameters: +// chain: sequence whose autocorrelation is returned +// Returns: +// autocorrelation estimates at all lags for the specified sequence +// """ +// size = 2 ** np.ceil(np.log2(2 * len(chain) - 1)).astype("int") +// var = np.var(chain) +// ndata = chain - np.mean(chain) +// fft = np.fft.fft(ndata, size) +// pwr = np.abs(fft) ** 2 +// N = len(ndata) +// acorr = np.fft.ifft(pwr).real / var / N +// return acorr + +export function autocorr_fft(chain: number[], n: number): number[] { + const size = Math.round(Math.pow(2, Math.ceil(Math.log2(2 * chain.length - 1)))) + const variance = computeVariance(chain) + if (variance === undefined) return [] + const mean = computeMean(chain) + const ndata = chain.map(x => (x - (mean || 0))) + while (ndata.length < size) { + ndata.push(0) + } + const ndataFftReal = [...ndata] + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ndataFftImag = ndata.map(_x => (0)) + transformFft(ndataFftReal, ndataFftImag) + const pwr = ndataFftReal.map((r, i) => (r * r + ndataFftImag[i] * ndataFftImag[i])) + const N = ndata.length + const acorrReal = [...pwr] + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const acorrImag = pwr.map(_ => (0)) + inverseTransform(acorrReal, acorrImag) // doesn't include scaling + return acorrReal.slice(0, n).map(x => (x / variance / N / chain.length)) +} + +export function autocorr_slow(chain: number[], n: number): number[] { + // todo: use FFT + + const mu = chain.length ? sum(chain) / chain.length : 0 + const chain_ctr = chain.map(a => (a - mu)) + const N = chain_ctr.length + + ////////////////////////////////////////////////////////////// + // acorrN = np.correlate(chain_ctr, chain_ctr, "full")[N - 1 :] + let acorrN: number[] = [] + for (let i = 0; i < n; i++) { + let aa = 0 + for (let j = 0; j < N - i; j++) { + aa += chain_ctr[j] * chain_ctr[j + i] + } + acorrN.push(aa) + } + ////////////////////////////////////////////////////////////// + + // normalize so that acorrN[0] = 1 + const a0 = acorrN[0] + acorrN = acorrN.map(a => (a / a0)) + + return acorrN +} + +export function first_neg_pair_start(chain: number[]): number { + const N = chain.length + let n = 0 + while (n + 1 < N) { + if (chain[n] + chain[n + 1] < 0) { + return n + } + n = n + 1 + } + return N +} + +export function ess_ipse(chain: number[]): number { + if (chain.length < 4) { + console.warn('ess requires chain.length >=4') + return 0 + } + + // for verifying we get the same answer with both methods + // console.log('test autocor_slow', autocorr_slow([1, 2, 3, 4, 0, 0, 0], 5)) + // console.log('test autocor_fft', autocorr_fft([1, 2, 3, 4, 0, 0, 0], 5)) + + // const acor = autocorr_slow(chain, chain.length) + const acor = autocorr_fft(chain, chain.length) + const n = first_neg_pair_start(acor) + const sigma_sq_hat = acor[0] + 2 * sum(acor.slice(1, n)) + const ess = chain.length / sigma_sq_hat + return ess +} + +export function ess_imse(chain: number[]): {essValue: number, acor: number[]} { + if (chain.length < 4) { + console.warn('ess requires chain.length >=4') + return {essValue: 0, acor: []} + } + // const acor = autocorr_slow(chain, chain.length) + const acor = autocorr_fft(chain, chain.length) + const n = first_neg_pair_start(acor) + let prev_min = 1 + let accum = 0 + let i = 1 + while (i + 1 < n) { + prev_min = Math.min(prev_min, acor[i] + acor[i + 1]) + accum = accum + prev_min + i = i + 2 + } + + const sigma_sq_hat = acor[0] + 2 * accum + const essValue = chain.length / sigma_sq_hat + return {essValue, acor} +} + +export function ess(chain: number[]) { + // use ess_imse for now + return ess_imse(chain) +} + +function sum(x: number[]) { + return x.reduce((a, b) => (a + b), 0) +} + +function computeVariance(x: number[]) { + const mu = computeMean(x) + if (mu === undefined) return undefined + return sum(x.map(a => ( + (a - mu) * (a - mu) + ))) / x.length +} + +function computeMean(x: number[]) { + return x.length ? sum(x) / x.length : undefined +} \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/advanced/fft.ts b/gui/src/app/SamplerOutputView/advanced/fft.ts new file mode 100644 index 00000000..822e0db0 --- /dev/null +++ b/gui/src/app/SamplerOutputView/advanced/fft.ts @@ -0,0 +1,225 @@ +/* eslint-disable @typescript-eslint/no-inferrable-types */ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable prefer-const */ +/* + * Free FFT and convolution (TypeScript) + * + * Copyright (c) 2022 Project Nayuki. (MIT License) + * https://www.nayuki.io/page/free-small-fft-in-multiple-languages + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * - The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * - The Software is provided "as is", without warranty of any kind, express or + * implied, including but not limited to the warranties of merchantability, + * fitness for a particular purpose and noninfringement. In no event shall the + * authors or copyright holders be liable for any claim, damages or other + * liability, whether in an action of contract, tort or otherwise, arising from, + * out of or in connection with the Software or the use or other dealings in the + * Software. + */ + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This is a wrapper function. + */ +export function transform(real: Array|Float64Array, imag: Array|Float64Array): void { + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + if (n == 0) + return; + else if ((n & (n - 1)) == 0) // Is power of 2 + transformRadix2(real, imag); + else // More complicated algorithm for arbitrary sizes + transformBluestein(real, imag); +} + + +/* + * Computes the inverse discrete Fourier transform (IDFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This is a wrapper function. This transform does not perform scaling, so the inverse is not a true inverse. + */ +export function inverseTransform(real: Array|Float64Array, imag: Array|Float64Array): void { + transform(imag, real); +} + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector's length must be a power of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm. + */ +function transformRadix2(real: Array|Float64Array, imag: Array|Float64Array): void { + // Length variables + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + if (n == 1) // Trivial transform + return; + let levels: number = -1; + for (let i = 0; i < 32; i++) { + if (1 << i == n) + levels = i; // Equal to log2(n) + } + if (levels == -1) + throw new RangeError("Length is not a power of 2"); + + // Trigonometric tables + let cosTable = new Array(n / 2); + let sinTable = new Array(n / 2); + for (let i = 0; i < n / 2; i++) { + cosTable[i] = Math.cos(2 * Math.PI * i / n); + sinTable[i] = Math.sin(2 * Math.PI * i / n); + } + + // Bit-reversed addressing permutation + for (let i = 0; i < n; i++) { + const j: number = reverseBits(i, levels); + if (j > i) { + let temp: number = real[i]; + real[i] = real[j]; + real[j] = temp; + temp = imag[i]; + imag[i] = imag[j]; + imag[j] = temp; + } + } + + // Cooley-Tukey decimation-in-time radix-2 FFT + for (let size = 2; size <= n; size *= 2) { + const halfsize: number = size / 2; + const tablestep: number = n / size; + for (let i = 0; i < n; i += size) { + for (let j = i, k = 0; j < i + halfsize; j++, k += tablestep) { + const l: number = j + halfsize; + const tpre: number = real[l] * cosTable[k] + imag[l] * sinTable[k]; + const tpim: number = -real[l] * sinTable[k] + imag[l] * cosTable[k]; + real[l] = real[j] - tpre; + imag[l] = imag[j] - tpim; + real[j] += tpre; + imag[j] += tpim; + } + } + } + + // Returns the integer whose value is the reverse of the lowest 'width' bits of the integer 'val'. + function reverseBits(val: number, width: number): number { + let result: number = 0; + for (let i = 0; i < width; i++) { + result = (result << 1) | (val & 1); + val >>>= 1; + } + return result; + } +} + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This requires the convolution function, which in turn requires the radix-2 FFT function. + * Uses Bluestein's chirp z-transform algorithm. + */ +function transformBluestein(real: Array|Float64Array, imag: Array|Float64Array): void { + // Find a power-of-2 convolution length m such that m >= n * 2 + 1 + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + let m: number = 1; + while (m < n * 2 + 1) + m *= 2; + + // Trigonometric tables + let cosTable = new Array(n); + let sinTable = new Array(n); + for (let i = 0; i < n; i++) { + const j: number = i * i % (n * 2); // This is more accurate than j = i * i + cosTable[i] = Math.cos(Math.PI * j / n); + sinTable[i] = Math.sin(Math.PI * j / n); + } + + // Temporary vectors and preprocessing + let areal: Array = newArrayOfZeros(m); + let aimag: Array = newArrayOfZeros(m); + for (let i = 0; i < n; i++) { + areal[i] = real[i] * cosTable[i] + imag[i] * sinTable[i]; + aimag[i] = -real[i] * sinTable[i] + imag[i] * cosTable[i]; + } + let breal: Array = newArrayOfZeros(m); + let bimag: Array = newArrayOfZeros(m); + breal[0] = cosTable[0]; + bimag[0] = sinTable[0]; + for (let i = 1; i < n; i++) { + breal[i] = breal[m - i] = cosTable[i]; + bimag[i] = bimag[m - i] = sinTable[i]; + } + + // Convolution + let creal = new Array(m); + let cimag = new Array(m); + convolveComplex(areal, aimag, breal, bimag, creal, cimag); + + // Postprocessing + for (let i = 0; i < n; i++) { + real[i] = creal[i] * cosTable[i] + cimag[i] * sinTable[i]; + imag[i] = -creal[i] * sinTable[i] + cimag[i] * cosTable[i]; + } +} + + +/* + * Computes the circular convolution of the given real vectors. Each vector's length must be the same. + */ +// function convolveReal(xvec: Array|Float64Array, yvec: Array|Float64Array, outvec: Array|Float64Array): void { +// const n: number = xvec.length; +// if (n != yvec.length || n != outvec.length) +// throw new RangeError("Mismatched lengths"); +// convolveComplex(xvec, newArrayOfZeros(n), yvec, newArrayOfZeros(n), outvec, newArrayOfZeros(n)); +// } + + +/* + * Computes the circular convolution of the given complex vectors. Each vector's length must be the same. + */ +function convolveComplex( + xreal: Array|Float64Array, ximag: Array|Float64Array, + yreal: Array|Float64Array, yimag: Array|Float64Array, + outreal: Array|Float64Array, outimag: Array|Float64Array): void { + + const n: number = xreal.length; + if (n != ximag.length || n != yreal.length || n != yimag.length + || n != outreal.length || n != outimag.length) + throw new RangeError("Mismatched lengths"); + + xreal = xreal.slice(); + ximag = ximag.slice(); + yreal = yreal.slice(); + yimag = yimag.slice(); + transform(xreal, ximag); + transform(yreal, yimag); + + for (let i = 0; i < n; i++) { + const temp: number = xreal[i] * yreal[i] - ximag[i] * yimag[i]; + ximag[i] = ximag[i] * yreal[i] + xreal[i] * yimag[i]; + xreal[i] = temp; + } + inverseTransform(xreal, ximag); + + for (let i = 0; i < n; i++) { // Scaling (because this FFT implementation omits it) + outreal[i] = xreal[i] / n; + outimag[i] = ximag[i] / n; + } +} + + +function newArrayOfZeros(n: number): Array { + let result: Array = []; + for (let i = 0; i < n; i++) + result.push(0); + return result; +} \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/advanced/rhat.ts b/gui/src/app/SamplerOutputView/advanced/rhat.ts new file mode 100644 index 00000000..4c1a0a64 --- /dev/null +++ b/gui/src/app/SamplerOutputView/advanced/rhat.ts @@ -0,0 +1,32 @@ +import { computeMean, computeStdDev } from "../util" + +export default function rhat(o: {counts: (number | undefined)[], means: (number | undefined)[], stdevs: (number | undefined)[]}) { + // chain_lengths = [len(chain) for chain in chains] + // mean_chain_length = np.mean(chain_lengths) + // means = [np.mean(chain) for chain in chains] + // vars = [np.var(chain, ddof=1) for chain in chains] + // r_hat: np.float64 = np.sqrt( + // (mean_chain_length - 1) / mean_chain_length + np.var(means, ddof=1) / np.mean(vars) + // ) + const { counts, means, stdevs } = o + if (counts.indexOf(undefined) >= 0) return NaN + if (means.indexOf(undefined) >= 0) return NaN + if (stdevs.indexOf(undefined) >= 0) return NaN + const cc = counts as number[] + const mm = means as number[] + const ss = stdevs as number[] + if (cc.length <= 1) return NaN + for (const count of cc) { + if (count <= 1) return NaN + } + const mean_chain_length = computeMean(cc) + if (mean_chain_length === undefined) return NaN + const vars = ss.map((s, i) => (s * s * cc[i] / (cc[i] - 1))) + const stdevMeans = computeStdDev(mm) + if (stdevMeans === undefined) return NaN + const varMeans = stdevMeans * stdevMeans * cc.length / (cc.length - 1) + const meanVars = computeMean(vars) + if (meanVars === undefined) return NaN + const r_hat = Math.sqrt((mean_chain_length - 1) / mean_chain_length + varMeans / meanVars) + return r_hat +} \ No newline at end of file From bdaad30cc3ee7269f35558bb72b3d12ca73ec37a Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Thu, 13 Jun 2024 19:07:59 -0400 Subject: [PATCH 2/7] try port ess compute from stan c++ --- gui/src/app/SamplerOutputView/SummaryView.tsx | 40 ++- .../compute_effective_sample_size.ts | 228 ++++++++++++++++++ .../ess_computation_from_stan/fft.ts | 225 +++++++++++++++++ 3 files changed, 484 insertions(+), 9 deletions(-) create mode 100644 gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts create mode 100644 gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index b84a9a5c..ee2e8f41 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -2,6 +2,7 @@ import { FunctionComponent, useMemo } from "react" import { ess } from "./advanced/ess" import { computeMean, computePercentile, computeStdDev } from "./util" import rhat from "./advanced/rhat" +import compute_effective_sample_size from "./ess_computation_from_stan/compute_effective_sample_size" type SummaryViewProps = { width: number @@ -44,9 +45,14 @@ const columns = [ title: '95th percentile of the parameter' }, { - key: 'nEff', - label: 'N_Eff', - title: 'Effective sample size: A crude measure of the effective sample size (uses ess_imse)' + key: 'nEff1', + label: 'N_Eff1', + title: 'Effective sample size: A crude measure of the effective sample size (ported from bayes_kit untested)' + }, + { + key: 'nEff2', + label: 'N_Eff2', + title: 'Effective sample size: A crude measure of the effective sample size (ported from stan C++ untested)' }, { key: 'nEff/s', @@ -73,14 +79,15 @@ const SummaryView: FunctionComponent = ({ width, height, draws for (const pname of paramNames) { const pDraws = draws[paramNames.indexOf(pname)]; const pDrawsSorted = [...pDraws].sort((a, b) => a - b); - const ess = computeEss(pDraws, drawChainIds); + const ess1 = computeEss1(pDraws, drawChainIds); + const ess2 = computeEss2(pDraws, drawChainIds); const stdDev = computeStdDev(pDraws); const values = columns.map((column) => { if (column.key === 'mean') { return computeMean(pDraws); } else if (column.key === 'mcse') { - return stdDev / Math.sqrt(ess); + return stdDev / Math.sqrt(ess1); } else if (column.key === 'stdDev') { return stdDev; @@ -94,11 +101,14 @@ const SummaryView: FunctionComponent = ({ width, height, draws else if (column.key === '95%') { return computePercentile(pDrawsSorted, 0.95); } - else if (column.key === 'nEff') { - return ess; + else if (column.key === 'nEff1') { + return ess1; + } + else if (column.key === 'nEff2') { + return ess2; } else if (column.key === 'nEff/s') { - return computeTimeSec ? ess / computeTimeSec : NaN; + return computeTimeSec ? ess1 / computeTimeSec : NaN; } else if (column.key === 'rHat') { const counts = computeChainCounts(drawChainIds, uniqueChainIds); @@ -157,7 +167,7 @@ const SummaryView: FunctionComponent = ({ width, height, draws ) } -const computeEss = (x: number[], chainIds: number[]) => { +const computeEss1 = (x: number[], chainIds: number[]) => { const uniqueChainIds = Array.from(new Set(chainIds)).sort(); let sumEss = 0; for (const chainId of uniqueChainIds) { @@ -168,6 +178,18 @@ const computeEss = (x: number[], chainIds: number[]) => { return sumEss; } +const computeEss2 = (x: number[], chainIds: number[]) => { + const uniqueChainIds = Array.from(new Set(chainIds)).sort(); + const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); + for (let i = 0; i < x.length; i++) { + const chainId = chainIds[i]; + const chainIndex = uniqueChainIds.indexOf(chainId); + draws[chainIndex].push(x[i]); + } + const ess = compute_effective_sample_size(draws); + return ess; +} + const computeChainCounts = (chainIds: number[], uniqueChainIds: number[]) => { return uniqueChainIds.map((chainId) => { return chainIds.filter((id) => id === chainId).length; diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts new file mode 100644 index 00000000..41c0a4da --- /dev/null +++ b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts @@ -0,0 +1,228 @@ +/* +Translated from +https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/compute_effective_sample_size.hpp +and +https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/autocovariance.hpp +*/ + +import { transform as inPlaceFftTransform, inverseTransform as inPlaceInverseFftTransform } from "../advanced/fft"; + +/** + * Computes the effective sample size (ESS) for the specified + * parameter across all kept samples. The value returned is the + * minimum of ESS and the number_total_draws * + * log10(number_total_draws). + * + * See more details in Stan reference manual section "Effective + * Sample Size". http://mc-stan.org/users/documentation + * + * Current implementation assumes draws are stored in contiguous + * blocks of memory. Chains are trimmed from the back to match the + * length of the shortest chain. Note that the effective sample size + * can not be estimated with less than four draws. + * + * draws: arrays of draws for each chain + * returns: effective sample size for the specified parameter + */ +function compute_effective_sample_size(draws: number[][]): number { + const num_chains = draws.length; + + // use the minimum number of draws across all chains + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + if (num_draws < 4) { + // we don't have enough draws to compute ESS + return NaN; + } + + // check if chains are constant; all equal to first draw's value + let are_all_const = false; + const init_draw = new Array(num_chains).fill(0); + for (let chain_idx = 0; chain_idx < num_chains; chain_idx++) { + const draw = draws[chain_idx]; + for (let n = 0; n < num_draws; n++) { + if (!isFinite(draw[n])) { + // we can't compute ESS if there are non-finite values + return NaN; + } + } + + init_draw[chain_idx] = draw[0]; + + const precision = 1e-12; + if (draw.every(d => Math.abs(d - draw[0]) < precision)) { + are_all_const = true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + const precision = 1e-12; + if (init_draw.every(d => Math.abs(d - init_draw[0]) < precision)) { + return NaN; + } + } + + const acov = new Array(num_chains).fill(0).map(() => new Array(num_draws).fill(0)); + const chain_mean = new Array(num_chains).fill(0); // mean of each chain + const chain_var = new Array(num_chains).fill(0); // variance of each chain + for (let chain = 0; chain < num_chains; ++chain) { + const draw = draws[chain]; + // the autocovariance is computed for each chain + acov[chain] = autocovariance(draw); + chain_mean[chain] = compute_mean(draw); + chain_var[chain] = acov[chain][0] * num_draws / (num_draws - 1); + } + + const mean_var = compute_mean(chain_var); + let var_plus = mean_var * (num_draws - 1) / num_draws; + if (num_chains > 1) { + var_plus += compute_variance(chain_mean); + } + const rho_hat_s = new Array(num_draws).fill(0); + const acov_s = new Array(num_chains).fill(0); + for (let chain = 0; chain < num_chains; ++chain) { + acov_s[chain] = acov[chain][1]; + } + let rho_hat_even = 1.0; + rho_hat_s[0] = rho_hat_even; + let rho_hat_odd = 1 - (mean_var - compute_mean(acov_s)) / var_plus; + rho_hat_s[1] = rho_hat_odd; + + // Convert raw autocovariance estimators into Geyer's initial + // positive sequence. Loop only until num_draws - 4 to + // leave the last pair of autocorrelations as a bias term that + // reduces variance in the case of antithetical chains. + let s = 1; + while (s < (num_draws - 4) && (rho_hat_even + rho_hat_odd) > 0) { + for (let chain = 0; chain < num_chains; ++chain) { + acov_s[chain] = acov[chain][s + 1]; + } + rho_hat_even = 1 - (mean_var - compute_mean(acov_s)) / var_plus; + for (let chain = 0; chain < num_chains; ++chain) { + acov_s[chain] = acov[chain][s + 2]; + } + rho_hat_odd = 1 - (mean_var - compute_mean(acov_s)) / var_plus; + if ((rho_hat_even + rho_hat_odd) >= 0) { + rho_hat_s[s + 1] = rho_hat_even; + rho_hat_s[s + 2] = rho_hat_odd; + } + s += 2; + } + + const max_s = s; + // this is used in the improved estimate, which reduces variance + // in antithetic case -- see tau_hat below + if (rho_hat_even > 0) { + rho_hat_s[max_s + 1] = rho_hat_even; + } + + // Convert Geyer's initial positive sequence into an initial + // monotone sequence + for (let s = 1; s <= max_s - 3; s += 2) { + if (rho_hat_s[s + 1] + rho_hat_s[s + 2] > rho_hat_s[s - 1] + rho_hat_s[s]) { + rho_hat_s[s + 1] = (rho_hat_s[s - 1] + rho_hat_s[s]) / 2; + rho_hat_s[s + 2] = rho_hat_s[s + 1]; + } + } + + const num_total_draws = num_chains * num_draws; + // Geyer's truncated estimator for the asymptotic variance + // Improved estimate reduces variance in antithetic case + const tau_hat = -1 + 2 * compute_sum(rho_hat_s.slice(0, max_s)) + rho_hat_s[max_s + 1]; + return Math.min(num_total_draws / tau_hat, num_total_draws * Math.log10(num_total_draws)); +} + +function compute_sum(arr: number[]): number { + return arr.reduce((a, b) => a + b, 0); +} + +function compute_mean(arr: number[]): number { + return compute_sum(arr) / arr.length; +} + +function compute_variance(arr: number[]): number { + // QUESTION: is this the correct formula for variance? + const mean = compute_mean(arr); + return compute_mean(arr.map(d => (d - mean) ** 2)); +} + +function autocorrelation(y: number[]): number[] { + const N = y.length; + const M = fftNextGoodSize(N); + const Mt2 = 2 * M; + + // centered_signal = y-mean(y) followed by N zeros + const centered_signal = new Array(Mt2).fill(0); + const y_mean = compute_mean(y); + for (let n = 0; n < N; n++) { + centered_signal[n] = y[n] - y_mean; + } + + const freqvec: [number, number][] = forwardFFT(centered_signal); + for (let i = 0; i < freqvec.length; i++) { + freqvec[i] = [freqvec[i][0] ** 2 + freqvec[i][1] ** 2, 0]; + } + + const ac_tmp = inverseFFT(freqvec); + + // use "biased" estimate as recommended by Geyer (1992) + const ac = new Array(N).fill(0); + for (let n = 0; n < N; n++) { + ac[n] = ac_tmp[n][0] / (N * N * 2); + } + const ac0 = ac[0]; + for (let n = 0; n < N; n++) { + ac[n] /= ac0; + } + + return ac; +} + +function autocovariance(y: number[]): number[] { + const acov = autocorrelation(y); + const variance = compute_variance(y); + for (let n = 0; n < y.length; n++) { + acov[n] *= variance; + } + return acov; +} + +function fftNextGoodSize(n: number): number { + const isGoodSize = (n: number) => { + while (n % 2 === 0) { + n /= 2; + } + while (n % 3 === 0) { + n /= 3; + } + while (n % 5 === 0) { + n /= 5; + } + return n === 1; + } + while (!isGoodSize(n)) { + n++; + } + return n; +} + +function forwardFFT(signal: number[]): [number, number][] { + const realPart = [...signal]; + const imagPart = new Array(signal.length).fill(0); + inPlaceFftTransform(realPart, imagPart); + return realPart.map((v, i) => [v, imagPart[i]]); +} + +function inverseFFT(freqvec: [number, number][]): [number, number][] { + const realPart = freqvec.map(v => v[0]); + const imagPart = freqvec.map(v => v[1]); + inPlaceInverseFftTransform(realPart, imagPart); + return realPart.map((v, i) => [v, imagPart[i]]); +} + +export default compute_effective_sample_size; \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts b/gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts new file mode 100644 index 00000000..822e0db0 --- /dev/null +++ b/gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts @@ -0,0 +1,225 @@ +/* eslint-disable @typescript-eslint/no-inferrable-types */ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable prefer-const */ +/* + * Free FFT and convolution (TypeScript) + * + * Copyright (c) 2022 Project Nayuki. (MIT License) + * https://www.nayuki.io/page/free-small-fft-in-multiple-languages + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * - The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * - The Software is provided "as is", without warranty of any kind, express or + * implied, including but not limited to the warranties of merchantability, + * fitness for a particular purpose and noninfringement. In no event shall the + * authors or copyright holders be liable for any claim, damages or other + * liability, whether in an action of contract, tort or otherwise, arising from, + * out of or in connection with the Software or the use or other dealings in the + * Software. + */ + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This is a wrapper function. + */ +export function transform(real: Array|Float64Array, imag: Array|Float64Array): void { + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + if (n == 0) + return; + else if ((n & (n - 1)) == 0) // Is power of 2 + transformRadix2(real, imag); + else // More complicated algorithm for arbitrary sizes + transformBluestein(real, imag); +} + + +/* + * Computes the inverse discrete Fourier transform (IDFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This is a wrapper function. This transform does not perform scaling, so the inverse is not a true inverse. + */ +export function inverseTransform(real: Array|Float64Array, imag: Array|Float64Array): void { + transform(imag, real); +} + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector's length must be a power of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm. + */ +function transformRadix2(real: Array|Float64Array, imag: Array|Float64Array): void { + // Length variables + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + if (n == 1) // Trivial transform + return; + let levels: number = -1; + for (let i = 0; i < 32; i++) { + if (1 << i == n) + levels = i; // Equal to log2(n) + } + if (levels == -1) + throw new RangeError("Length is not a power of 2"); + + // Trigonometric tables + let cosTable = new Array(n / 2); + let sinTable = new Array(n / 2); + for (let i = 0; i < n / 2; i++) { + cosTable[i] = Math.cos(2 * Math.PI * i / n); + sinTable[i] = Math.sin(2 * Math.PI * i / n); + } + + // Bit-reversed addressing permutation + for (let i = 0; i < n; i++) { + const j: number = reverseBits(i, levels); + if (j > i) { + let temp: number = real[i]; + real[i] = real[j]; + real[j] = temp; + temp = imag[i]; + imag[i] = imag[j]; + imag[j] = temp; + } + } + + // Cooley-Tukey decimation-in-time radix-2 FFT + for (let size = 2; size <= n; size *= 2) { + const halfsize: number = size / 2; + const tablestep: number = n / size; + for (let i = 0; i < n; i += size) { + for (let j = i, k = 0; j < i + halfsize; j++, k += tablestep) { + const l: number = j + halfsize; + const tpre: number = real[l] * cosTable[k] + imag[l] * sinTable[k]; + const tpim: number = -real[l] * sinTable[k] + imag[l] * cosTable[k]; + real[l] = real[j] - tpre; + imag[l] = imag[j] - tpim; + real[j] += tpre; + imag[j] += tpim; + } + } + } + + // Returns the integer whose value is the reverse of the lowest 'width' bits of the integer 'val'. + function reverseBits(val: number, width: number): number { + let result: number = 0; + for (let i = 0; i < width; i++) { + result = (result << 1) | (val & 1); + val >>>= 1; + } + return result; + } +} + + +/* + * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. + * The vector can have any length. This requires the convolution function, which in turn requires the radix-2 FFT function. + * Uses Bluestein's chirp z-transform algorithm. + */ +function transformBluestein(real: Array|Float64Array, imag: Array|Float64Array): void { + // Find a power-of-2 convolution length m such that m >= n * 2 + 1 + const n: number = real.length; + if (n != imag.length) + throw new RangeError("Mismatched lengths"); + let m: number = 1; + while (m < n * 2 + 1) + m *= 2; + + // Trigonometric tables + let cosTable = new Array(n); + let sinTable = new Array(n); + for (let i = 0; i < n; i++) { + const j: number = i * i % (n * 2); // This is more accurate than j = i * i + cosTable[i] = Math.cos(Math.PI * j / n); + sinTable[i] = Math.sin(Math.PI * j / n); + } + + // Temporary vectors and preprocessing + let areal: Array = newArrayOfZeros(m); + let aimag: Array = newArrayOfZeros(m); + for (let i = 0; i < n; i++) { + areal[i] = real[i] * cosTable[i] + imag[i] * sinTable[i]; + aimag[i] = -real[i] * sinTable[i] + imag[i] * cosTable[i]; + } + let breal: Array = newArrayOfZeros(m); + let bimag: Array = newArrayOfZeros(m); + breal[0] = cosTable[0]; + bimag[0] = sinTable[0]; + for (let i = 1; i < n; i++) { + breal[i] = breal[m - i] = cosTable[i]; + bimag[i] = bimag[m - i] = sinTable[i]; + } + + // Convolution + let creal = new Array(m); + let cimag = new Array(m); + convolveComplex(areal, aimag, breal, bimag, creal, cimag); + + // Postprocessing + for (let i = 0; i < n; i++) { + real[i] = creal[i] * cosTable[i] + cimag[i] * sinTable[i]; + imag[i] = -creal[i] * sinTable[i] + cimag[i] * cosTable[i]; + } +} + + +/* + * Computes the circular convolution of the given real vectors. Each vector's length must be the same. + */ +// function convolveReal(xvec: Array|Float64Array, yvec: Array|Float64Array, outvec: Array|Float64Array): void { +// const n: number = xvec.length; +// if (n != yvec.length || n != outvec.length) +// throw new RangeError("Mismatched lengths"); +// convolveComplex(xvec, newArrayOfZeros(n), yvec, newArrayOfZeros(n), outvec, newArrayOfZeros(n)); +// } + + +/* + * Computes the circular convolution of the given complex vectors. Each vector's length must be the same. + */ +function convolveComplex( + xreal: Array|Float64Array, ximag: Array|Float64Array, + yreal: Array|Float64Array, yimag: Array|Float64Array, + outreal: Array|Float64Array, outimag: Array|Float64Array): void { + + const n: number = xreal.length; + if (n != ximag.length || n != yreal.length || n != yimag.length + || n != outreal.length || n != outimag.length) + throw new RangeError("Mismatched lengths"); + + xreal = xreal.slice(); + ximag = ximag.slice(); + yreal = yreal.slice(); + yimag = yimag.slice(); + transform(xreal, ximag); + transform(yreal, yimag); + + for (let i = 0; i < n; i++) { + const temp: number = xreal[i] * yreal[i] - ximag[i] * yimag[i]; + ximag[i] = ximag[i] * yreal[i] + xreal[i] * yimag[i]; + xreal[i] = temp; + } + inverseTransform(xreal, ximag); + + for (let i = 0; i < n; i++) { // Scaling (because this FFT implementation omits it) + outreal[i] = xreal[i] / n; + outimag[i] = ximag[i] / n; + } +} + + +function newArrayOfZeros(n: number): Array { + let result: Array = []; + for (let i = 0; i < n; i++) + result.push(0); + return result; +} \ No newline at end of file From 9db52ddfa55953e381cfb9be2c4effa24851b13f Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 14 Jun 2024 06:58:46 -0400 Subject: [PATCH 3/7] adjust ess calc --- .../compute_effective_sample_size.ts | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts index 41c0a4da..9adfc673 100644 --- a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts +++ b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts @@ -67,22 +67,27 @@ function compute_effective_sample_size(draws: number[][]): number { } } + // acov: autocovariance for each chain const acov = new Array(num_chains).fill(0).map(() => new Array(num_draws).fill(0)); - const chain_mean = new Array(num_chains).fill(0); // mean of each chain - const chain_var = new Array(num_chains).fill(0); // variance of each chain + // chain_mean: mean of each chain + const chain_mean = new Array(num_chains).fill(0); + // chain_var: variance of each chain + const chain_var = new Array(num_chains).fill(0); for (let chain = 0; chain < num_chains; ++chain) { const draw = draws[chain]; - // the autocovariance is computed for each chain acov[chain] = autocovariance(draw); chain_mean[chain] = compute_mean(draw); chain_var[chain] = acov[chain][0] * num_draws / (num_draws - 1); } + // mean_var: mean of the chain variances const mean_var = compute_mean(chain_var); + let var_plus = mean_var * (num_draws - 1) / num_draws; if (num_chains > 1) { var_plus += compute_variance(chain_mean); } + const rho_hat_s = new Array(num_draws).fill(0); const acov_s = new Array(num_chains).fill(0); for (let chain = 0; chain < num_chains; ++chain) { @@ -186,10 +191,7 @@ function autocorrelation(y: number[]): number[] { function autocovariance(y: number[]): number[] { const acov = autocorrelation(y); const variance = compute_variance(y); - for (let n = 0; n < y.length; n++) { - acov[n] *= variance; - } - return acov; + return acov.map(v => v * variance); } function fftNextGoodSize(n: number): number { @@ -214,15 +216,45 @@ function fftNextGoodSize(n: number): number { function forwardFFT(signal: number[]): [number, number][] { const realPart = [...signal]; const imagPart = new Array(signal.length).fill(0); - inPlaceFftTransform(realPart, imagPart); + inPlaceInverseFftTransform(realPart, imagPart); return realPart.map((v, i) => [v, imagPart[i]]); } function inverseFFT(freqvec: [number, number][]): [number, number][] { const realPart = freqvec.map(v => v[0]); const imagPart = freqvec.map(v => v[1]); - inPlaceInverseFftTransform(realPart, imagPart); + inPlaceFftTransform(realPart, imagPart); return realPart.map((v, i) => [v, imagPart[i]]); } +const split_chains = (draws: number[][]) => { + const num_chains = draws.length; + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + const half = num_draws / 2; + const half_draws = Math.ceil(half); + const split_draws = new Array(2 * num_chains); + for (let n = 0; n < num_chains; ++n) { + split_draws[2 * n] = draws[n].slice(0, half_draws); + split_draws[2 * n + 1] = draws[n].slice(half_draws); + } + + return split_draws; +} + +export const compute_split_effective_sample_size = (draws: number[][]) => { + const num_chains = draws.length; + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + const split_draws = split_chains(draws); + + return compute_effective_sample_size(split_draws); +} + export default compute_effective_sample_size; \ No newline at end of file From 65c7969b5786dfb65b6fd4f777013ad98f21dc13 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Jun 2024 14:14:15 +0000 Subject: [PATCH 4/7] Fix ESS calculation (sample vs population variance issue) --- .../compute_effective_sample_size.ts | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts index 9adfc673..155aa12c 100644 --- a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts +++ b/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts @@ -85,7 +85,7 @@ function compute_effective_sample_size(draws: number[][]): number { let var_plus = mean_var * (num_draws - 1) / num_draws; if (num_chains > 1) { - var_plus += compute_variance(chain_mean); + var_plus += compute_sample_variance(chain_mean); } const rho_hat_s = new Array(num_draws).fill(0); @@ -150,12 +150,16 @@ function compute_mean(arr: number[]): number { return compute_sum(arr) / arr.length; } -function compute_variance(arr: number[]): number { - // QUESTION: is this the correct formula for variance? +function compute_population_variance(arr: number[]): number { const mean = compute_mean(arr); return compute_mean(arr.map(d => (d - mean) ** 2)); } +function compute_sample_variance(arr: number[]): number { + const mean = compute_mean(arr); + return compute_sum(arr.map(d => (d - mean) ** 2)) / (arr.length - 1); +} + function autocorrelation(y: number[]): number[] { const N = y.length; const M = fftNextGoodSize(N); @@ -190,7 +194,7 @@ function autocorrelation(y: number[]): number[] { function autocovariance(y: number[]): number[] { const acov = autocorrelation(y); - const variance = compute_variance(y); + const variance = compute_population_variance(y); return acov.map(v => v * variance); } @@ -257,4 +261,4 @@ export const compute_split_effective_sample_size = (draws: number[][]) => { return compute_effective_sample_size(split_draws); } -export default compute_effective_sample_size; \ No newline at end of file +export default compute_effective_sample_size; From 7e319ffb14ce56253f77f0456231450716e5eddc Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 14 Jun 2024 10:48:45 -0400 Subject: [PATCH 5/7] implement Rhat and rearrange stan stats --- gui/src/app/SamplerOutputView/SummaryView.tsx | 80 ++----- gui/src/app/SamplerOutputView/advanced/ess.ts | 142 ----------- .../app/SamplerOutputView/advanced/rhat.ts | 32 --- .../ess_computation_from_stan/fft.ts | 225 ------------------ .../{advanced => stan_stats}/fft.ts | 0 .../stan_stats.ts} | 71 +++++- 6 files changed, 90 insertions(+), 460 deletions(-) delete mode 100644 gui/src/app/SamplerOutputView/advanced/ess.ts delete mode 100644 gui/src/app/SamplerOutputView/advanced/rhat.ts delete mode 100644 gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts rename gui/src/app/SamplerOutputView/{advanced => stan_stats}/fft.ts (100%) rename gui/src/app/SamplerOutputView/{ess_computation_from_stan/compute_effective_sample_size.ts => stan_stats/stan_stats.ts} (78%) diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index ee2e8f41..832fa128 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -1,8 +1,6 @@ import { FunctionComponent, useMemo } from "react" -import { ess } from "./advanced/ess" import { computeMean, computePercentile, computeStdDev } from "./util" -import rhat from "./advanced/rhat" -import compute_effective_sample_size from "./ess_computation_from_stan/compute_effective_sample_size" +import { compute_effective_sample_size, compute_split_potential_scale_reduction } from "./stan_stats/stan_stats" type SummaryViewProps = { width: number @@ -45,14 +43,9 @@ const columns = [ title: '95th percentile of the parameter' }, { - key: 'nEff1', - label: 'N_Eff1', - title: 'Effective sample size: A crude measure of the effective sample size (ported from bayes_kit untested)' - }, - { - key: 'nEff2', - label: 'N_Eff2', - title: 'Effective sample size: A crude measure of the effective sample size (ported from stan C++ untested)' + key: 'nEff', + label: 'N_Eff', + title: 'Effective sample size: A crude measure of the effective sample size' }, { key: 'nEff/s', @@ -72,22 +65,20 @@ type TableRow = { } const SummaryView: FunctionComponent = ({ width, height, draws, paramNames, drawChainIds, computeTimeSec }) => { - const uniqueChainIds = useMemo(() => (Array.from(new Set(drawChainIds)).sort()), [drawChainIds]); - const rows = useMemo(() => { const rows: TableRow[] = []; for (const pname of paramNames) { const pDraws = draws[paramNames.indexOf(pname)]; const pDrawsSorted = [...pDraws].sort((a, b) => a - b); - const ess1 = computeEss1(pDraws, drawChainIds); - const ess2 = computeEss2(pDraws, drawChainIds); + const ess = computeEss(pDraws, drawChainIds); + const rhat = computeRhat(pDraws, drawChainIds); const stdDev = computeStdDev(pDraws); const values = columns.map((column) => { if (column.key === 'mean') { return computeMean(pDraws); } else if (column.key === 'mcse') { - return stdDev / Math.sqrt(ess1); + return stdDev / Math.sqrt(ess); } else if (column.key === 'stdDev') { return stdDev; @@ -101,20 +92,14 @@ const SummaryView: FunctionComponent = ({ width, height, draws else if (column.key === '95%') { return computePercentile(pDrawsSorted, 0.95); } - else if (column.key === 'nEff1') { - return ess1; - } - else if (column.key === 'nEff2') { - return ess2; + else if (column.key === 'nEff') { + return ess; } else if (column.key === 'nEff/s') { - return computeTimeSec ? ess1 / computeTimeSec : NaN; + return computeTimeSec ? ess / computeTimeSec : NaN; } else if (column.key === 'rHat') { - const counts = computeChainCounts(drawChainIds, uniqueChainIds); - const means = computeChainMeans(pDraws, drawChainIds, uniqueChainIds); - const stdevs = computeChainStdDevs(pDraws, drawChainIds, uniqueChainIds); - return rhat({ counts, means, stdevs }); + return rhat; } else { return NaN; @@ -126,7 +111,7 @@ const SummaryView: FunctionComponent = ({ width, height, draws }) } return rows; - }, [paramNames, draws, drawChainIds, uniqueChainIds, computeTimeSec]); + }, [draws, paramNames, drawChainIds, computeTimeSec]); return (
@@ -167,18 +152,7 @@ const SummaryView: FunctionComponent = ({ width, height, draws ) } -const computeEss1 = (x: number[], chainIds: number[]) => { - const uniqueChainIds = Array.from(new Set(chainIds)).sort(); - let sumEss = 0; - for (const chainId of uniqueChainIds) { - const chainX = x.filter((_, i) => chainIds[i] === chainId); - const {essValue} = ess(chainX); - sumEss += essValue; - } - return sumEss; -} - -const computeEss2 = (x: number[], chainIds: number[]) => { +const computeEss = (x: number[], chainIds: number[]) => { const uniqueChainIds = Array.from(new Set(chainIds)).sort(); const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); for (let i = 0; i < x.length; i++) { @@ -190,24 +164,16 @@ const computeEss2 = (x: number[], chainIds: number[]) => { return ess; } -const computeChainCounts = (chainIds: number[], uniqueChainIds: number[]) => { - return uniqueChainIds.map((chainId) => { - return chainIds.filter((id) => id === chainId).length; - }); -} - -const computeChainMeans = (x: number[], chainIds: number[], uniqueChainIds: number[]) => { - return uniqueChainIds.map((chainId) => { - const chainX = x.filter((_, i) => chainIds[i] === chainId); - return computeMean(chainX); - }); -} - -const computeChainStdDevs = (x: number[], chainIds: number[], uniqueChainIds: number[]) => { - return uniqueChainIds.map((chainId) => { - const chainX = x.filter((_, i) => chainIds[i] === chainId); - return computeStdDev(chainX); - }); +const computeRhat = (x: number[], chainIds: number[]) => { + const uniqueChainIds = Array.from(new Set(chainIds)).sort(); + const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); + for (let i = 0; i < x.length; i++) { + const chainId = chainIds[i]; + const chainIndex = uniqueChainIds.indexOf(chainId); + draws[chainIndex].push(x[i]); + } + const rhat = compute_split_potential_scale_reduction(draws); + return rhat; } // Example of Stan output... diff --git a/gui/src/app/SamplerOutputView/advanced/ess.ts b/gui/src/app/SamplerOutputView/advanced/ess.ts deleted file mode 100644 index 1ebe1553..00000000 --- a/gui/src/app/SamplerOutputView/advanced/ess.ts +++ /dev/null @@ -1,142 +0,0 @@ -// See: https://github.com/flatironinstitute/bayes-kit/blob/main/bayes_kit/ess.py - -import { inverseTransform, transform as transformFft } from "./fft" - -// def autocorr_fft(chain: VectorType) -> VectorType: -// """ -// Return sample autocorrelations at all lags for the specified sequence. -// Algorithmically, this function calls a fast Fourier transform (FFT). -// Parameters: -// chain: sequence whose autocorrelation is returned -// Returns: -// autocorrelation estimates at all lags for the specified sequence -// """ -// size = 2 ** np.ceil(np.log2(2 * len(chain) - 1)).astype("int") -// var = np.var(chain) -// ndata = chain - np.mean(chain) -// fft = np.fft.fft(ndata, size) -// pwr = np.abs(fft) ** 2 -// N = len(ndata) -// acorr = np.fft.ifft(pwr).real / var / N -// return acorr - -export function autocorr_fft(chain: number[], n: number): number[] { - const size = Math.round(Math.pow(2, Math.ceil(Math.log2(2 * chain.length - 1)))) - const variance = computeVariance(chain) - if (variance === undefined) return [] - const mean = computeMean(chain) - const ndata = chain.map(x => (x - (mean || 0))) - while (ndata.length < size) { - ndata.push(0) - } - const ndataFftReal = [...ndata] - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const ndataFftImag = ndata.map(_x => (0)) - transformFft(ndataFftReal, ndataFftImag) - const pwr = ndataFftReal.map((r, i) => (r * r + ndataFftImag[i] * ndataFftImag[i])) - const N = ndata.length - const acorrReal = [...pwr] - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const acorrImag = pwr.map(_ => (0)) - inverseTransform(acorrReal, acorrImag) // doesn't include scaling - return acorrReal.slice(0, n).map(x => (x / variance / N / chain.length)) -} - -export function autocorr_slow(chain: number[], n: number): number[] { - // todo: use FFT - - const mu = chain.length ? sum(chain) / chain.length : 0 - const chain_ctr = chain.map(a => (a - mu)) - const N = chain_ctr.length - - ////////////////////////////////////////////////////////////// - // acorrN = np.correlate(chain_ctr, chain_ctr, "full")[N - 1 :] - let acorrN: number[] = [] - for (let i = 0; i < n; i++) { - let aa = 0 - for (let j = 0; j < N - i; j++) { - aa += chain_ctr[j] * chain_ctr[j + i] - } - acorrN.push(aa) - } - ////////////////////////////////////////////////////////////// - - // normalize so that acorrN[0] = 1 - const a0 = acorrN[0] - acorrN = acorrN.map(a => (a / a0)) - - return acorrN -} - -export function first_neg_pair_start(chain: number[]): number { - const N = chain.length - let n = 0 - while (n + 1 < N) { - if (chain[n] + chain[n + 1] < 0) { - return n - } - n = n + 1 - } - return N -} - -export function ess_ipse(chain: number[]): number { - if (chain.length < 4) { - console.warn('ess requires chain.length >=4') - return 0 - } - - // for verifying we get the same answer with both methods - // console.log('test autocor_slow', autocorr_slow([1, 2, 3, 4, 0, 0, 0], 5)) - // console.log('test autocor_fft', autocorr_fft([1, 2, 3, 4, 0, 0, 0], 5)) - - // const acor = autocorr_slow(chain, chain.length) - const acor = autocorr_fft(chain, chain.length) - const n = first_neg_pair_start(acor) - const sigma_sq_hat = acor[0] + 2 * sum(acor.slice(1, n)) - const ess = chain.length / sigma_sq_hat - return ess -} - -export function ess_imse(chain: number[]): {essValue: number, acor: number[]} { - if (chain.length < 4) { - console.warn('ess requires chain.length >=4') - return {essValue: 0, acor: []} - } - // const acor = autocorr_slow(chain, chain.length) - const acor = autocorr_fft(chain, chain.length) - const n = first_neg_pair_start(acor) - let prev_min = 1 - let accum = 0 - let i = 1 - while (i + 1 < n) { - prev_min = Math.min(prev_min, acor[i] + acor[i + 1]) - accum = accum + prev_min - i = i + 2 - } - - const sigma_sq_hat = acor[0] + 2 * accum - const essValue = chain.length / sigma_sq_hat - return {essValue, acor} -} - -export function ess(chain: number[]) { - // use ess_imse for now - return ess_imse(chain) -} - -function sum(x: number[]) { - return x.reduce((a, b) => (a + b), 0) -} - -function computeVariance(x: number[]) { - const mu = computeMean(x) - if (mu === undefined) return undefined - return sum(x.map(a => ( - (a - mu) * (a - mu) - ))) / x.length -} - -function computeMean(x: number[]) { - return x.length ? sum(x) / x.length : undefined -} \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/advanced/rhat.ts b/gui/src/app/SamplerOutputView/advanced/rhat.ts deleted file mode 100644 index 4c1a0a64..00000000 --- a/gui/src/app/SamplerOutputView/advanced/rhat.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { computeMean, computeStdDev } from "../util" - -export default function rhat(o: {counts: (number | undefined)[], means: (number | undefined)[], stdevs: (number | undefined)[]}) { - // chain_lengths = [len(chain) for chain in chains] - // mean_chain_length = np.mean(chain_lengths) - // means = [np.mean(chain) for chain in chains] - // vars = [np.var(chain, ddof=1) for chain in chains] - // r_hat: np.float64 = np.sqrt( - // (mean_chain_length - 1) / mean_chain_length + np.var(means, ddof=1) / np.mean(vars) - // ) - const { counts, means, stdevs } = o - if (counts.indexOf(undefined) >= 0) return NaN - if (means.indexOf(undefined) >= 0) return NaN - if (stdevs.indexOf(undefined) >= 0) return NaN - const cc = counts as number[] - const mm = means as number[] - const ss = stdevs as number[] - if (cc.length <= 1) return NaN - for (const count of cc) { - if (count <= 1) return NaN - } - const mean_chain_length = computeMean(cc) - if (mean_chain_length === undefined) return NaN - const vars = ss.map((s, i) => (s * s * cc[i] / (cc[i] - 1))) - const stdevMeans = computeStdDev(mm) - if (stdevMeans === undefined) return NaN - const varMeans = stdevMeans * stdevMeans * cc.length / (cc.length - 1) - const meanVars = computeMean(vars) - if (meanVars === undefined) return NaN - const r_hat = Math.sqrt((mean_chain_length - 1) / mean_chain_length + varMeans / meanVars) - return r_hat -} \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts b/gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts deleted file mode 100644 index 822e0db0..00000000 --- a/gui/src/app/SamplerOutputView/ess_computation_from_stan/fft.ts +++ /dev/null @@ -1,225 +0,0 @@ -/* eslint-disable @typescript-eslint/no-inferrable-types */ -/* eslint-disable @typescript-eslint/no-unused-vars */ -/* eslint-disable prefer-const */ -/* - * Free FFT and convolution (TypeScript) - * - * Copyright (c) 2022 Project Nayuki. (MIT License) - * https://www.nayuki.io/page/free-small-fft-in-multiple-languages - * - * Permission is hereby granted, free of charge, to any person obtaining a copy of - * this software and associated documentation files (the "Software"), to deal in - * the Software without restriction, including without limitation the rights to - * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of - * the Software, and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - The Software is provided "as is", without warranty of any kind, express or - * implied, including but not limited to the warranties of merchantability, - * fitness for a particular purpose and noninfringement. In no event shall the - * authors or copyright holders be liable for any claim, damages or other - * liability, whether in an action of contract, tort or otherwise, arising from, - * out of or in connection with the Software or the use or other dealings in the - * Software. - */ - - -/* - * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. - * The vector can have any length. This is a wrapper function. - */ -export function transform(real: Array|Float64Array, imag: Array|Float64Array): void { - const n: number = real.length; - if (n != imag.length) - throw new RangeError("Mismatched lengths"); - if (n == 0) - return; - else if ((n & (n - 1)) == 0) // Is power of 2 - transformRadix2(real, imag); - else // More complicated algorithm for arbitrary sizes - transformBluestein(real, imag); -} - - -/* - * Computes the inverse discrete Fourier transform (IDFT) of the given complex vector, storing the result back into the vector. - * The vector can have any length. This is a wrapper function. This transform does not perform scaling, so the inverse is not a true inverse. - */ -export function inverseTransform(real: Array|Float64Array, imag: Array|Float64Array): void { - transform(imag, real); -} - - -/* - * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. - * The vector's length must be a power of 2. Uses the Cooley-Tukey decimation-in-time radix-2 algorithm. - */ -function transformRadix2(real: Array|Float64Array, imag: Array|Float64Array): void { - // Length variables - const n: number = real.length; - if (n != imag.length) - throw new RangeError("Mismatched lengths"); - if (n == 1) // Trivial transform - return; - let levels: number = -1; - for (let i = 0; i < 32; i++) { - if (1 << i == n) - levels = i; // Equal to log2(n) - } - if (levels == -1) - throw new RangeError("Length is not a power of 2"); - - // Trigonometric tables - let cosTable = new Array(n / 2); - let sinTable = new Array(n / 2); - for (let i = 0; i < n / 2; i++) { - cosTable[i] = Math.cos(2 * Math.PI * i / n); - sinTable[i] = Math.sin(2 * Math.PI * i / n); - } - - // Bit-reversed addressing permutation - for (let i = 0; i < n; i++) { - const j: number = reverseBits(i, levels); - if (j > i) { - let temp: number = real[i]; - real[i] = real[j]; - real[j] = temp; - temp = imag[i]; - imag[i] = imag[j]; - imag[j] = temp; - } - } - - // Cooley-Tukey decimation-in-time radix-2 FFT - for (let size = 2; size <= n; size *= 2) { - const halfsize: number = size / 2; - const tablestep: number = n / size; - for (let i = 0; i < n; i += size) { - for (let j = i, k = 0; j < i + halfsize; j++, k += tablestep) { - const l: number = j + halfsize; - const tpre: number = real[l] * cosTable[k] + imag[l] * sinTable[k]; - const tpim: number = -real[l] * sinTable[k] + imag[l] * cosTable[k]; - real[l] = real[j] - tpre; - imag[l] = imag[j] - tpim; - real[j] += tpre; - imag[j] += tpim; - } - } - } - - // Returns the integer whose value is the reverse of the lowest 'width' bits of the integer 'val'. - function reverseBits(val: number, width: number): number { - let result: number = 0; - for (let i = 0; i < width; i++) { - result = (result << 1) | (val & 1); - val >>>= 1; - } - return result; - } -} - - -/* - * Computes the discrete Fourier transform (DFT) of the given complex vector, storing the result back into the vector. - * The vector can have any length. This requires the convolution function, which in turn requires the radix-2 FFT function. - * Uses Bluestein's chirp z-transform algorithm. - */ -function transformBluestein(real: Array|Float64Array, imag: Array|Float64Array): void { - // Find a power-of-2 convolution length m such that m >= n * 2 + 1 - const n: number = real.length; - if (n != imag.length) - throw new RangeError("Mismatched lengths"); - let m: number = 1; - while (m < n * 2 + 1) - m *= 2; - - // Trigonometric tables - let cosTable = new Array(n); - let sinTable = new Array(n); - for (let i = 0; i < n; i++) { - const j: number = i * i % (n * 2); // This is more accurate than j = i * i - cosTable[i] = Math.cos(Math.PI * j / n); - sinTable[i] = Math.sin(Math.PI * j / n); - } - - // Temporary vectors and preprocessing - let areal: Array = newArrayOfZeros(m); - let aimag: Array = newArrayOfZeros(m); - for (let i = 0; i < n; i++) { - areal[i] = real[i] * cosTable[i] + imag[i] * sinTable[i]; - aimag[i] = -real[i] * sinTable[i] + imag[i] * cosTable[i]; - } - let breal: Array = newArrayOfZeros(m); - let bimag: Array = newArrayOfZeros(m); - breal[0] = cosTable[0]; - bimag[0] = sinTable[0]; - for (let i = 1; i < n; i++) { - breal[i] = breal[m - i] = cosTable[i]; - bimag[i] = bimag[m - i] = sinTable[i]; - } - - // Convolution - let creal = new Array(m); - let cimag = new Array(m); - convolveComplex(areal, aimag, breal, bimag, creal, cimag); - - // Postprocessing - for (let i = 0; i < n; i++) { - real[i] = creal[i] * cosTable[i] + cimag[i] * sinTable[i]; - imag[i] = -creal[i] * sinTable[i] + cimag[i] * cosTable[i]; - } -} - - -/* - * Computes the circular convolution of the given real vectors. Each vector's length must be the same. - */ -// function convolveReal(xvec: Array|Float64Array, yvec: Array|Float64Array, outvec: Array|Float64Array): void { -// const n: number = xvec.length; -// if (n != yvec.length || n != outvec.length) -// throw new RangeError("Mismatched lengths"); -// convolveComplex(xvec, newArrayOfZeros(n), yvec, newArrayOfZeros(n), outvec, newArrayOfZeros(n)); -// } - - -/* - * Computes the circular convolution of the given complex vectors. Each vector's length must be the same. - */ -function convolveComplex( - xreal: Array|Float64Array, ximag: Array|Float64Array, - yreal: Array|Float64Array, yimag: Array|Float64Array, - outreal: Array|Float64Array, outimag: Array|Float64Array): void { - - const n: number = xreal.length; - if (n != ximag.length || n != yreal.length || n != yimag.length - || n != outreal.length || n != outimag.length) - throw new RangeError("Mismatched lengths"); - - xreal = xreal.slice(); - ximag = ximag.slice(); - yreal = yreal.slice(); - yimag = yimag.slice(); - transform(xreal, ximag); - transform(yreal, yimag); - - for (let i = 0; i < n; i++) { - const temp: number = xreal[i] * yreal[i] - ximag[i] * yimag[i]; - ximag[i] = ximag[i] * yreal[i] + xreal[i] * yimag[i]; - xreal[i] = temp; - } - inverseTransform(xreal, ximag); - - for (let i = 0; i < n; i++) { // Scaling (because this FFT implementation omits it) - outreal[i] = xreal[i] / n; - outimag[i] = ximag[i] / n; - } -} - - -function newArrayOfZeros(n: number): Array { - let result: Array = []; - for (let i = 0; i < n; i++) - result.push(0); - return result; -} \ No newline at end of file diff --git a/gui/src/app/SamplerOutputView/advanced/fft.ts b/gui/src/app/SamplerOutputView/stan_stats/fft.ts similarity index 100% rename from gui/src/app/SamplerOutputView/advanced/fft.ts rename to gui/src/app/SamplerOutputView/stan_stats/fft.ts diff --git a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts b/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts similarity index 78% rename from gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts rename to gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts index 155aa12c..6efdd814 100644 --- a/gui/src/app/SamplerOutputView/ess_computation_from_stan/compute_effective_sample_size.ts +++ b/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts @@ -3,9 +3,11 @@ Translated from https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/compute_effective_sample_size.hpp and https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/autocovariance.hpp +and +https://github.com/stan-dev/stan/blob/develop/src/stan/analyze/mcmc/compute_potential_scale_reduction.hpp */ -import { transform as inPlaceFftTransform, inverseTransform as inPlaceInverseFftTransform } from "../advanced/fft"; +import { transform as inPlaceFftTransform, inverseTransform as inPlaceInverseFftTransform } from "./fft"; /** * Computes the effective sample size (ESS) for the specified @@ -24,7 +26,7 @@ import { transform as inPlaceFftTransform, inverseTransform as inPlaceInverseFft * draws: arrays of draws for each chain * returns: effective sample size for the specified parameter */ -function compute_effective_sample_size(draws: number[][]): number { +export function compute_effective_sample_size(draws: number[][]): number { const num_chains = draws.length; // use the minimum number of draws across all chains @@ -71,7 +73,7 @@ function compute_effective_sample_size(draws: number[][]): number { const acov = new Array(num_chains).fill(0).map(() => new Array(num_draws).fill(0)); // chain_mean: mean of each chain const chain_mean = new Array(num_chains).fill(0); - // chain_var: variance of each chain + // chain_var: sample variance of each chain const chain_var = new Array(num_chains).fill(0); for (let chain = 0; chain < num_chains; ++chain) { const draw = draws[chain]; @@ -261,4 +263,65 @@ export const compute_split_effective_sample_size = (draws: number[][]) => { return compute_effective_sample_size(split_draws); } -export default compute_effective_sample_size; +export function compute_potential_scale_reduction(draws: number[][]): number { + const num_chains = draws.length; + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + // check if chains are constant; all equal to first draw's value + let are_all_const = false; + const init_draw = new Array(num_chains).fill(0); + for (let chain_idx = 0; chain_idx < num_chains; chain_idx++) { + const draw = draws[chain_idx]; + for (let n = 0; n < num_draws; n++) { + if (!isFinite(draw[n])) { + return NaN; + } + } + + init_draw[chain_idx] = draw[0]; + + const precision = 1e-12; + if (draw.every(d => Math.abs(d - draw[0]) < precision)) { + are_all_const = true; + } + } + + if (are_all_const) { + // If all chains are constant then return NaN + // if they all equal the same constant value + const precision = 1e-12; + if (init_draw.every(d => Math.abs(d - init_draw[0]) < precision)) { + return NaN; + } + } + + // chain_mean: mean of each chain + const chain_mean = new Array(num_chains).fill(0); + // chain_var: sample variance of each chain + const chain_var = new Array(num_chains).fill(0); + for (let chain = 0; chain < num_chains; ++chain) { + const draw = draws[chain]; + chain_mean[chain] = compute_mean(draw); + chain_var[chain] = compute_sample_variance(draw); + } + + const var_between = num_draws * compute_sample_variance(chain_mean); + const var_within = compute_mean(chain_var); + + return Math.sqrt((var_between / var_within + num_draws - 1) / num_draws); +} + +export function compute_split_potential_scale_reduction(draws: number[][]): number { + const num_chains = draws.length; + let num_draws = draws[0].length; + for (let chain = 1; chain < num_chains; ++chain) { + num_draws = Math.min(num_draws, draws[chain].length); + } + + const split_draws = split_chains(draws); + + return compute_potential_scale_reduction(split_draws); +} \ No newline at end of file From 2dd1875b891b8d8edb7ced86547b3cc464a25d92 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Jun 2024 16:55:35 +0000 Subject: [PATCH 6/7] Implement split_chains such that it has the same behavior as Stan for odd numbered draws --- .../stan_stats/stan_stats.ts | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts b/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts index 6efdd814..cba20558 100644 --- a/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts +++ b/gui/src/app/SamplerOutputView/stan_stats/stan_stats.ts @@ -240,24 +240,20 @@ const split_chains = (draws: number[][]) => { num_draws = Math.min(num_draws, draws[chain].length); } - const half = num_draws / 2; - const half_draws = Math.ceil(half); + const half = num_draws / 2.0; + // when N is odd, the (N+1)/2th draw is ignored + const end_first_half = Math.floor(half); + const start_second_half = Math.ceil(half); const split_draws = new Array(2 * num_chains); for (let n = 0; n < num_chains; ++n) { - split_draws[2 * n] = draws[n].slice(0, half_draws); - split_draws[2 * n + 1] = draws[n].slice(half_draws); + split_draws[2 * n] = draws[n].slice(0, end_first_half); + split_draws[2 * n + 1] = draws[n].slice(start_second_half); } return split_draws; } export const compute_split_effective_sample_size = (draws: number[][]) => { - const num_chains = draws.length; - let num_draws = draws[0].length; - for (let chain = 1; chain < num_chains; ++chain) { - num_draws = Math.min(num_draws, draws[chain].length); - } - const split_draws = split_chains(draws); return compute_effective_sample_size(split_draws); @@ -302,6 +298,7 @@ export function compute_potential_scale_reduction(draws: number[][]): number { const chain_mean = new Array(num_chains).fill(0); // chain_var: sample variance of each chain const chain_var = new Array(num_chains).fill(0); + for (let chain = 0; chain < num_chains; ++chain) { const draw = draws[chain]; chain_mean[chain] = compute_mean(draw); @@ -315,13 +312,7 @@ export function compute_potential_scale_reduction(draws: number[][]): number { } export function compute_split_potential_scale_reduction(draws: number[][]): number { - const num_chains = draws.length; - let num_draws = draws[0].length; - for (let chain = 1; chain < num_chains; ++chain) { - num_draws = Math.min(num_draws, draws[chain].length); - } - const split_draws = split_chains(draws); return compute_potential_scale_reduction(split_draws); -} \ No newline at end of file +} From e3d634060dc3ec09c4c113abf9e44c59255953aa Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Fri, 14 Jun 2024 17:01:19 -0400 Subject: [PATCH 7/7] helper function drawsByChain --- gui/src/app/SamplerOutputView/SummaryView.tsx | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/gui/src/app/SamplerOutputView/SummaryView.tsx b/gui/src/app/SamplerOutputView/SummaryView.tsx index 832fa128..1c3e573f 100644 --- a/gui/src/app/SamplerOutputView/SummaryView.tsx +++ b/gui/src/app/SamplerOutputView/SummaryView.tsx @@ -152,26 +152,26 @@ const SummaryView: FunctionComponent = ({ width, height, draws ) } -const computeEss = (x: number[], chainIds: number[]) => { +const drawsByChain = (draws: number[], chainIds: number[]): number[][] => { + // Group draws by chain for use in computing ESS and Rhat const uniqueChainIds = Array.from(new Set(chainIds)).sort(); - const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); - for (let i = 0; i < x.length; i++) { + const drawsByChain: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); + for (let i = 0; i < draws.length; i++) { const chainId = chainIds[i]; const chainIndex = uniqueChainIds.indexOf(chainId); - draws[chainIndex].push(x[i]); + drawsByChain[chainIndex].push(draws[i]); } + return drawsByChain; +} + +const computeEss = (x: number[], chainIds: number[]) => { + const draws = drawsByChain(x, chainIds); const ess = compute_effective_sample_size(draws); return ess; } const computeRhat = (x: number[], chainIds: number[]) => { - const uniqueChainIds = Array.from(new Set(chainIds)).sort(); - const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []); - for (let i = 0; i < x.length; i++) { - const chainId = chainIds[i]; - const chainIndex = uniqueChainIds.indexOf(chainId); - draws[chainIndex].push(x[i]); - } + const draws = drawsByChain(x, chainIds); const rhat = compute_split_potential_scale_reduction(draws); return rhat; }