Skip to content

Commit

Permalink
implement Rhat and rearrange stan stats
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Jun 14, 2024
1 parent 65c7969 commit 7e319ff
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 460 deletions.
80 changes: 23 additions & 57 deletions gui/src/app/SamplerOutputView/SummaryView.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { FunctionComponent, useMemo } from "react"
import { ess } from "./advanced/ess"
import { computeMean, computePercentile, computeStdDev } from "./util"
import rhat from "./advanced/rhat"
import compute_effective_sample_size from "./ess_computation_from_stan/compute_effective_sample_size"
import { compute_effective_sample_size, compute_split_potential_scale_reduction } from "./stan_stats/stan_stats"

type SummaryViewProps = {
width: number
Expand Down Expand Up @@ -45,14 +43,9 @@ const columns = [
title: '95th percentile of the parameter'
},
{
key: 'nEff1',
label: 'N_Eff1',
title: 'Effective sample size: A crude measure of the effective sample size (ported from bayes_kit untested)'
},
{
key: 'nEff2',
label: 'N_Eff2',
title: 'Effective sample size: A crude measure of the effective sample size (ported from stan C++ untested)'
key: 'nEff',
label: 'N_Eff',
title: 'Effective sample size: A crude measure of the effective sample size'
},
{
key: 'nEff/s',
Expand All @@ -72,22 +65,20 @@ type TableRow = {
}

const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws, paramNames, drawChainIds, computeTimeSec }) => {
const uniqueChainIds = useMemo(() => (Array.from(new Set(drawChainIds)).sort()), [drawChainIds]);

const rows = useMemo(() => {
const rows: TableRow[] = [];
for (const pname of paramNames) {
const pDraws = draws[paramNames.indexOf(pname)];
const pDrawsSorted = [...pDraws].sort((a, b) => a - b);
const ess1 = computeEss1(pDraws, drawChainIds);
const ess2 = computeEss2(pDraws, drawChainIds);
const ess = computeEss(pDraws, drawChainIds);
const rhat = computeRhat(pDraws, drawChainIds);
const stdDev = computeStdDev(pDraws);
const values = columns.map((column) => {
if (column.key === 'mean') {
return computeMean(pDraws);
}
else if (column.key === 'mcse') {
return stdDev / Math.sqrt(ess1);
return stdDev / Math.sqrt(ess);
}
else if (column.key === 'stdDev') {
return stdDev;
Expand All @@ -101,20 +92,14 @@ const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws
else if (column.key === '95%') {
return computePercentile(pDrawsSorted, 0.95);
}
else if (column.key === 'nEff1') {
return ess1;
}
else if (column.key === 'nEff2') {
return ess2;
else if (column.key === 'nEff') {
return ess;
}
else if (column.key === 'nEff/s') {
return computeTimeSec ? ess1 / computeTimeSec : NaN;
return computeTimeSec ? ess / computeTimeSec : NaN;
}
else if (column.key === 'rHat') {
const counts = computeChainCounts(drawChainIds, uniqueChainIds);
const means = computeChainMeans(pDraws, drawChainIds, uniqueChainIds);
const stdevs = computeChainStdDevs(pDraws, drawChainIds, uniqueChainIds);
return rhat({ counts, means, stdevs });
return rhat;
}
else {
return NaN;
Expand All @@ -126,7 +111,7 @@ const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws
})
}
return rows;
}, [paramNames, draws, drawChainIds, uniqueChainIds, computeTimeSec]);
}, [draws, paramNames, drawChainIds, computeTimeSec]);

return (
<div style={{position: 'absolute', width, height, overflowY: 'auto'}}>
Expand Down Expand Up @@ -167,18 +152,7 @@ const SummaryView: FunctionComponent<SummaryViewProps> = ({ width, height, draws
)
}

const computeEss1 = (x: number[], chainIds: number[]) => {
const uniqueChainIds = Array.from(new Set(chainIds)).sort();
let sumEss = 0;
for (const chainId of uniqueChainIds) {
const chainX = x.filter((_, i) => chainIds[i] === chainId);
const {essValue} = ess(chainX);
sumEss += essValue;
}
return sumEss;
}

const computeEss2 = (x: number[], chainIds: number[]) => {
const computeEss = (x: number[], chainIds: number[]) => {
const uniqueChainIds = Array.from(new Set(chainIds)).sort();
const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []);
for (let i = 0; i < x.length; i++) {
Expand All @@ -190,24 +164,16 @@ const computeEss2 = (x: number[], chainIds: number[]) => {
return ess;
}

const computeChainCounts = (chainIds: number[], uniqueChainIds: number[]) => {
return uniqueChainIds.map((chainId) => {
return chainIds.filter((id) => id === chainId).length;
});
}

const computeChainMeans = (x: number[], chainIds: number[], uniqueChainIds: number[]) => {
return uniqueChainIds.map((chainId) => {
const chainX = x.filter((_, i) => chainIds[i] === chainId);
return computeMean(chainX);
});
}

const computeChainStdDevs = (x: number[], chainIds: number[], uniqueChainIds: number[]) => {
return uniqueChainIds.map((chainId) => {
const chainX = x.filter((_, i) => chainIds[i] === chainId);
return computeStdDev(chainX);
});
const computeRhat = (x: number[], chainIds: number[]) => {
const uniqueChainIds = Array.from(new Set(chainIds)).sort();
const draws: number[][] = new Array(uniqueChainIds.length).fill(0).map(() => []);
for (let i = 0; i < x.length; i++) {
const chainId = chainIds[i];
const chainIndex = uniqueChainIds.indexOf(chainId);
draws[chainIndex].push(x[i]);
}
const rhat = compute_split_potential_scale_reduction(draws);
return rhat;
}

// Example of Stan output...
Expand Down
142 changes: 0 additions & 142 deletions gui/src/app/SamplerOutputView/advanced/ess.ts

This file was deleted.

32 changes: 0 additions & 32 deletions gui/src/app/SamplerOutputView/advanced/rhat.ts

This file was deleted.

Loading

0 comments on commit 7e319ff

Please sign in to comment.