Skip to content

Commit

Permalink
Call positional args with names
Browse files Browse the repository at this point in the history
* Per @pearsonca, @jcblemai suggestions call positional arguements with
  names.
* Restyle call to `get_seeding_data` in `gempyor.seir.onerun_SEIR`.
  • Loading branch information
TimothyWillard committed Jan 8, 2025
1 parent 692865f commit 085beca
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 29 deletions.
2 changes: 1 addition & 1 deletion flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
out_prefix=prefix,
)

seeding_data = modinf.get_seeding_data(100)
seeding_data = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf)

mobility_subpop_indices = modinf.mobility.indices
Expand Down
4 changes: 2 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_static_arguments(modinf: model_info.ModelInfo):
)

initial_conditions = modinf.initial_conditions.get_from_config(sim_id=0, modinf=modinf)
seeding_data, seeding_amounts = modinf.get_seeding_data(0)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=0)

# reduce them
parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir)
Expand Down Expand Up @@ -673,7 +673,7 @@ def one_simulation(

with Timer("onerun_SEIR.seeding"):
seeding_data, seeding_amounts = self.modinf.get_seeding_data(
sim_id2load if load_ID else sim_id2write
sim_id=sim_id2load if load_ID else sim_id2write
)
if load_ID:
initial_conditions = self.modinf.initial_conditions.get_from_file(
Expand Down
20 changes: 10 additions & 10 deletions flepimop/gempyor_pkg/src/gempyor/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,10 @@ def get_filename(
self, ftype: str, sim_id: int, input: bool, extension_override: str = ""
):
return self.path_prefix / file_paths.create_file_name(
self.in_run_id if input else self.out_run_id,
self.in_prefix if input else self.out_prefix,
sim_id + self.first_sim_index - 1,
ftype,
run_id=self.in_run_id if input else self.out_run_id,
prefix=self.in_prefix if input else self.out_prefix,
index=sim_id + self.first_sim_index - 1,
ftype=ftype,
extension=extension_override if extension_override else self.extension,
inference_filepath_suffix=self.inference_filepath_suffix,
inference_filename_prefix=self.inference_filename_prefix,
Expand Down Expand Up @@ -368,12 +368,12 @@ def get_seeding_data(self, sim_id: int) -> tuple[nb.typed.Dict, npt.NDArray[np.n
`gempyor.seeding.Seeding.get_from_config`
"""
return self.seeding.get_from_config(
self.compartments,
self.subpop_struct,
self.n_days,
self.ti,
self.tf,
(
compartments=self.compartments,
subpop_struct=self.subpop_struct,
n_days=self.n_days,
ti=self.ti,
tf=self.tf,
input_filename=(
None
if self.seeding_config is None
else self.get_input_filename(
Expand Down
5 changes: 3 additions & 2 deletions flepimop/gempyor_pkg/src/gempyor/seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,13 @@ def onerun_SEIR(
initial_conditions = modinf.initial_conditions.get_from_file(
sim_id2load, modinf=modinf
)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id2load)
else:
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id2write, modinf=modinf
)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id2write)
seeding_data, seeding_amounts = modinf.get_seeding_data(
sim_id=sim_id2load if load_ID else sim_id2write
)

with Timer("onerun_SEIR.parameters"):
# Draw or load parameters
Expand Down
12 changes: 6 additions & 6 deletions flepimop/gempyor_pkg/tests/seir/test_seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def test_Seeding_draw_success(self):
s.seeding_config["method"] = "NoSeeding"

seeding_result = sic.get_from_config(
s.compartments,
s.subpop_struct,
s.n_days,
s.ti,
s.tf,
s.get_input_filename(
compartments=s.compartments,
subpop_struct=s.subpop_struct,
n_days=s.n_days,
ti=s.ti,
tf=s.tf,
input_filename=s.get_input_filename(
ftype=s.seeding_config["seeding_file_type"].get(),
sim_id=0,
extension_override="csv",
Expand Down
16 changes: 8 additions & 8 deletions flepimop/gempyor_pkg/tests/seir/test_seir.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_constant_population_legacy_integration():
)
integration_method = "legacy"

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_constant_population_rk4jit_integration_fail():
)
modinf.seir_config["integration"]["method"] = "rk4.jit"

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_constant_population_rk4jit_integration():
# s.integration_method = "rk4.jit"
assert modinf.seir_config["integration"]["method"].get() == "rk4"

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -304,7 +304,7 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices():
out_prefix=prefix,
)

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices():
out_prefix=prefix,
)

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -490,7 +490,7 @@ def test_steps_SEIR_no_spread():
out_prefix=prefix,
)

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -767,7 +767,7 @@ def test_parallel_compartments_with_vacc():
out_prefix=prefix,
)

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down Expand Up @@ -861,7 +861,7 @@ def test_parallel_compartments_no_vacc():
out_prefix=prefix,
)

seeding_data, seeding_amounts = modinf.get_seeding_data(100)
seeding_data, seeding_amounts = modinf.get_seeding_data(sim_id=100)
initial_conditions = modinf.initial_conditions.get_from_config(
sim_id=100, modinf=modinf
)
Expand Down

0 comments on commit 085beca

Please sign in to comment.