Skip to content

Commit

Permalink
Merge pull request #90 from flatironinstitute/include-sampling-opts-i…
Browse files Browse the repository at this point in the history
…n-csv-zip

include sampling_opts.json in exported csv zip with draws
  • Loading branch information
WardBrian authored Jun 26, 2024
2 parents 0013c9d + ee1fb1d commit 7b7ba5c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
18 changes: 12 additions & 6 deletions gui/src/app/SamplerOutputView/SamplerOutputView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { SmallIconButton } from "@fi-sci/misc"
import { Download } from "@mui/icons-material"
import JSZip from 'jszip'
import { FunctionComponent, useCallback, useMemo, useState } from "react"
import StanSampler from "../StanSampler/StanSampler"
import StanSampler, { SamplingOpts } from "../StanSampler/StanSampler"
import { useSamplerOutput } from "../StanSampler/useStanSampler"
import TabWidget from "../TabWidget/TabWidget"
import { triggerDownload } from "../util/triggerDownload"
Expand Down Expand Up @@ -30,6 +30,7 @@ const SamplerOutputView: FunctionComponent<SamplerOutputViewProps> = ({width, he
paramNames={paramNames}
numChains={numChains}
computeTimeSec={computeTimeSec}
samplingOpts={sampler.samplingOpts}
/>
)
}
Expand All @@ -41,6 +42,7 @@ type DrawsDisplayProps = {
numChains: number,
paramNames: string[]
computeTimeSec: number | undefined
samplingOpts: SamplingOpts // for including in exported zip
}

const tabs = [
Expand Down Expand Up @@ -70,7 +72,7 @@ const tabs = [
}
]

const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, draws, paramNames, numChains, computeTimeSec }) => {
const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, draws, paramNames, numChains, computeTimeSec, samplingOpts }) => {

const [currentTabId, setCurrentTabId] = useState('summary');

Expand Down Expand Up @@ -106,6 +108,7 @@ const DrawsDisplay: FunctionComponent<DrawsDisplayProps> = ({ width, height, dra
paramNames={paramNames}
drawChainIds={drawChainIds}
drawNumbers={drawNumbers}
samplingOpts={samplingOpts}
/>
<HistsView
width={0}
Expand All @@ -132,9 +135,10 @@ type DrawsViewProps = {
paramNames: string[]
drawChainIds: number[]
drawNumbers: number[]
samplingOpts: SamplingOpts // for including in exported zip
}

const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, paramNames, drawChainIds, drawNumbers }) => {
const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, paramNames, drawChainIds, drawNumbers, samplingOpts }) => {
const [abbreviatedToNumRows, setAbbreviatedToNumRows] = useState<number | undefined>(300);
const draws2 = useMemo(() => {
if (abbreviatedToNumRows === undefined) return draws;
Expand All @@ -147,15 +151,15 @@ const DrawsView: FunctionComponent<DrawsViewProps> = ({ width, height, draws, pa
const handleExportToMultipleCsvs = useCallback(async () => {
const uniqueChainIds = Array.from(new Set(drawChainIds));
const csvTexts = prepareMultipleCsvsText(draws, paramNames, drawChainIds, uniqueChainIds);
const blob = await createZipBlobForMultipleCsvs(csvTexts, uniqueChainIds);
const blob = await createZipBlobForMultipleCsvs(csvTexts, uniqueChainIds, samplingOpts);
const fileName = 'SP-draws.zip';
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = fileName;
a.click();
URL.revokeObjectURL(url);
}, [draws, paramNames, drawChainIds]);
}, [draws, paramNames, drawChainIds, samplingOpts]);
return (
<div style={{position: 'absolute', width, height, overflow: 'auto'}}>
<div>
Expand Down Expand Up @@ -237,14 +241,16 @@ const prepareMultipleCsvsText = (draws: number[][], paramNames: string[], drawCh
})
}

const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: number[]) => {
const createZipBlobForMultipleCsvs = async (csvTexts: string[], uniqueChainIds: number[], samplingOpts: SamplingOpts) => {
const zip = new JSZip();
// put them all in a folder called 'draws'
const folder = zip.folder('draws');
if (!folder) throw new Error('Failed to create folder');
csvTexts.forEach((text, i) => {
folder.file(`chain_${uniqueChainIds[i]}.csv`, text);
});
const samplingOptsText = JSON.stringify(samplingOpts, null, 2);
folder.file('sampling_opts.json', samplingOptsText);
const blob = await zip.generateAsync({type: 'blob'});
return blob;
}
Expand Down
10 changes: 5 additions & 5 deletions gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class StanSampler {
#draws: number[][] = [];
#computeTimeSec: number | undefined = undefined;
#paramNames: string[] = [];
#numChains: number = 0;
#samplingStartTimeSec: number = 0;
#samplingOpts: SamplingOpts = defaultSamplingOpts; // the sampling options used in the last sample call

private constructor(private compiledUrl: string) {
this._initialize()
Expand Down Expand Up @@ -98,7 +98,6 @@ class StanSampler {
console.warn('Number of chains not specified')
return
}
this.#numChains = sampleConfig.num_chains;
if (this.#status === 'sampling') {
console.warn('Already sampling')
return
Expand All @@ -107,6 +106,7 @@ class StanSampler {
console.warn('Model not loaded yet')
return
}
this.#samplingOpts = samplingOpts;
this.#draws = [];
this.#paramNames = [];
this.#worker
Expand Down Expand Up @@ -137,9 +137,6 @@ class StanSampler {
get paramNames() {
return this.#paramNames;
}
get numChains() {
return this.#numChains;
}
get status() {
return this.#status;
}
Expand All @@ -149,6 +146,9 @@ class StanSampler {
get computeTimeSec() {
return this.#computeTimeSec;
}
get samplingOpts() {
return this.#samplingOpts;
}
}

const calculateReasonableRefreshRate = (samplingOpts: SamplingOpts) => {
Expand Down
2 changes: 1 addition & 1 deletion gui/src/app/StanSampler/useStanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export const useSamplerOutput = (sampler: StanSampler | undefined) => {
if (sampler.status === 'completed') {
setDraws(sampler.draws);
setParamNames(sampler.paramNames);
setNumChains(sampler.numChains);
setNumChains(sampler.samplingOpts.num_chains);
setComputeTimeSec(sampler.computeTimeSec);
}
else {
Expand Down

0 comments on commit 7b7ba5c

Please sign in to comment.