diff --git a/gui/src/app/RunPanel/RunPanel.tsx b/gui/src/app/RunPanel/RunPanel.tsx index 686c082c..8427dafa 100644 --- a/gui/src/app/RunPanel/RunPanel.tsx +++ b/gui/src/app/RunPanel/RunPanel.tsx @@ -4,7 +4,7 @@ import LinearProgress, { LinearProgressProps } from '@mui/material/LinearProgres import Typography from '@mui/material/Typography'; import { FunctionComponent, useCallback } from 'react'; -import StanSampler from '../StanSampler/StanSampler'; +import StanSampler, { SamplingOpts } from '../StanSampler/StanSampler'; import { useSamplerProgress, useSamplerStatus } from '../StanSampler/useStanSampler'; import { Progress } from '../tinystan/Worker'; @@ -14,18 +14,17 @@ type RunPanelProps = { sampler?: StanSampler; data: any | undefined dataIsSaved: boolean + samplingOpts: SamplingOpts }; -const numChains = 4; - -const RunPanel: FunctionComponent = ({ width, height, sampler, data, dataIsSaved }) => { +const RunPanel: FunctionComponent = ({ width, height, sampler, data, dataIsSaved, samplingOpts }) => { const {status: runStatus, errorMessage} = useSamplerStatus(sampler) const progress = useSamplerProgress(sampler) const handleRun = useCallback(async () => { if (!sampler) return; - sampler.sample({ data, num_chains: numChains}) - }, [sampler, data]); + sampler.sample(data, samplingOpts) + }, [sampler, data, samplingOpts]); const cancelRun = useCallback(() => { if (!sampler) return; @@ -49,8 +48,9 @@ const RunPanel: FunctionComponent = ({ width, height, sampler, da
- - +   + {runStatus === 'sampling' && } +
{ runStatus === 'loading' && (
@@ -62,7 +62,10 @@ const RunPanel: FunctionComponent = ({ width, height, sampler, da runStatus === 'sampling' && (
Sampling - +
) } @@ -88,14 +91,15 @@ const RunPanel: FunctionComponent = ({ width, height, sampler, da type SamplingProgressComponentProps = { report: Progress | undefined + numChains: number } -const SamplingProgressComponent: FunctionComponent = ({ report }) => { +const SamplingProgressComponent: FunctionComponent = ({ report, numChains }) => { if (!report) return const progress = (report.iteration + ((report.chain - 1) * report.totalIterations)) / (report.totalIterations * numChains) * 100; return ( <> -
+
void; +} + +// From Brian +// The following would be nice to have control over: + +// [ ]Number of chains (default 4, in practice I think this being a slider from 1 to 8 is probably reasonable, but in theory it can be unbounded) +// Number of warmup iterations (default 1000, can be [0, +// )) +// Number of sampling iterations (default 1000, can be [1, +// )) +// Initializations. This is tricky, can either be 1 json object or a list of num_chains JSON objects. +// "Initialization radius" (parameters not given an initial value directly are drawn from uniform(-R, R) on the unconstrained scale) (default 2.0, can be [0, +// ), but in practice limiting to say 10 in the UI is probably fine) +// Seed. Any uint32. + + +const sp1 = 0.5; +const sp2 = 1; + +const SamplingOptsPanel: FunctionComponent = ({ samplingOpts, setSamplingOpts }) => { + const num_chains = samplingOpts.num_chains; + const readOnly = setSamplingOpts === undefined; + const handleReset = useCallback(() => { + setSamplingOpts && setSamplingOpts(defaultSamplingOpts) + }, [setSamplingOpts]) + return ( +
+ + + + # chains + + + setSamplingOpts && setSamplingOpts({ ...samplingOpts, num_chains: value as number })} + min={1} + max={8} + readOnly={readOnly} + type="int" + /> + + + + + # warmup + + + setSamplingOpts && setSamplingOpts({ ...samplingOpts, num_warmup: value as number })} + min={0} + readOnly={readOnly} + type="int" + /> + + + + + # samples + + + setSamplingOpts && setSamplingOpts({ ...samplingOpts, num_samples: value as number })} + min={1} + readOnly={readOnly} + type="int" + /> + + + + + init radius + + + setSamplingOpts && setSamplingOpts({ ...samplingOpts, init_radius: value as number })} + min={0} + readOnly={readOnly} + type="float" + /> + + + + + seed + + + setSamplingOpts && setSamplingOpts({ ...samplingOpts, seed: value })} + min={0} + readOnly={readOnly} + type="int" + /> + + + +
+
+ reset +
+
+ ) +} + +type NumberEditProps = { + value: number | undefined; + onChange: (value: number | undefined) => void; + min: number; + max?: number; + readOnly: boolean; + type: 'int' | 'float' | 'intOrUndefined'; +} + +const NumberEdit: FunctionComponent = ({ value, onChange, min, max, readOnly, type }) => { + const handleChange = (e: React.ChangeEvent) => { + let newValue: number | undefined; + switch (type) { + case 'int': + newValue = parseInt(e.target.value); + break; + case 'float': + newValue = parseFloat(e.target.value); + break; + case 'intOrUndefined': + newValue = parseInt(e.target.value); + if (isNaN(newValue)) { + newValue = undefined; + } + break; + } + onChange(newValue); + }; + + return ( + + ) +} + +export default SamplingOptsPanel; \ No newline at end of file diff --git a/gui/src/app/StanSampler/StanSampler.ts b/gui/src/app/StanSampler/StanSampler.ts index e679c8cb..8f59858e 100644 --- a/gui/src/app/StanSampler/StanSampler.ts +++ b/gui/src/app/StanSampler/StanSampler.ts @@ -4,6 +4,22 @@ import StanWorker from '../tinystan/Worker?worker'; export type StanSamplerStatus = '' | 'loading' | 'loaded' | 'sampling' | 'completed' | 'failed'; +export type SamplingOpts = { + num_chains: number + num_warmup: number + num_samples: number + init_radius: number + seed: number | undefined +} + +export const defaultSamplingOpts: SamplingOpts = { + num_chains: 4, + num_warmup: 1000, + num_samples: 1000, + init_radius: 2.0, + seed: undefined +} + class StanSampler { #worker: Worker | undefined; #status: StanSamplerStatus = ''; @@ -58,7 +74,17 @@ class StanSampler { } this.#worker.postMessage({ purpose: Requests.Load, url: this.compiledUrl }); } - sample(sampleConfig: Partial) { + sample(data: any, samplingOpts: SamplingOpts) { + const refresh = calculateReasonableRefreshRate(samplingOpts); + const sampleConfig: Partial = { + data, + num_chains: samplingOpts.num_chains, + num_warmup: samplingOpts.num_warmup, + num_samples: samplingOpts.num_samples, + init_radius: samplingOpts.init_radius, + seed: samplingOpts.seed !== undefined ? samplingOpts.seed : null, + refresh + } if (!this.#worker) return if (this.#status === '') { console.warn('Model not loaded yet') @@ -109,4 +135,14 @@ class StanSampler { } } +const calculateReasonableRefreshRate = (samplingOpts: SamplingOpts) => { + const totalSamples = (samplingOpts.num_samples + samplingOpts.num_warmup) * samplingOpts.num_chains; + + const onePercent = Math.floor(totalSamples / 100); + + const nearestMultipleOfTen = Math.round(onePercent / 10) * 10; + + return Math.max(10, nearestMultipleOfTen); +} + export default StanSampler diff --git a/gui/src/app/pages/HomePage/HomePage.tsx b/gui/src/app/pages/HomePage/HomePage.tsx index 6bcaeb2e..b2f2abb7 100644 --- a/gui/src/app/pages/HomePage/HomePage.tsx +++ b/gui/src/app/pages/HomePage/HomePage.tsx @@ -5,8 +5,10 @@ import DataFileEditor from "../../FileEditor/DataFileEditor"; import StanFileEditor from "../../FileEditor/StanFileEditor"; import RunPanel from "../../RunPanel/RunPanel"; import SamplerOutputView from "../../SamplerOutputView/SamplerOutputView"; -import useStanSampler from "../../StanSampler/useStanSampler"; +import useStanSampler, { useSamplerStatus } from "../../StanSampler/useStanSampler"; import examplesStanies, { Stanie, StanieMetaData } from "../../exampleStanies/exampleStanies"; +import SamplingOptsPanel from "../../SamplingOptsPanel/SamplingOptsPanel"; +import { SamplingOpts, defaultSamplingOpts } from "../../StanSampler/StanSampler"; type Props = { width: number @@ -16,6 +18,7 @@ type Props = { const defaultStanContent = '' const defaultDataContent = '' const defaultMetaContent = '{"title": "Untitled"}' +const defaultSamplingOptsContent = JSON.stringify(defaultSamplingOpts) const initialFileContent = localStorage.getItem('main.stan') || defaultStanContent @@ -23,7 +26,7 @@ const initialDataFileContent = localStorage.getItem('data.json') || defaultDataC const initialMetaContent = localStorage.getItem('meta.json') || defaultMetaContent - +const initialSamplingOptsContent = localStorage.getItem('samplingOpts.json') || defaultSamplingOptsContent const HomePage: FunctionComponent = ({ width, height }) => { const [fileContent, saveFileContent] = useState(initialFileContent) @@ -44,6 +47,18 @@ const HomePage: FunctionComponent = ({ width, height }) => { localStorage.setItem('data.json', dataFileContent) }, [dataFileContent]) + const [samplingOptsContent, setSamplingOptsContent] = useState(initialSamplingOptsContent) + useEffect(() => { + localStorage.setItem('samplingOpts.json', samplingOptsContent) + }, [samplingOptsContent]) + const samplingOpts = useMemo(() => ( + {...defaultSamplingOpts, ...JSON.parse(samplingOptsContent)} + ), [samplingOptsContent]) + + const setSamplingOpts = useCallback((opts: SamplingOpts) => { + setSamplingOptsContent(JSON.stringify(opts, null, 2)) + }, [setSamplingOptsContent]) + const [metaContent, setMetaContent] = useState(initialMetaContent) useEffect(() => { localStorage.setItem('meta.json', metaContent) @@ -130,6 +145,8 @@ const HomePage: FunctionComponent = ({ width, height }) => { editedDataFileContent={editedDataFileContent} setEditedDataFileContent={setEditedDataFileContent} compiledMainJsUrl={compiledMainJsUrl} + samplingOpts={samplingOpts} + setSamplingOpts={setSamplingOpts} />
@@ -145,9 +162,11 @@ type RightViewProps = { editedDataFileContent: string setEditedDataFileContent: (text: string) => void compiledMainJsUrl?: string + samplingOpts: SamplingOpts + setSamplingOpts: (opts: SamplingOpts) => void } -const RightView: FunctionComponent = ({ width, height, dataFileContent, saveDataFileContent, editedDataFileContent, setEditedDataFileContent, compiledMainJsUrl }) => { +const RightView: FunctionComponent = ({ width, height, dataFileContent, saveDataFileContent, editedDataFileContent, setEditedDataFileContent, compiledMainJsUrl, samplingOpts, setSamplingOpts }) => { return ( = ({ width, height, dataFileC compiledMainJsUrl={compiledMainJsUrl} dataFileContent={dataFileContent} dataIsSaved={dataFileContent === editedDataFileContent} + samplingOpts={samplingOpts} + setSamplingOpts={setSamplingOpts} /> ) @@ -182,9 +203,11 @@ type LowerRightViewProps = { compiledMainJsUrl?: string dataFileContent: string dataIsSaved: boolean + samplingOpts: SamplingOpts + setSamplingOpts: (opts: SamplingOpts) => void } -const LowerRightView: FunctionComponent = ({ width, height, compiledMainJsUrl, dataFileContent, dataIsSaved }) => { +const LowerRightView: FunctionComponent = ({ width, height, compiledMainJsUrl, dataFileContent, dataIsSaved, samplingOpts, setSamplingOpts }) => { const parsedData = useMemo(() => { try { return JSON.parse(dataFileContent) @@ -193,25 +216,34 @@ const LowerRightView: FunctionComponent = ({ width, height, return undefined } }, [dataFileContent]) - const runPanelHeight = 80 + const samplingOptsPanelHeight = 160 + const samplingOptsPanelWidth = Math.min(180, width / 2) const {sampler} = useStanSampler(compiledMainJsUrl) - + const {status: samplerStatus} = useSamplerStatus(sampler) + const isSampling = samplerStatus === 'sampling' return (
-
+
+ +
+
-
+
{sampler && } diff --git a/gui/src/app/tinystan/Worker.ts b/gui/src/app/tinystan/Worker.ts index c87706eb..a4549e2f 100644 --- a/gui/src/app/tinystan/Worker.ts +++ b/gui/src/app/tinystan/Worker.ts @@ -23,6 +23,12 @@ const parseProgress = (msg: string): Progress => { // Examples (note different spacing): // Chain [1] Iteration: 2000 / 2000 [100%] (Sampling) // Chain [2] Iteration: 800 / 2000 [ 40%] (Warmup) + + // But if there is only one chain, then + // the "Chain [x]" part is omitted. + if (msg.startsWith('Iteration:')) { + msg = 'Chain [1] ' + msg; + } const parts = msg.split(/\s+/); const chain = parseInt(parts[1].slice(1, -1)); const iteration = parseInt(parts[3]); @@ -33,7 +39,7 @@ const parseProgress = (msg: string): Progress => { } const progressPrintCallback = (msg: string) => { - if (!msg.startsWith('Chain')) { + if ((!msg.startsWith('Chain')) && (!msg.startsWith('Iteration:'))) { console.log(msg); return; }