Skip to content

Commit

Permalink
Apply black formatter
Browse files Browse the repository at this point in the history
Apply black formatter to python files touched in the rebase. Took the
simple approach of just accepting the `dev` changes always, which were
not linted.
  • Loading branch information
TimothyWillard committed Nov 4, 2024
1 parent 9a2d0eb commit 092069a
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 46 deletions.
4 changes: 1 addition & 3 deletions examples/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def test_sample_2pop_modifiers_combined_deprecated():
def test_simple_usa_statelevel_deprecated():
os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel")
runner = CliRunner()
result = runner.invoke(
_click_simulate, ["-n", "1", "-c", "simple_usa_statelevel.yml"]
)
result = runner.invoke(_click_simulate, ["-n", "1", "-c", "simple_usa_statelevel.yml"])
print(result.output) # useful for debug
print(result.exit_code) # useful for debug
print(result.exception) # useful for debug
Expand Down
3 changes: 2 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
# add some basic commands to the CLI
@cli.command(params=[config_files_argument] + list(config_file_options.values()))
@pass_context
def patch(ctx : Context = mock_context, **kwargs) -> None:
def patch(ctx: Context = mock_context, **kwargs) -> None:
"""Merge configuration files"""
parse_config_files(config, ctx, **kwargs)
print(config.dump())


if __name__ == "__main__":
cli()
17 changes: 12 additions & 5 deletions flepimop/gempyor_pkg/src/gempyor/compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,20 +892,24 @@ def list_recursive_convert_to_string(thing):
return [list_recursive_convert_to_string(x) for x in thing]
return str(thing)


@cli.group()
@pass_context
def compartments(ctx: Context):
"""Commands for working with FlepiMoP compartments"""
pass


@compartments.command(params=[config_files_argument] + list(config_file_options.values()))
@pass_context
def plot(ctx : Context, **kwargs):
def plot(ctx: Context, **kwargs):
"""Plot compartments"""
parse_config_files(config, ctx, **kwargs)
assert config["compartments"].exists()
assert config["seir"].exists()
comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"])
comp = Compartments(
seir_config=config["seir"], compartments_config=config["compartments"]
)

# TODO: this should be a command like build compartments.
(
Expand All @@ -920,12 +924,14 @@ def plot(ctx : Context, **kwargs):

@compartments.command(params=[config_files_argument] + list(config_file_options.values()))
@pass_context
def export(ctx : Context, **kwargs):
def export(ctx: Context, **kwargs):
"""Export compartments"""
parse_config_files(config, ctx, **kwargs)
assert config["compartments"].exists()
assert config["seir"].exists()
comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"])
comp = Compartments(
seir_config=config["seir"], compartments_config=config["compartments"]
)
(
unique_strings,
transition_array,
Expand All @@ -935,4 +941,5 @@ def export(ctx : Context, **kwargs):
comp.toFile("compartments_file.csv", "transitions_file.csv", write_parquet=False)
print("wrote files 'compartments_file.csv', 'transitions_file.csv' ")

cli.add_command(compartments)

cli.add_command(compartments)
5 changes: 4 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import partial
from gempyor import compartments


def read_yaml(file_path: str) -> dict:
with open(file_path, "r") as stream:
config = yaml.safe_load(stream)
Expand All @@ -23,11 +24,13 @@ def allowed_values(v, values):
assert v in values
return v


# def parse_value(cls, values):
# value = values.get('value')
# parsed_val = compartments.Compartments.parse_parameter_strings_to_numpy_arrays_v2(value)
# return parsed_val



class SubpopSetupConfig(BaseModel):
geodata: str
mobility: Optional[str]
Expand Down
15 changes: 5 additions & 10 deletions flepimop/gempyor_pkg/src/gempyor/shared_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def cli(ctx: click.Context) -> None:
"""Flexible Epidemic Modeling Platform (FlepiMoP) Command Line Interface"""
pass


# click decorator to handle configuration file(s) as arguments
# use as `@argument_config_files` before a cli command definition
config_files_argument = click.Argument(
Expand Down Expand Up @@ -162,9 +163,7 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
additional_doc = "\n\tCommand Line Interface arguments:\n"
for param in params:
paraminfo = param.to_info_dict()
additional_doc += (
f"\n\t{paraminfo['name']}: {paraminfo['type']['param_type']}"
)
additional_doc += f"\n\t{paraminfo['name']}: {paraminfo['type']['param_type']}"

if func.__doc__ is None:
func.__doc__ = ""
Expand Down Expand Up @@ -203,9 +202,7 @@ def parse_config_files(

def _parse_option(param: click.Parameter, value: Any) -> Any:
"""internal parser to autobox values"""
if (param.multiple or param.nargs == -1) and not isinstance(
value, (list, tuple)
):
if (param.multiple or param.nargs == -1) and not isinstance(value, (list, tuple)):
value = [value]
return param.type_cast_value(ctx, value)

Expand Down Expand Up @@ -244,9 +241,7 @@ def _parse_option(param: click.Parameter, value: Any) -> Any:
tmp = confuse.Configuration("tmp")
tmp.set_file(config_file)
if intersect := set(tmp.keys()) & set(cfg.keys()):
warnings.warn(
f"Configuration files contain overlapping keys: {intersect}."
)
warnings.warn(f"Configuration files contain overlapping keys: {intersect}.")
cfg.set_file(config_file)
cfg["config_src"] = [str(k) for k in config_src]

Expand All @@ -268,5 +263,5 @@ def _parse_option(param: click.Parameter, value: Any) -> Any:
if (value := kwargs.get(option)) is not None:
# auto box the value if the option expects a multiple
cfg[option] = _parse_option(config_file_options[option], value)

return cfg
58 changes: 45 additions & 13 deletions flepimop/gempyor_pkg/src/gempyor/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,18 @@
from click import Context, pass_context

from . import seir, outcomes, model_info, utils
from .shared_cli import config_files_argument, config_file_options, parse_config_files, cli, click_helpstring, mock_context
from .shared_cli import (
config_files_argument,
config_file_options,
parse_config_files,
cli,
click_helpstring,
mock_context,
)

# from .profile import profile_options


# @profile_options
# @profile()
def simulate(
Expand All @@ -184,7 +192,7 @@ def simulate(
write_parquet: bool = True,
first_sim_index: int = 1,
stoch_traj_flag: bool = False,
verbose : bool = True,
verbose: bool = True,
) -> int:
"""
Forward simulate a model using gempyor.
Expand Down Expand Up @@ -216,14 +224,26 @@ def simulate(
cfg = config_filepath

scenarios_combinations = [
[s, d] for s in (cfg["seir_modifiers"]["scenarios"].as_str_seq() if cfg["seir_modifiers"].exists() else [None])
for d in (cfg["outcome_modifiers"]["scenarios"].as_str_seq() if cfg["outcome_modifiers"].exists() else [None])]

[s, d]
for s in (
cfg["seir_modifiers"]["scenarios"].as_str_seq()
if cfg["seir_modifiers"].exists()
else [None]
)
for d in (
cfg["outcome_modifiers"]["scenarios"].as_str_seq()
if cfg["outcome_modifiers"].exists()
else [None]
)
]

if verbose:
print("Combination of modifiers scenarios to be run: ")
print(scenarios_combinations)
for seir_modifiers_scenario, outcome_modifiers_scenario in scenarios_combinations:
print(f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier: {outcome_modifiers_scenario}")
print(
f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier: {outcome_modifiers_scenario}"
)

nchains = cfg["nslots"].as_number()

Expand Down Expand Up @@ -265,17 +285,25 @@ def simulate(
if cfg["seir"].exists():
seir.run_parallel_SEIR(modinf, config=cfg, n_jobs=cfg["jobs"].get(int))
if cfg["outcomes"].exists():
outcomes.run_parallel_outcomes(sim_id2write=cfg["first_sim_index"].get(int), modinf=modinf, nslots=nchains, n_jobs=cfg["jobs"].get(int))
outcomes.run_parallel_outcomes(
sim_id2write=cfg["first_sim_index"].get(int),
modinf=modinf,
nslots=nchains,
n_jobs=cfg["jobs"].get(int),
)
if verbose:
print(
f">>> {seir_modifiers_scenario}_{outcome_modifiers_scenario} completed in {time.monotonic() - start:.1f} seconds"
)

return 0

@cli.command(name="simulate", params=[config_files_argument] + list(config_file_options.values()))

@cli.command(
name="simulate", params=[config_files_argument] + list(config_file_options.values())
)
@pass_context
def _click_simulate(ctx : Context, **kwargs) -> int:
def _click_simulate(ctx: Context, **kwargs) -> int:
"""Forward simulate a model using gempyor."""
cfg = parse_config_files(utils.config, ctx, **kwargs)
return simulate(cfg)
Expand All @@ -285,16 +313,20 @@ def _click_simulate(ctx : Context, **kwargs) -> int:

import subprocess

def _deprecated_simulate(argv : list[str] = []) -> int:

def _deprecated_simulate(argv: list[str] = []) -> int:
if not argv:
argv = sys.argv[1:]
clickcmd = ' '.join(['flepimop', 'simulate'] + argv)
warnings.warn(f"This command is deprecated, use the CLI instead: `{clickcmd}`", DeprecationWarning)
clickcmd = " ".join(["flepimop", "simulate"] + argv)
warnings.warn(
f"This command is deprecated, use the CLI instead: `{clickcmd}`", DeprecationWarning
)
return subprocess.run(clickcmd, shell=True).returncode


if __name__ == "__main__":
argv = sys.argv[1:]
clickcmd = ' '.join(['flepimop', 'simulate'] + argv)
clickcmd = " ".join(["flepimop", "simulate"] + argv)
warnings.warn(f"Use the CLI instead: `{clickcmd}`", DeprecationWarning)
_deprecated_simulate(argv)

Expand Down
7 changes: 4 additions & 3 deletions flepimop/gempyor_pkg/src/gempyor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def mock_empty_config() -> confuse.Configuration:
"""
return confuse.Configuration("flepiMoPMock", read=False)


def create_confuse_config_from_file(
data_file: Path,
) -> confuse.Configuration:
Expand All @@ -74,6 +75,7 @@ def create_confuse_config_from_file(
cv.set_file(data_file)
return cv


def create_confuse_configview_from_dict(
data: dict[str, Any], name: None | str = None
) -> confuse.ConfigView:
Expand Down Expand Up @@ -133,9 +135,8 @@ def create_confuse_configview_from_dict(
cv = cv[name] if name is not None else cv
return cv

def create_confuse_config_from_dict(
data: dict[str, Any]
) -> confuse.Configuration:

def create_confuse_config_from_dict(data: dict[str, Any]) -> confuse.Configuration:
"""
Create a Configuration from a dictionary for unit testing confuse parameters.
Expand Down
24 changes: 16 additions & 8 deletions flepimop/gempyor_pkg/tests/shared_cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_config_sample_2pop():
result = runner.invoke(_click_simulate, ["config_sample_2pop.yml"])
assert result.exit_code == 0


def test_config_sample_2pop_deprecated():
os.chdir(tutorialpath)
runner = CliRunner()
Expand All @@ -42,29 +43,36 @@ def test_sample_2pop_modifiers():
assert result.exit_code == 0


def test_sample_2pop_modifiers_combined(tmp_path : Path):
def test_sample_2pop_modifiers_combined(tmp_path: Path):
os.chdir(tutorialpath)
tmp_cfg1 = tmp_path / "patch_modifiers.yml"
tmp_cfg2 = tmp_path / "nopatch_modifiers.yml"
runner = CliRunner()

result = runner.invoke(patch, ["config_sample_2pop.yml",

result = runner.invoke(
patch,
[
"config_sample_2pop.yml",
"config_sample_2pop_outcomes_part.yml",
"config_sample_2pop_modifiers_part.yml"])
"config_sample_2pop_modifiers_part.yml",
],
)
assert result.exit_code == 0
with open(tmp_cfg1, "w") as f:
f.write(result.output)

result = runner.invoke(patch, ["config_sample_2pop_modifiers.yml"])
assert result.exit_code == 0
with open(tmp_cfg2, "w") as f:
f.write(result.output)



tmpconfig1 = create_confuse_config_from_file(str(tmp_cfg1)).flatten()
tmpconfig2 = create_confuse_config_from_file(str(tmp_cfg2)).flatten()

assert { k: v for k, v in tmpconfig1.items() if k != "config_src" } == { k: v for k, v in tmpconfig2.items() if k != "config_src" }
assert {k: v for k, v in tmpconfig1.items() if k != "config_src"} == {
k: v for k, v in tmpconfig2.items() if k != "config_src"
}


def test_simple_usa_statelevel_more_deprecated():
os.chdir(tutorialpath + "/../simple_usa_statelevel")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pathlib
from typing import Any

Expand Down Expand Up @@ -159,7 +158,7 @@ def test_multifile_config_collision(
tmpconfigfile1 = config_file(tmp_path, testdict1, "config1.yaml")
tmpconfigfile2 = config_file(tmp_path, testdict2, "config2.yaml")
mockconfig = mock_empty_config()
with pytest.warns(UserWarning, match=r'foo'):
with pytest.warns(UserWarning, match=r"foo"):
parse_config_files(mockconfig, config_files=[tmpconfigfile1, tmpconfigfile2])
for k, v in (testdict1 | testdict2).items():
assert mockconfig[k].get(v) == v
Expand Down

0 comments on commit 092069a

Please sign in to comment.