Skip to content

Commit

Permalink
Extract format functionality from Model class
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Mar 18, 2024
1 parent b420952 commit 4f28871
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 68 deletions.
3 changes: 2 additions & 1 deletion cmdstanpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _cleanup_tmpdir() -> None:


from ._version import __version__ # noqa
from .compilation import compile_stan_file
from .compilation import compile_stan_file, format_stan_file
from .install_cmdstan import rebuild_cmdstan
from .model import CmdStanModel
from .stanfit import (
Expand Down Expand Up @@ -50,6 +50,7 @@ def _cleanup_tmpdir() -> None:
'set_make_env',
'install_cmdstan',
'compile_stan_file',
'format_stan_file',
'CmdStanMCMC',
'CmdStanMLE',
'CmdStanGQ',
Expand Down
103 changes: 102 additions & 1 deletion cmdstanpy/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
import shutil
import subprocess
from copy import copy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union

from cmdstanpy.utils import get_logger
from cmdstanpy.utils.cmdstan import EXTENSION, cmdstan_path
from cmdstanpy.utils.cmdstan import (
EXTENSION,
cmdstan_path,
cmdstan_version,
cmdstan_version_before,
)
from cmdstanpy.utils.command import do_command
from cmdstanpy.utils.filesystem import SanitizedOrTmpFilePath

Expand Down Expand Up @@ -476,3 +482,98 @@ def compile_stan_file(
f"Failed to compile Stan model '{src}'. " f"Console:\n{console}"
)
return str(exe_target)


def format_stan_file(
stan_file: Union[str, os.PathLike],
*,
overwrite_file: bool = False,
canonicalize: Union[bool, str, Iterable[str]] = False,
max_line_length: int = 78,
backup: bool = True,
stanc_options: Optional[Dict[str, Any]] = None,
) -> None:
"""
Run stanc's auto-formatter on the model code. Either saves directly
back to the file or prints for inspection
:param stan_file: Path to Stan program file.
:param overwrite_file: If True, save the updated code to disk, rather
than printing it. By default False
:param canonicalize: Whether or not the compiler should 'canonicalize'
the Stan model, removing things like deprecated syntax. Default is
False. If True, all canonicalizations are run. If it is a list of
strings, those options are passed to stanc (new in Stan 2.29)
:param max_line_length: Set the wrapping point for the formatter. The
default value is 78, which wraps most lines by the 80th character.
:param backup: If True, create a stanfile.bak backup before
writing to the file. Only disable this if you're sure you have other
copies of the file or are using a version control system like Git.
:param stanc_options: Additional options to pass to the stanc compiler.
"""
stan_file = Path(stan_file).resolve()

if not stan_file.exists():
raise ValueError(f'File does not exist: {stan_file}')

try:
cmd = (
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
# handle include-paths, allow-undefined etc
+ CompilerOptions(stanc_options=stanc_options).compose_stanc(None)
+ [str(stan_file)]
)

if canonicalize:
if cmdstan_version_before(2, 29):
if isinstance(canonicalize, bool):
cmd.append('--print-canonical')
else:
raise ValueError(
"Invalid arguments passed for current CmdStan"
+ " version({})\n".format(
cmdstan_version() or "Unknown"
)
+ "--canonicalize requires 2.29 or higher"
)
else:
if isinstance(canonicalize, str):
cmd.append('--canonicalize=' + canonicalize)
elif isinstance(canonicalize, Iterable):
cmd.append('--canonicalize=' + ','.join(canonicalize))
else:
cmd.append('--print-canonical')

# before 2.29, having both --print-canonical
# and --auto-format printed twice
if not (cmdstan_version_before(2, 29) and canonicalize):
cmd.append('--auto-format')

if not cmdstan_version_before(2, 29):
cmd.append(f'--max-line-length={max_line_length}')
elif max_line_length != 78:
raise ValueError(
"Invalid arguments passed for current CmdStan version"
+ " ({})\n".format(cmdstan_version() or "Unknown")
+ "--max-line-length requires 2.29 or higher"
)

out = subprocess.run(cmd, capture_output=True, text=True, check=True)
if out.stderr:
get_logger().warning(out.stderr)
result = out.stdout
if overwrite_file:
if result:
if backup:
shutil.copyfile(
stan_file,
str(stan_file)
+ '.bak-'
+ datetime.now().strftime("%Y%m%d%H%M%S"),
)
stan_file.write_text(result)
else:
print(result)

except (ValueError, RuntimeError) as e:
raise RuntimeError("Stanc formatting failed") from e
84 changes: 18 additions & 66 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import threading
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from io import StringIO
from multiprocessing import cpu_count
from typing import (
Expand Down Expand Up @@ -57,9 +56,7 @@
from_csv,
)
from cmdstanpy.utils import (
EXTENSION,
cmdstan_path,
cmdstan_version,
cmdstan_version_before,
do_command,
get_logger,
Expand Down Expand Up @@ -320,6 +317,7 @@ def src_info(self) -> Dict[str, Any]:
return {}
return compilation.src_info(str(self.stan_file), self._compiler_options)

# TODO(2.0) remove
def format(
self,
overwrite_file: bool = False,
Expand All @@ -329,6 +327,8 @@ def format(
backup: bool = True,
) -> None:
"""
Deprecated: Use :func:`cmdstanpy.format_stan_file()` instead.
Run stanc's auto-formatter on the model code. Either saves directly
back to the file or prints for inspection
Expand All @@ -345,72 +345,24 @@ def format(
writing to the file. Only disable this if you're sure you have other
copies of the file or are using a version control system like Git.
"""
if self.stan_file is None or not os.path.isfile(self.stan_file):
raise ValueError("No Stan file found for this module")
try:
cmd = (
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
# handle include-paths, allow-undefined etc
+ self._compiler_options.compose_stanc(None)
+ [str(self.stan_file)]
)

if canonicalize:
if cmdstan_version_before(2, 29):
if isinstance(canonicalize, bool):
cmd.append('--print-canonical')
else:
raise ValueError(
"Invalid arguments passed for current CmdStan"
+ " version({})\n".format(
cmdstan_version() or "Unknown"
)
+ "--canonicalize requires 2.29 or higher"
)
else:
if isinstance(canonicalize, str):
cmd.append('--canonicalize=' + canonicalize)
elif isinstance(canonicalize, Iterable):
cmd.append('--canonicalize=' + ','.join(canonicalize))
else:
cmd.append('--print-canonical')

# before 2.29, having both --print-canonical
# and --auto-format printed twice
if not (cmdstan_version_before(2, 29) and canonicalize):
cmd.append('--auto-format')

if not cmdstan_version_before(2, 29):
cmd.append(f'--max-line-length={max_line_length}')
elif max_line_length != 78:
raise ValueError(
"Invalid arguments passed for current CmdStan version"
+ " ({})\n".format(cmdstan_version() or "Unknown")
+ "--max-line-length requires 2.29 or higher"
)
get_logger().warning(
"CmdStanModel.format() is deprecated and will be "
"removed in the next major version.\n"
"Use cmdstanpy.format_stan_file() instead."
)

out = subprocess.run(
cmd, capture_output=True, text=True, check=True
)
if out.stderr:
get_logger().warning(out.stderr)
result = out.stdout
if overwrite_file:
if result:
if backup:
shutil.copyfile(
self.stan_file,
str(self.stan_file)
+ '.bak-'
+ datetime.now().strftime("%Y%m%d%H%M%S"),
)
with open(self.stan_file, 'w') as file_handle:
file_handle.write(result)
else:
print(result)
if self.stan_file is None:
raise ValueError("No Stan file found for this module")

except (ValueError, RuntimeError) as e:
raise RuntimeError("Stanc formatting failed") from e
compilation.format_stan_file(
self.stan_file,
overwrite_file=overwrite_file,
max_line_length=max_line_length,
canonicalize=canonicalize,
backup=backup,
stanc_options=self.stanc_options,
)

@property
def stanc_options(self) -> Dict[str, Union[bool, int, str]]:
Expand Down

0 comments on commit 4f28871

Please sign in to comment.