Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract single state runner #182

Merged
merged 24 commits into from
Jun 28, 2024
Merged

Conversation

arik-shurygin
Copy link
Collaborator

Introducing the abstract_azure_runner a light weight framework to abstract away a lot of the annoying parts of running an azure job.

The AbstractAzureRunner is a class which contains 4 functions, but only one of which needs to be overiden by the user:

  • __init__(azure_output_dir): sets up azure dual logging so users are able to see stdout and stderr on their azure node, but also saves a txt version of the logs to the output directory.
  • save_inference_posteriors(inferer, save_filename) saves inference chains from inferer in a standard format for future visualization.
  • save_inference_timelines(inferer, timeline_filename, particles_saved) picks some number (particles_saved) of particles from the inference process, runs those particles, and saves key metrics like vaccination, predicted/observed hospitalization, and many others.
  • save_static_run_timelines(parameters, sol, timeline_filename) saves the same metrics as save_inference_timelines but for a static run of the model.
  • process_state(self, state): filled in by the user, this is where you would slot in your existing azure code. The user is encouraged to call the other functions above from process_state to easily save model outputs so visualizer scripts can pull azure output down after the run finishes!

On a tiny sidenote, I added a useful type hint called SEIC_compartments to describe the tuple structure that commonly represents our 4 compartment model. This is why this PR impacts so many files, but many changes are quite minor

Copy link
Collaborator

@kokbent kokbent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See specific comments

"HOSP_PATH": "/input/data/hospital_220213_220108.csv",
"VAX_MODEL_DATA": "/input/data/spline_fits.csv",
"VAX_MODEL_NUM_KNOTS": 18,
"HOSPITALIZATION_DATA_PATH": "/input/data/hospital_220213_220108.csv",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't actually used this before... wonder if we should just remove it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will keep for now because we will eventually want to implement this rather than hardcoding in the runner like we did during SMH

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. It's up to the person designing the inferer (likelihood and whatnot) to decide what's being pass into it... Just thought it being there is more confusing than useful

)
df["seasonality_coef"] = seasonality_timeline
# save external introductions timeline of shape (num_days_predicted, num_strains)
# sum across age groups since we do % of each age bin anyways
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did do age specific introduction though... but it probably is not a big issue that we want age specific visualization for now...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we only introduce our working age population, if we start to introduce multiple age groups that plot would be useful. For now this works though.

df["total_infection_incidence"] = np.insert(
infection_incidence, [0], [0]
)
strain_proportions, _ = utils.get_timeline_from_solution_with_command(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not wrong the method for strain prevalence in this function is quite slow, be great to change it to a faster one (see https://github.com/cdcent/cfa-scenarios-model/blob/dbb3f615e2e6f7fbeebc9a67b5693b8cd8a9c6a9/exp/fifty_state_2202_2307_3strain/postaz_process.py#L359)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Referenced in a separate issue to be addressed later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None
"""
# if inference complete, convert jnp/np arrays to list, then json dump
if inferer.infer_complete:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should raise a warning or error here if it's not complete. Most likely something gone haywire already at this point 🤔

Copy link
Collaborator Author

@arik-shurygin arik-shurygin Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont currently raise a warning, we just do nothing if they are trying to save inference posteriors before fitting. I will modify to warn though

mechanistic_model/abstract_azure_runner.py Outdated Show resolved Hide resolved
timeline_filename,
)
all_particles_df = pd.DataFrame()
for particle in range(particles_saved):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reads like you can only run particle 0, 1, 2, 3, up to particles_saved, instead of random sets of particles?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, fixing this to pick random particles

@@ -971,7 +971,9 @@ def drop_sample_chains(samples: dict, dropped_chain_vals: list):
return filtered_dict


def flatten_list_parameters(samples):
def flatten_list_parameters(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the sorcery here, rubber stamping this for now lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks nasty but it basically flattens things that use numpyro.plate into 2 dimensions (chain, particle)

@arik-shurygin arik-shurygin requested a review from kokbent June 28, 2024 19:34
):
# select this random particle for each of our chains
chain_particle_pairs = [
(particle, chain)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(chain, particle) I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right sorry, fixed. Thanks for the catch

Copy link
Collaborator

@kokbent kokbent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lglglgtm

@arik-shurygin arik-shurygin merged commit 3b2fcab into main Jun 28, 2024
2 checks passed
@arik-shurygin arik-shurygin deleted the abstract-single-state-runner branch June 28, 2024 20:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants