diff --git a/gui/src/app/Project/ProjectSerialization.ts b/gui/src/app/Project/ProjectSerialization.ts index c9cfd10..d5a18a2 100644 --- a/gui/src/app/Project/ProjectSerialization.ts +++ b/gui/src/app/Project/ProjectSerialization.ts @@ -4,7 +4,6 @@ import { FileRegistry, ProjectFileMap, mapFileContentsToModel, - mapModelToFileManifest, } from "@SpCore/FileMapping"; import { ProjectDataModel, @@ -16,7 +15,6 @@ import { parseSamplingOpts, persistStateToEphemera, } from "@SpCore/ProjectDataModel"; -import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces"; import JSZip from "jszip"; export const serializeProjectToLocalStorage = ( @@ -49,26 +47,6 @@ export const deserializeProjectFromLocalStorage = ( } }; -export const serializeAsZip = async ( - data: ProjectDataModel, -): Promise<[Blob, string]> => { - const fileManifest = mapModelToFileManifest(data); - const folderName = replaceSpacesWithUnderscores(data.meta.title); - const zip = new JSZip(); - const folder = zip.folder(folderName); - if (!folder) { - throw new Error("Error creating folder in zip file"); - } - Object.entries(fileManifest).forEach(([name, content]) => { - if (content.trim() !== "") { - folder.file(name, content); - } - }); - const zipBlob = await zip.generateAsync({ type: "blob" }); - - return [zipBlob, folderName]; -}; - export const parseFile = (fileBuffer: ArrayBuffer) => { const content = new TextDecoder().decode(fileBuffer); return content; @@ -98,7 +76,7 @@ export const deserializeZipToFiles = async (zipBuffer: ArrayBuffer) => { const content = await file.async("arraybuffer"); const decoded = new TextDecoder().decode(content); files[basename] = decoded; - } else { + } else if (!["run.R", "run.py"].includes(basename)) { throw new Error( `Unrecognized file in zip: ${file.name} (basename ${basename})`, ); diff --git a/gui/src/app/Scripting/Runtime/cmdstan.py b/gui/src/app/Scripting/Runtime/cmdstan.py new file mode 100644 index 0000000..4d23c58 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/cmdstan.py @@ -0,0 +1,7 @@ +try: + cmdstanpy.cmdstan_path() +except Exception: + if args.install_cmdstan: + cmdstanpy.install_cmdstan() + else: + raise ValueError("cmdstan not found, use --install-cmdstan to install") diff --git a/gui/src/app/Scripting/Runtime/load_args.py b/gui/src/app/Scripting/Runtime/load_args.py new file mode 100644 index 0000000..9d1bb47 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/load_args.py @@ -0,0 +1,27 @@ +def rename_sampling_options(k): + """ + convert between names used in + Stan-Playground and CmdStanPy + """ + + if k == "init_radius": + return "inits" + if k == "num_warmup": + return "iter_warmup" + if k == "num_samples": + return "iter_sampling" + if k == "num_chains": + return "chains" + if k == "seed": + return "seed" + + raise ValueError(f"Unknown sampling option: {k}") + + +if os.path.isfile(os.path.join(HERE, "sampling_opts.json")): + print("loading sampling_opts.json") + with open(os.path.join(HERE, "sampling_opts.json")) as f: + s = json.load(f) + sampling_opts = {rename_sampling_options(k): v for k, v in s.items()} +else: + sampling_opts = {} diff --git a/gui/src/app/Scripting/Runtime/makePyRuntime.ts b/gui/src/app/Scripting/Runtime/makePyRuntime.ts new file mode 100644 index 0000000..a164300 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/makePyRuntime.ts @@ -0,0 +1,65 @@ +import { ProjectDataModel } from "@SpCore/ProjectDataModel"; + +import spPreamble from "./preamble.py?raw"; +import spRunData from "./run_data.py?raw"; +import spLoadConfig from "./load_args.py?raw"; +import spCmdStan from "./cmdstan.py?raw"; +import spRunSampling from "./sample.py?raw"; +import spDrawsScript from "../pyodide/sp_load_draws.py?raw"; +import spRunAnalysis from "./run_analysis.py?raw"; + +const indent = (s: string) => { + return s + .trim() + .split("\n") + .map((x) => " " + x) + .join("\n"); +}; + +const makePyRuntimeScript = (project: ProjectDataModel) => { + const hasDataJson = project.dataFileContent.length > 0; + const hasDataPy = project.dataPyFileContent.length > 0; + const hasAnalysisPy = project.analysisPyFileContent.length > 0; + + let script = `TITLE=${JSON.stringify(project.meta.title)}\n` + spPreamble; + + // arguments + script += `argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing")\n`; + if (hasDataJson && hasDataPy) { + script += `argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files")\n`; + } + script += `args, _ = argparser.parse_known_args()\n\n`; + + // data + if (hasDataJson && hasDataPy) { + script += `if args.ignore_saved_data:\n`; + script += indent(spRunData); + script += `\nelse:\n`; + script += ` print("Loading data from data.json, pass --ignore-saved-data to run data.py instead")\n`; + script += ` data = os.path.join(HERE, 'data.json')\n\n`; + } else if (hasDataJson) { + script += `data = os.path.join(HERE, 'data.json')\n\n`; + } else if (hasDataPy) { + script += spRunData; + script += `\n`; + } + + // running sampler + script += spLoadConfig; + script += `\n`; + script += spCmdStan; + script += `\n`; + script += spRunSampling; + + // analysis + if (hasAnalysisPy) { + script += `\n`; + script += spDrawsScript; + script += `\n`; + script += spRunAnalysis; + } + + return script; +}; + +export default makePyRuntimeScript; diff --git a/gui/src/app/Scripting/Runtime/preamble.py b/gui/src/app/Scripting/Runtime/preamble.py new file mode 100644 index 0000000..3eaa8fd --- /dev/null +++ b/gui/src/app/Scripting/Runtime/preamble.py @@ -0,0 +1,9 @@ +import argparse +import json +import os + +import cmdstanpy + +HERE = os.path.dirname(os.path.abspath(__file__)) + +argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}") diff --git a/gui/src/app/Scripting/Runtime/run_analysis.py b/gui/src/app/Scripting/Runtime/run_analysis.py new file mode 100644 index 0000000..3ccfbc0 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/run_analysis.py @@ -0,0 +1,18 @@ +import matplotlib.pyplot as plt + +print("executing analysis.py") + +sp_data = { + "draws": fit.draws(concat_chains=True).transpose(), + "paramNames": fit.metadata.cmdstan_config["raw_header"].split(","), + "numChains": fit.chains, +} + +draws = sp_load_draws(sp_data) +del sp_data + +with open(os.path.join(HERE, "analysis.py")) as f: + exec(f.read()) + +if len(plt.gcf().get_children()) > 1: + plt.show() diff --git a/gui/src/app/Scripting/Runtime/run_data.py b/gui/src/app/Scripting/Runtime/run_data.py new file mode 100644 index 0000000..e6e65a6 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/run_data.py @@ -0,0 +1,5 @@ +print("executing data.py") +with open(os.path.join(HERE, "data.py")) as f: + exec(f.read()) +if "data" not in locals(): + raise ValueError("data variable not defined in data.py") diff --git a/gui/src/app/Scripting/Runtime/sample.py b/gui/src/app/Scripting/Runtime/sample.py new file mode 100644 index 0000000..54d23f0 --- /dev/null +++ b/gui/src/app/Scripting/Runtime/sample.py @@ -0,0 +1,7 @@ +print("compiling model") +model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan")) + +print("sampling") +fit = model.sample(data=data, **sampling_opts) + +print(fit.summary()) diff --git a/gui/src/app/pages/HomePage/SaveProjectWindow.tsx b/gui/src/app/pages/HomePage/SaveProjectWindow.tsx index 24ad937..c50ef53 100644 --- a/gui/src/app/pages/HomePage/SaveProjectWindow.tsx +++ b/gui/src/app/pages/HomePage/SaveProjectWindow.tsx @@ -7,16 +7,18 @@ import { AlternatingTableRow } from "@SpComponents/StyledTables"; import { FileRegistry, mapModelToFileManifest } from "@SpCore/FileMapping"; import { ProjectContext } from "@SpCore/ProjectContextProvider"; import { triggerDownload } from "@SpUtil/triggerDownload"; -import Button from "@mui/material/Button"; +import makePyRuntimeScript from "@SpScripting/Runtime/makePyRuntime"; import loadFilesFromGist from "@SpCore/gists/loadFilesFromGist"; -import { serializeAsZip } from "@SpCore/ProjectSerialization"; import saveAsGitHubGist, { createPatchForUpdatingGist, updateGitHubGist, } from "@SpCore/gists/saveAsGitHubGist"; +import Button from "@mui/material/Button"; import TextField from "@mui/material/TextField"; import TableRow from "@mui/material/TableRow"; import Link from "@mui/material/Link"; +import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces"; +import { serializeAsZip } from "@SpUtil/serializeAsZip"; type SaveProjectWindowProps = { onClose: () => void; @@ -31,6 +33,8 @@ const SaveProjectWindow: FunctionComponent = ({ const [exportingToGist, setExportingToGist] = useState(false); const [updatingExistingGist, setUpdatingExistingGist] = useState(false); + const [includeRunPy, setIncludeRunPy] = useState(false); + return (
@@ -64,6 +68,18 @@ const SaveProjectWindow: FunctionComponent = ({ ), )} + + + Include a run.py file for use with CmdStanPy? + + + setIncludeRunPy(e.target.checked)} + /> + + @@ -72,7 +88,13 @@ const SaveProjectWindow: FunctionComponent = ({