Skip to content

Commit

Permalink
Basic support for calling Pathfinder
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed May 16, 2024
1 parent 70e820c commit 450e7a6
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ COPY make/js /app/tinystan/make/js
# Build a test model
RUN cd tinystan && \
echo 'include make/js' >> Makefile && \
emmake make test_models/bernoulli/bernoulli.js -j2 && \
emmake make test_models/bernoulli/bernoulli.js -j4 && \
emstrip test_models/bernoulli/bernoulli.wasm

RUN pip install fastapi
Expand Down
5 changes: 2 additions & 3 deletions docker/make/local
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ LDLIBS_TBB ?= -ltbb

# could also uses -fexceptions which is more compatible, but slower
CXXFLAGS+=-fwasm-exceptions
CXXFLAGS+=-g

LDFLAGS+=-sMODULARIZE -sEXPORT_NAME=createModule -sEXPORT_ES6 -sENVIRONMENT=web
LDFLAGS+=-sMODULARIZE -sEXPORT_NAME=createModule -sEXPORT_ES6 -sENVIRONMENT=web -sINCOMING_MODULE_JS_API=print,printErr
LDFLAGS+=-sEXIT_RUNTIME=1 -sALLOW_MEMORY_GROWTH=1
# Functions we want. Can add more, with a prepended _, from tinystan.h
EXPORTS=_malloc,_free,_tinystan_api_version,_tinystan_create_model,_tinystan_destroy_error,_tinystan_destroy_model,_tinystan_get_error_message,_tinystan_get_error_type,_tinystan_model_num_free_params,_tinystan_model_param_names,_tinystan_sample,_tinystan_separator_char,_tinystan_stan_version
EXPORTS=_malloc,_free,_tinystan_api_version,_tinystan_create_model,_tinystan_destroy_error,_tinystan_destroy_model,_tinystan_get_error_message,_tinystan_get_error_type,_tinystan_model_num_free_params,_tinystan_model_param_names,_tinystan_sample,_tinystan_pathfinder,_tinystan_separator_char,_tinystan_stan_version
LDFLAGS+=-sEXPORTED_FUNCTIONS=$(EXPORTS) -sEXPORTED_RUNTIME_METHODS=stringToUTF8,getValue,UTF8ToString,lengthBytesUTF8

2 changes: 1 addition & 1 deletion gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class StanSampler {
this.#onStatusChangedCallbacks.forEach(cb => cb())
break;
}
case Replies.SampleReturn: {
case Replies.StanReturn: {
if (e.data.error) {
this.#errorMessage = e.data.error;
this.#status = 'failed';
Expand Down
23 changes: 19 additions & 4 deletions gui/src/app/tinystan/Worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ import StanModel from ".";
export enum Requests {
Load = "load",
Sample = "sample",
Pathfinder = "pathfinder",
}

export enum Replies {
ModelLoaded = "modelLoaded",
SampleReturn = "sampleReturn",
StanReturn = "stanReturn",
Progress = "progress",
}

Expand Down Expand Up @@ -62,15 +63,29 @@ onmessage = function (e) {
}
case Requests.Sample: {
if (!model) {
postMessage({ purpose: Replies.SampleReturn, error: "Model not loaded yet!" })
postMessage({ purpose: Replies.StanReturn, error: "Model not loaded yet!" })
return;
}
try {
const { paramNames, draws } = model.sample(e.data.sampleConfig);
// TODO? use an ArrayBuffer so we can transfer without serialization cost
postMessage({ purpose: Replies.SampleReturn, draws, paramNames, error: null });
postMessage({ purpose: Replies.StanReturn, draws, paramNames, error: null });
} catch (e: any) {
postMessage({ purpose: Replies.SampleReturn, error: e.toString() })
postMessage({ purpose: Replies.StanReturn, error: e.toString() })
}
break;
}
case Requests.Pathfinder: {
if (!model) {
postMessage({ purpose: Replies.StanReturn, error: "Model not loaded yet!" })
return;
}
try {
const { draws, paramNames } = model.pathfinder(e.data.pathfinderConfig);
// TODO? use an ArrayBuffer so we can transfer without serialization cost
postMessage({ purpose: Replies.StanReturn, draws, paramNames, error: null });
} catch (e: any) {
postMessage({ purpose: Replies.StanReturn, error: e.toString() })
}
break;
}
Expand Down
183 changes: 176 additions & 7 deletions gui/src/app/tinystan/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ interface WasmModule {
term_buffer: number, window: number, save_warmup: number, stepsize: number, stepsize_jitter: number,
max_depth: number, refresh: number, num_threads: number, out: ptr, out_size: number, metric_out: ptr,
err_ptr: ptr): number;
// prettier-ignore
_tinystan_pathfinder(model: model_ptr, num_paths: number, inits: cstr, seed: number, id: number,
init_radius: number, num_draws: number, max_history_size: number, init_alpha: number, tol_obj: number,
tol_rel_obj: number, tol_grad: number, tol_rel_grad: number, tol_param: number, num_iterations: number,
num_elbo_draws: number, num_multi_draws: number, calculate_lp: number, psis_resample: number,
refresh: number, num_threads: number, out: ptr, out_size: number, err_ptr: ptr): number;
_tinystan_get_error_message(err_ptr: error_ptr): cstr;
_tinystan_get_error_type(err_ptr: error_ptr): number;
_tinystan_destroy_error(err_ptr: error_ptr): void;
Expand All @@ -38,6 +44,8 @@ interface WasmModule {

const NULL = 0 as ptr;

const PTR_SIZE = 4;

const HMC_SAMPLER_VARIABLES = [
"lp__",
"accept_stat__",
Expand All @@ -54,6 +62,8 @@ export enum HMCMetric {
DIAGONAL = 2,
}

const PATHFINDER_VARIABLES = ["lp_approx__", "lp__"];

export type PrintCallback = (s: string) => void;

export type StanDraws = {
Expand Down Expand Up @@ -119,6 +129,59 @@ const defaultSamplerParams: SamplerParams = {
num_threads: -1,
};

interface LBFGSConfig {
max_history_size: number;
init_alpha: number;
tol_obj: number;
tol_rel_obj: number;
tol_grad: number;
tol_rel_grad: number;
tol_param: number;
num_iterations: number;
}

interface PathfinderUniqueParams {
data: string | StanVariableInputs;
num_paths: number;
inits: string | StanVariableInputs | string[] | StanVariableInputs[];
seed: number | null;
id: number;
init_radius: number;
num_draws: number;
num_elbo_draws: number;
num_multi_draws: number;
calculate_lp: boolean;
psis_resample: boolean;
refresh: number;
num_threads: number;
}

export type PathfinderParams = LBFGSConfig & PathfinderUniqueParams;

const defaultPathfinderParams: PathfinderParams = {
data: "",
num_paths: 4,
inits: "",
seed: null,
id: 1,
init_radius: 2.0,
num_draws: 1000,
max_history_size: 5,
init_alpha: 0.001,
tol_obj: 1e-12,
tol_rel_obj: 1e4,
tol_grad: 1e-8,
tol_rel_grad: 1e7,
tol_param: 1e-8,
num_iterations: 1000,
num_elbo_draws: 25,
num_multi_draws: 1000,
calculate_lp: true,
psis_resample: true,
refresh: 100,
num_threads: -1,
};

export default class StanModel {
private m: WasmModule;
private printCallback: PrintCallback | null;
Expand Down Expand Up @@ -190,7 +253,7 @@ export default class StanModel {
f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T,
): T {
const data_ptr = this.encodeString(string_safe_jsonify(data));
const err_ptr = this.m._malloc(4);
const err_ptr = this.m._malloc(PTR_SIZE);
const model = this.m._tinystan_create_model(data_ptr, seed, err_ptr);
this.m._free(data_ptr);

Expand Down Expand Up @@ -248,10 +311,7 @@ export default class StanModel {
throw new Error("num_samples must be at least 1");
}

let seed_ = seed;
if (seed_ === null) {
seed_ = Math.floor(Math.random() * (2 ^ 32));
}
const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32));

return this.withModel(data, seed_, (model, deferredFree) => {
// Get the parameter names
Expand All @@ -264,7 +324,7 @@ export default class StanModel {

const free_params = this.m._tinystan_model_num_free_params(model);
if (free_params === 0) {
throw new Error("No parameters to sample");
throw new Error("Model has no parameters to sample.");
}

// TODO: allow init_inv_metric to be specified
Expand Down Expand Up @@ -297,7 +357,7 @@ export default class StanModel {
const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);
deferredFree(out_ptr);

const err_ptr = this.m._malloc(4);
const err_ptr = this.m._malloc(PTR_SIZE);
deferredFree(err_ptr);

// Sample from the model
Expand Down Expand Up @@ -386,6 +446,115 @@ export default class StanModel {
});
}

public pathfinder(p: Partial<PathfinderParams>): StanDraws {
const {
data,
num_paths,
inits,
seed,
id,
init_radius,
num_draws,
max_history_size,
init_alpha,
tol_obj,
tol_rel_obj,
tol_grad,
tol_rel_grad,
tol_param,
num_iterations,
num_elbo_draws,
num_multi_draws,
calculate_lp,
psis_resample,
refresh,
num_threads,
} = { ...defaultPathfinderParams, ...p };

if (num_paths < 1) {
throw new Error("num_paths must be at least 1");
}
if (num_draws < 1) {
throw new Error("num_draws must be at least 1");
}
if (num_multi_draws < 1) {
throw new Error("num_multi_draws must be at least 1");
}

const output_rows =
calculate_lp && psis_resample ? num_multi_draws : num_draws * num_paths;

const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32));

return this.withModel(data, seed_, (model, deferredFree) => {
const rawParamNames = this.m.UTF8ToString(
this.m._tinystan_model_param_names(model),
);
const paramNames = PATHFINDER_VARIABLES.concat(rawParamNames.split(","));

const n_params = paramNames.length;

const free_params = this.m._tinystan_model_num_free_params(model);
if (free_params === 0) {
throw new Error("Model has no parameters.");
}

const inits_ptr = this.encodeInits(inits);
deferredFree(inits_ptr);

const n_out = output_rows * n_params;
const out = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);
deferredFree(out);
const err_ptr = this.m._malloc(PTR_SIZE);
deferredFree(err_ptr);

const result = this.m._tinystan_pathfinder(
model,
num_paths,
inits_ptr,
seed_ || 0,
id,
init_radius,
num_draws,
max_history_size,
init_alpha,
tol_obj,
tol_rel_obj,
tol_grad,
tol_rel_grad,
tol_param,
num_iterations,
num_elbo_draws,
num_multi_draws,
calculate_lp ? 1 : 0,
psis_resample ? 1 : 0,
refresh,
num_threads,
out,
n_out,
err_ptr,
);

if (result != 0) {
this.handleError(err_ptr);
}

const out_buffer = this.m.HEAPF64.subarray(
out / Float64Array.BYTES_PER_ELEMENT,
out / Float64Array.BYTES_PER_ELEMENT + n_out,
);

const draws: number[][] = Array.from({ length: n_params }, (_, i) =>
Array.from(
{ length: output_rows },
(_, j) => out_buffer[i + n_params * j],
),
);

return { paramNames, draws };
});
}

public stanVersion(): string {
const major = this.m._malloc(4);
const minor = this.m._malloc(4);
Expand Down

0 comments on commit 450e7a6

Please sign in to comment.