Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

develop sampling output view #46

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
"@mui/icons-material": "^5.15.17",
"@mui/material": "^5.15.17",
"monaco-editor": "^0.48.0",
"plotly.js": "^2.33.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8",
"react-plotly.js": "^2.6.0",
"react-router-dom": "^6.17.0",
"react-syntax-highlighter": "^15.5.0",
"react-visibility-sensor": "^5.1.1",
"rehype-mathjax": "^6.0.0",
"rehype-raw": "^6.1.1",
"remark-gfm": "^4.0.0",
Expand All @@ -33,6 +36,7 @@
"@types/node": "^20.12.11",
"@types/react": "^18.2.15",
"@types/react-dom": "^18.2.7",
"@types/react-plotly.js": "^2.6.3",
"@types/react-syntax-highlighter": "^15.5.13",
"@typescript-eslint/eslint-plugin": "^6.0.0",
"@typescript-eslint/parser": "^6.0.0",
Expand Down
59 changes: 59 additions & 0 deletions gui/src/app/SamplerOutputView/HistsView.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { FunctionComponent, useMemo } from "react";
import SequenceHistogramWidget from "./SequenceHistogramWidget";

type HistsViewProps = {
width: number,
height: number,
draws: number[][],
paramNames: string[]
drawChainIds: number[]
}

const HistsView: FunctionComponent<HistsViewProps> = ({ width, height, draws, paramNames, drawChainIds }) => {
const paramNamesResorted = useMemo(() => {
// put the names that don't end with __ first
const names = paramNames.filter((name) => !name.endsWith('__'));
const namesWithSuffix = paramNames.filter((name) => name.endsWith('__'));
return [...names, ...namesWithSuffix];
}, [paramNames]);
return (
<div style={{position: 'absolute', width, height, overflowY: 'auto', display: 'flex', flexWrap: 'wrap'}}>
{
paramNamesResorted.map((paramName) => (
<SequenceHist
key={paramName}
width={300}
height={300}
variableName={paramName}
columnIndex={paramNames.indexOf(paramName)}
draws={draws}
drawChainIds={drawChainIds}
/>
))
}
</div>
)
}

type SequenceHistProps = {
width: number,
height: number,
variableName: string,
draws: number[][]
columnIndex: number
drawChainIds: number[]
}

const SequenceHist: FunctionComponent<SequenceHistProps> = ({ width, height, variableName, draws, columnIndex }) => {
return (
<SequenceHistogramWidget
histData={draws[columnIndex]}
title={variableName}
variableName={variableName}
width={width}
height={height}
/>
)
}

export default HistsView;
116 changes: 61 additions & 55 deletions gui/src/app/SamplerOutputView/SamplerOutputView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { FunctionComponent, useCallback, useMemo, useState } from "react"
import StanSampler from "../StanSampler/StanSampler"
import { useSamplerOutput } from "../StanSampler/useStanSampler"
import TabWidget from "../TabWidget/TabWidget"
import TracePlotsView from "./TracePlotsView"
import SummaryView from "./SummaryView"
import HistsView from "./HistsView"

type SamplerOutputViewProps = {
width: number
Expand All @@ -12,9 +15,9 @@ type SamplerOutputViewProps = {
}

const SamplerOutputView: FunctionComponent<SamplerOutputViewProps> = ({width, height, sampler}) => {
const {draws, paramNames} = useSamplerOutput(sampler)
const {draws, paramNames, numChains, computeTimeSec} = useSamplerOutput(sampler)

if (!draws || !paramNames) return (
if (!draws || !paramNames || !numChains) return (
<span />
)
return (
Expand All @@ -23,6 +26,8 @@ const SamplerOutputView: FunctionComponent<SamplerOutputViewProps> = ({width, he
height={height}
draws={draws}
paramNames={paramNames}
numChains={numChains}
computeTimeSec={computeTimeSec}
/>
)
}
Expand All @@ -31,7 +36,9 @@ type DrawsDisplayProps = {
width: number,
height: number,
draws: number[][],
numChains: number,
paramNames: string[]
computeTimeSec: number | undefined
}

const tabs = [
Expand All @@ -46,22 +53,33 @@ const tabs = [
label: 'Draws',
title: 'Draws view',
closeable: false
},
{
id: 'hists',
label: 'Histograms',
title: 'Histograms view',
closeable: false
},
{
id: 'traceplots',
label: 'Trace Plots',
title: 'Trace Plots view',
closeable: false
}
]

const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, draws, paramNames }) => {
const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, draws, paramNames, numChains, computeTimeSec }) => {

const [currentTabId, setCurrentTabId] = useState('summary');

const means: { [k: string]: number } = {};
const drawChainIds = useMemo(() => {
return [...new Array(draws[0].length).keys()].map(i => 1 + Math.floor(i / draws[0].length * numChains));
}, [draws, numChains]);

for (const [i, element] of paramNames.entries()) {
let sum = 0;
for (const draw of draws[i]) {
sum += draw;
}
means[element] = sum / draws[i].length;
}
const drawNumbers: number[] = useMemo(() => {
const numDrawsPerChain = Math.floor(draws[0].length / numChains);
return [...new Array(draws[0].length).keys()].map(i => 1 + (i % numDrawsPerChain));
}, [draws, numChains]);

return (
<TabWidget
Expand All @@ -72,74 +90,58 @@ const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, dra
setCurrentTabId={setCurrentTabId}
>
<SummaryView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
computeTimeSec={computeTimeSec}
/>
<DrawsView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
drawNumbers={drawNumbers}
/>
<HistsView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
/>
<TracePlotsView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
/>
</TabWidget>
)
}

type SummaryViewProps = {
draws: number[][],
paramNames: string[]
}

const SummaryView: FunctionComponent<SummaryViewProps> = ({ draws, paramNames }) => {
const {means} = useMemo(() => {
const means: { [k: string]: number } = {};
for (const [i, element] of paramNames.entries()) {
let sum = 0;
for (const draw of draws[i]) {
sum += draw;
}
means[element] = sum / draws[i].length;
}
return {means};
}, [draws, paramNames]);

return (
<table className="scientific-table">
<thead>
<tr>
<th>Parameter</th>
<th>Mean</th>
</tr>
</thead>
<tbody>
{Object.entries(means).map(([name, mean]) => (
<tr key={name}>
<td>{name}</td>
<td>{mean}</td>
</tr>
))}
</tbody>
</table>
)
}

type DrawsViewProps = {
width: number
height: number
draws: number[][],
paramNames: string[]
drawChainIds: number[]
drawNumbers: number[]
}

const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, paramNames }) => {
const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, paramNames, drawChainIds, drawNumbers }) => {
const [abbreviatedToNumRows, setAbbreviatedToNumRows] = useState<number | undefined>(300);
const draws2 = useMemo(() => {
if (abbreviatedToNumRows === undefined) return draws;
return draws.map(draw => draw.slice(0, abbreviatedToNumRows));
}, [draws, abbreviatedToNumRows]);
const handleExportToCsv = useCallback(() => {
const csvText = prepareCsvText(draws, paramNames);
const csvText = prepareCsvText(draws, paramNames, drawChainIds, drawNumbers);
downloadTextFile(csvText, 'draws.csv');
}, [draws, paramNames]);
}, [draws, paramNames, drawChainIds, drawNumbers]);
return (
<div style={{position: 'absolute', width, height, overflow: 'auto'}}>
<SmallIconButton
Expand All @@ -150,6 +152,8 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
<table className="draws-table">
<thead>
<tr>
<th key="chain">Chain</th>
<th key="draw">Draw</th>
{
paramNames.map((name, i) => (
<th key={i}>{name}</th>
Expand All @@ -161,6 +165,8 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
{
draws2[0].map((_, i) => (
<tr key={i}>
<td>{drawChainIds[i]}</td>
<td>{drawNumbers[i]}</td>
{
draws.map((draw, j) => (
<td key={j}>{draw[i]}</td>
Expand All @@ -184,11 +190,11 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
)
}

const prepareCsvText = (draws: number[][], paramNames: string[]) => {
const prepareCsvText = (draws: number[][], paramNames: string[], drawChainIds: number[], drawNumbers: number[]) => {
const lines = draws[0].map((_, i) => {
return paramNames.map((_, j) => draws[j][i]).join(',')
return [`${drawChainIds[i]}`, `${drawNumbers[i]}`, ...paramNames.map((_, j) => draws[j][i])].join(',')
})
return [paramNames.join(','), ...lines].join('\n')
return [['Chain', 'Draw', ...paramNames].join(','), ...lines].join('\n')
}

const downloadTextFile = (text: string, filename: string) => {
Expand Down
42 changes: 42 additions & 0 deletions gui/src/app/SamplerOutputView/SequenceHistogramWidget.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import React, { FunctionComponent, Suspense, useMemo } from "react";

type Props = {
histData: number[]
title: string
variableName: string
width: number
height: number
}

const Plot = React.lazy(() => (import('react-plotly.js')))

const SequenceHistogramWidget: FunctionComponent<Props> = ({ histData, title, width, height, variableName }) => {
const data = useMemo(() => (
{
x: histData,
type: 'histogram',
nbinsx: Math.ceil(1.5 * Math.sqrt(histData.length)),
marker: {color: '#505060'},
histnorm: 'probability'
} as any // had to do it this way because ts was not recognizing nbinsx
), [histData])
return (
<div style={{ position: 'relative', width, height }}>
<Suspense fallback={<div>Loading plotly</div>}>
<Plot
data={[data]}
layout={{
width: width,
height,
title: {text: title, font: {size: 12}},
xaxis: {title: variableName},
yaxis: {title: 'Count'},
margin: {r: 0}
}}
/>
</Suspense>
</div>
)
}

export default SequenceHistogramWidget
Loading