diff --git a/atomica/optimization.py b/atomica/optimization.py index f43319d0..a173be90 100644 --- a/atomica/optimization.py +++ b/atomica/optimization.py @@ -17,7 +17,7 @@ import sciris as sc from .cascade import get_cascade_vals -from .model import Model, Link +from .model import Model, Link, run_model from .parameters import ParameterSet from .programs import ProgramSet, ProgramInstructions from .results import Result @@ -1074,7 +1074,7 @@ def constrain_instructions(self, instructions: ProgramInstructions, hard_constra x0 = sc.odict() # Order matters here lb = [] ub = [] - progs = hard_constraints["programs"][t] # Programs eligible for constraining at this time + progs = sorted(hard_constraints["programs"][t]) # Programs eligible for constraining at this time for prog in progs: if prog not in instructions.alloc: @@ -1364,7 +1364,66 @@ def _objective_fcn(x, pickled_model, optimization, hard_constraints: list, basel return obj_val -def optimize(project, optimization, parset: ParameterSet, progset: ProgramSet, instructions: ProgramInstructions, x0=None, xmin=None, xmax=None, hard_constraints=None, baselines=None, optim_args: dict = None): +def _calc_skippable_year(optimization, framework): + allowed_adjustments = (SpendingAdjustment, ) + allowed_measurables = (MinimizeMeasurable, MaximizeMeasurable, AtMostMeasurable, AtLeastMeasurable, IncreaseByMeasurable, DecreaseByMeasurable, MaximizeCascadeStage, MaximizeCascadeConversionRate) + # type not isinstance because subclasses might have different behaviour + if any(type(adjustment) not in allowed_adjustments for adjustment in optimization.adjustments): + print(f'Could not skip start because we have unknown adjustments: {[adjustment.name for adjustment in optimization.adjustments if type(adjustment) not in allowed_adjustments]}') + return False + if any(type(measurable) not in allowed_measurables for measurable in optimization.measurables): + print(f'Could not skip start because we have non-default measurables: {[measurable.measurable_name for measurable in optimization.measurables if type(measurable) not in allowed_measurables]}') + return False + if any(framework.pars.at[par_name, "is derivative"] == "y" for par_name in list(framework.pars.index)): + print('Could not skip start because we have derivative pars:', [par_name for par_name in list(framework.pars.index) if framework.pars.at[par_name, "is derivative"] == "y"]) + return False + + skip_start_year_adj = min(min(sc.promotetoarray(spending_adjustment.t)) for spending_adjustment in optimization.adjustments) + skip_start_year_meas = min(min(sc.promotetoarray(measurable.t)) for measurable in optimization.measurables) + skip_start_year = min(skip_start_year_adj, skip_start_year_meas) + + if not np.isfinite(skip_start_year): + skip_start_year = False + + return skip_start_year + + +def _make_skip_model(optimization, project, parset, progset, instructions, skip_start): + if skip_start == False: + return None + + possible_skip_year = _calc_skippable_year(optimization, project.framework) + + if type(skip_start) == bool: # skip_start == True, because checked False above + if not possible_skip_year: + raise Exception('skip_start == True, but do not know which year it is possible to skip until, please set skip_start = year you want to skip to') + skip_start = possible_skip_year + + if skip_start is None and not possible_skip_year: # We wanted to skip if possible but it is not known to be possible + return None + + if skip_start is None: # We have a valid possible_skip_year to fill in to skip_start + skip_start = possible_skip_year + + # skip_start is now a float or int hopefully + assert float(skip_start) > project.settings.sim_start, f'skip_start ({skip_start}) should be a number larger than the original starting year ({project.settings.sim_start})' + skip_start = float(skip_start) + + if skip_start > possible_skip_year: + logger.warning(f"skip_start was set to {skip_start} but the calculated safe year to start at is {possible_skip_year} - so the optimization might not give accurate results") + + logger.info(f"Skipping until year {skip_start}") + + tmp_settings, tmp_parset = sc.dcp(project.settings), sc.dcp(parset) + + res = run_model(settings=tmp_settings, framework=project.framework, parset=tmp_parset, progset=progset, program_instructions=instructions) + tmp_parset.set_initialization(res, year=skip_start) + tmp_settings.update_time_vector(start=skip_start) + + model = Model(tmp_settings, project.framework, tmp_parset, progset, instructions) + return model + +def optimize(project, optimization, parset: ParameterSet, progset: ProgramSet, instructions: ProgramInstructions, x0=None, xmin=None, xmax=None, hard_constraints=None, baselines=None, optim_args: dict = None, skip_start=None): """ Main user entry point for optimization @@ -1386,13 +1445,19 @@ def optimize(project, optimization, parset: ParameterSet, progset: ProgramSet, i :param hard_constraints: Not for manual use - override hard constraints :param baselines: Not for manual use - override Measurable baseline values (for relative Measurables) :param optim_args: Pass a dictionary of keyword arguments to pass to the optimization algorithm (set in ``optimization.method``) + :param skip_start: True or False or a year to start the optimization runs at (everything before that year will be kept the same) :return: A :class:`ProgramInstructions` instance representing optimal instructions """ assert optimization.method in ["asd", "pso", "hyperopt"] - model = Model(project.settings, project.framework, parset, progset, instructions) + model = None + if skip_start is None or skip_start: + model = _make_skip_model(optimization, project, parset, progset, instructions, skip_start) + if model is None: + model = Model(project.settings, project.framework, parset, progset, instructions) + pickled_model = pickle.dumps(model) # Unpickling effectively makes a deep copy, so this _should_ be faster initialization = optimization.get_initialization(progset, model.program_instructions) @@ -1526,6 +1591,12 @@ def constrain_sum_bounded(x: np.array, s: float, lb: np.array, ub: np.array) -> """ tolerance = 1e-6 + sort_inds = np.argsort(x) + orig_inds = np.argsort(sort_inds) + x = np.array(x) [sort_inds] + lb = np.array(lb)[sort_inds] + ub = np.array(ub)[sort_inds] + # Normalize values x0_scaled = x / (x.sum() or 1) # Normalize the initial values, unless they sum to 0 (i.e., they are all zero) lb_scaled = lb / s @@ -1534,7 +1605,7 @@ def constrain_sum_bounded(x: np.array, s: float, lb: np.array, ub: np.array) -> # First, check if the constraint is already satisfied just by multiplicative rescaling # The final check for x0_scaled.sum()==1 catches the case where all of the input values are 0 if np.all((x0_scaled >= lb_scaled) & (x0_scaled <= ub_scaled)) and np.isclose(x0_scaled.sum(), 1): - return x0_scaled * s + return x0_scaled[orig_inds] * s # If not, we need to actually run the constrained optimization bounds = [(lower, upper) for lower, upper in zip(lb_scaled, ub_scaled)] @@ -1557,4 +1628,6 @@ def jacfcn(x): # Enforce upper/lower bound constraints to prevent numerically exceeding them sol = np.minimum(np.maximum(res["x"], lb_scaled), ub_scaled) * s assert np.isclose(sol.sum(), s), f"FAILED as {sol} has a total of {sol.sum()} which is not sufficiently close to the target value {s}" + + sol = sol[orig_inds] return sol diff --git a/tests/test_tox_optimization.py b/tests/test_tox_optimization.py index d2943628..a7620f95 100644 --- a/tests/test_tox_optimization.py +++ b/tests/test_tox_optimization.py @@ -7,6 +7,7 @@ import sciris as sc import atomica as at import logging +import pickle logger = logging.getLogger() @@ -670,6 +671,38 @@ def test_package_all_fixed(): # assert np.isclose(optimized_result.model.program_instructions.alloc["Treatment 1"].get(2020), 10) +## Test that skipping until the first adjustable or measurable gives the exact same solution as when we don't skip +def test_skip_start(): + seed = 1 + + measures = ['dx', 'tx', 'sus'] + + for test in ['hypertension']: # at.demos.options: + P = at.demo(which=test, do_run=False) + P.update_settings(sim_end=2030.0) + + instructions = at.ProgramInstructions(alloc=P.progset(), start_year=2020) # Instructions for default spending + + adjustments = list() + for prog in P.progset().programs.keys(): + adjustments.append(at.SpendingAdjustment(prog, 2020, "rel", 0.0, 2.0)) + + for measure in measures: + if measure in P.framework.comps.index: break + if measure not in P.framework.comps.index: raise Exception(f'Could not find a measurable for project {P.name} comps {measure} {P.framework.comps.index}') + measurables = at.MaximizeMeasurable(measure, [2019, np.inf]) + + constraints = at.TotalSpendConstraint() # Cap total spending in all years + + optimization = at.Optimization(name="default", adjustments=adjustments, measurables=measurables, constraints=constraints) # Evaluate from 2020 to end of simulation + + instructions_start = at.optimize(P, optimization, parset=P.parset(), progset=P.progset(), instructions=instructions, skip_start=False, optim_args={'randseed':seed}) + instructions_skip = at.optimize(P, optimization, parset=P.parset(), progset=P.progset(), instructions=instructions, skip_start=None, optim_args={'randseed':seed}) + + assert pickle.dumps(instructions_start) == pickle.dumps(instructions_skip), f'Optimization skipping part of the years did not work: {test}' + print(f'Correctly gave the same optimization when skipping through until the first measurable or adjustable: {test}\n') + + if __name__ == "__main__": test_standard() test_unresolvable() @@ -690,3 +723,4 @@ def test_package_all_fixed(): test_package_variable() test_package_fixed_prop() test_package_all_fixed() + test_skip_start() \ No newline at end of file diff --git a/tests/text_tox_timed_initialization.py b/tests/text_tox_timed_initialization.py index 66782372..0d764e21 100644 --- a/tests/text_tox_timed_initialization.py +++ b/tests/text_tox_timed_initialization.py @@ -71,5 +71,52 @@ def set_initialization_basic(F, D, year, y_factor=True): d = at.PlotData([res1, res2, res3, res4], ["sus", "inf", "rec"]) at.plot_series(d) +def test_halfway_run(): + P = at.demo("tb", do_run=False) + + start, middle, end = P.settings.sim_start, 2019.0, 2030.0 + prog_start = 2018 + dt = P.settings.sim_dt + + P.settings.update_time_vector(start=start, end=end, dt=dt) + middle_ind = sc.findlast(P.settings.tvec, middle) + + parset1 = P.parset() + progset = P.progset() + + # res = P.run_sim(parset=parset1, progset=progset) + # parset1.set_initialization(res, year=middle) + # parset1 now setup to test + + instructions = at.ProgramInstructions(start_year=prog_start, alloc=progset) + kwargs = dict(parset=parset1, progset=progset, progset_instructions=instructions) + + res_orig = P.run_sim(**kwargs) + + parset1.set_initialization(res_orig, year=middle) + P.settings.update_time_vector(start=middle, end=end, dt=dt) + + res_half = P.run_sim(**kwargs) + + all_equal = True + + for pop in parset1.pop_names: + comps = res_orig.comp_names(pop) + characs = res_orig.charac_names(pop) + pars = res_orig.par_names(pop) + + equal = {out: all(res_orig.get_variable(out, pop)[0].vals[middle_ind:] == res_half.get_variable(out, pop)[0].vals) + for out in comps + characs + pars} + + # print([(out, sum(res_half.get_variable(out, pop)[0].vals)) for out in comps + characs + pars]) + + all_equal = all_equal and all(equal.values()) + if not all_equal: + raise Exception(f'Project "{P.name}" pop "{pop}" start, middle, end, dt: {start, middle, end, dt}, equal: {equal}') + + print(f'Project: {P.name}: Success running from halfway through!') + + if __name__ == '__main__': - test_timed_initialization() \ No newline at end of file + test_timed_initialization() + test_halfway_run() \ No newline at end of file