diff --git a/cmdstanpy/__init__.py b/cmdstanpy/__init__.py index 815f1ab1..4e3c134e 100644 --- a/cmdstanpy/__init__.py +++ b/cmdstanpy/__init__.py @@ -22,7 +22,7 @@ def _cleanup_tmpdir() -> None: from ._version import __version__ # noqa -from .compilation import compile_stan_file +from .compilation import compile_stan_file, format_stan_file from .install_cmdstan import rebuild_cmdstan from .model import CmdStanModel from .stanfit import ( @@ -50,6 +50,7 @@ def _cleanup_tmpdir() -> None: 'set_make_env', 'install_cmdstan', 'compile_stan_file', + 'format_stan_file', 'CmdStanMCMC', 'CmdStanMLE', 'CmdStanGQ', diff --git a/cmdstanpy/compilation.py b/cmdstanpy/compilation.py index e096c9dc..4c21585a 100644 --- a/cmdstanpy/compilation.py +++ b/cmdstanpy/compilation.py @@ -9,11 +9,17 @@ import shutil import subprocess from copy import copy +from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union from cmdstanpy.utils import get_logger -from cmdstanpy.utils.cmdstan import EXTENSION, cmdstan_path +from cmdstanpy.utils.cmdstan import ( + EXTENSION, + cmdstan_path, + cmdstan_version, + cmdstan_version_before, +) from cmdstanpy.utils.command import do_command from cmdstanpy.utils.filesystem import SanitizedOrTmpFilePath @@ -476,3 +482,98 @@ def compile_stan_file( f"Failed to compile Stan model '{src}'. " f"Console:\n{console}" ) return str(exe_target) + + +def format_stan_file( + stan_file: Union[str, os.PathLike], + *, + overwrite_file: bool = False, + canonicalize: Union[bool, str, Iterable[str]] = False, + max_line_length: int = 78, + backup: bool = True, + stanc_options: Optional[Dict[str, Any]] = None, +) -> None: + """ + Run stanc's auto-formatter on the model code. Either saves directly + back to the file or prints for inspection + + :param stan_file: Path to Stan program file. + :param overwrite_file: If True, save the updated code to disk, rather + than printing it. By default False + :param canonicalize: Whether or not the compiler should 'canonicalize' + the Stan model, removing things like deprecated syntax. Default is + False. If True, all canonicalizations are run. If it is a list of + strings, those options are passed to stanc (new in Stan 2.29) + :param max_line_length: Set the wrapping point for the formatter. The + default value is 78, which wraps most lines by the 80th character. + :param backup: If True, create a stanfile.bak backup before + writing to the file. Only disable this if you're sure you have other + copies of the file or are using a version control system like Git. + :param stanc_options: Additional options to pass to the stanc compiler. + """ + stan_file = Path(stan_file).resolve() + + if not stan_file.exists(): + raise ValueError(f'File does not exist: {stan_file}') + + try: + cmd = ( + [os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)] + # handle include-paths, allow-undefined etc + + CompilerOptions(stanc_options=stanc_options).compose_stanc(None) + + [str(stan_file)] + ) + + if canonicalize: + if cmdstan_version_before(2, 29): + if isinstance(canonicalize, bool): + cmd.append('--print-canonical') + else: + raise ValueError( + "Invalid arguments passed for current CmdStan" + + " version({})\n".format( + cmdstan_version() or "Unknown" + ) + + "--canonicalize requires 2.29 or higher" + ) + else: + if isinstance(canonicalize, str): + cmd.append('--canonicalize=' + canonicalize) + elif isinstance(canonicalize, Iterable): + cmd.append('--canonicalize=' + ','.join(canonicalize)) + else: + cmd.append('--print-canonical') + + # before 2.29, having both --print-canonical + # and --auto-format printed twice + if not (cmdstan_version_before(2, 29) and canonicalize): + cmd.append('--auto-format') + + if not cmdstan_version_before(2, 29): + cmd.append(f'--max-line-length={max_line_length}') + elif max_line_length != 78: + raise ValueError( + "Invalid arguments passed for current CmdStan version" + + " ({})\n".format(cmdstan_version() or "Unknown") + + "--max-line-length requires 2.29 or higher" + ) + + out = subprocess.run(cmd, capture_output=True, text=True, check=True) + if out.stderr: + get_logger().warning(out.stderr) + result = out.stdout + if overwrite_file: + if result: + if backup: + shutil.copyfile( + stan_file, + str(stan_file) + + '.bak-' + + datetime.now().strftime("%Y%m%d%H%M%S"), + ) + stan_file.write_text(result) + else: + print(result) + + except (ValueError, RuntimeError) as e: + raise RuntimeError("Stanc formatting failed") from e diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 17be3b5e..0063d2cd 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -10,7 +10,6 @@ import threading from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor -from datetime import datetime from io import StringIO from multiprocessing import cpu_count from typing import ( @@ -57,9 +56,7 @@ from_csv, ) from cmdstanpy.utils import ( - EXTENSION, cmdstan_path, - cmdstan_version, cmdstan_version_before, do_command, get_logger, @@ -320,6 +317,7 @@ def src_info(self) -> Dict[str, Any]: return {} return compilation.src_info(str(self.stan_file), self._compiler_options) + # TODO(2.0) remove def format( self, overwrite_file: bool = False, @@ -329,6 +327,8 @@ def format( backup: bool = True, ) -> None: """ + Deprecated: Use :func:`cmdstanpy.format_stan_file()` instead. + Run stanc's auto-formatter on the model code. Either saves directly back to the file or prints for inspection @@ -345,72 +345,24 @@ def format( writing to the file. Only disable this if you're sure you have other copies of the file or are using a version control system like Git. """ - if self.stan_file is None or not os.path.isfile(self.stan_file): - raise ValueError("No Stan file found for this module") - try: - cmd = ( - [os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)] - # handle include-paths, allow-undefined etc - + self._compiler_options.compose_stanc(None) - + [str(self.stan_file)] - ) - if canonicalize: - if cmdstan_version_before(2, 29): - if isinstance(canonicalize, bool): - cmd.append('--print-canonical') - else: - raise ValueError( - "Invalid arguments passed for current CmdStan" - + " version({})\n".format( - cmdstan_version() or "Unknown" - ) - + "--canonicalize requires 2.29 or higher" - ) - else: - if isinstance(canonicalize, str): - cmd.append('--canonicalize=' + canonicalize) - elif isinstance(canonicalize, Iterable): - cmd.append('--canonicalize=' + ','.join(canonicalize)) - else: - cmd.append('--print-canonical') - - # before 2.29, having both --print-canonical - # and --auto-format printed twice - if not (cmdstan_version_before(2, 29) and canonicalize): - cmd.append('--auto-format') - - if not cmdstan_version_before(2, 29): - cmd.append(f'--max-line-length={max_line_length}') - elif max_line_length != 78: - raise ValueError( - "Invalid arguments passed for current CmdStan version" - + " ({})\n".format(cmdstan_version() or "Unknown") - + "--max-line-length requires 2.29 or higher" - ) + get_logger().warning( + "CmdStanModel.format() is deprecated and will be " + "removed in the next major version.\n" + "Use cmdstanpy.format_stan_file() instead." + ) - out = subprocess.run( - cmd, capture_output=True, text=True, check=True - ) - if out.stderr: - get_logger().warning(out.stderr) - result = out.stdout - if overwrite_file: - if result: - if backup: - shutil.copyfile( - self.stan_file, - str(self.stan_file) - + '.bak-' - + datetime.now().strftime("%Y%m%d%H%M%S"), - ) - with open(self.stan_file, 'w') as file_handle: - file_handle.write(result) - else: - print(result) + if self.stan_file is None: + raise ValueError("No Stan file found for this module") - except (ValueError, RuntimeError) as e: - raise RuntimeError("Stanc formatting failed") from e + compilation.format_stan_file( + self.stan_file, + overwrite_file=overwrite_file, + max_line_length=max_line_length, + canonicalize=canonicalize, + backup=backup, + stanc_options=self.stanc_options, + ) @property def stanc_options(self) -> Dict[str, Union[bool, int, str]]: