diff --git a/gui/src/app/RunPanel/RunPanel.tsx b/gui/src/app/RunPanel/RunPanel.tsx index 6029f837..686c082c 100644 --- a/gui/src/app/RunPanel/RunPanel.tsx +++ b/gui/src/app/RunPanel/RunPanel.tsx @@ -6,7 +6,6 @@ import { FunctionComponent, useCallback } from 'react'; import StanSampler from '../StanSampler/StanSampler'; import { useSamplerProgress, useSamplerStatus } from '../StanSampler/useStanSampler'; -import { defaultSamplerParams } from '../tinystan'; import { Progress } from '../tinystan/Worker'; type RunPanelProps = { @@ -25,7 +24,7 @@ const RunPanel: FunctionComponent = ({ width, height, sampler, da const handleRun = useCallback(async () => { if (!sampler) return; - sampler.sample({...defaultSamplerParams, data, num_chains: numChains}) + sampler.sample({ data, num_chains: numChains}) }, [sampler, data]); const cancelRun = useCallback(() => { diff --git a/gui/src/app/StanSampler/StanSampler.ts b/gui/src/app/StanSampler/StanSampler.ts index 93cd8c74..e679c8cb 100644 --- a/gui/src/app/StanSampler/StanSampler.ts +++ b/gui/src/app/StanSampler/StanSampler.ts @@ -58,7 +58,7 @@ class StanSampler { } this.#worker.postMessage({ purpose: Requests.Load, url: this.compiledUrl }); } - sample(sampleConfig: SamplerParams) { + sample(sampleConfig: Partial) { if (!this.#worker) return if (this.#status === '') { console.warn('Model not loaded yet') diff --git a/gui/src/app/tinystan/index.ts b/gui/src/app/tinystan/index.ts index 1aa2a008..053ba895 100644 --- a/gui/src/app/tinystan/index.ts +++ b/gui/src/app/tinystan/index.ts @@ -19,7 +19,7 @@ interface WasmModule { _tinystan_separator_char(): number; // prettier-ignore _tinystan_sample(model: model_ptr, num_chains: number, inits: cstr, seed: number, id: number, - init_radius: number, num_warmup: number, num_samples: number, metric: number, init_inv_metric: cstr, + init_radius: number, num_warmup: number, num_samples: number, metric: number, init_inv_metric: ptr, adapt: number, delta: number, gamma: number, kappa: number, t0: number, init_buffer: number, term_buffer: number, window: number, save_warmup: number, stepsize: number, stepsize_jitter: number, max_depth: number, refresh: number, num_threads: number, out: ptr, out_size: number, metric_out: ptr, @@ -37,8 +37,6 @@ interface WasmModule { } const NULL = 0 as ptr; -const NULLSTR = 0 as cstr; - const HMC_SAMPLER_VARIABLES = [ "lp__", @@ -63,15 +61,20 @@ export type StanDraws = { draws: number[][]; }; +export type StanVariableInputs = Record; + export interface SamplerParams { - data: string | object; + data: string | StanVariableInputs; num_chains: number; + inits: string | StanVariableInputs | string[] | StanVariableInputs[]; seed: number | null; id: number; init_radius: number; num_warmup: number; num_samples: number; metric: HMCMetric; + save_metric: boolean; + init_inv_metric: number[] | number[][] | number[][][] | null; adapt: boolean; delta: number; gamma: number; @@ -88,15 +91,18 @@ export interface SamplerParams { num_threads: number; } -export const defaultSamplerParams: SamplerParams = { +const defaultSamplerParams: SamplerParams = { data: "", num_chains: 4, + inits: "", seed: null, id: 1, init_radius: 2.0, num_warmup: 1000, num_samples: 1000, metric: HMCMetric.DIAGONAL, + save_metric: false, + init_inv_metric: null, // currently unused adapt: true, delta: 0.8, gamma: 0.05, @@ -109,17 +115,20 @@ export const defaultSamplerParams: SamplerParams = { stepsize: 1.0, stepsize_jitter: 0.0, max_depth: 10, - refresh: 100, // this is how often it prints out progress + refresh: 100, num_threads: -1, }; export default class StanModel { private m: WasmModule; private printCallback: PrintCallback | null; + // used to send multiple JSON values in one string + private sep: string; private constructor(m: WasmModule, pc: PrintCallback | null) { this.m = m; this.printCallback = pc; + this.sep = String.fromCharCode(m._tinystan_separator_char()); } public static async load( @@ -151,17 +160,23 @@ export default class StanModel { throw new Error(err_msg); } + private encodeInits(inits: string | StanVariableInputs | string[] | StanVariableInputs[]): cstr { + if (Array.isArray(inits)) { + return this.encodeString( + inits.map(i => string_safe_jsonify(i)).join(this.sep), + ); + } else { + return this.encodeString(string_safe_jsonify(inits)); + } + } + private withModel( - data: string | object, + data: string | StanVariableInputs, seed: number, f: (model: model_ptr) => T, ): T { + const data_ptr = this.encodeString(string_safe_jsonify(data)); const err_ptr = this.m._malloc(4); - if (typeof data === "object") { - data = JSON.stringify(data); - } - - const data_ptr = this.encodeString(data); const model = this.m._tinystan_create_model(data_ptr, seed, err_ptr); this.m._free(data_ptr); @@ -176,20 +191,18 @@ export default class StanModel { } } - // Supports most of the TinyStan API except for - // - inits - // - init inv metric - // - save_metric public sample(p: Partial): StanDraws { const { data, num_chains, + inits, seed, id, init_radius, num_warmup, num_samples, metric, + save_metric, adapt, delta, gamma, @@ -226,29 +239,54 @@ export default class StanModel { const rawParamNames = this.m.UTF8ToString( this.m._tinystan_model_param_names(model), ); - const paramNames = HMC_SAMPLER_VARIABLES.concat(rawParamNames.split(',')); + const paramNames = HMC_SAMPLER_VARIABLES.concat(rawParamNames.split(",")); const n_params = paramNames.length; + const free_params = this.m._tinystan_model_num_free_params(model); + if (free_params === 0) { + throw new Error("No parameters to sample"); + } + + // TODO: allow init_inv_metric to be specified + const init_inv_metric_ptr = NULL; + + let metric_out = NULL; + if (save_metric) { + if (metric === HMCMetric.DENSE) + metric_out = this.m._malloc( + num_chains * + free_params * + free_params * + Float64Array.BYTES_PER_ELEMENT, + ); + else + metric_out = this.m._malloc( + num_chains * free_params * Float64Array.BYTES_PER_ELEMENT, + ); + } + // Allocate memory for the output const n_draws = num_chains * (save_warmup ? num_samples + num_warmup : num_samples); const n_out = n_draws * n_params; const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT); + const inits_ptr = this.encodeInits(inits); + // Sample from the model const err_ptr = this.m._malloc(4); const result = this.m._tinystan_sample( model, num_chains, - NULLSTR, // inits + inits_ptr, seed_ || 0, id, init_radius, num_warmup, num_samples, metric.valueOf(), - NULLSTR, // init inv metric + init_inv_metric_ptr, adapt ? 1 : 0, delta, gamma, @@ -265,7 +303,7 @@ export default class StanModel { num_threads, out_ptr, n_out, - NULL, + metric_out, err_ptr, ); @@ -273,6 +311,7 @@ export default class StanModel { this.handleError(err_ptr); } this.m._free(err_ptr); + this.m._free(inits_ptr); const out_buffer = this.m.HEAPF64.subarray( out_ptr / Float64Array.BYTES_PER_ELEMENT, @@ -280,17 +319,51 @@ export default class StanModel { ); // copy out parameters of interest - const draws: number[][] = Array.from({ length: n_params }, () => []); - for (let i = 0; i < n_draws; i++) { - for (let j = 0; j < n_params; j++) { - const elm = out_buffer[i * n_params + j]; - draws[j][i] = elm; - } - } + const draws: number[][] = Array.from({ length: n_params }, (_, i) => + Array.from({ length: n_draws }, (_, j) => out_buffer[i + n_params * j]), + ); + // Clean up this.m._free(out_ptr); - return { paramNames, draws }; + let metric_array: number[][] | number[][][] | null = null; + + if (save_metric) { + if (metric === HMCMetric.DENSE) { + const metric_buffer = this.m.HEAPF64.subarray( + metric_out / Float64Array.BYTES_PER_ELEMENT, + metric_out / Float64Array.BYTES_PER_ELEMENT + + num_chains * free_params * free_params, + ); + + metric_array = Array.from({ length: num_chains }, (_, i) => + Array.from({ length: free_params }, (_, j) => + Array.from( + { length: free_params }, + (_, k) => + metric_buffer[ + i * free_params * free_params + j * free_params + k + ], + ), + ), + ); + } else { + const metric_buffer = this.m.HEAPF64.subarray( + metric_out / Float64Array.BYTES_PER_ELEMENT, + metric_out / Float64Array.BYTES_PER_ELEMENT + + num_chains * free_params, + ); + metric_array = Array.from({ length: num_chains }, (_, i) => + Array.from( + { length: free_params }, + (_, j) => metric_buffer[i * free_params + j], + ), + ); + } + } + this.m._free(metric_out); + + return { paramNames, draws, metric: metric_array }; }); } @@ -311,3 +384,11 @@ export default class StanModel { return version; } } + +const string_safe_jsonify = (obj: string | unknown): string => { + if (typeof obj === "string") { + return obj; + } else { + return JSON.stringify(obj); + } +};