Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the option to download 'takeout scripts' #245

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 1 addition & 23 deletions gui/src/app/Project/ProjectSerialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
FileRegistry,
ProjectFileMap,
mapFileContentsToModel,
mapModelToFileManifest,
} from "@SpCore/FileMapping";
import {
ProjectDataModel,
Expand All @@ -16,7 +15,6 @@ import {
parseSamplingOpts,
persistStateToEphemera,
} from "@SpCore/ProjectDataModel";
import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces";
import JSZip from "jszip";

export const serializeProjectToLocalStorage = (
Expand Down Expand Up @@ -49,26 +47,6 @@ export const deserializeProjectFromLocalStorage = (
}
};

export const serializeAsZip = async (
data: ProjectDataModel,
): Promise<[Blob, string]> => {
const fileManifest = mapModelToFileManifest(data);
const folderName = replaceSpacesWithUnderscores(data.meta.title);
const zip = new JSZip();
const folder = zip.folder(folderName);
if (!folder) {
throw new Error("Error creating folder in zip file");
}
Object.entries(fileManifest).forEach(([name, content]) => {
if (content.trim() !== "") {
folder.file(name, content);
}
});
const zipBlob = await zip.generateAsync({ type: "blob" });

return [zipBlob, folderName];
};

export const parseFile = (fileBuffer: ArrayBuffer) => {
const content = new TextDecoder().decode(fileBuffer);
return content;
Expand Down Expand Up @@ -98,7 +76,7 @@ export const deserializeZipToFiles = async (zipBuffer: ArrayBuffer) => {
const content = await file.async("arraybuffer");
const decoded = new TextDecoder().decode(content);
files[basename] = decoded;
} else {
} else if (!["run.R", "run.py"].includes(basename)) {
throw new Error(
`Unrecognized file in zip: ${file.name} (basename ${basename})`,
);
Expand Down
7 changes: 7 additions & 0 deletions gui/src/app/Scripting/Runtime/cmdstan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
try:
cmdstanpy.cmdstan_path()
except Exception:
if args.install_cmdstan:
cmdstanpy.install_cmdstan()
else:
raise ValueError("cmdstan not found, use --install-cmdstan to install")
27 changes: 27 additions & 0 deletions gui/src/app/Scripting/Runtime/load_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
def rename_sampling_options(k):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this only renames one option at a time.

"""
convert between names used in
Stan-Playground and CmdStanPy
"""

if k == "init_radius":
return "inits"
if k == "num_warmup":
return "iter_warmup"
if k == "num_samples":
return "iter_sampling"
if k == "num_chains":
return "chains"
if k == "seed":
return "seed"
Comment on lines +7 to +16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I want this to be a dictionary, but I think I actually just want Python to be Typescript.


raise ValueError(f"Unknown sampling option: {k}")


if os.path.isfile(os.path.join(HERE, "sampling_opts.json")):
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
print("loading sampling_opts.json")
with open(os.path.join(HERE, "sampling_opts.json")) as f:
s = json.load(f)
sampling_opts = {rename_sampling_options(k): v for k, v in s.items()}
else:
sampling_opts = {}
65 changes: 65 additions & 0 deletions gui/src/app/Scripting/Runtime/makePyRuntime.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { ProjectDataModel } from "@SpCore/ProjectDataModel";

import spPreamble from "./preamble.py?raw";
import spRunData from "./run_data.py?raw";
import spLoadConfig from "./load_args.py?raw";
import spCmdStan from "./cmdstan.py?raw";
import spRunSampling from "./sample.py?raw";
import spDrawsScript from "../pyodide/sp_load_draws.py?raw";
import spRunAnalysis from "./run_analysis.py?raw";

const indent = (s: string) => {
return s
.trim()
.split("\n")
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
.map((x) => " " + x)
.join("\n");
};

const makePyRuntimeScript = (project: ProjectDataModel) => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of code is always going to be pretty ugly to write, but when you're ready to take this out of draft we can chat about whether there's anything to be done with it.

const hasDataJson = project.dataFileContent.length > 0;
const hasDataPy = project.dataPyFileContent.length > 0;
const hasAnalysisPy = project.analysisPyFileContent.length > 0;

let script = `TITLE=${JSON.stringify(project.meta.title)}\n` + spPreamble;

// arguments
script += `argparser.add_argument("--install-cmdstan", action="store_true", help="Install cmdstan if it is missing")\n`;
if (hasDataJson && hasDataPy) {
script += `argparser.add_argument("--ignore-saved-data", action="store_true", help="Ignore saved data.json files")\n`;
}
script += `args, _ = argparser.parse_known_args()\n\n`;

// data
if (hasDataJson && hasDataPy) {
script += `if args.ignore_saved_data:\n`;
script += indent(spRunData);
script += `\nelse:\n`;
script += ` print("Loading data from data.json, pass --ignore-saved-data to run data.py instead")\n`;
script += ` data = os.path.join(HERE, 'data.json')\n\n`;
} else if (hasDataJson) {
script += `data = os.path.join(HERE, 'data.json')\n\n`;
} else if (hasDataPy) {
script += spRunData;
script += `\n`;
}

// running sampler
script += spLoadConfig;
script += `\n`;
script += spCmdStan;
script += `\n`;
script += spRunSampling;

// analysis
if (hasAnalysisPy) {
script += `\n`;
script += spDrawsScript;
script += `\n`;
script += spRunAnalysis;
}

return script;
};

export default makePyRuntimeScript;
9 changes: 9 additions & 0 deletions gui/src/app/Scripting/Runtime/preamble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import argparse
import json
import os

import cmdstanpy

HERE = os.path.dirname(os.path.abspath(__file__))

argparser = argparse.ArgumentParser(prog=f"Stan-Playground: {TITLE}")
18 changes: 18 additions & 0 deletions gui/src/app/Scripting/Runtime/run_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import matplotlib.pyplot as plt

print("executing analysis.py")

sp_data = {
"draws": fit.draws(concat_chains=True).transpose(),
"paramNames": fit.metadata.cmdstan_config["raw_header"].split(","),
"numChains": fit.chains,
}

draws = sp_load_draws(sp_data)
del sp_data

with open(os.path.join(HERE, "analysis.py")) as f:
exec(f.read())

if len(plt.gcf().get_children()) > 1:
plt.show()
5 changes: 5 additions & 0 deletions gui/src/app/Scripting/Runtime/run_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
print("executing data.py")
with open(os.path.join(HERE, "data.py")) as f:
exec(f.read())
if "data" not in locals():
raise ValueError("data variable not defined in data.py")
7 changes: 7 additions & 0 deletions gui/src/app/Scripting/Runtime/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
print("compiling model")
model = cmdstanpy.CmdStanModel(stan_file=os.path.join(HERE, "main.stan"))

print("sampling")
fit = model.sample(data=data, **sampling_opts)

print(fit.summary())
28 changes: 25 additions & 3 deletions gui/src/app/pages/HomePage/SaveProjectWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ import { AlternatingTableRow } from "@SpComponents/StyledTables";
import { FileRegistry, mapModelToFileManifest } from "@SpCore/FileMapping";
import { ProjectContext } from "@SpCore/ProjectContextProvider";
import { triggerDownload } from "@SpUtil/triggerDownload";
import Button from "@mui/material/Button";
import makePyRuntimeScript from "@SpScripting/Runtime/makePyRuntime";
import loadFilesFromGist from "@SpCore/gists/loadFilesFromGist";
import { serializeAsZip } from "@SpCore/ProjectSerialization";
import saveAsGitHubGist, {
createPatchForUpdatingGist,
updateGitHubGist,
} from "@SpCore/gists/saveAsGitHubGist";
import Button from "@mui/material/Button";
import TextField from "@mui/material/TextField";
import TableRow from "@mui/material/TableRow";
import Link from "@mui/material/Link";
import { replaceSpacesWithUnderscores } from "@SpUtil/replaceSpaces";
import { serializeAsZip } from "@SpUtil/serializeAsZip";

type SaveProjectWindowProps = {
onClose: () => void;
Expand All @@ -31,6 +33,8 @@ const SaveProjectWindow: FunctionComponent<SaveProjectWindowProps> = ({
const [exportingToGist, setExportingToGist] = useState(false);
const [updatingExistingGist, setUpdatingExistingGist] = useState(false);

const [includeRunPy, setIncludeRunPy] = useState(false);

return (
<div className="dialogWrapper">
<TableContainer>
Expand Down Expand Up @@ -64,6 +68,18 @@ const SaveProjectWindow: FunctionComponent<SaveProjectWindowProps> = ({
</AlternatingTableRow>
),
)}
<AlternatingTableRow hover>
<TableCell>
Include a run.py file for use with CmdStanPy?
</TableCell>
<TableCell>
<input
type="checkbox"
checked={includeRunPy}
onChange={(e) => setIncludeRunPy(e.target.checked)}
/>
</TableCell>
</AlternatingTableRow>
</TableBody>
</Table>
</TableContainer>
Expand All @@ -72,7 +88,13 @@ const SaveProjectWindow: FunctionComponent<SaveProjectWindowProps> = ({
<div>
<Button
onClick={async () => {
serializeAsZip(data).then(([zipBlob, name]) =>
const fileManifest: { [key: string]: string } =
mapModelToFileManifest(data);
const folderName = replaceSpacesWithUnderscores(data.meta.title);
if (includeRunPy) {
fileManifest["run.py"] = makePyRuntimeScript(data);
}
serializeAsZip(folderName, fileManifest).then(([zipBlob, name]) =>
triggerDownload(zipBlob, `SP-${name}.zip`, onClose),
);
}}
Expand Down
21 changes: 21 additions & 0 deletions gui/src/app/util/serializeAsZip.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import JSZip from "jszip";

export const serializeAsZip = async (
folderName: string,
files: { [key: string]: string } = {},
): Promise<[Blob, string]> => {
const zip = new JSZip();
const folder = zip.folder(folderName);
if (!folder) {
throw new Error("Error creating folder in zip file");
}

Object.entries(files).forEach(([name, content]) => {
if (content.trim() !== "") {
folder.file(name, content);
}
});
const zipBlob = await zip.generateAsync({ type: "blob" });

return [zipBlob, folderName];
};
Loading