Skip to content

Commit

Permalink
Merge pull request #14 from flatironinstitute/sync-tinystan.ts
Browse files Browse the repository at this point in the history
Update tinystan.ts to support initialization
  • Loading branch information
WardBrian authored May 16, 2024
2 parents de1cc70 + 3833122 commit 92c4cca
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 31 deletions.
3 changes: 1 addition & 2 deletions gui/src/app/RunPanel/RunPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { FunctionComponent, useCallback } from 'react';

import StanSampler from '../StanSampler/StanSampler';
import { useSamplerProgress, useSamplerStatus } from '../StanSampler/useStanSampler';
import { defaultSamplerParams } from '../tinystan';
import { Progress } from '../tinystan/Worker';

type RunPanelProps = {
Expand All @@ -25,7 +24,7 @@ const RunPanel: FunctionComponent<RunPanelProps> = ({ width, height, sampler, da

const handleRun = useCallback(async () => {
if (!sampler) return;
sampler.sample({...defaultSamplerParams, data, num_chains: numChains})
sampler.sample({ data, num_chains: numChains})
}, [sampler, data]);

const cancelRun = useCallback(() => {
Expand Down
2 changes: 1 addition & 1 deletion gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class StanSampler {
}
this.#worker.postMessage({ purpose: Requests.Load, url: this.compiledUrl });
}
sample(sampleConfig: SamplerParams) {
sample(sampleConfig: Partial<SamplerParams>) {
if (!this.#worker) return
if (this.#status === '') {
console.warn('Model not loaded yet')
Expand Down
137 changes: 109 additions & 28 deletions gui/src/app/tinystan/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ interface WasmModule {
_tinystan_separator_char(): number;
// prettier-ignore
_tinystan_sample(model: model_ptr, num_chains: number, inits: cstr, seed: number, id: number,
init_radius: number, num_warmup: number, num_samples: number, metric: number, init_inv_metric: cstr,
init_radius: number, num_warmup: number, num_samples: number, metric: number, init_inv_metric: ptr,
adapt: number, delta: number, gamma: number, kappa: number, t0: number, init_buffer: number,
term_buffer: number, window: number, save_warmup: number, stepsize: number, stepsize_jitter: number,
max_depth: number, refresh: number, num_threads: number, out: ptr, out_size: number, metric_out: ptr,
Expand All @@ -37,8 +37,6 @@ interface WasmModule {
}

const NULL = 0 as ptr;
const NULLSTR = 0 as cstr;


const HMC_SAMPLER_VARIABLES = [
"lp__",
Expand All @@ -63,15 +61,20 @@ export type StanDraws = {
draws: number[][];
};

export type StanVariableInputs = Record<string, unknown>;

export interface SamplerParams {
data: string | object;
data: string | StanVariableInputs;
num_chains: number;
inits: string | StanVariableInputs | string[] | StanVariableInputs[];
seed: number | null;
id: number;
init_radius: number;
num_warmup: number;
num_samples: number;
metric: HMCMetric;
save_metric: boolean;
init_inv_metric: number[] | number[][] | number[][][] | null;
adapt: boolean;
delta: number;
gamma: number;
Expand All @@ -88,15 +91,18 @@ export interface SamplerParams {
num_threads: number;
}

export const defaultSamplerParams: SamplerParams = {
const defaultSamplerParams: SamplerParams = {
data: "",
num_chains: 4,
inits: "",
seed: null,
id: 1,
init_radius: 2.0,
num_warmup: 1000,
num_samples: 1000,
metric: HMCMetric.DIAGONAL,
save_metric: false,
init_inv_metric: null, // currently unused
adapt: true,
delta: 0.8,
gamma: 0.05,
Expand All @@ -109,17 +115,20 @@ export const defaultSamplerParams: SamplerParams = {
stepsize: 1.0,
stepsize_jitter: 0.0,
max_depth: 10,
refresh: 100, // this is how often it prints out progress
refresh: 100,
num_threads: -1,
};

export default class StanModel {
private m: WasmModule;
private printCallback: PrintCallback | null;
// used to send multiple JSON values in one string
private sep: string;

private constructor(m: WasmModule, pc: PrintCallback | null) {
this.m = m;
this.printCallback = pc;
this.sep = String.fromCharCode(m._tinystan_separator_char());
}

public static async load(
Expand Down Expand Up @@ -151,17 +160,23 @@ export default class StanModel {
throw new Error(err_msg);
}

private encodeInits(inits: string | StanVariableInputs | string[] | StanVariableInputs[]): cstr {
if (Array.isArray(inits)) {
return this.encodeString(
inits.map(i => string_safe_jsonify(i)).join(this.sep),
);
} else {
return this.encodeString(string_safe_jsonify(inits));
}
}

private withModel<T>(
data: string | object,
data: string | StanVariableInputs,
seed: number,
f: (model: model_ptr) => T,
): T {
const data_ptr = this.encodeString(string_safe_jsonify(data));
const err_ptr = this.m._malloc(4);
if (typeof data === "object") {
data = JSON.stringify(data);
}

const data_ptr = this.encodeString(data);
const model = this.m._tinystan_create_model(data_ptr, seed, err_ptr);
this.m._free(data_ptr);

Expand All @@ -176,20 +191,18 @@ export default class StanModel {
}
}

// Supports most of the TinyStan API except for
// - inits
// - init inv metric
// - save_metric
public sample(p: Partial<SamplerParams>): StanDraws {
const {
data,
num_chains,
inits,
seed,
id,
init_radius,
num_warmup,
num_samples,
metric,
save_metric,
adapt,
delta,
gamma,
Expand Down Expand Up @@ -226,29 +239,54 @@ export default class StanModel {
const rawParamNames = this.m.UTF8ToString(
this.m._tinystan_model_param_names(model),
);
const paramNames = HMC_SAMPLER_VARIABLES.concat(rawParamNames.split(','));
const paramNames = HMC_SAMPLER_VARIABLES.concat(rawParamNames.split(","));

const n_params = paramNames.length;

const free_params = this.m._tinystan_model_num_free_params(model);
if (free_params === 0) {
throw new Error("No parameters to sample");
}

// TODO: allow init_inv_metric to be specified
const init_inv_metric_ptr = NULL;

let metric_out = NULL;
if (save_metric) {
if (metric === HMCMetric.DENSE)
metric_out = this.m._malloc(
num_chains *
free_params *
free_params *
Float64Array.BYTES_PER_ELEMENT,
);
else
metric_out = this.m._malloc(
num_chains * free_params * Float64Array.BYTES_PER_ELEMENT,
);
}

// Allocate memory for the output
const n_draws =
num_chains * (save_warmup ? num_samples + num_warmup : num_samples);
const n_out = n_draws * n_params;
const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);

const inits_ptr = this.encodeInits(inits);

// Sample from the model
const err_ptr = this.m._malloc(4);
const result = this.m._tinystan_sample(
model,
num_chains,
NULLSTR, // inits
inits_ptr,
seed_ || 0,
id,
init_radius,
num_warmup,
num_samples,
metric.valueOf(),
NULLSTR, // init inv metric
init_inv_metric_ptr,
adapt ? 1 : 0,
delta,
gamma,
Expand All @@ -265,32 +303,67 @@ export default class StanModel {
num_threads,
out_ptr,
n_out,
NULL,
metric_out,
err_ptr,
);

if (result != 0) {
this.handleError(err_ptr);
}
this.m._free(err_ptr);
this.m._free(inits_ptr);

const out_buffer = this.m.HEAPF64.subarray(
out_ptr / Float64Array.BYTES_PER_ELEMENT,
out_ptr / Float64Array.BYTES_PER_ELEMENT + n_out,
);

// copy out parameters of interest
const draws: number[][] = Array.from({ length: n_params }, () => []);
for (let i = 0; i < n_draws; i++) {
for (let j = 0; j < n_params; j++) {
const elm = out_buffer[i * n_params + j];
draws[j][i] = elm;
}
}
const draws: number[][] = Array.from({ length: n_params }, (_, i) =>
Array.from({ length: n_draws }, (_, j) => out_buffer[i + n_params * j]),
);

// Clean up
this.m._free(out_ptr);

return { paramNames, draws };
let metric_array: number[][] | number[][][] | null = null;

if (save_metric) {
if (metric === HMCMetric.DENSE) {
const metric_buffer = this.m.HEAPF64.subarray(
metric_out / Float64Array.BYTES_PER_ELEMENT,
metric_out / Float64Array.BYTES_PER_ELEMENT +
num_chains * free_params * free_params,
);

metric_array = Array.from({ length: num_chains }, (_, i) =>
Array.from({ length: free_params }, (_, j) =>
Array.from(
{ length: free_params },
(_, k) =>
metric_buffer[
i * free_params * free_params + j * free_params + k
],
),
),
);
} else {
const metric_buffer = this.m.HEAPF64.subarray(
metric_out / Float64Array.BYTES_PER_ELEMENT,
metric_out / Float64Array.BYTES_PER_ELEMENT +
num_chains * free_params,
);
metric_array = Array.from({ length: num_chains }, (_, i) =>
Array.from(
{ length: free_params },
(_, j) => metric_buffer[i * free_params + j],
),
);
}
}
this.m._free(metric_out);

return { paramNames, draws, metric: metric_array };
});
}

Expand All @@ -311,3 +384,11 @@ export default class StanModel {
return version;
}
}

const string_safe_jsonify = (obj: string | unknown): string => {
if (typeof obj === "string") {
return obj;
} else {
return JSON.stringify(obj);
}
};

0 comments on commit 92c4cca

Please sign in to comment.