From 82ce578013a51397a9c35eb5eb5f5925f19dff7a Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 28 Oct 2024 20:45:45 +0000 Subject: [PATCH 1/4] Python run script skeleton --- gui/src/app/Scripting/Runtime/cmdstan.py | 7 + gui/src/app/Scripting/Runtime/load_args.py | 25 ++ .../app/Scripting/Runtime/makePyRuntime.ts | 65 +++++ gui/src/app/Scripting/Runtime/preamble.py | 9 + gui/src/app/Scripting/Runtime/run_analysis.py | 18 ++ gui/src/app/Scripting/Runtime/run_data.py | 5 + gui/src/app/Scripting/Runtime/sample.py | 7 + gui/test/app/Scripting/makePyRuntime.test.ts | 224 ++++++++++++++++++ 8 files changed, 360 insertions(+) create mode 100644 gui/src/app/Scripting/Runtime/cmdstan.py create mode 100644 gui/src/app/Scripting/Runtime/load_args.py create mode 100644 gui/src/app/Scripting/Runtime/makePyRuntime.ts create mode 100644 gui/src/app/Scripting/Runtime/preamble.py create mode 100644 gui/src/app/Scripting/Runtime/run_analysis.py create mode 100644 gui/src/app/Scripting/Runtime/run_data.py create mode 100644 gui/src/app/Scripting/Runtime/sample.py create mode 100644 gui/test/app/Scripting/makePyRuntime.test.ts diff --git a/gui/src/app/Scripting/Runtime/cmdstan.py b/gui/src/app/Scripting/Runtime/cmdstan.py new file mode 100644 index 0000000..4d23c58 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/cmdstan.py @@ -0,0 +1,7 @@ +try: + cmdstanpy.cmdstan_path() +except Exception: + if args.install_cmdstan: + cmdstanpy.install_cmdstan() + else: + raise ValueError("cmdstan not found, use --install-cmdstan to install") diff --git a/gui/src/app/Scripting/Runtime/load_args.py b/gui/src/app/Scripting/Runtime/load_args.py new file mode 100644 index 0000000..bde5d8c --- /dev/null +++ b/gui/src/app/Scripting/Runtime/load_args.py @@ -0,0 +1,25 @@ +def rename_sampling_options(k): + """ + convert between names used in + Stan-Playground and CmdStanPy + """ + + if k == "init_radius": + return "inits" + if k == "num_warmup": + return "iter_warmup" + if k == "num_samples": + return "iter_sampling" + if k == "num_chains": + return "chains" + + raise ValueError(f"Unknown sampling option: {k}") + + +if os.path.isfile(os.path.join(HERE, "sampling_opts.json")): + print("loading sampling_opts.json") + with open(os.path.join(HERE, "sampling_opts.json")) as f: + s = json.load(f) + sampling_opts = {rename_sampling_options(k): v for k, v in s.items()} +else: + sampling_opts = {} diff --git a/gui/src/app/Scripting/Runtime/makePyRuntime.ts b/gui/src/app/Scripting/Runtime/makePyRuntime.ts new file mode 100644 index 0000000..1faf81c --- /dev/null +++ b/gui/src/app/Scripting/Runtime/makePyRuntime.ts @@ -0,0 +1,65 @@ +import { ProjectDataModel } from "@SpCore/ProjectDataModel"; + +import spPreamble from "./preamble.py?raw"; +import spRunData from "./run_data.py?raw"; +import spLoadConfig from "./load_args.py?raw"; +import spCmdStan from "./cmdstan.py?raw"; +import spRunSampling from "./sample.py?raw"; +import spDrawsScript from "../pyodide/sp_load_draws.py?raw"; +import spRunAnalysis from "./run_analysis.py?raw"; + +const indent = (s: string) => { + return s + .trim() + .split("\n") + .map((x) => " " + x) + .join("\n"); +}; + +const makePyRuntimeScript = (project: ProjectDataModel) => { + const hasDataJson = project.dataFileContent.length > 0; + const hasDataPy = project.dataPyFileContent.length > 0; + const hasAnalysisPy = project.analysisPyFileContent.length > 0; + + let script = `TITLE=${JSON.stringify(project.meta.title)}\n` + spPreamble; + + // arguments + script += `argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing")\n`; + if (hasDataJson && hasDataPy) { + script += `argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files")\n`; + } + script += `args, _ = argparser.parse_known_args()\n\n`; + + // data + if (hasDataJson && hasDataPy) { + script += `if args.ignore_saved_data:\n`; + script += indent(spRunData); + script += `\nelse:\n`; + script += ` print("Loading data from data.json, pass --ignore--saved-data to run data.py instead")\n`; + script += ` data = os.path.join(HERE, 'data.json')\n\n`; + } else if (hasDataJson) { + script += `data = os.path.join(HERE, 'data.json')\n\n`; + } else if (hasDataPy) { + script += spRunData; + script += `\n`; + } + + // running sampler + script += spLoadConfig; + script += `\n`; + script += spCmdStan; + script += `\n`; + script += spRunSampling; + + // analysis + if (hasAnalysisPy) { + script += `\n`; + script += spDrawsScript; + script += `\n`; + script += spRunAnalysis; + } + + return script; +}; + +export default makePyRuntimeScript; diff --git a/gui/src/app/Scripting/Runtime/preamble.py b/gui/src/app/Scripting/Runtime/preamble.py new file mode 100644 index 0000000..3eaa8fd --- /dev/null +++ b/gui/src/app/Scripting/Runtime/preamble.py @@ -0,0 +1,9 @@ +import argparse +import json +import os + +import cmdstanpy + +HERE = os.path.dirname(os.path.abspath(__file__)) + +argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}") diff --git a/gui/src/app/Scripting/Runtime/run_analysis.py b/gui/src/app/Scripting/Runtime/run_analysis.py new file mode 100644 index 0000000..757d6d2 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/run_analysis.py @@ -0,0 +1,18 @@ +import matplotlib.pyplot as plt + +print("executing analysis.py") + +sp_data = { + "draws": fit.draws(concat_chains=True), + "paramNames": fit.metadata.cmdstan_config["raw_header"].split(","), + "numChains": fit.chains, +} + +draws = sp_load_draws(sp_data) +del sp_data + +with open(os.path.join(HERE, "analysis.py")) as f: + exec(f.read()) + +if len(plt.gcf().get_children()) > 1: + plt.show() diff --git a/gui/src/app/Scripting/Runtime/run_data.py b/gui/src/app/Scripting/Runtime/run_data.py new file mode 100644 index 0000000..e6e65a6 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/run_data.py @@ -0,0 +1,5 @@ +print("executing data.py") +with open(os.path.join(HERE, "data.py")) as f: + exec(f.read()) +if "data" not in locals(): + raise ValueError("data variable not defined in data.py") diff --git a/gui/src/app/Scripting/Runtime/sample.py b/gui/src/app/Scripting/Runtime/sample.py new file mode 100644 index 0000000..54d23f0 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/sample.py @@ -0,0 +1,7 @@ +print("compiling model") +model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan")) + +print("sampling") +fit = model.sample(data=data, **sampling_opts) + +print(fit.summary()) diff --git a/gui/test/app/Scripting/makePyRuntime.test.ts b/gui/test/app/Scripting/makePyRuntime.test.ts new file mode 100644 index 0000000..ec0178a --- /dev/null +++ b/gui/test/app/Scripting/makePyRuntime.test.ts @@ -0,0 +1,224 @@ +import { + initialDataModel, + ProjectDataModel, + ProjectKnownFiles, +} from "@SpCore/ProjectDataModel"; +import makePyRuntimeScript from "@SpScripting/Runtime/makePyRuntime"; +import { describe, expect, test } from "vitest"; + +const testDataModel: ProjectDataModel = structuredClone(initialDataModel); +Object.values(ProjectKnownFiles).forEach((f) => { + testDataModel[f] = JSON.stringify(f); +}); +testDataModel.meta.title = "my title"; + +const full = `TITLE="my title" +import argparse +import json +import os + +import cmdstanpy + +HERE = os.path.dirname(os.path.abspath(__file__)) + +argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}") +argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing") +argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files") +args, _ = argparser.parse_known_args() + +if args.ignore_saved_data: + print("executing data.py") + with open(os.path.join(HERE, "data.py")) as f: + exec(f.read()) + if "data" not in locals(): + raise ValueError("data variable not defined in data.py") +else: + print("Loading data from data.json, pass --ignore--saved-data to run data.py instead") + data = os.path.join(HERE, 'data.json') + +def rename_sampling_options(k): + """ + convert between names used in + Stan-Playground and CmdStanPy + """ + + if k == "init_radius": + return "inits" + if k == "num_warmup": + return "iter_warmup" + if k == "num_samples": + return "iter_sampling" + if k == "num_chains": + return "chains" + + raise ValueError(f"Unknown sampling option: {k}") + + +if os.path.isfile(os.path.join(HERE, "sampling_opts.json")): + print("loading sampling_opts.json") + with open(os.path.join(HERE, "sampling_opts.json")) as f: + s = json.load(f) + sampling_opts = {rename_sampling_options(k): v for k, v in s.items()} +else: + sampling_opts = {} + +try: + cmdstanpy.cmdstan_path() +except Exception: + if args.install_cmdstan: + cmdstanpy.install_cmdstan() + else: + raise ValueError("cmdstan not found, use --install-cmdstan to install") + +print("compiling model") +model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan")) + +print("sampling") +fit = model.sample(data=data, **sampling_opts) + +print(fit.summary()) + +# Used in pyodideWorker for running analysis.py + +from typing import TYPE_CHECKING, List, TypedDict + +import numpy as np +import pandas as pd +import stanio + +# We don't import this unconditionaly because +# we only install it when the user's script needs it +if TYPE_CHECKING: + from arviz import InferenceData + + +class SpData(TypedDict): + draws: List[List[float]] + paramNames: List[str] + numChains: int + + +class DrawsObject: + def __init__(self, sp_data: SpData): + + self._all_parameter_names: List[str] = sp_data["paramNames"] + + self._params = stanio.parse_header(",".join(self._all_parameter_names)) + + self._num_chains: int = sp_data["numChains"] + + # draws come in as num_params by (num_chains * num_draws) + self._draws = ( + np.array(sp_data["draws"]) + .transpose() + .reshape(self._num_chains, -1, len(self._all_parameter_names)) + ) + + def __repr__(self) -> str: + return f"""SpDraws with {self._num_chains} chains, {self._draws.shape[1]} draws, and {self._draws.shape[2]} parameters. + Methods: + - as_dataframe(): return a pandas DataFrame of the draws. + - as_numpy(): return a numpy array indexed by (chain, draw, parameter) + - as_arviz(): return an arviz InferenceData object + - get(pname: str): return a numpy array of the parameter values for the given parameter name""" + + def as_dataframe(self) -> pd.DataFrame: + # The first column is the chain id + # The second column is the draw number + # The remaining columns are the parameter values + + (num_chains, num_draws, num_params) = self._draws.shape + flattened = self._draws.reshape(-1, num_params) + + chain_ids = np.repeat(np.arange(1, num_chains + 1), num_draws) + draw_numbers = np.tile(np.arange(1, num_draws + 1), num_chains) + + data = np.column_stack((chain_ids, draw_numbers, flattened)) + + df = pd.DataFrame(data, columns=["chain", "draw"] + self._all_parameter_names) + return df + + def as_numpy(self) -> np.ndarray: + return np.array(self._draws) + + def get(self, pname: str) -> np.ndarray: + if pname not in self._params: + raise ValueError(f"Parameter {pname} not found") + return self._params[pname].extract_reshape(self._draws) + + def to_arviz(self) -> "InferenceData": + import arviz as az + + return az.from_dict( + posterior={pname: self.get(pname) for pname in self.parameter_names}, + ) + + @property + def parameter_names(self) -> List[str]: + return list(self._params.keys()) + + @property + def raw_parameter_names(self) -> List[str]: + return list(self._all_parameter_names) + + +def sp_load_draws(sp_data: SpData) -> DrawsObject: + return DrawsObject(sp_data) + +import matplotlib.pyplot as plt + +print("executing analysis.py") + +sp_data = { + "draws": fit.draws(concat_chains=True), + "paramNames": fit.metadata.cmdstan_config["raw_header"].split(","), + "numChains": fit.chains, +} + +draws = sp_load_draws(sp_data) +del sp_data + +with open(os.path.join(HERE, "analysis.py")) as f: + exec(f.read()) + +if len(plt.gcf().get_children()) > 1: + plt.show() +`; + +describe("Python runtime", () => { + // these serve as "golden" tests, just to make sure the output is as expected + + test("Export full", () => { + const runPy = makePyRuntimeScript(testDataModel); + expect(runPy).toEqual(full); + }); + + test("Export without data", () => { + const noData = { + ...testDataModel, + dataFileContent: "", + dataPyFileContent: "", + }; + const runPy = makePyRuntimeScript(noData); + + // we expect the same output minus the data loading part + const lines = full.split("\n"); + const dataless = + lines.slice(0, 11).join("\n") + + "\n" + + lines[12] + + "\n" + + lines.slice(23).join("\n"); + expect(runPy).toEqual(dataless); + }); + + test("Export without analysis", () => { + const noAnalysis = { ...testDataModel, analysisPyFileContent: "" }; + const runPy = makePyRuntimeScript(noAnalysis); + + // we expect the same output, truncated after the sampling part + const analysisless = full.split("\n").slice(0, 65).join("\n") + "\n"; + + expect(runPy).toEqual(analysisless); + }); +}); From 73c8030f6b261d6a473fc40ac250be7c9cc09ac0 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 28 Oct 2024 20:55:28 +0000 Subject: [PATCH 2/4] Bad initial download UI for testing --- gui/src/app/Project/ProjectSerialization.ts | 22 +-------------- .../app/pages/HomePage/SaveProjectWindow.tsx | 28 +++++++++++++++++-- gui/src/app/util/serializeAsZip.ts | 21 ++++++++++++++ 3 files changed, 47 insertions(+), 24 deletions(-) create mode 100644 gui/src/app/util/serializeAsZip.ts diff --git a/gui/src/app/Project/ProjectSerialization.ts b/gui/src/app/Project/ProjectSerialization.ts index c9cfd10..2169b53 100644 --- a/gui/src/app/Project/ProjectSerialization.ts +++ b/gui/src/app/Project/ProjectSerialization.ts @@ -49,26 +49,6 @@ export const deserializeProjectFromLocalStorage = ( } }; -export const serializeAsZip = async ( - data: ProjectDataModel, -): Promise<[Blob, string]> => { - const fileManifest = mapModelToFileManifest(data); - const folderName = replaceSpacesWithUnderscores(data.meta.title); - const zip = new JSZip(); - const folder = zip.folder(folderName); - if (!folder) { - throw new Error("Error creating folder in zip file"); - } - Object.entries(fileManifest).forEach(([name, content]) => { - if (content.trim() !== "") { - folder.file(name, content); - } - }); - const zipBlob = await zip.generateAsync({ type: "blob" }); - - return [zipBlob, folderName]; -}; - export const parseFile = (fileBuffer: ArrayBuffer) => { const content = new TextDecoder().decode(fileBuffer); return content; @@ -98,7 +78,7 @@ export const deserializeZipToFiles = async (zipBuffer: ArrayBuffer) => { const content = await file.async("arraybuffer"); const decoded = new TextDecoder().decode(content); files[basename] = decoded; - } else { + } else if (!["run.R", "run.py"].includes(basename)) { throw new Error( `Unrecognized file in zip: ${file.name} (basename ${basename})`, ); diff --git a/gui/src/app/pages/HomePage/SaveProjectWindow.tsx b/gui/src/app/pages/HomePage/SaveProjectWindow.tsx index 24ad937..c50ef53 100644 --- a/gui/src/app/pages/HomePage/SaveProjectWindow.tsx +++ b/gui/src/app/pages/HomePage/SaveProjectWindow.tsx @@ -7,16 +7,18 @@ import { AlternatingTableRow } from "@SpComponents/StyledTables"; import { FileRegistry, mapModelToFileManifest } from "@SpCore/FileMapping"; import { ProjectContext } from "@SpCore/ProjectContextProvider"; import { triggerDownload } from "@SpUtil/triggerDownload"; -import Button from "@mui/material/Button"; +import makePyRuntimeScript from "@SpScripting/Runtime/makePyRuntime"; import loadFilesFromGist from "@SpCore/gists/loadFilesFromGist"; -import { serializeAsZip } from "@SpCore/ProjectSerialization"; import saveAsGitHubGist, { createPatchForUpdatingGist, updateGitHubGist, } from "@SpCore/gists/saveAsGitHubGist"; +import Button from "@mui/material/Button"; import TextField from "@mui/material/TextField"; import TableRow from "@mui/material/TableRow"; import Link from "@mui/material/Link"; +import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces"; +import { serializeAsZip } from "@SpUtil/serializeAsZip"; type SaveProjectWindowProps = { onClose: () => void; @@ -31,6 +33,8 @@ const SaveProjectWindow: FunctionComponent = ({ const [exportingToGist, setExportingToGist] = useState(false); const [updatingExistingGist, setUpdatingExistingGist] = useState(false); + const [includeRunPy, setIncludeRunPy] = useState(false); + return (
@@ -64,6 +68,18 @@ const SaveProjectWindow: FunctionComponent = ({ ), )} + + + Include a run.py file for use with CmdStanPy? + + + setIncludeRunPy(e.target.checked)} + /> + + @@ -72,7 +88,13 @@ const SaveProjectWindow: FunctionComponent = ({