Skip to content

Commit

Permalink
Merge pull request #18 from flatironinstitute/sampling-opts
Browse files Browse the repository at this point in the history
Add sampling options panel and update sampler to use new options
  • Loading branch information
magland authored May 17, 2024
2 parents 70e820c + aa9643d commit 3fbc83c
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 23 deletions.
26 changes: 15 additions & 11 deletions gui/src/app/RunPanel/RunPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import LinearProgress, { LinearProgressProps } from '@mui/material/LinearProgres
import Typography from '@mui/material/Typography';
import { FunctionComponent, useCallback } from 'react';

import StanSampler from '../StanSampler/StanSampler';
import StanSampler, { SamplingOpts } from '../StanSampler/StanSampler';
import { useSamplerProgress, useSamplerStatus } from '../StanSampler/useStanSampler';
import { Progress } from '../tinystan/Worker';

Expand All @@ -14,18 +14,17 @@ type RunPanelProps = {
sampler?: StanSampler;
data: any | undefined
dataIsSaved: boolean
samplingOpts: SamplingOpts
};

const numChains = 4;

const RunPanel: FunctionComponent<RunPanelProps> = ({ width, height, sampler, data, dataIsSaved }) => {
const RunPanel: FunctionComponent<RunPanelProps> = ({ width, height, sampler, data, dataIsSaved, samplingOpts }) => {
const {status: runStatus, errorMessage} = useSamplerStatus(sampler)
const progress = useSamplerProgress(sampler)

const handleRun = useCallback(async () => {
if (!sampler) return;
sampler.sample({ data, num_chains: numChains})
}, [sampler, data]);
sampler.sample(data, samplingOpts)
}, [sampler, data, samplingOpts]);

const cancelRun = useCallback(() => {
if (!sampler) return;
Expand All @@ -49,8 +48,9 @@ const RunPanel: FunctionComponent<RunPanelProps> = ({ width, height, sampler, da
<div style={{ position: 'absolute', width, height, overflowY: 'auto' }}>
<div style={{ padding: 5 }}>
<div>
<button onClick={handleRun} disabled={runStatus === 'sampling' || runStatus === 'loading'}>Run</button>
<button onClick={cancelRun} disabled={runStatus !== 'sampling'}>Cancel</button>
<button onClick={handleRun} disabled={runStatus === 'sampling' || runStatus === 'loading'}>run sampling</button>&nbsp;
{runStatus === 'sampling' && <button onClick={cancelRun} disabled={runStatus !== 'sampling'}>cancel</button>}
<hr />
{
runStatus === 'loading' && (
<div>
Expand All @@ -62,7 +62,10 @@ const RunPanel: FunctionComponent<RunPanelProps> = ({ width, height, sampler, da
runStatus === 'sampling' && (
<div>
Sampling
<SamplingProgressComponent report={progress} />
<SamplingProgressComponent
report={progress}
numChains={samplingOpts.num_chains}
/>
</div>
)
}
Expand All @@ -88,14 +91,15 @@ const RunPanel: FunctionComponent<RunPanelProps> = ({ width, height, sampler, da

type SamplingProgressComponentProps = {
report: Progress | undefined
numChains: number
}

const SamplingProgressComponent: FunctionComponent<SamplingProgressComponentProps> = ({ report }) => {
const SamplingProgressComponent: FunctionComponent<SamplingProgressComponentProps> = ({ report, numChains }) => {
if (!report) return <span />
const progress = (report.iteration + ((report.chain - 1) * report.totalIterations)) / (report.totalIterations * numChains) * 100;
return (
<>
<div style={{ width: "60%" }}>
<div style={{ width: "45%" }}>
<LinearProgressWithLabel
sx={{
height: 10,
Expand Down
159 changes: 159 additions & 0 deletions gui/src/app/SamplingOptsPanel/SamplingOptsPanel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import { FunctionComponent, useCallback } from "react";
import { SamplingOpts, defaultSamplingOpts } from "../StanSampler/StanSampler";
import { Hyperlink } from "@fi-sci/misc";
import { Grid } from "@mui/material";

type SamplingOptsPanelProps = {
samplingOpts: SamplingOpts;
setSamplingOpts?: (opts: SamplingOpts) => void;
}

// From Brian
// The following would be nice to have control over:

// [ ]Number of chains (default 4, in practice I think this being a slider from 1 to 8 is probably reasonable, but in theory it can be unbounded)
// Number of warmup iterations (default 1000, can be [0,
// ))
// Number of sampling iterations (default 1000, can be [1,
// ))
// Initializations. This is tricky, can either be 1 json object or a list of num_chains JSON objects.
// "Initialization radius" (parameters not given an initial value directly are drawn from uniform(-R, R) on the unconstrained scale) (default 2.0, can be [0,
// ), but in practice limiting to say 10 in the UI is probably fine)
// Seed. Any uint32.


const sp1 = 0.5;
const sp2 = 1;

const SamplingOptsPanel: FunctionComponent<SamplingOptsPanelProps> = ({ samplingOpts, setSamplingOpts }) => {
const num_chains = samplingOpts.num_chains;
const readOnly = setSamplingOpts === undefined;
const handleReset = useCallback(() => {
setSamplingOpts && setSamplingOpts(defaultSamplingOpts)
}, [setSamplingOpts])
return (
<div>
<Grid container spacing={sp1}>
<Grid container item xs={12} spacing={sp2} title="Number of sampling chains">
<Grid item xs={6}>
# chains
</Grid>
<Grid item xs={6}>
<NumberEdit
value={num_chains}
onChange={(value) => setSamplingOpts && setSamplingOpts({ ...samplingOpts, num_chains: value as number })}
min={1}
max={8}
readOnly={readOnly}
type="int"
/>
</Grid>
</Grid>
<Grid container item xs={12} spacing={sp2} title="Number of warmup draws per chain">
<Grid item xs={6}>
# warmup
</Grid>
<Grid item xs={6}>
<NumberEdit
value={samplingOpts.num_warmup}
onChange={(value) => setSamplingOpts && setSamplingOpts({ ...samplingOpts, num_warmup: value as number })}
min={0}
readOnly={readOnly}
type="int"
/>
</Grid>
</Grid>
<Grid container item xs={12} spacing={sp2} title="Number of regular draws per chain">
<Grid item xs={6}>
# samples
</Grid>
<Grid item xs={6}>
<NumberEdit
value={samplingOpts.num_samples}
onChange={(value) => setSamplingOpts && setSamplingOpts({ ...samplingOpts, num_samples: value as number })}
min={1}
readOnly={readOnly}
type="int"
/>
</Grid>
</Grid>
<Grid container item xs={12} spacing={sp2} title="Radius of the hypercube from which initial values for the model parameters are drawn">
<Grid item xs={6}>
init radius
</Grid>
<Grid item xs={6}>
<NumberEdit
value={samplingOpts.init_radius}
onChange={(value) => setSamplingOpts && setSamplingOpts({ ...samplingOpts, init_radius: value as number })}
min={0}
readOnly={readOnly}
type="float"
/>
</Grid>
</Grid>
<Grid container item xs={12} spacing={sp2} title="Random seed for the sampler. Leave blank (not 0) for a random seed.">
<Grid item xs={6}>
seed
</Grid>
<Grid item xs={6}>
<NumberEdit
value={samplingOpts.seed}
onChange={(value) => setSamplingOpts && setSamplingOpts({ ...samplingOpts, seed: value })}
min={0}
readOnly={readOnly}
type="int"
/>
</Grid>
</Grid>
</Grid>
<div style={{position: 'relative', height: 5}} />
<div>
<Hyperlink onClick={handleReset} color="gray">reset</Hyperlink>
</div>
</div>
)
}

type NumberEditProps = {
value: number | undefined;
onChange: (value: number | undefined) => void;
min: number;
max?: number;
readOnly: boolean;
type: 'int' | 'float' | 'intOrUndefined';
}

const NumberEdit: FunctionComponent<NumberEditProps> = ({ value, onChange, min, max, readOnly, type }) => {
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
let newValue: number | undefined;
switch (type) {
case 'int':
newValue = parseInt(e.target.value);
break;
case 'float':
newValue = parseFloat(e.target.value);
break;
case 'intOrUndefined':
newValue = parseInt(e.target.value);
if (isNaN(newValue)) {
newValue = undefined;
}
break;
}
onChange(newValue);
};

return (
<input
type="number"
value={value === undefined ? "" : value}
onChange={handleChange}
min={min}
max={max}
readOnly={readOnly}
style={{ width: "4em" }}
/>
)
}

export default SamplingOptsPanel;
38 changes: 37 additions & 1 deletion gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@ import StanWorker from '../tinystan/Worker?worker';

export type StanSamplerStatus = '' | 'loading' | 'loaded' | 'sampling' | 'completed' | 'failed';

export type SamplingOpts = {
num_chains: number
num_warmup: number
num_samples: number
init_radius: number
seed: number | undefined
}

export const defaultSamplingOpts: SamplingOpts = {
num_chains: 4,
num_warmup: 1000,
num_samples: 1000,
init_radius: 2.0,
seed: undefined
}

class StanSampler {
#worker: Worker | undefined;
#status: StanSamplerStatus = '';
Expand Down Expand Up @@ -58,7 +74,17 @@ class StanSampler {
}
this.#worker.postMessage({ purpose: Requests.Load, url: this.compiledUrl });
}
sample(sampleConfig: Partial<SamplerParams>) {
sample(data: any, samplingOpts: SamplingOpts) {
const refresh = calculateReasonableRefreshRate(samplingOpts);
const sampleConfig: Partial<SamplerParams> = {
data,
num_chains: samplingOpts.num_chains,
num_warmup: samplingOpts.num_warmup,
num_samples: samplingOpts.num_samples,
init_radius: samplingOpts.init_radius,
seed: samplingOpts.seed !== undefined ? samplingOpts.seed : null,
refresh
}
if (!this.#worker) return
if (this.#status === '') {
console.warn('Model not loaded yet')
Expand Down Expand Up @@ -109,4 +135,14 @@ class StanSampler {
}
}

const calculateReasonableRefreshRate = (samplingOpts: SamplingOpts) => {
const totalSamples = (samplingOpts.num_samples + samplingOpts.num_warmup) * samplingOpts.num_chains;

const onePercent = Math.floor(totalSamples / 100);

const nearestMultipleOfTen = Math.round(onePercent / 10) * 10;

return Math.max(10, nearestMultipleOfTen);
}

export default StanSampler
Loading

0 comments on commit 3fbc83c

Please sign in to comment.