Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip start of optimization until the first measurable or adjustable #505

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions atomica/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import sciris as sc
from .cascade import get_cascade_vals
from .model import Model, Link
from .model import Model, Link, run_model
from .parameters import ParameterSet
from .programs import ProgramSet, ProgramInstructions
from .results import Result
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def constrain_instructions(self, instructions: ProgramInstructions, hard_constra
x0 = sc.odict() # Order matters here
lb = []
ub = []
progs = hard_constraints["programs"][t] # Programs eligible for constraining at this time
progs = sorted(hard_constraints["programs"][t]) # Programs eligible for constraining at this time

for prog in progs:
if prog not in instructions.alloc:
Expand Down Expand Up @@ -1364,7 +1364,66 @@ def _objective_fcn(x, pickled_model, optimization, hard_constraints: list, basel
return obj_val


def optimize(project, optimization, parset: ParameterSet, progset: ProgramSet, instructions: ProgramInstructions, x0=None, xmin=None, xmax=None, hard_constraints=None, baselines=None, optim_args: dict = None):
def _calc_skippable_year(optimization, framework):
allowed_adjustments = (SpendingAdjustment, )
allowed_measurables = (MinimizeMeasurable, MaximizeMeasurable, AtMostMeasurable, AtLeastMeasurable, IncreaseByMeasurable, DecreaseByMeasurable, MaximizeCascadeStage, MaximizeCascadeConversionRate)
# type not isinstance because subclasses might have different behaviour
if any(type(adjustment) not in allowed_adjustments for adjustment in optimization.adjustments):
print(f'Could not skip start because we have unknown adjustments: {[adjustment.name for adjustment in optimization.adjustments if type(adjustment) not in allowed_adjustments]}')
return False
if any(type(measurable) not in allowed_measurables for measurable in optimization.measurables):
print(f'Could not skip start because we have non-default measurables: {[measurable.measurable_name for measurable in optimization.measurables if type(measurable) not in allowed_measurables]}')
return False
if any(framework.pars.at[par_name, "is derivative"] == "y" for par_name in list(framework.pars.index)):
print('Could not skip start because we have derivative pars:', [par_name for par_name in list(framework.pars.index) if framework.pars.at[par_name, "is derivative"] == "y"])
return False

skip_start_year_adj = min(min(sc.promotetoarray(spending_adjustment.t)) for spending_adjustment in optimization.adjustments)
skip_start_year_meas = min(min(sc.promotetoarray(measurable.t)) for measurable in optimization.measurables)
skip_start_year = min(skip_start_year_adj, skip_start_year_meas)

if not np.isfinite(skip_start_year):
skip_start_year = False

return skip_start_year


def _make_skip_model(optimization, project, parset, progset, instructions, skip_start):
if skip_start == False:
return None

possible_skip_year = _calc_skippable_year(optimization, project.framework)

if type(skip_start) == bool: # skip_start == True, because checked False above
if not possible_skip_year:
raise Exception('skip_start == True, but do not know which year it is possible to skip until, please set skip_start = year you want to skip to')
skip_start = possible_skip_year

if skip_start is None and not possible_skip_year: # We wanted to skip if possible but it is not known to be possible
return None

if skip_start is None: # We have a valid possible_skip_year to fill in to skip_start
skip_start = possible_skip_year

# skip_start is now a float or int hopefully
assert float(skip_start) > project.settings.sim_start, f'skip_start ({skip_start}) should be a number larger than the original starting year ({project.settings.sim_start})'
skip_start = float(skip_start)

if skip_start > possible_skip_year:
logger.warning(f"skip_start was set to {skip_start} but the calculated safe year to start at is {possible_skip_year} - so the optimization might not give accurate results")

logger.info(f"Skipping until year {skip_start}")

tmp_settings, tmp_parset = sc.dcp(project.settings), sc.dcp(parset)

res = run_model(settings=tmp_settings, framework=project.framework, parset=tmp_parset, progset=progset, program_instructions=instructions)
tmp_parset.set_initialization(res, year=skip_start)
tmp_settings.update_time_vector(start=skip_start)

model = Model(tmp_settings, project.framework, tmp_parset, progset, instructions)
return model

def optimize(project, optimization, parset: ParameterSet, progset: ProgramSet, instructions: ProgramInstructions, x0=None, xmin=None, xmax=None, hard_constraints=None, baselines=None, optim_args: dict = None, skip_start=None):
"""
Main user entry point for optimization

Expand All @@ -1386,13 +1445,19 @@ def optimize(project, optimization, parset: ParameterSet, progset: ProgramSet, i
:param hard_constraints: Not for manual use - override hard constraints
:param baselines: Not for manual use - override Measurable baseline values (for relative Measurables)
:param optim_args: Pass a dictionary of keyword arguments to pass to the optimization algorithm (set in ``optimization.method``)
:param skip_start: True or False or a year to start the optimization runs at (everything before that year will be kept the same)
:return: A :class:`ProgramInstructions` instance representing optimal instructions

"""

assert optimization.method in ["asd", "pso", "hyperopt"]

model = Model(project.settings, project.framework, parset, progset, instructions)
model = None
if skip_start is None or skip_start:
model = _make_skip_model(optimization, project, parset, progset, instructions, skip_start)
if model is None:
model = Model(project.settings, project.framework, parset, progset, instructions)

pickled_model = pickle.dumps(model) # Unpickling effectively makes a deep copy, so this _should_ be faster

initialization = optimization.get_initialization(progset, model.program_instructions)
Expand Down Expand Up @@ -1526,6 +1591,12 @@ def constrain_sum_bounded(x: np.array, s: float, lb: np.array, ub: np.array) ->
"""
tolerance = 1e-6

sort_inds = np.argsort(x)
orig_inds = np.argsort(sort_inds)
x = np.array(x) [sort_inds]
lb = np.array(lb)[sort_inds]
ub = np.array(ub)[sort_inds]

# Normalize values
x0_scaled = x / (x.sum() or 1) # Normalize the initial values, unless they sum to 0 (i.e., they are all zero)
lb_scaled = lb / s
Expand All @@ -1534,7 +1605,7 @@ def constrain_sum_bounded(x: np.array, s: float, lb: np.array, ub: np.array) ->
# First, check if the constraint is already satisfied just by multiplicative rescaling
# The final check for x0_scaled.sum()==1 catches the case where all of the input values are 0
if np.all((x0_scaled >= lb_scaled) & (x0_scaled <= ub_scaled)) and np.isclose(x0_scaled.sum(), 1):
return x0_scaled * s
return x0_scaled[orig_inds] * s

# If not, we need to actually run the constrained optimization
bounds = [(lower, upper) for lower, upper in zip(lb_scaled, ub_scaled)]
Expand All @@ -1557,4 +1628,6 @@ def jacfcn(x):
# Enforce upper/lower bound constraints to prevent numerically exceeding them
sol = np.minimum(np.maximum(res["x"], lb_scaled), ub_scaled) * s
assert np.isclose(sol.sum(), s), f"FAILED as {sol} has a total of {sol.sum()} which is not sufficiently close to the target value {s}"

sol = sol[orig_inds]
return sol
34 changes: 34 additions & 0 deletions tests/test_tox_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sciris as sc
import atomica as at
import logging
import pickle

logger = logging.getLogger()

Expand Down Expand Up @@ -670,6 +671,38 @@ def test_package_all_fixed():
# assert np.isclose(optimized_result.model.program_instructions.alloc["Treatment 1"].get(2020), 10)


## Test that skipping until the first adjustable or measurable gives the exact same solution as when we don't skip
def test_skip_start():
seed = 1

measures = ['dx', 'tx', 'sus']

for test in ['hypertension']: # at.demos.options:
P = at.demo(which=test, do_run=False)
P.update_settings(sim_end=2030.0)

instructions = at.ProgramInstructions(alloc=P.progset(), start_year=2020) # Instructions for default spending

adjustments = list()
for prog in P.progset().programs.keys():
adjustments.append(at.SpendingAdjustment(prog, 2020, "rel", 0.0, 2.0))

for measure in measures:
if measure in P.framework.comps.index: break
if measure not in P.framework.comps.index: raise Exception(f'Could not find a measurable for project {P.name} comps {measure} {P.framework.comps.index}')
measurables = at.MaximizeMeasurable(measure, [2019, np.inf])

constraints = at.TotalSpendConstraint() # Cap total spending in all years

optimization = at.Optimization(name="default", adjustments=adjustments, measurables=measurables, constraints=constraints) # Evaluate from 2020 to end of simulation

instructions_start = at.optimize(P, optimization, parset=P.parset(), progset=P.progset(), instructions=instructions, skip_start=False, optim_args={'randseed':seed})
instructions_skip = at.optimize(P, optimization, parset=P.parset(), progset=P.progset(), instructions=instructions, skip_start=None, optim_args={'randseed':seed})

assert pickle.dumps(instructions_start) == pickle.dumps(instructions_skip), f'Optimization skipping part of the years did not work: {test}'
print(f'Correctly gave the same optimization when skipping through until the first measurable or adjustable: {test}\n')


if __name__ == "__main__":
test_standard()
test_unresolvable()
Expand All @@ -690,3 +723,4 @@ def test_package_all_fixed():
test_package_variable()
test_package_fixed_prop()
test_package_all_fixed()
test_skip_start()
49 changes: 48 additions & 1 deletion tests/text_tox_timed_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,52 @@ def set_initialization_basic(F, D, year, y_factor=True):
d = at.PlotData([res1, res2, res3, res4], ["sus", "inf", "rec"])
at.plot_series(d)

def test_halfway_run():
P = at.demo("tb", do_run=False)

start, middle, end = P.settings.sim_start, 2019.0, 2030.0
prog_start = 2018
dt = P.settings.sim_dt

P.settings.update_time_vector(start=start, end=end, dt=dt)
middle_ind = sc.findlast(P.settings.tvec, middle)

parset1 = P.parset()
progset = P.progset()

# res = P.run_sim(parset=parset1, progset=progset)
# parset1.set_initialization(res, year=middle)
# parset1 now setup to test

instructions = at.ProgramInstructions(start_year=prog_start, alloc=progset)
kwargs = dict(parset=parset1, progset=progset, progset_instructions=instructions)

res_orig = P.run_sim(**kwargs)

parset1.set_initialization(res_orig, year=middle)
P.settings.update_time_vector(start=middle, end=end, dt=dt)

res_half = P.run_sim(**kwargs)

all_equal = True

for pop in parset1.pop_names:
comps = res_orig.comp_names(pop)
characs = res_orig.charac_names(pop)
pars = res_orig.par_names(pop)

equal = {out: all(res_orig.get_variable(out, pop)[0].vals[middle_ind:] == res_half.get_variable(out, pop)[0].vals)
for out in comps + characs + pars}

# print([(out, sum(res_half.get_variable(out, pop)[0].vals)) for out in comps + characs + pars])

all_equal = all_equal and all(equal.values())
if not all_equal:
raise Exception(f'Project "{P.name}" pop "{pop}" start, middle, end, dt: {start, middle, end, dt}, equal: {equal}')

print(f'Project: {P.name}: Success running from halfway through!')


if __name__ == '__main__':
test_timed_initialization()
test_timed_initialization()
test_halfway_run()
Loading