Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frontend: sampling output view (take 3) #47

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions gui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
"@mui/icons-material": "^5.15.17",
"@mui/material": "^5.15.17",
"monaco-editor": "^0.48.0",
"plotly.js": "^2.33.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[note] it may be worth using an alternative plotly distribution for size if we only need certain kinds of plots. For stan-web-demo this made a pretty significant difference in overall bundle size.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out, I didn't know it was available.

Not an easy change right now since I am using react-plotly (npm package) which seems to only work with the full plotly.js. But I'm sure there's a workaround.

I'll note that since we are doing lazy loading, the plotly.js doesn't contribute to the main bundle on initial page load.

Shall we defer for a separate issue when we can figure out how to work around limitations of react-plotly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can open a ticket to try it later. I know react-plotly supports it, but I’m not sure how this interacts with the lazy loading, and as you say the lazy loading helps make it less necessary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For information, I just tried using a smaller distribution and ran into some bundling problems, so yeah I think we should do this in a different ticket.

"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8",
"react-plotly.js": "^2.6.0",
"react-router-dom": "^6.17.0",
"react-syntax-highlighter": "^15.5.0",
"react-visibility-sensor": "^5.1.1",
"rehype-mathjax": "^6.0.0",
"rehype-raw": "^6.1.1",
"remark-gfm": "^4.0.0",
Expand All @@ -35,6 +38,7 @@
"@types/node": "^20.12.11",
"@types/react": "^18.2.15",
"@types/react-dom": "^18.2.7",
"@types/react-plotly.js": "^2.6.3",
"@types/react-syntax-highlighter": "^15.5.13",
"@typescript-eslint/eslint-plugin": "^6.0.0",
"@typescript-eslint/parser": "^6.0.0",
Expand Down
59 changes: 59 additions & 0 deletions gui/src/app/SamplerOutputView/HistsView.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { FunctionComponent, useMemo } from "react";
import SequenceHistogramWidget from "./SequenceHistogramWidget";

type HistsViewProps = {
width: number,
height: number,
draws: number[][],
paramNames: string[]
drawChainIds: number[]
}

const HistsView: FunctionComponent<HistsViewProps> = ({ width, height, draws, paramNames, drawChainIds }) => {
const paramNamesResorted = useMemo(() => {
// put the names that don't end with __ first
const names = paramNames.filter((name) => !name.endsWith('__'));
const namesWithSuffix = paramNames.filter((name) => name.endsWith('__'));
return [...names, ...namesWithSuffix];
}, [paramNames]);
return (
<div style={{position: 'absolute', width, height, overflowY: 'auto', display: 'flex', flexWrap: 'wrap'}}>
{
paramNamesResorted.map((paramName) => (
<SequenceHist
key={paramName}
width={300}
height={300}
variableName={paramName}
columnIndex={paramNames.indexOf(paramName)}
draws={draws}
drawChainIds={drawChainIds}
/>
))
}
</div>
)
}

type SequenceHistProps = {
width: number,
height: number,
variableName: string,
draws: number[][]
columnIndex: number
drawChainIds: number[]
}

const SequenceHist: FunctionComponent<SequenceHistProps> = ({ width, height, variableName, draws, columnIndex }) => {
return (
<SequenceHistogramWidget
histData={draws[columnIndex]}
title={variableName}
variableName={variableName}
width={width}
height={height}
/>
)
}

export default HistsView;
116 changes: 61 additions & 55 deletions gui/src/app/SamplerOutputView/SamplerOutputView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { FunctionComponent, useCallback, useMemo, useState } from "react"
import StanSampler from "../StanSampler/StanSampler"
import { useSamplerOutput } from "../StanSampler/useStanSampler"
import TabWidget from "../TabWidget/TabWidget"
import TracePlotsView from "./TracePlotsView"
import SummaryView from "./SummaryView"
import HistsView from "./HistsView"

type SamplerOutputViewProps = {
width: number
Expand All @@ -12,9 +15,9 @@ type SamplerOutputViewProps = {
}

const SamplerOutputView: FunctionComponent<SamplerOutputViewProps> = ({width, height, sampler}) => {
const {draws, paramNames} = useSamplerOutput(sampler)
const {draws, paramNames, numChains, computeTimeSec} = useSamplerOutput(sampler)

if (!draws || !paramNames) return (
if (!draws || !paramNames || !numChains) return (
<span />
)
return (
Expand All @@ -23,6 +26,8 @@ const SamplerOutputView: FunctionComponent<SamplerOutputViewProps> = ({width, he
height={height}
draws={draws}
paramNames={paramNames}
numChains={numChains}
computeTimeSec={computeTimeSec}
/>
)
}
Expand All @@ -31,7 +36,9 @@ type DrawsDisplayProps = {
width: number,
height: number,
draws: number[][],
numChains: number,
paramNames: string[]
computeTimeSec: number | undefined
}

const tabs = [
Expand All @@ -46,22 +53,33 @@ const tabs = [
label: 'Draws',
title: 'Draws view',
closeable: false
},
{
id: 'hists',
label: 'Histograms',
title: 'Histograms view',
closeable: false
},
{
id: 'traceplots',
label: 'Trace Plots',
title: 'Trace Plots view',
closeable: false
}
]

const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, draws, paramNames }) => {
const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, draws, paramNames, numChains, computeTimeSec }) => {

const [currentTabId, setCurrentTabId] = useState('summary');

const means: { [k: string]: number } = {};
const drawChainIds = useMemo(() => {
return [...new Array(draws[0].length).keys()].map(i => 1 + Math.floor(i / draws[0].length * numChains));
}, [draws, numChains]);

for (const [i, element] of paramNames.entries()) {
let sum = 0;
for (const draw of draws[i]) {
sum += draw;
}
means[element] = sum / draws[i].length;
}
const drawNumbers: number[] = useMemo(() => {
const numDrawsPerChain = Math.floor(draws[0].length / numChains);
return [...new Array(draws[0].length).keys()].map(i => 1 + (i % numDrawsPerChain));
}, [draws, numChains]);

return (
<TabWidget
Expand All @@ -72,74 +90,58 @@ const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, dra
setCurrentTabId={setCurrentTabId}
>
<SummaryView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
computeTimeSec={computeTimeSec}
/>
<DrawsView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
drawNumbers={drawNumbers}
/>
<HistsView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
/>
<TracePlotsView
width={0}
height={0}
draws={draws}
paramNames={paramNames}
drawChainIds={drawChainIds}
/>
</TabWidget>
)
}

type SummaryViewProps = {
draws: number[][],
paramNames: string[]
}

const SummaryView: FunctionComponent<SummaryViewProps> = ({ draws, paramNames }) => {
const {means} = useMemo(() => {
const means: { [k: string]: number } = {};
for (const [i, element] of paramNames.entries()) {
let sum = 0;
for (const draw of draws[i]) {
sum += draw;
}
means[element] = sum / draws[i].length;
}
return {means};
}, [draws, paramNames]);

return (
<table className="scientific-table">
<thead>
<tr>
<th>Parameter</th>
<th>Mean</th>
</tr>
</thead>
<tbody>
{Object.entries(means).map(([name, mean]) => (
<tr key={name}>
<td>{name}</td>
<td>{mean}</td>
</tr>
))}
</tbody>
</table>
)
}

type DrawsViewProps = {
width: number
height: number
draws: number[][],
paramNames: string[]
drawChainIds: number[]
drawNumbers: number[]
}

const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, paramNames }) => {
const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, paramNames, drawChainIds, drawNumbers }) => {
const [abbreviatedToNumRows, setAbbreviatedToNumRows] = useState<number | undefined>(300);
const draws2 = useMemo(() => {
if (abbreviatedToNumRows === undefined) return draws;
return draws.map(draw => draw.slice(0, abbreviatedToNumRows));
}, [draws, abbreviatedToNumRows]);
const handleExportToCsv = useCallback(() => {
const csvText = prepareCsvText(draws, paramNames);
const csvText = prepareCsvText(draws, paramNames, drawChainIds, drawNumbers);
downloadTextFile(csvText, 'draws.csv');
}, [draws, paramNames]);
}, [draws, paramNames, drawChainIds, drawNumbers]);
return (
<div style={{position: 'absolute', width, height, overflow: 'auto'}}>
<SmallIconButton
Expand All @@ -150,6 +152,8 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
<table className="draws-table">
<thead>
<tr>
<th key="chain">Chain</th>
<th key="draw">Draw</th>
{
paramNames.map((name, i) => (
<th key={i}>{name}</th>
Expand All @@ -161,6 +165,8 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
{
draws2[0].map((_, i) => (
<tr key={i}>
<td>{drawChainIds[i]}</td>
<td>{drawNumbers[i]}</td>
{
draws.map((draw, j) => (
<td key={j}>{draw[i]}</td>
Expand All @@ -184,11 +190,11 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
)
}

const prepareCsvText = (draws: number[][], paramNames: string[]) => {
const prepareCsvText = (draws: number[][], paramNames: string[], drawChainIds: number[], drawNumbers: number[]) => {
const lines = draws[0].map((_, i) => {
return paramNames.map((_, j) => draws[j][i]).join(',')
return [`${drawChainIds[i]}`, `${drawNumbers[i]}`, ...paramNames.map((_, j) => draws[j][i])].join(',')
})
return [paramNames.join(','), ...lines].join('\n')
return [['Chain', 'Draw', ...paramNames].join(','), ...lines].join('\n')
}

const downloadTextFile = (text: string, filename: string) => {
Expand Down
39 changes: 39 additions & 0 deletions gui/src/app/SamplerOutputView/SequenceHistogramWidget.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { FunctionComponent, useMemo } from "react";
import LazyPlotlyPlot from "../components/LazyPlotlyPlot";

type Props = {
histData: number[]
title: string
variableName: string
width: number
height: number
}

const SequenceHistogramWidget: FunctionComponent<Props> = ({ histData, title, width, height, variableName }) => {
const data = useMemo(() => (
{
x: histData,
type: 'histogram',
nbinsx: Math.ceil(1.5 * Math.sqrt(histData.length)),
marker: {color: '#505060'},
histnorm: 'probability'
} as any // had to do it this way because ts was not recognizing nbinsx
), [histData])
return (
<div style={{ position: 'relative', width, height }}>
<LazyPlotlyPlot
data={[data]}
layout={{
width: width,
height,
title: {text: title, font: {size: 12}},
xaxis: {title: variableName},
yaxis: {title: 'Count'},
margin: {r: 0}
}}
/>
</div>
)
}

export default SequenceHistogramWidget
57 changes: 57 additions & 0 deletions gui/src/app/SamplerOutputView/SequencePlotWidget.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { FunctionComponent, useMemo } from "react";
import LazyPlotlyPlot from "../components/LazyPlotlyPlot";

export type PlotSequence = {
label: string
data: number[]
color: string
}

type Props = {
plotSequences: PlotSequence[]
variableName: string
highlightDrawRange?: [number, number]
width: number
height: number
}

const SequencePlotWidget: FunctionComponent<Props> = ({ plotSequences, variableName, highlightDrawRange, width, height }) => {
const shapes = useMemo(() => (
(highlightDrawRange ? (
[{type: 'rect', x0: highlightDrawRange[0], x1: highlightDrawRange[1], y0: 0, y1: 1, yref: 'paper', fillcolor: 'yellow', opacity: 0.1}]
) : []) as any
), [highlightDrawRange])
const data: any[] = useMemo(() => (
plotSequences.map(ps => (
{
x: [...new Array(ps.data.length).keys()].map(i => (i + 1)),
y: ps.data,
type: 'scatter',
mode: 'lines+markers',
marker: {color: ps.color}
}
))
), [plotSequences])
const layout = useMemo(() => ({
width: width,
height,
title: '',
yaxis: {title: variableName},
xaxis: {title: 'draw'},
shapes,
margin: {
t: 30, b: 40, r: 0
},
showlegend: false
}), [width, height, variableName, shapes])
return (
<div style={{ position: 'relative', width, height }}>
<LazyPlotlyPlot
data={data}
layout={layout}
/>
</div>
)
}

export default SequencePlotWidget
Loading