diff --git a/docs/source/Tutorials/two_period_model_tutorial.ipynb b/docs/source/Tutorials/two_period_model_tutorial.ipynb index e58b14bf..27cc8f28 100644 --- a/docs/source/Tutorials/two_period_model_tutorial.ipynb +++ b/docs/source/Tutorials/two_period_model_tutorial.ipynb @@ -296,7 +296,7 @@ "source": [ "#### State space functions\n", "\n", - "Next we define state space functions ```create_state_space``` and ```get_state_specific_choice_set```. They can be directly imported the dcegm package, but we display them here, in order to explain how they work." + "Next we define state space functions ```create_state_space``` and ```state_specific_choice_set```. They can be directly imported the dcegm package, but we display them here, in order to explain how they work." ] }, { @@ -402,7 +402,7 @@ "metadata": {}, "outputs": [], "source": [ - "def get_state_specific_choice_set(\n", + "def state_specific_choice_set(\n", " state: np.ndarray,\n", " state_space: np.ndarray, # noqa: U100\n", " indexer: np.ndarray,\n", @@ -460,7 +460,7 @@ } ], "source": [ - "choice_set_5 = get_state_specific_choice_set(_state_space[4], _state_space, _indexer)\n", + "choice_set_5 = state_specific_choice_set(_state_space[4], _state_space, _indexer)\n", "choice_set_5" ] }, @@ -479,7 +479,7 @@ "source": [ "state_space_functions = {\n", " \"create_state_space\": create_state_space,\n", - " \"get_state_specific_choice_set\": get_state_specific_choice_set,\n", + " \"state_specific_choice_set\": state_specific_choice_set,\n", "}" ] }, diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index 538a5f66..9beeb68b 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -18,7 +18,7 @@ def solve_last_two_periods( income_shock_weights: jnp.ndarray, exog_grids: Dict[str, jnp.ndarray], model_funcs: Dict[str, Callable], - batch_info, + last_two_period_batch_info, value_solved, policy_solved, endog_grid_solved, @@ -39,7 +39,7 @@ def solve_last_two_periods( compute_utility (callable): User supplied utility function. compute_marginal_utility (callable): User supplied marginal utility function. - batch_info (dict): Dictionary containing information about the batch + last_two_period_batch_info (dict): Dictionary containing information about the batch size and the state space. value_solved (np.ndarray): 3d array of shape (n_states, n_grid_wealth, n_income_shocks) of the value function for @@ -56,9 +56,15 @@ def solve_last_two_periods( value_interp_final_period, marginal_utility_final_last_period, ) = solve_final_period( - idx_state_choices_final_period=batch_info["idx_state_choices_final_period"], - idx_parent_states_final_period=batch_info["idxs_parent_states_final_period"], - state_choice_mat_final_period=batch_info["state_choice_mat_final_period"], + idx_state_choices_final_period=last_two_period_batch_info[ + "idx_state_choices_final_period" + ], + idx_parent_states_final_period=last_two_period_batch_info[ + "idxs_parent_states_final_period" + ], + state_choice_mat_final_period=last_two_period_batch_info[ + "state_choice_mat_final_period" + ], cont_grids_next_period=cont_grids_next_period, exog_grids=exog_grids, params=params, @@ -72,9 +78,13 @@ def solve_last_two_periods( endog_grid, policy, value = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, - state_choice_mat=batch_info["state_choice_mat_second_last_period"], - child_state_idxs=batch_info["child_states_second_last_period"], - states_to_choices_child_states=batch_info["state_to_choices_final_period"], + state_choice_mat=last_two_period_batch_info[ + "state_choice_mat_second_last_period" + ], + child_state_idxs=last_two_period_batch_info["child_states_second_last_period"], + states_to_choices_child_states=last_two_period_batch_info[ + "state_to_choices_final_period" + ], params=params, taste_shock_scale=taste_shock_scale, income_shock_weights=income_shock_weights, @@ -83,7 +93,7 @@ def solve_last_two_periods( has_second_continuous_state=has_second_continuous_state, ) - idx_second_last = batch_info["idx_state_choices_second_last_period"] + idx_second_last = last_two_period_batch_info["idx_state_choices_second_last_period"] value_solved = value_solved.at[idx_second_last, ...].set(value) policy_solved = policy_solved.at[idx_second_last, ...].set(policy) diff --git a/src/dcegm/interface.py b/src/dcegm/interface.py index d8148e7d..d16266dc 100644 --- a/src/dcegm/interface.py +++ b/src/dcegm/interface.py @@ -5,6 +5,11 @@ interp_policy_on_wealth, interp_value_on_wealth, ) +from dcegm.interpolation.interp2d import ( + interp2d_policy_and_value_on_wealth_and_regular_grid, + interp2d_policy_on_wealth_and_regular_grid, + interp2d_value_on_wealth_and_regular_grid, +) def policy_and_value_for_state_choice_vec( @@ -47,14 +52,13 @@ def policy_and_value_for_state_choice_vec( def value_for_state_choice_vec( - state_choice_vec, - wealth, - map_state_choice_to_index, - discrete_states_names, endog_grid_solved, value_solved, - compute_utility, params, + model, + state_choice_vec, + wealth, + second_continous=None, ): """Get policy and value for a given state and choice vector. @@ -67,20 +71,36 @@ def value_for_state_choice_vec( Tuple[float, float]: Policy and value for the given state and choice vector. """ + map_state_choice_to_index = model["model_structure"]["map_state_choice_to_index"] + discrete_states_names = model["model_structure"]["discrete_states_names"] + compute_utility = model["model_funcs"]["compute_utility"] + state_choice_tuple = tuple( state_choice_vec[st] for st in discrete_states_names + ["choice"] ) state_choice_index = map_state_choice_to_index[state_choice_tuple] - value = interp_value_on_wealth( - wealth=wealth, - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value=jnp.take(value_solved, state_choice_index, axis=0), - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - ) + if second_continous is None: + value = interp_value_on_wealth( + wealth=wealth, + endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), + value=jnp.take(value_solved, state_choice_index, axis=0), + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + ) + else: + value = interp2d_value_on_wealth_and_regular_grid( + regular_grid=model["options"]["exog_grids"]["second_continuous"], + wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), + value_grid=jnp.take(value_solved, state_choice_index, axis=0), + regular_point_to_interp=second_continous, + wealth_point_to_interp=wealth, + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + ) return value diff --git a/src/dcegm/law_of_motion.py b/src/dcegm/law_of_motion.py index e7ade826..de4f2014 100644 --- a/src/dcegm/law_of_motion.py +++ b/src/dcegm/law_of_motion.py @@ -14,7 +14,7 @@ def calc_cont_grids_next_period( discrete_states_beginning_of_period=state_space_dict, continuous_grid=exog_grids["second_continuous"], params=params, - compute_continuous_state=model_funcs["update_continuous_state"], + compute_continuous_state=model_funcs["next_period_continuous_state"], ) # Extra dimension for continuous state diff --git a/src/dcegm/pre_processing/batches.py b/src/dcegm/pre_processing/batches.py deleted file mode 100644 index f5170a46..00000000 --- a/src/dcegm/pre_processing/batches.py +++ /dev/null @@ -1,458 +0,0 @@ -import numpy as np - - -def create_batches_and_information( - model_structure, - options, -): - """Batches are used instead of periods to have chunks of equal sized state choices. - The batch information dictionary contains the following arrays reflecting the. - - steps in the backward induction: - - batches_state_choice_idx: The state choice indexes in each batch to be solved. - To solve the state choices in the egm step, we have to look at the child states - and the corresponding state choice indexes in the child states. For that we save - the following: - - child_state_choice_idxs_to_interp: The state choice indexes in we need to - interpolate the wealth on. - - child_states_idxs: The parent state indexes of the child states, i.e. the - child states themself. We calculate the resources at the beginning of - period before the backwards induction with the budget equation for each - saving and income shock grid point. - - Note: These two index arrays containing indexes on the whole - state/state-choice space. - - Once we have the interpolated in all possible child state-choice states, - we rearange them to an array with row as states and columns as choices to - aggregate over the choices. This is saved in: - - - child_state_choices_to_aggr_choice: The state choice indexes in the child - states to aggregate over. Note these are relative indexes indexing to the - batch arrays from the step before. - Now we have for each child state a value/marginal utility with the index arrays - above and what is missing is the mapping for the exogenous/stochastic processes. - This is saved via: - - child_states_to_integrate_exog: The state choice indexes in the child states - to integrate over the exogenous processes. This is a relative index to the - batch arrays from the step before. - - """ - - n_periods = options["state_space"]["n_periods"] - state_choice_space = model_structure["state_choice_space"] - - out_of_bounds_state_choice_idx = state_choice_space.shape[0] + 1 - - state_space = model_structure["state_space"] - discrete_states_names = model_structure["discrete_states_names"] - - map_state_choice_to_parent_state = model_structure[ - "map_state_choice_to_parent_state" - ] - map_state_choice_to_child_states = model_structure[ - "map_state_choice_to_child_states" - ] - map_state_choice_to_index = model_structure["map_state_choice_to_index"] - - if n_periods == 2: - # In the case of a two period model, we just need the information of the last - # two periods - batch_info = { - "two_period_model": True, - "n_state_choices": state_choice_space.shape[0], - } - batch_info = add_last_two_period_information( - n_periods=n_periods, - state_choice_space=state_choice_space, - map_state_choice_to_parent_state=map_state_choice_to_parent_state, - map_state_choice_to_child_states=map_state_choice_to_child_states, - map_state_choice_to_index=map_state_choice_to_index, - discrete_states_names=discrete_states_names, - state_space=state_space, - batch_info=batch_info, - ) - - return batch_info - - ( - batches_list, - child_state_choice_idxs_to_interp_list, - child_state_choices_to_aggr_choice_list, - child_states_to_integrate_exog_list, - ) = determine_optimal_batch_size( - state_choice_space=state_choice_space, - n_periods=n_periods, - map_state_choice_to_child_states=map_state_choice_to_child_states, - map_state_choice_to_index=map_state_choice_to_index, - state_space=state_space, - out_of_bounds_state_choice_idx=out_of_bounds_state_choice_idx, - ) - - if len(batches_list) == 1: - # This is the case for a three period model. - batches_cover_all = True - else: - # In the case of more periods we determine if the last two batches have equal - # size - batches_cover_all = len(batches_list[-1]) == len(batches_list[-2]) - - if not batches_cover_all: - # In the case batches don't cover everything, we have to solve the last batch - # separately. Delete the last element from the relevant lists and save it in - # an extra dictionary - last_batch = batches_list[-1] - last_child_states_to_integrate_exog = child_states_to_integrate_exog_list[-1] - last_idx_to_aggregate_choice = child_state_choices_to_aggr_choice_list[-1] - last_child_state_idx_interp = child_state_choice_idxs_to_interp_list[-1] - - last_state_choices = { - key: state_choice_space[:, i][last_batch] - for i, key in enumerate(discrete_states_names + ["choice"]) - } - last_state_choices_childs = { - key: state_choice_space[:, i][last_child_state_idx_interp] - for i, key in enumerate(discrete_states_names + ["choice"]) - } - last_parent_state_idx_of_state_choice = map_state_choice_to_parent_state[ - last_child_state_idx_interp - ] - - last_batch_info = { - "state_choice_idx": last_batch, - "state_choices": last_state_choices, - "child_states_to_integrate_exog": last_child_states_to_integrate_exog, - # Child state infos. - "child_state_choices_to_aggr_choice": last_idx_to_aggregate_choice, - "child_state_choice_idxs_to_interp": last_child_state_idx_interp, - "child_states_idxs": last_parent_state_idx_of_state_choice, - "state_choices_childs": last_state_choices_childs, - } - batches_list = batches_list[:-1] - child_states_to_integrate_exog_list = child_states_to_integrate_exog_list[:-1] - child_state_choices_to_aggr_choice_list = ( - child_state_choices_to_aggr_choice_list[:-1] - ) - child_state_choice_idxs_to_interp_list = child_state_choice_idxs_to_interp_list[ - :-1 - ] - - # First convert batch information - batch_array = np.array(batches_list) - child_states_to_integrate_exog = np.array(child_states_to_integrate_exog_list) - - state_choices_batches = { - key: state_choice_space[:, i][batch_array] - for i, key in enumerate(discrete_states_names + ["choice"]) - } - - # Now create the child state arrays. As these can have different shapes than the - # batches, we have to extend them: - max_child_state_index_batch = np.max(child_states_to_integrate_exog, axis=(1, 2)) - ( - child_state_choice_idxs_to_interp, - child_state_choices_to_aggr_choice, - ) = extend_child_state_choices_to_aggregate_choices( - idx_to_aggregate_choice=child_state_choices_to_aggr_choice_list, - max_child_state_index_batch=max_child_state_index_batch, - idx_to_interpolate=child_state_choice_idxs_to_interp_list, - out_of_bounds_state_choice_idx=out_of_bounds_state_choice_idx, - ) - parent_state_idx_of_state_choice = map_state_choice_to_parent_state[ - child_state_choice_idxs_to_interp - ] - state_choices_childs = { - key: state_choice_space[:, i][child_state_choice_idxs_to_interp] - for i, key in enumerate(discrete_states_names + ["choice"]) - } - - batch_info = { - # First two bools determining the structure of solution functions we call - "two_period_model": False, - "batches_cover_all": batches_cover_all, - # Now the batch array information. First the batch itself - "batches_state_choice_idx": batch_array, - "state_choices": state_choices_batches, - "child_states_to_integrate_exog": child_states_to_integrate_exog, - # Then the child states - "child_state_choices_to_aggr_choice": child_state_choices_to_aggr_choice, - "child_state_choice_idxs_to_interp": child_state_choice_idxs_to_interp, - "child_states_idxs": parent_state_idx_of_state_choice, - "state_choices_childs": state_choices_childs, - } - if not batches_cover_all: - batch_info["last_batch_info"] = last_batch_info - batch_info = add_last_two_period_information( - n_periods=n_periods, - state_choice_space=state_choice_space, - map_state_choice_to_parent_state=map_state_choice_to_parent_state, - map_state_choice_to_child_states=map_state_choice_to_child_states, - map_state_choice_to_index=map_state_choice_to_index, - discrete_states_names=discrete_states_names, - state_space=state_space, - batch_info=batch_info, - ) - - return batch_info - - -def extend_child_state_choices_to_aggregate_choices( - idx_to_aggregate_choice, - max_child_state_index_batch, - idx_to_interpolate, - out_of_bounds_state_choice_idx, -): - """In case of uneven batches, we need to extend the child state objects to cover the - same number of state choices in each batch. - - As this object has in each batch the shape of n_state_choices x n_ - - """ - # There can be also be an uneven number of child states across batches. The - # indexes recorded in state_choice_times_exog_child_state_idxs only contain - # the indexes up the length. So we can just fill up without of bounds indexes. - # We also test this here - max_n_state_unique_in_batches = list( - map(lambda x: x.shape[0], idx_to_aggregate_choice) - ) - - # We check for internal constincy. The size (i.e. the number of states) of the - # state_choice idx to aggregate choices in each state has to correspond to the - # maximum state index in child indexes we integrate over. - if not np.all( - np.equal( - np.array(max_n_state_unique_in_batches) - 1, max_child_state_index_batch - ) - ): - raise ValueError( - "\n\nInternal error in the batch creation \n\n. " - "Please contact developer." - ) - - # Now span an array with n_states times the maximum number of child states across - # all batches and the number of choices. Fill with invalid state choice index - n_batches = len(idx_to_aggregate_choice) - max_n_child_states = np.max(max_n_state_unique_in_batches) - n_choices = idx_to_aggregate_choice[0].shape[1] - child_state_choices_to_aggr_choice = np.full( - (n_batches, max_n_child_states, n_choices), - fill_value=out_of_bounds_state_choice_idx, - dtype=int, # what about this hard-coded int here? - ) - - for id_batch in range(n_batches): - child_state_choices_to_aggr_choice[ - id_batch, : max_n_state_unique_in_batches[id_batch], : - ] = idx_to_aggregate_choice[id_batch] - - # The second array are the state choice indexes in the child states. As child - # states can have different admissible state choices this can be different in - # each batch. We fill up with invalid numbers. - max_child_state_choices = np.max(list(map(len, idx_to_interpolate))) - dummy_state = idx_to_interpolate[0][0] - child_state_choice_idxs_to_interp = np.full( - (n_batches, max_child_state_choices), - fill_value=dummy_state, - dtype=int, - ) - for id_batch in range(n_batches): - child_state_choice_idxs_to_interp[ - id_batch, : len(idx_to_interpolate[id_batch]) - ] = idx_to_interpolate[id_batch] - - return child_state_choice_idxs_to_interp, child_state_choices_to_aggr_choice - - -def add_last_two_period_information( - n_periods, - state_choice_space, - map_state_choice_to_parent_state, - map_state_choice_to_child_states, - map_state_choice_to_index, - discrete_states_names, - state_space, - batch_info, -): - # Select state_choice idxs in final period - idx_state_choice_final_period = np.where(state_choice_space[:, 0] == n_periods - 1)[ - 0 - ] - # To solve the second last period, we need the child states in the last period - # and the corresponding matrix, where each row is a state with the state choice - # ids as entry in each choice - idx_states_final_period = np.where(state_space[:, 0] == n_periods - 1)[0] - states_final_period = state_space[idx_states_final_period] - # Now construct a tuple for indexing - n_state_vars = states_final_period.shape[1] - states_tuple = tuple(states_final_period[:, i] for i in range(n_state_vars)) - - # Now get the matrix we use for choice aggregation - state_to_choices_final_period = map_state_choice_to_index[states_tuple] - - # Reindex the state choices in the final period, to have them starting at 0. - min_val = int(np.min(idx_state_choice_final_period)) - state_to_choices_final_period -= min_val - - idx_state_choice_second_last_period = np.where( - state_choice_space[:, 0] == n_periods - 2 - )[0] - # Also normalize the state choice idxs - child_states_second_last_period = map_state_choice_to_child_states[ - idx_state_choice_second_last_period - ] - - min_val = int(np.min(idx_states_final_period)) - child_states_second_last_period -= min_val - - # Also add parent states in last period - parent_states_final_period = map_state_choice_to_parent_state[ - idx_state_choice_final_period - ] - - batch_info = { - **batch_info, - "idx_state_choices_final_period": idx_state_choice_final_period, - "idx_state_choices_second_last_period": idx_state_choice_second_last_period, - "idxs_parent_states_final_period": parent_states_final_period, - "state_to_choices_final_period": state_to_choices_final_period, - "child_states_second_last_period": child_states_second_last_period, - } - - # Also add state choice mat as dictionary for each of the two periods - for idx, period_name in [ - (idx_state_choice_final_period, "final"), - (idx_state_choice_second_last_period, "second_last"), - ]: - batch_info[f"state_choice_mat_{period_name}_period"] = { - key: state_choice_space[:, i][idx] - for i, key in enumerate(discrete_states_names + ["choice"]) - } - return batch_info - - -def determine_optimal_batch_size( - state_choice_space, - n_periods, - map_state_choice_to_child_states, - map_state_choice_to_index, - state_space, - out_of_bounds_state_choice_idx, -): - invalid_state_idx = np.iinfo(map_state_choice_to_index.dtype).max - - state_choice_space_wo_last_two = state_choice_space[ - state_choice_space[:, 0] < n_periods - 2 - ] - - # Filter out last period state_choice_ids - child_states_idx_backward = map_state_choice_to_child_states[ - state_choice_space[:, 0] < n_periods - 2 - ] - # Order by child index to solve state choices in the same child states together - sort_index_by_child_states = np.argsort(child_states_idx_backward[:, 0]) - - state_choice_index_raw = np.arange( - state_choice_space_wo_last_two.shape[0], dtype=int - ) - state_choice_index_back = np.take( - state_choice_index_raw, sort_index_by_child_states, axis=0 - ) - - n_state_vars = state_space.shape[1] - - size_last_period = state_choice_space[ - state_choice_space[:, 0] == state_choice_space_wo_last_two[-1, 0] - ].shape[0] - - batch_not_found = True - current_batch_size = size_last_period - need_to_reduce_batchsize = False - - while batch_not_found: - if need_to_reduce_batchsize: - current_batch_size = int(current_batch_size * 0.98) - need_to_reduce_batchsize = False - - # Split state choice indexes in - index_to_spilt = np.arange( - current_batch_size, - state_choice_index_back.shape[0], - current_batch_size, - ) - - batches_to_check = np.split( - np.flip(state_choice_index_back), - index_to_spilt, - ) - - child_states_to_integrate_exog = [] - child_state_choices_to_aggr_choice = [] - child_state_choice_idxs_to_interpolate = [] - - for i, batch in enumerate(batches_to_check): - # First get all child states and a mapping from the state-choice to the - # different child states due to exogenous change of states. - child_states_idxs = map_state_choice_to_child_states[batch] - unique_child_states, inverse_ids = np.unique( - child_states_idxs, return_index=False, return_inverse=True - ) - child_states_to_integrate_exog += [ - inverse_ids.reshape(child_states_idxs.shape) - ] - - # Next we use the child state indexes to get all unique child states and - # their corresponding state-choices. - child_states_batch = np.take(state_space, unique_child_states, axis=0) - child_states_tuple = tuple( - child_states_batch[:, i] for i in range(n_state_vars) - ) - unique_state_choice_idxs_childs = map_state_choice_to_index[ - child_states_tuple - ] - - # Now we create a mapping from the child-state choices back to the states - # with state-choices in columns for the choices - ( - unique_child_state_choice_idxs, - inverse_child_state_choice_ids, - ) = np.unique( - unique_state_choice_idxs_childs, return_index=False, return_inverse=True - ) - - # Treat invalid choices: - if unique_child_state_choice_idxs[-1] == invalid_state_idx: - unique_child_state_choice_idxs = unique_child_state_choice_idxs[:-1] - inverse_child_state_choice_ids[ - inverse_child_state_choice_ids - >= np.max(inverse_child_state_choice_ids) - ] = out_of_bounds_state_choice_idx - - # Save the mapping from child-state-choices to child-states - child_state_choices_to_aggr_choice += [ - inverse_child_state_choice_ids.reshape( - unique_state_choice_idxs_childs.shape - ) - ] - # And the list of the unique child states. - child_state_choice_idxs_to_interpolate += [unique_child_state_choice_idxs] - - # Now check if the smallest index of the child state choices is larger than - # the maximum index of the batch, i.e. if all state choice relevant to - # solve the current state choices of the batch are in previous batches - min_state_choice_idx = np.min(unique_child_state_choice_idxs) - if batch.max() >= min_state_choice_idx: - batch_not_found = True - need_to_reduce_batchsize = True - break - - print("The batch size of the backwards induction is ", current_batch_size) - - if not need_to_reduce_batchsize: - batch_not_found = False - - return ( - batches_to_check, - child_state_choice_idxs_to_interpolate, - child_state_choices_to_aggr_choice, - child_states_to_integrate_exog, - ) diff --git a/src/dcegm/pre_processing/batches/__init__.py b/src/dcegm/pre_processing/batches/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dcegm/pre_processing/batches/algo_batch_size.py b/src/dcegm/pre_processing/batches/algo_batch_size.py new file mode 100644 index 00000000..53ea46c8 --- /dev/null +++ b/src/dcegm/pre_processing/batches/algo_batch_size.py @@ -0,0 +1,130 @@ +import numpy as np + + +def determine_optimal_batch_size( + bool_state_choices_to_batch, + state_choice_space, + map_state_choice_to_child_states, + map_state_choice_to_index, + state_space, +): + # Get invalid state idx, by looking at the index mapping dtype + invalid_state_idx = np.iinfo(map_state_choice_to_index.dtype).max + # Get out of bound state choice idx, by taking the number of state choices + 1 + out_of_bounds_state_choice_idx = state_choice_space.shape[0] + 1 + + state_choice_space_to_batch = state_choice_space[bool_state_choices_to_batch] + + child_states_of_state_choices_to_batch = map_state_choice_to_child_states[ + bool_state_choices_to_batch + ] + # Order by child index to solve state choices in the same child states together + # Use first child state of the n_exog_states of each child states, because + # rows are the same in the child states mapping array + sort_index_by_child_states = np.argsort( + child_states_of_state_choices_to_batch[:, 0] + ) + + idx_state_choice_raw = np.where(bool_state_choices_to_batch)[0] + state_choice_index_back = np.take( + idx_state_choice_raw, sort_index_by_child_states, axis=0 + ) + + n_state_vars = state_space.shape[1] + + size_last_period = state_choice_space[ + state_choice_space[:, 0] == state_choice_space_to_batch[-1, 0] + ].shape[0] + + batch_not_found = True + current_batch_size = size_last_period + need_to_reduce_batchsize = False + + while batch_not_found: + if need_to_reduce_batchsize: + current_batch_size = int(current_batch_size * 0.98) + need_to_reduce_batchsize = False + + # Split state choice indexes in + index_to_spilt = np.arange( + current_batch_size, + state_choice_index_back.shape[0], + current_batch_size, + ) + + batches_to_check = np.split( + np.flip(state_choice_index_back), + index_to_spilt, + ) + + child_states_to_integrate_exog = [] + child_state_choices_to_aggr_choice = [] + child_state_choice_idxs_to_interpolate = [] + + for i, batch in enumerate(batches_to_check): + # First get all child states and a mapping from the state-choice to the + # different child states due to exogenous change of states. + child_states_idxs = map_state_choice_to_child_states[batch] + unique_child_states, inverse_ids = np.unique( + child_states_idxs, return_index=False, return_inverse=True + ) + child_states_to_integrate_exog += [ + inverse_ids.reshape(child_states_idxs.shape) + ] + + # Next we use the child state indexes to get all unique child states and + # their corresponding state-choices. + child_states_batch = np.take(state_space, unique_child_states, axis=0) + child_states_tuple = tuple( + child_states_batch[:, i] for i in range(n_state_vars) + ) + unique_state_choice_idxs_childs = map_state_choice_to_index[ + child_states_tuple + ] + + # Now we create a mapping from the child-state choices back to the states + # with state-choices in columns for the choices + ( + unique_child_state_choice_idxs, + inverse_child_state_choice_ids, + ) = np.unique( + unique_state_choice_idxs_childs, return_index=False, return_inverse=True + ) + + # Treat invalid choices: + if unique_child_state_choice_idxs[-1] == invalid_state_idx: + unique_child_state_choice_idxs = unique_child_state_choice_idxs[:-1] + inverse_child_state_choice_ids[ + inverse_child_state_choice_ids + >= np.max(inverse_child_state_choice_ids) + ] = out_of_bounds_state_choice_idx + + # Save the mapping from child-state-choices to child-states + child_state_choices_to_aggr_choice += [ + inverse_child_state_choice_ids.reshape( + unique_state_choice_idxs_childs.shape + ) + ] + # And the list of the unique child states. + child_state_choice_idxs_to_interpolate += [unique_child_state_choice_idxs] + + # Now check if the smallest index of the child state choices is larger than + # the maximum index of the batch, i.e. if all state choice relevant to + # solve the current state choices of the batch are in previous batches + min_state_choice_idx = np.min(unique_child_state_choice_idxs) + if batch.max() >= min_state_choice_idx: + batch_not_found = True + need_to_reduce_batchsize = True + break + + print("The batch size of the backwards induction is ", current_batch_size) + + if not need_to_reduce_batchsize: + batch_not_found = False + + return ( + batches_to_check, + child_state_choice_idxs_to_interpolate, + child_state_choices_to_aggr_choice, + child_states_to_integrate_exog, + ) diff --git a/src/dcegm/pre_processing/batches/batch_creation.py b/src/dcegm/pre_processing/batches/batch_creation.py new file mode 100644 index 00000000..26c34021 --- /dev/null +++ b/src/dcegm/pre_processing/batches/batch_creation.py @@ -0,0 +1,135 @@ +from dcegm.pre_processing.batches.last_two_periods import ( + add_last_two_period_information, +) +from dcegm.pre_processing.batches.single_segment import create_single_segment_of_batches + + +def create_batches_and_information( + model_structure, + state_space_options, +): + """Batches are used instead of periods to have chunks of equal sized state choices. + The batch information dictionary contains the following arrays reflecting the. + + steps in the backward induction: + - batches_state_choice_idx: The state choice indexes in each batch to be solved. + To solve the state choices in the egm step, we have to look at the child states + and the corresponding state choice indexes in the child states. For that we save + the following: + - child_state_choice_idxs_to_interp: The state choice indexes in we need to + interpolate the wealth on. + - child_states_idxs: The parent state indexes of the child states, i.e. the + child states themself. We calculate the resources at the beginning of + period before the backwards induction with the budget equation for each + saving and income shock grid point. + + Note: These two index arrays containing indexes on the whole + state/state-choice space. + + Once we have the interpolated in all possible child state-choice states, + we rearange them to an array with row as states and columns as choices to + aggregate over the choices. This is saved in: + + - child_state_choices_to_aggr_choice: The state choice indexes in the child + states to aggregate over. Note these are relative indexes indexing to the + batch arrays from the step before. + Now we have for each child state a value/marginal utility with the index arrays + above and what is missing is the mapping for the exogenous/stochastic processes. + This is saved via: + - child_states_to_integrate_exog: The state choice indexes in the child states + to integrate over the exogenous processes. This is a relative index to the + batch arrays from the step before. + + """ + + n_periods = state_space_options["n_periods"] + + last_two_period_info = add_last_two_period_information( + n_periods=n_periods, + model_structure=model_structure, + ) + + if n_periods == 2: + # In the case of a two period model, we just need the information of the last + # two periods + batch_info = { + "two_period_model": True, + "last_two_period_info": last_two_period_info, + } + + return batch_info + + state_choice_space = model_structure["state_choice_space"] + bool_state_choices_to_batch = state_choice_space[:, 0] < n_periods - 2 + + if "min_period_batch_segments" not in state_space_options.keys(): + + single_batch_segment_info = create_single_segment_of_batches( + bool_state_choices_to_batch, model_structure + ) + segment_infos = { + "n_segments": 1, + "batches_info_segment_0": single_batch_segment_info, + } + + else: + + if isinstance(state_space_options["min_period_batch_segments"], int): + n_segments = 2 + min_periods_to_split = [state_space_options["min_period_batch_segments"]] + elif isinstance(state_space_options["min_period_batch_segments"], list): + n_segments = len(state_space_options["min_period_batch_segments"]) + 1 + min_periods_to_split = state_space_options["min_period_batch_segments"] + # Check if periods are increasing and at least two periods apart. + # Also that they are at least two periods smaller than n_periods - 2 + if not all( + min_periods_to_split[i] < min_periods_to_split[i + 1] + for i in range(len(min_periods_to_split) - 1) + ) or not all( + min_periods_to_split[i] < n_periods - 2 - 2 + for i in range(len(min_periods_to_split)) + ): + raise ValueError( + "The periods to split the batches have to be increasing and at least two periods apart." + ) + else: + raise ValueError("So far only int or list separation is supported.") + + segment_infos = { + "n_segments": n_segments, + } + + for id_segment in range(n_segments - 1): + + # Start from the end and assign segments, i.e. segment 0 starts at + # min_periods_to_split[-1] and ends at n_periods - 2 + period_to_split = min_periods_to_split[-id_segment - 1] + + split_cond = state_choice_space[:, 0] < period_to_split + bool_state_choices_segment = bool_state_choices_to_batch & (~split_cond) + + segment_batch_info = create_single_segment_of_batches( + bool_state_choices_segment, model_structure + ) + segment_infos[f"batches_info_segment_{id_segment}"] = segment_batch_info + + # Set the bools to False which have been batched already + bool_state_choices_to_batch = bool_state_choices_to_batch & split_cond + + last_segment_batch_info = create_single_segment_of_batches( + bool_state_choices_to_batch, model_structure + ) + + # We loop until n_segments - 2 and then add the last segment + segment_infos[f"batches_info_segment_{n_segments - 1}"] = ( + last_segment_batch_info + ) + + batch_info = { + # First two bools determining the structure of solution functions we call + "two_period_model": False, + **segment_infos, + "last_two_period_info": last_two_period_info, + } + + return batch_info diff --git a/src/dcegm/pre_processing/batches/last_two_periods.py b/src/dcegm/pre_processing/batches/last_two_periods.py new file mode 100644 index 00000000..76f6aba5 --- /dev/null +++ b/src/dcegm/pre_processing/batches/last_two_periods.py @@ -0,0 +1,74 @@ +import numpy as np + + +def add_last_two_period_information( + n_periods, + model_structure, +): + state_choice_space = model_structure["state_choice_space"] + + state_space = model_structure["state_space"] + discrete_states_names = model_structure["discrete_states_names"] + + map_state_choice_to_parent_state = model_structure[ + "map_state_choice_to_parent_state" + ] + map_state_choice_to_child_states = model_structure[ + "map_state_choice_to_child_states" + ] + map_state_choice_to_index = model_structure["map_state_choice_to_index"] + + # Select state_choice idxs in final period + idx_state_choice_final_period = np.where(state_choice_space[:, 0] == n_periods - 1)[ + 0 + ] + # To solve the second last period, we need the child states in the last period + # and the corresponding matrix, where each row is a state with the state choice + # ids as entry in each choice + idx_states_final_period = np.where(state_space[:, 0] == n_periods - 1)[0] + states_final_period = state_space[idx_states_final_period] + # Now construct a tuple for indexing + n_state_vars = states_final_period.shape[1] + states_tuple = tuple(states_final_period[:, i] for i in range(n_state_vars)) + + # Now get the matrix we use for choice aggregation + state_to_choices_final_period = map_state_choice_to_index[states_tuple] + + # Reindex the state choices in the final period, to have them starting at 0. + min_val = int(np.min(idx_state_choice_final_period)) + state_to_choices_final_period -= min_val + + idx_state_choice_second_last_period = np.where( + state_choice_space[:, 0] == n_periods - 2 + )[0] + # Also normalize the state choice idxs + child_states_second_last_period = map_state_choice_to_child_states[ + idx_state_choice_second_last_period + ] + + min_val = int(np.min(idx_states_final_period)) + child_states_second_last_period -= min_val + + # Also add parent states in last period + parent_states_final_period = map_state_choice_to_parent_state[ + idx_state_choice_final_period + ] + + last_two_period_info = { + "idx_state_choices_final_period": idx_state_choice_final_period, + "idx_state_choices_second_last_period": idx_state_choice_second_last_period, + "idxs_parent_states_final_period": parent_states_final_period, + "state_to_choices_final_period": state_to_choices_final_period, + "child_states_second_last_period": child_states_second_last_period, + } + + # Also add state choice mat as dictionary for each of the two periods + for idx, period_name in [ + (idx_state_choice_final_period, "final"), + (idx_state_choice_second_last_period, "second_last"), + ]: + last_two_period_info[f"state_choice_mat_{period_name}_period"] = { + key: state_choice_space[:, i][idx] + for i, key in enumerate(discrete_states_names + ["choice"]) + } + return last_two_period_info diff --git a/src/dcegm/pre_processing/batches/single_segment.py b/src/dcegm/pre_processing/batches/single_segment.py new file mode 100644 index 00000000..c219f1cd --- /dev/null +++ b/src/dcegm/pre_processing/batches/single_segment.py @@ -0,0 +1,265 @@ +import numpy as np + +from dcegm.pre_processing.batches.algo_batch_size import determine_optimal_batch_size + + +def create_single_segment_of_batches(bool_state_choices_to_batch, model_structure): + """Create a single segment of evenly sized batches. If the last batch is not evenly + we correct it. + """ + + state_choice_space = model_structure["state_choice_space"] + + state_space = model_structure["state_space"] + discrete_states_names = model_structure["discrete_states_names"] + + map_state_choice_to_parent_state = model_structure[ + "map_state_choice_to_parent_state" + ] + map_state_choice_to_child_states = model_structure[ + "map_state_choice_to_child_states" + ] + map_state_choice_to_index = model_structure["map_state_choice_to_index"] + + ( + batches_list, + child_state_choice_idxs_to_interp_list, + child_state_choices_to_aggr_choice_list, + child_states_to_integrate_exog_list, + ) = determine_optimal_batch_size( + bool_state_choices_to_batch=bool_state_choices_to_batch, + state_choice_space=state_choice_space, + map_state_choice_to_child_states=map_state_choice_to_child_states, + map_state_choice_to_index=map_state_choice_to_index, + state_space=state_space, + ) + + ( + batches_list, + child_states_to_integrate_exog_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + batches_cover_all, + last_batch_info, + ) = correct_for_uneven_last_batch( + batches_list, + child_states_to_integrate_exog_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + state_choice_space, + map_state_choice_to_parent_state, + discrete_states_names, + ) + + single_batch_segment_info = prepare_and_align_batch_arrays( + batches_list, + child_states_to_integrate_exog_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + state_choice_space, + map_state_choice_to_parent_state, + discrete_states_names, + ) + single_batch_segment_info["batches_cover_all"] = batches_cover_all + if not batches_cover_all: + single_batch_segment_info["last_batch_info"] = last_batch_info + + return single_batch_segment_info + + +def correct_for_uneven_last_batch( + batches_list, + child_states_to_integrate_exog_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + state_choice_space, + map_state_choice_to_parent_state, + discrete_states_names, +): + """Check if the last batch has the same length as the others. If not, we need to save + the information separately. + """ + if len(batches_list) == 1: + # This is the case for a three period model. + batches_cover_all = True + # Set last batch info to None, because it is not needed + last_batch_info = None + else: + # In the case of more periods we determine if the last two batches have equal + # size + batches_cover_all = len(batches_list[-1]) == len(batches_list[-2]) + # Set last batch info to None for now. If batches_cover_all is True it is not needed, + # if it is False, it will be overwritten + last_batch_info = None + + if not batches_cover_all: + # In the case batches don't cover everything, we have to solve the last batch + # separately. Delete the last element from the relevant lists and save it in + # an extra dictionary + last_batch = batches_list[-1] + last_child_states_to_integrate_exog = child_states_to_integrate_exog_list[-1] + last_idx_to_aggregate_choice = child_state_choices_to_aggr_choice_list[-1] + last_child_state_idx_interp = child_state_choice_idxs_to_interp_list[-1] + + last_state_choices = { + key: state_choice_space[:, i][last_batch] + for i, key in enumerate(discrete_states_names + ["choice"]) + } + last_state_choices_childs = { + key: state_choice_space[:, i][last_child_state_idx_interp] + for i, key in enumerate(discrete_states_names + ["choice"]) + } + last_parent_state_idx_of_state_choice = map_state_choice_to_parent_state[ + last_child_state_idx_interp + ] + + last_batch_info = { + "state_choice_idx": last_batch, + "state_choices": last_state_choices, + "child_states_to_integrate_exog": last_child_states_to_integrate_exog, + # Child state infos. + "child_state_choices_to_aggr_choice": last_idx_to_aggregate_choice, + "child_state_choice_idxs_to_interp": last_child_state_idx_interp, + "child_states_idxs": last_parent_state_idx_of_state_choice, + "state_choices_childs": last_state_choices_childs, + } + batches_list = batches_list[:-1] + child_states_to_integrate_exog_list = child_states_to_integrate_exog_list[:-1] + child_state_choices_to_aggr_choice_list = ( + child_state_choices_to_aggr_choice_list[:-1] + ) + child_state_choice_idxs_to_interp_list = child_state_choice_idxs_to_interp_list[ + :-1 + ] + return ( + batches_list, + child_states_to_integrate_exog_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + batches_cover_all, + last_batch_info, + ) + + +def prepare_and_align_batch_arrays( + batches_list, + child_states_to_integrate_exog_list, + child_state_choices_to_aggr_choice_list, + child_state_choice_idxs_to_interp_list, + state_choice_space, + map_state_choice_to_parent_state, + discrete_states_names, +): + """Prepare the lists we get out of the algorithm (and after correction) for the + jax calculations. They all need to have the same length of leading axis""" + # Get out of bound state choice idx, by taking the number of state choices + 1 + out_of_bounds_state_choice_idx = state_choice_space.shape[0] + 1 + + # First convert batch information + batch_array = np.array(batches_list) + child_states_to_integrate_exog = np.array(child_states_to_integrate_exog_list) + + state_choices_batches = { + key: state_choice_space[:, i][batch_array] + for i, key in enumerate(discrete_states_names + ["choice"]) + } + + # Now create the child state arrays. As these can have different shapes than the + # batches, we have to extend them: + max_child_state_index_batch = np.max(child_states_to_integrate_exog, axis=(1, 2)) + ( + child_state_choice_idxs_to_interp, + child_state_choices_to_aggr_choice, + ) = extend_child_state_choices_to_aggregate_choices( + idx_to_aggregate_choice=child_state_choices_to_aggr_choice_list, + max_child_state_index_batch=max_child_state_index_batch, + idx_to_interpolate=child_state_choice_idxs_to_interp_list, + out_of_bounds_state_choice_idx=out_of_bounds_state_choice_idx, + ) + parent_state_idx_of_state_choice = map_state_choice_to_parent_state[ + child_state_choice_idxs_to_interp + ] + state_choices_childs = { + key: state_choice_space[:, i][child_state_choice_idxs_to_interp] + for i, key in enumerate(discrete_states_names + ["choice"]) + } + + batch_info = { + # Now the batch array information. First the batch itself + "batches_state_choice_idx": batch_array, + "state_choices": state_choices_batches, + "child_states_to_integrate_exog": child_states_to_integrate_exog, + # Then the child states + "child_state_choices_to_aggr_choice": child_state_choices_to_aggr_choice, + "child_state_choice_idxs_to_interp": child_state_choice_idxs_to_interp, + "child_states_idxs": parent_state_idx_of_state_choice, + "state_choices_childs": state_choices_childs, + } + return batch_info + + +def extend_child_state_choices_to_aggregate_choices( + idx_to_aggregate_choice, + max_child_state_index_batch, + idx_to_interpolate, + out_of_bounds_state_choice_idx, +): + """In case of uneven batches, we need to extend the child state objects to cover the + same number of state choices in each batch. + + As this object has in each batch the shape of n_state_choices x n_ + + """ + # There can be also be an uneven number of child states across batches. The + # indexes recorded in state_choice_times_exog_child_state_idxs only contain + # the indexes up the length. So we can just fill up without of bounds indexes. + # We also test this here + max_n_state_unique_in_batches = list( + map(lambda x: x.shape[0], idx_to_aggregate_choice) + ) + + # We check for internal constincy. The size (i.e. the number of states) of the + # state_choice idx to aggregate choices in each state has to correspond to the + # maximum state index in child indexes we integrate over. + if not np.all( + np.equal( + np.array(max_n_state_unique_in_batches) - 1, max_child_state_index_batch + ) + ): + raise ValueError( + "\n\nInternal error in the batch creation \n\n. " + "Please contact developer." + ) + + # Now span an array with n_states times the maximum number of child states across + # all batches and the number of choices. Fill with invalid state choice index + n_batches = len(idx_to_aggregate_choice) + max_n_child_states = np.max(max_n_state_unique_in_batches) + n_choices = idx_to_aggregate_choice[0].shape[1] + child_state_choices_to_aggr_choice = np.full( + (n_batches, max_n_child_states, n_choices), + fill_value=out_of_bounds_state_choice_idx, + dtype=int, # what about this hard-coded int here? + ) + + for id_batch in range(n_batches): + child_state_choices_to_aggr_choice[ + id_batch, : max_n_state_unique_in_batches[id_batch], : + ] = idx_to_aggregate_choice[id_batch] + + # The second array are the state choice indexes in the child states. As child + # states can have different admissible state choices this can be different in + # each batch. We fill up with invalid numbers. + max_child_state_choices = np.max(list(map(len, idx_to_interpolate))) + dummy_state = idx_to_interpolate[0][0] + child_state_choice_idxs_to_interp = np.full( + (n_batches, max_child_state_choices), + fill_value=dummy_state, + dtype=int, + ) + for id_batch in range(n_batches): + child_state_choice_idxs_to_interp[ + id_batch, : len(idx_to_interpolate[id_batch]) + ] = idx_to_interpolate[id_batch] + + return child_state_choice_idxs_to_interp, child_state_choices_to_aggr_choice diff --git a/src/dcegm/pre_processing/check_options.py b/src/dcegm/pre_processing/check_options.py new file mode 100644 index 00000000..ce64af78 --- /dev/null +++ b/src/dcegm/pre_processing/check_options.py @@ -0,0 +1,116 @@ +import numpy as np + + +def check_options_and_set_defaults(options): + """Check if options are valid and set defaults.""" + + if not isinstance(options, dict): + raise ValueError("Options must be a dictionary.") + + if "state_space" in options: + if not isinstance(options["state_space"], dict): + raise ValueError("State space must be a dictionary.") + else: + raise ValueError("Options must contain a state space dictionary.") + + if "n_periods" not in options["state_space"]: + raise ValueError("State space must contain the number of periods.") + + if not isinstance(options["state_space"]["n_periods"], int): + raise ValueError("Number of periods must be an integer.") + + if not options["state_space"]["n_periods"] > 1: + raise ValueError("Number of periods must be greater than 1.") + + if "choices" not in options["state_space"]: + print("Choices not given. Assume only single choice with value 0") + options["state_space"]["choices"] = np.array([0], dtype=np.uint8) + + if "choices" in options["state_space"]: + if isinstance(options["state_space"]["choices"], list): + options["state_space"]["choices"] = np.array( + options["state_space"]["choices"], dtype=np.uint8 + ) + elif isinstance(options["state_space"]["choices"], int): + options["state_space"]["choices"] = np.array( + [options["state_space"]["choices"]], dtype=np.uint8 + ) + elif isinstance(options["state_space"]["choices"], np.ndarray): + options["state_space"]["choices"] = options["state_space"][ + "choices" + ].astype(np.uint8) + else: + raise ValueError("Choices must be a list or an integer.") + + if "model_params" not in options: + raise ValueError("Options must contain a model parameters dictionary.") + + if not isinstance(options["model_params"], dict): + raise ValueError("Model parameters must be a dictionary.") + + if "n_choices" not in options["model_params"]: + options["model_params"]["n_choices"] = len(options["state_space"]["choices"]) + + n_savings_grid_points = len(options["state_space"]["continuous_states"]["wealth"]) + options["n_wealth_grid"] = n_savings_grid_points + + if "tuning_params" not in options: + options["tuning_params"] = {} + + options["tuning_params"]["extra_wealth_grid_factor"] = ( + options["tuning_params"]["extra_wealth_grid_factor"] + if "extra_wealth_grid_factor" in options["tuning_params"] + else 0.2 + ) + options["tuning_params"]["n_constrained_points_to_add"] = ( + options["tuning_params"]["n_constrained_points_to_add"] + if "n_constrained_points_to_add" in options["tuning_params"] + else n_savings_grid_points // 10 + ) + + if ( + n_savings_grid_points + * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) + < n_savings_grid_points + + options["tuning_params"]["n_constrained_points_to_add"] + ): + raise ValueError( + f"""\n\n + When preparing the tuning parameters for the upper + envelope, we found the following contradicting parameters: \n + The extra wealth grid factor of {options["tuning_params"]["extra_wealth_grid_factor"]} is too small + to cover the {options["tuning_params"]["n_constrained_points_to_add"]} wealth points which are added in + the credit constrained part of the wealth grid. \n\n""" + ) + options["tuning_params"]["n_total_wealth_grid"] = int( + n_savings_grid_points + * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) + ) + + exog_grids = options["state_space"]["continuous_states"].copy() + + if len(options["state_space"]["continuous_states"]) == 2: + second_continuous_state = next( + ( + {key: value} + for key, value in options["state_space"]["continuous_states"].items() + if key != "wealth" + ), + None, + ) + + second_continuous_state_name = list(second_continuous_state.keys())[0] + options["second_continuous_state_name"] = second_continuous_state_name + + options["tuning_params"]["n_second_continuous_grid"] = len( + second_continuous_state[second_continuous_state_name] + ) + + exog_grids["second_continuous"] = options["state_space"]["continuous_states"][ + second_continuous_state_name + ] + exog_grids.pop(second_continuous_state_name) + + options["exog_grids"] = exog_grids + + return options diff --git a/src/dcegm/pre_processing/params.py b/src/dcegm/pre_processing/check_params.py similarity index 100% rename from src/dcegm/pre_processing/params.py rename to src/dcegm/pre_processing/check_params.py diff --git a/src/dcegm/pre_processing/debugging.py b/src/dcegm/pre_processing/debugging.py deleted file mode 100644 index c21ce0e4..00000000 --- a/src/dcegm/pre_processing/debugging.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Dict - -import numpy as np -import pandas as pd - -from dcegm.pre_processing.state_space import ( - process_endog_state_specifications, - process_exog_model_specifications, -) - - -def inspect_state_space( - options: Dict[str, float], -): - """Creates a data frame of all potential states and a feasibility flag.""" - state_space_options = options["state_space"] - model_params = options["model_params"] - - n_periods = state_space_options["n_periods"] - n_choices = len(state_space_options["choices"]) - - ( - add_endog_state_func, - endog_states_names, - n_endog_states, - sparsity_func, - ) = process_endog_state_specifications( - state_space_options=state_space_options, model_params=model_params - ) - - ( - exog_states_names, - exog_state_space, - ) = process_exog_model_specifications(state_space_options=state_space_options) - - states_names_without_exog = ["period", "lagged_choice"] + endog_states_names - - state_space_wo_exog_list = [] - is_feasible_list = [] - - for period in range(n_periods): - for endog_state_id in range(n_endog_states): - for lagged_choice in range(n_choices): - # Select the endogenous state combination - endog_states = add_endog_state_func(endog_state_id) - - # Create the state vector without the exogenous processes - state_without_exog = [period, lagged_choice] + endog_states - state_space_wo_exog_list += [state_without_exog] - - # Transform to dictionary to call sparsity function from user - state_dict_without_exog = { - states_names_without_exog[i]: state_value - for i, state_value in enumerate(state_without_exog) - } - - is_state_feasible = sparsity_func(**state_dict_without_exog) - is_feasible_list += [is_state_feasible] - - n_exog_states = exog_state_space.shape[0] - state_space_wo_exog = np.array(state_space_wo_exog_list) - state_space_wo_exog_full = np.repeat(state_space_wo_exog, n_exog_states, axis=0) - exog_state_space_full = np.tile(exog_state_space, (state_space_wo_exog.shape[0], 1)) - - state_space = np.concatenate( - (state_space_wo_exog_full, exog_state_space_full), axis=1 - ) - - state_space_df = pd.DataFrame( - state_space, columns=states_names_without_exog + exog_states_names - ) - is_feasible_array = np.array(is_feasible_list, dtype=bool) - - state_space_df["is_feasible"] = np.repeat(is_feasible_array, n_exog_states, axis=0) - - return state_space_df diff --git a/src/dcegm/pre_processing/model_functions.py b/src/dcegm/pre_processing/model_functions.py index cac4f7f9..e800d655 100644 --- a/src/dcegm/pre_processing/model_functions.py +++ b/src/dcegm/pre_processing/model_functions.py @@ -3,7 +3,9 @@ import jax.numpy as jnp from upper_envelope.fues_jax.fues_jax import fues_jax -from dcegm.pre_processing.exog_processes import create_exog_transition_function +from dcegm.pre_processing.model_structure.exogenous_processes import ( + create_exog_transition_function, +) from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options @@ -98,12 +100,16 @@ def process_model_functions( ) # Now state space functions - get_state_specific_choice_set, get_next_period_state, update_continuous_state = ( + state_specific_choice_set, next_period_endogenous_state, sparsity_condition = ( process_state_space_functions( state_space_functions, options, continuous_state_name ) ) + next_period_continuous_state = process_second_continuous_update_function( + continuous_state_name, state_space_functions, options + ) + # Budget equation compute_beginning_of_period_wealth = ( determine_function_arguments_and_partial_options( @@ -126,11 +132,12 @@ def process_model_functions( "compute_utility_final": compute_utility_final, "compute_marginal_utility_final": compute_marginal_utility_final, "compute_beginning_of_period_wealth": compute_beginning_of_period_wealth, - "update_continuous_state": update_continuous_state, + "next_period_continuous_state": next_period_continuous_state, + "sparsity_condition": sparsity_condition, "compute_exog_transition_vec": compute_exog_transition_vec, "processed_exog_funcs": processed_exog_funcs_dict, - "get_state_specific_choice_set": get_state_specific_choice_set, - "get_next_period_state": get_next_period_state, + "state_specific_choice_set": state_specific_choice_set, + "next_period_endogenous_state": next_period_endogenous_state, "compute_upper_envelope": compute_upper_envelope, } @@ -145,65 +152,74 @@ def process_state_space_functions( {} if state_space_functions is None else state_space_functions ) - if "get_state_specific_choice_set" not in state_space_functions: + if "state_specific_choice_set" not in state_space_functions: print( "State specific choice set not provided. Assume all choices are " "available in every state." ) - def get_state_specific_choice_set(**kwargs): + def state_specific_choice_set(**kwargs): return jnp.array(options["state_space"]["choices"]) else: - get_state_specific_choice_set = ( - determine_function_arguments_and_partial_options( - func=state_space_functions["get_state_specific_choice_set"], - options=options["model_params"], - continuous_state_name=continuous_state_name, - ) + state_specific_choice_set = determine_function_arguments_and_partial_options( + func=state_space_functions["state_specific_choice_set"], + options=options["model_params"], + continuous_state_name=continuous_state_name, ) - if "get_next_period_state" not in state_space_functions: + if "next_period_endogenous_state" not in state_space_functions: print( "Update function for state space not given. Assume states only change " "with an increase of the period and lagged choice." ) - def get_next_period_state(**kwargs): + def next_period_endogenous_state(**kwargs): return {"period": kwargs["period"] + 1, "lagged_choice": kwargs["choice"]} else: - get_next_period_state = determine_function_arguments_and_partial_options( - func=state_space_functions["get_next_period_state"], + next_period_endogenous_state = determine_function_arguments_and_partial_options( + func=state_space_functions["next_period_endogenous_state"], options=options["model_params"], continuous_state_name=continuous_state_name, ) - if continuous_state_name is not None: - func_name = next( - ( - key - for key in state_space_functions - for name in [ - "continuous_state", - continuous_state_name, - ] - if f"get_next_period_{name}" in key - or f"get_next_{name}" in key - or f"update_{name}" in key - ), - None, + sparsity_condition = process_sparsity_condition(state_space_functions, options) + + return state_specific_choice_set, next_period_endogenous_state, sparsity_condition + + +def process_sparsity_condition(state_space_functions, options): + if "sparsity_condition" in state_space_functions.keys(): + sparsity_condition = determine_function_arguments_and_partial_options( + func=state_space_functions["sparsity_condition"], + options=options["model_params"], ) + # ToDo: Error if sparsity condition takes second continuous state as input + else: + print("Sparsity condition not provided. Assume all states are valid.") + + def sparsity_condition(**kwargs): + return True + + return sparsity_condition + + +def process_second_continuous_update_function( + continuous_state_name, state_space_functions, options +): + if continuous_state_name is not None: + func_name = f"next_period_{continuous_state_name}" - update_continuous_state = determine_function_arguments_and_partial_options( + next_period_continuous_state = determine_function_arguments_and_partial_options( func=state_space_functions[func_name], options=options["model_params"], continuous_state_name=continuous_state_name, ) else: - update_continuous_state = None + next_period_continuous_state = None - return get_state_specific_choice_set, get_next_period_state, update_continuous_state + return next_period_continuous_state def create_upper_envelope_function(options, continuous_state=None): diff --git a/src/dcegm/pre_processing/model_structure/__init__.py b/src/dcegm/pre_processing/model_structure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dcegm/pre_processing/model_structure/endogenous_states.py b/src/dcegm/pre_processing/model_structure/endogenous_states.py new file mode 100644 index 00000000..827a74ed --- /dev/null +++ b/src/dcegm/pre_processing/model_structure/endogenous_states.py @@ -0,0 +1,25 @@ +import numpy as np + +from dcegm.pre_processing.model_structure.shared import span_subspace +from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options + + +def process_endog_state_specifications(state_space_options): + """Get number of endog states which we loop over when creating the state space.""" + + if state_space_options.get("endogenous_states"): + + endog_states_names = list(state_space_options["endogenous_states"].keys()) + endog_state_space = span_subspace( + subdict_of_space=state_space_options["endogenous_states"], + states_names=endog_states_names, + ) + + else: + endog_states_names = [] + endog_state_space = np.array([[0]]) + + return ( + endog_state_space, + endog_states_names, + ) diff --git a/src/dcegm/pre_processing/exog_processes.py b/src/dcegm/pre_processing/model_structure/exogenous_processes.py similarity index 79% rename from src/dcegm/pre_processing/exog_processes.py rename to src/dcegm/pre_processing/model_structure/exogenous_processes.py index 193150bc..439a9d51 100644 --- a/src/dcegm/pre_processing/exog_processes.py +++ b/src/dcegm/pre_processing/model_structure/exogenous_processes.py @@ -1,8 +1,10 @@ from functools import partial from typing import Callable +import numpy as np from jax import numpy as jnp +from dcegm.pre_processing.model_structure.shared import span_subspace from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options @@ -83,3 +85,22 @@ def exog_state_mapping(exog_proc_state): return exog_state_dict return exog_state_mapping + + +def process_exog_model_specifications(state_space_options): + if "exogenous_processes" in state_space_options: + exog_state_names = list(state_space_options["exogenous_processes"].keys()) + dict_of_only_states = { + key: state_space_options["exogenous_processes"][key]["states"] + for key in exog_state_names + } + + exog_state_space = span_subspace( + subdict_of_space=dict_of_only_states, + states_names=exog_state_names, + ) + else: + exog_state_names = ["dummy_exog"] + exog_state_space = np.array([[0]], dtype=np.uint8) + + return exog_state_names, exog_state_space diff --git a/src/dcegm/pre_processing/model_structure/model_structure.py b/src/dcegm/pre_processing/model_structure/model_structure.py new file mode 100644 index 00000000..02ef354e --- /dev/null +++ b/src/dcegm/pre_processing/model_structure/model_structure.py @@ -0,0 +1,53 @@ +import jax +from jax import numpy as jnp + +from dcegm.pre_processing.model_structure.state_choice_space import ( + create_state_choice_space_and_child_state_mapping, +) +from dcegm.pre_processing.model_structure.state_space import create_state_space +from dcegm.pre_processing.shared import create_array_with_smallest_int_dtype + + +def create_model_structure( + options, + model_funcs, +): + """Create dictionary of discrete state and state-choice objects for each period. + + Args: + options (Dict[str, int]): Options dictionary. + + Returns: + dict of np.ndarray: Dictionary containing period-specific + state and state-choice objects, with the following keys: + - "state_choice_mat" (np.ndarray) + - "idx_state_of_state_choice" (np.ndarray) + - "reshape_state_choice_vec_to_mat" (callable) + - "transform_between_state_and_state_choice_vec" (callable) + + """ + print("Starting state space creation") + state_space_objects = create_state_space( + state_space_options=options["state_space"], + sparsity_condition=model_funcs["sparsity_condition"], + debugging=False, + ) + print("State space created.\n") + print("Starting state-choice space creation and child state mapping.") + + state_choice_and_child_state_objects = ( + create_state_choice_space_and_child_state_mapping( + state_space_options=options["state_space"], + state_specific_choice_set=model_funcs["state_specific_choice_set"], + next_period_endogenous_state=model_funcs["next_period_endogenous_state"], + state_space_arrays=state_space_objects, + ) + ) + state_space_objects.pop("map_child_state_to_index") + + model_structure = { + **state_space_objects, + **state_choice_and_child_state_objects, + "choice_range": jnp.asarray(options["state_space"]["choices"]), + } + return jax.tree.map(create_array_with_smallest_int_dtype, model_structure) diff --git a/src/dcegm/pre_processing/model_structure/shared.py b/src/dcegm/pre_processing/model_structure/shared.py new file mode 100644 index 00000000..19cd912e --- /dev/null +++ b/src/dcegm/pre_processing/model_structure/shared.py @@ -0,0 +1,35 @@ +import numpy as np + +from dcegm.pre_processing.shared import get_smallest_int_type + + +def create_indexer_for_space(space): + """Create indexer for space.""" + + # Indexer has always unsigned data type with integers starting at zero + # Leave one additional value for the invalid number + data_type = get_smallest_int_type(space.shape[0] + 1) + max_value = np.iinfo(data_type).max + + max_var_values = np.max(space, axis=0) + + map_vars_to_index = np.full( + max_var_values + 1, fill_value=max_value, dtype=data_type + ) + index_tuple = tuple(space[:, i] for i in range(space.shape[1])) + + map_vars_to_index[index_tuple] = np.arange(space.shape[0], dtype=data_type) + + return map_vars_to_index, max_value + + +def span_subspace(subdict_of_space, states_names): + """Span subspace and read information from dictionary.""" + # Retrieve all state arrays from the dictionary + states = [np.array(subdict_of_space[name]) for name in states_names] + + # Use np.meshgrid to get all combinations, then reshape and stack them + grids = np.meshgrid(*states, indexing="ij") + space = np.column_stack([grid.ravel() for grid in grids]) + + return space diff --git a/src/dcegm/pre_processing/model_structure/state_choice_space.py b/src/dcegm/pre_processing/model_structure/state_choice_space.py new file mode 100644 index 00000000..21fcebc3 --- /dev/null +++ b/src/dcegm/pre_processing/model_structure/state_choice_space.py @@ -0,0 +1,302 @@ +import numpy as np + +from dcegm.pre_processing.model_structure.shared import create_indexer_for_space +from dcegm.pre_processing.shared import get_smallest_int_type + + +def create_state_choice_space_and_child_state_mapping( + state_space_options, + state_specific_choice_set, + next_period_endogenous_state, + state_space_arrays, +): + """Create state choice space of all feasible state-choice combinations. + + Also conditional on any realization of exogenous processes. + + Args: + state_space (np.ndarray): 2d array of shape (n_states, n_state_vars + 1) + which serves as a collection of all possible states. By convention, + the first column must contain the period and the last column the + exogenous state. Any other state variables are in between. + E.g. if the two state variables are period and lagged choice and all choices + are admissible in each period, the shape of the state space array is + (n_periods * n_choices, 3). + map_state_to_state_space_index (np.ndarray): Indexer array that maps + a period-specific state vector to the respective index positions in the + state space. + The shape of this object is quite complicated. For each state variable it + has the number of potential states as rows, i.e. + (n_potential_states_state_var_1, n_potential_states_state_var_2, ....). + state_specific_choice_set (Callable): User-supplied function that returns + the set of feasible choices for a given state. + + Returns: + tuple: + + - state_choice_space(np.ndarray): 2d array of shape + (n_feasible_state_choice_combs, n_state_and_exog_variables + 1) containing + the space of all feasible state-choice combinations. By convention, + the second to last column contains the exogenous process. + The last column always contains the choice to be made at the end of the + period (which is not a state variable). + - map_state_choice_vec_to_parent_state (np.ndarray): 1d array of shape + (n_states * n_feasible_choices,) that maps from any vector of state-choice + combinations to the respective parent state. + - reshape_state_choice_vec_to_mat (np.ndarray): 2d array of shape + (n_states, n_feasible_choices). For each parent state, this array can be + used to reshape the vector of feasible state-choice combinations + to a matrix of lagged and current choice combinations of + shape (n_choices, n_choices). + - transform_between_state_and_state_choice_space (np.ndarray): 2d boolean + array of shape (n_states, n_states * n_feasible_choices) indicating which + state belongs to which state-choice combination in the entire state and + state choice space. The array is used to + (i) contract state-choice level arrays to the state level by summing + over state-choice combinations. + (ii) to expand state level arrays to the state-choice level. + + """ + + states_names_without_exog = state_space_arrays["state_names_without_exog"] + exog_state_names = state_space_arrays["exog_states_names"] + exog_state_space = state_space_arrays["exog_state_space"] + map_child_state_to_index = state_space_arrays["map_child_state_to_index"] + map_state_to_index = state_space_arrays["map_state_to_index"] + state_space = state_space_arrays["state_space"] + + n_states, n_state_and_exog_variables = state_space.shape + n_exog_states, n_exog_vars = exog_state_space.shape + n_choices = len(state_space_options["choices"]) + discrete_states_names = states_names_without_exog + exog_state_names + n_periods = state_space_options["n_periods"] + + dtype_exog_state_space = get_smallest_int_type(n_exog_states) + + # Get dtype and maxint for choices + dtype_choices = get_smallest_int_type(n_choices) + # Get dtype and max int for state space + state_space_dtype = state_space.dtype + + if np.iinfo(state_space_dtype).max > np.iinfo(dtype_choices).max: + state_choice_space_dtype = state_space_dtype + else: + state_choice_space_dtype = dtype_choices + + state_choice_space_raw = np.zeros( + (n_states * n_choices, n_state_and_exog_variables + 1), + dtype=state_choice_space_dtype, + ) + + state_space_indexer_dtype = map_child_state_to_index.dtype + invalid_indexer_idx = np.iinfo(state_space_indexer_dtype).max + + map_state_choice_to_parent_state = np.zeros( + (n_states * n_choices), dtype=state_space_indexer_dtype + ) + + map_state_choice_to_child_states = np.full( + (n_states * n_choices, n_exog_states), + fill_value=invalid_indexer_idx, + dtype=state_space_indexer_dtype, + ) + + exog_states_tuple = tuple(exog_state_space[:, i] for i in range(n_exog_vars)) + + idx = 0 + for state_vec in state_space: + state_idx = map_state_to_index[tuple(state_vec)] + + # Full state dictionary + this_period_state = { + key: state_vec[i] for i, key in enumerate(discrete_states_names) + } + + feasible_choice_set = state_specific_choice_set( + **this_period_state, + ) + + for choice in feasible_choice_set: + state_choice_space_raw[idx, :-1] = state_vec + state_choice_space_raw[idx, -1] = choice + + map_state_choice_to_parent_state[idx] = state_idx + + if state_vec[0] < n_periods - 1: + + endog_state_update = next_period_endogenous_state( + **this_period_state, choice=choice + ) + + check_endog_update_function( + endog_state_update, this_period_state, choice, exog_state_names + ) + + next_period_state = this_period_state.copy() + next_period_state.update(endog_state_update) + + next_period_state_tuple_wo_exog = tuple( + np.full( + n_exog_states, + fill_value=next_period_state[key], + dtype=dtype_exog_state_space, + ) + for key in states_names_without_exog + ) + + states_next_tuple = next_period_state_tuple_wo_exog + exog_states_tuple + + try: + child_idxs = map_child_state_to_index[states_next_tuple] + except: + raise IndexError( + f"\n\n The state \n\n{endog_state_update}\n\n is a child state of " + f"the state-choice combination \n\n{this_period_state}\n\n with choice: " + f"{choice}.\n\n The state variables are out of bounds for the defined state space " + f"Please check the possible state values in the state space definition." + ) + + invalid_child_state_idxs = np.where(child_idxs == invalid_indexer_idx)[ + 0 + ] + if len(invalid_child_state_idxs) > 0: + invalid_child_state_example = np.array(states_next_tuple).T[ + invalid_child_state_idxs[0] + ] + invalid_child_state_dict = { + key: invalid_child_state_example[i] + for i, key in enumerate(discrete_states_names) + } + raise IndexError( + f"\n\n The state \n\n{invalid_child_state_dict}\n\n is a child state of " + f"the state \n\n{this_period_state}\n\n with choice: {choice}.\n\n " + f"It is also declared invalid by the sparsity condition. Please " + f"remember, that if a state is invalid because it can't be reached by the deterministic" + f"update of states, this has to be reflected in the state space function next_period_endogenous_state." + f"If its exogenous state realization is invalid, this state has to be proxied to another state" + f"by the sparsity condition." + ) + + map_state_choice_to_child_states[idx, :] = child_idxs + + idx += 1 + + # Select only needed rows of arrays + state_choice_space = state_choice_space_raw[:idx] + map_state_choice_to_parent_state = map_state_choice_to_parent_state[:idx] + map_state_choice_to_child_states = map_state_choice_to_child_states[:idx, :] + + map_state_choice_to_index, _ = create_indexer_for_space(state_choice_space) + + state_choice_space_dict = { + key: state_choice_space[:, i] + for i, key in enumerate(discrete_states_names + ["choice"]) + } + + test_child_state_mapping( + state_space_options=state_space_options, + state_choice_space=state_choice_space, + state_space=state_space, + map_state_choice_to_child_states=map_state_choice_to_child_states, + discrete_states_names=discrete_states_names, + ) + + dict_of_state_choice_space_objects = { + "state_choice_space": state_choice_space, + "state_choice_space_dict": state_choice_space_dict, + "map_state_choice_to_index": map_state_choice_to_index, + "map_state_choice_to_parent_state": map_state_choice_to_parent_state, + "map_state_choice_to_child_states": map_state_choice_to_child_states, + } + + return dict_of_state_choice_space_objects + + +def test_child_state_mapping( + state_space_options, + state_choice_space, + state_space, + map_state_choice_to_child_states, + discrete_states_names, +): + """Test state space objects for consistency.""" + n_periods = state_space_options["n_periods"] + state_choices_idxs_wo_last = np.where(state_choice_space[:, 0] < n_periods - 1)[0] + + # Check if all feasible state choice combinations have a valid child state + idxs_child_states = map_state_choice_to_child_states[state_choices_idxs_wo_last, :] + + # Get dtype and max int for state space indexer + state_space_indexer_dtype = map_state_choice_to_child_states.dtype + invalid_state_space_idx = np.iinfo(state_space_indexer_dtype).max + + if np.any(idxs_child_states == invalid_state_space_idx): + # Get row axis of child states that are invalid + invalid_child_states = np.unique( + np.where(idxs_child_states == invalid_state_space_idx)[0] + ) + invalid_state_choices_example = state_choice_space[invalid_child_states[0]] + example_dict = { + key: invalid_state_choices_example[i] + for i, key in enumerate(discrete_states_names) + } + example_dict["choice"] = invalid_state_choices_example[-1] + raise ValueError( + f"\n\n\n\n Some state-choice combinations have invalid child " + f"states. Please update accordingly the deterministic law of motion or" + f"the proxy function." + f"\n \n An example of a combination of state and choice with " + f"invalid child states is: \n \n" + f"{example_dict} \n \n" + ) + + # Check if all states are a child states except the ones in the first period + idxs_states_except_first = np.where(state_space[:, 0] > 0)[0] + idxs_states_except_first_in_child_states = np.isin( + idxs_states_except_first, idxs_child_states + ) + if not np.all(idxs_states_except_first_in_child_states): + not_child_state_idxs = idxs_states_except_first[ + ~idxs_states_except_first_in_child_states + ] + not_child_state_example = state_space[not_child_state_idxs[0]] + example_dict = { + key: not_child_state_example[i] + for i, key in enumerate(discrete_states_names) + } + raise ValueError( + f"\n\n\n\n Some states are not child states of any state-choice " + f"combination or stochastic transition. Please revisit the sparsity condition. \n \n" + f"An example of a state that is not a child state is: \n \n" + f"{example_dict} \n \n" + ) + + +def check_endog_update_function( + endog_state_update, this_period_state, choice, exog_state_names +): + """Conduct several checks on the endogenous state update function.""" + if endog_state_update["period"] != this_period_state["period"] + 1: + raise ValueError( + f"\n\n The update function does not return the correct next period count." + f"An example of this update happens with the state choice combination: \n\n" + f"{this_period_state} \n\n" + ) + + if endog_state_update["lagged_choice"] != choice: + raise ValueError( + f"\n\n The update function does not return the correct lagged choice for a given choice." + f"An example of this update happens with the state choice combination: \n\n" + f"{this_period_state} \n\n" + ) + + # Check if exogenous state is updated. This is forbidden. + for exog_state_name in exog_state_names: + if exog_state_name in endog_state_update.keys(): + raise ValueError( + f"\n\n The exogenous state {exog_state_name} is also updated (or just returned)" + f"for in the endogenous update function. You can use the proxy function to implement" + f"a custom update rule, i.e. redirecting the exogenous process." + f"An example of this update happens with the state choice combination: \n\n" + f"{this_period_state} \n\n" + ) diff --git a/src/dcegm/pre_processing/model_structure/state_space.py b/src/dcegm/pre_processing/model_structure/state_space.py new file mode 100644 index 00000000..8a895f0c --- /dev/null +++ b/src/dcegm/pre_processing/model_structure/state_space.py @@ -0,0 +1,270 @@ +"""Functions for creating internal state space objects.""" + +import numpy as np +import pandas as pd + +from dcegm.pre_processing.model_structure.endogenous_states import ( + process_endog_state_specifications, +) +from dcegm.pre_processing.model_structure.exogenous_processes import ( + process_exog_model_specifications, +) +from dcegm.pre_processing.model_structure.shared import create_indexer_for_space +from dcegm.pre_processing.shared import create_array_with_smallest_int_dtype + + +def create_state_space(state_space_options, sparsity_condition, debugging=False): + """Create state space object and indexer. + + We need to add the convention for the state space objects. + + Args: + options (dict): Options dictionary. + + Returns: + Dict: + + - state_vars (list): List of state variables. + - state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1) + which serves as a collection of all possible states. By convention, + the first column must contain the period and the last column the + exogenous processes. Any other state variables are in between. + E.g. if the two state variables are period and lagged choice and all choices + are admissible in each period, the shape of the state space array is + (n_periods * n_choices, 3). + - map_state_to_index (np.ndarray): Indexer array that maps states to indexes. + The shape of this object is quite complicated. For each state variable it + has the number of possible states as rows, i.e. + (n_poss_states_state_var_1, n_poss_states_state_var_2, ....). + + """ + n_periods = state_space_options["n_periods"] + n_choices = len(state_space_options["choices"]) + + ( + endog_state_space, + endog_states_names, + ) = process_endog_state_specifications(state_space_options=state_space_options) + state_names_without_exog = ["period", "lagged_choice"] + endog_states_names + + ( + exog_states_names, + exog_state_space_raw, + ) = process_exog_model_specifications(state_space_options=state_space_options) + discrete_states_names = state_names_without_exog + exog_states_names + + n_exog_states = exog_state_space_raw.shape[0] + + state_space_list = [] + list_of_states_proxied_from = [] + list_of_states_proxied_to = [] + proxies_exist = False + + # For debugging we create some additional containers + full_state_space_list = [] + proxy_list = [] + valid_list = [] + + for period in range(n_periods): + for endog_state_id in range(endog_state_space.shape[0]): + for lagged_choice in range(n_choices): + # Select the endogenous state, if present + if len(endog_states_names) == 0: + endog_states = [] + else: + endog_states = list(endog_state_space[endog_state_id]) + + for exog_state_id in range(n_exog_states): + exog_states = exog_state_space_raw[exog_state_id, :] + + # Create the state vector + state = [period, lagged_choice] + endog_states + list(exog_states) + + full_state_space_list += [state] + + # Transform to dictionary to call sparsity function from user + state_dict = { + discrete_states_names[i]: state_value + for i, state_value in enumerate(state) + } + + # Check if the state is valid by calling the sparsity function + sparsity_output = sparsity_condition(**state_dict) + + # The sparsity condition can either return a boolean indicating if the state + # is valid or not, or a dictionary which contains the valid state which is used + # instead as a child state for other states. If a state is invalid because of the + # exogenous state component, the user must specify a valid state to use instead, as + # we assume a state choice combination has n_exog_states children. + # We do check later if the user correctly specified the proxy state. Here we just check + # the format of the output. To simplify this specification the user can also return the same + # state as used as input. Then the state is just valid. This allows to easier define a proxy + # state for a whole set of states. + if isinstance(sparsity_output, dict): + # Check if dictionary keys are the same + if set(sparsity_output.keys()) != set(discrete_states_names): + raise ValueError( + f" The state \n\n{sparsity_output}\n\n returned by the sparsity condition " + f"does not have the correct format. The dictionary keys should be the same as " + f"the discrete state names defined in the state space options. These are" + f": \n\n{discrete_states_names}\n\n." + ) + + # Check if each value is integer or array with dtype int + for key, value in sparsity_output.items(): + if isinstance(value, int) or np.issubdtype( + value.dtype, np.integer + ): + pass + else: + raise ValueError( + f"The value of the key {key} in the state \n\n{sparsity_output}\n\n" + f"returned by the sparsity condition is not of integer type." + ) + + # Now check if the state is actually the same as the input state + is_same_state = True + for key, value in sparsity_output.items(): + same_value = state_dict[key] == value + is_same_state &= same_value + + if is_same_state: + state_is_valid = True + proxy_list += [False] + else: + proxy_list += [True] + state_is_valid = False + proxies_exist = True + list_of_states_proxied_from += [state] + state_list_proxied_to = [ + sparsity_output[key] for key in discrete_states_names + ] + list_of_states_proxied_to += [state_list_proxied_to] + elif isinstance(sparsity_output, bool): + state_is_valid = sparsity_output + proxy_list += [False] + else: + raise ValueError( + f"The sparsity condition for the state \n\n{state_dict}\n\n" + f"returned an output of the wrong type. It should return either a boolean" + f"or a dictionary." + ) + + valid_list += [state_is_valid] + if state_is_valid: + state_space_list += [state] + + state_space_raw = np.array(state_space_list) + state_space = create_array_with_smallest_int_dtype(state_space_raw) + map_state_to_index, invalid_index = create_indexer_for_space(state_space) + + if proxies_exist: + # If proxies exist we create a different indexer, to map + # the child states of state choices later to proxied states + map_state_to_index_with_proxies = create_indexer_inclucing_proxies( + map_state_to_index, + list_of_states_proxied_from, + list_of_states_proxied_to, + discrete_states_names, + invalid_index, + ) + map_child_state_to_index = map_state_to_index_with_proxies + else: + map_child_state_to_index = map_state_to_index + + state_space_dict = { + key: create_array_with_smallest_int_dtype(state_space[:, i]) + for i, key in enumerate(discrete_states_names) + } + + exog_state_space = create_array_with_smallest_int_dtype(exog_state_space_raw) + + dict_of_state_space_objects = { + "state_space": state_space, + "state_space_dict": state_space_dict, + "map_state_to_index": map_state_to_index, + "map_child_state_to_index": map_child_state_to_index, + "exog_state_space": exog_state_space, + "exog_states_names": exog_states_names, + "state_names_without_exog": state_names_without_exog, + "discrete_states_names": discrete_states_names, + } + + # If debugging is called we create a dataframe with detailed information on + # full state space + if debugging: + state_space_full = np.array(full_state_space_list) + debug_df = pd.DataFrame(data=state_space_full, columns=discrete_states_names) + debug_df["is_valid"] = valid_list + debug_df["is_proxied"] = proxy_list + + if proxies_exist: + array_of_states_proxied_to = np.array(list_of_states_proxied_to) + tuple_of_states_proxied_from = tuple( + array_of_states_proxied_to[:, i] + for i in range(array_of_states_proxied_to.shape[1]) + ) + full_indexer, _ = create_indexer_for_space(state_space_full) + idxs_proxied_to = full_indexer[tuple_of_states_proxied_from] + debug_df["idxs_proxied_to"] = -9999 + debug_df.loc[debug_df["is_proxied"], "idxs_proxied_to"] = idxs_proxied_to + + return debug_df + + return dict_of_state_space_objects + + +def create_indexer_inclucing_proxies( + map_state_to_index, + list_of_states_proxied_from, + list_of_states_proxied_to, + discrete_state_names, + invalid_index, +): + """Create an indexer that includes the index of proxied invalid states.""" + array_of_states_proxied_from = np.array(list_of_states_proxied_from) + array_of_states_proxied_to = np.array(list_of_states_proxied_to) + + tuple_of_states_proxied_from = tuple( + array_of_states_proxied_from[:, i] + for i in range(array_of_states_proxied_from.shape[1]) + ) + tuple_of_states_proxied_to = tuple( + array_of_states_proxied_to[:, i] + for i in range(array_of_states_proxied_to.shape[1]) + ) + index_proxy_to = map_state_to_index[tuple_of_states_proxied_to] + invalid_proxy_idxs = np.where(index_proxy_to == invalid_index)[0] + if len(invalid_proxy_idxs) > 0: + example_state_proxy_to = array_of_states_proxied_to[invalid_proxy_idxs[0]] + invalid_state_dict_to = { + state_name: example_state_proxy_to[i] + for i, state_name in enumerate(discrete_state_names) + } + example_state_proxy_from = array_of_states_proxied_from[invalid_proxy_idxs[0]] + invalid_state_dict_from = { + state_name: example_state_proxy_from[i] + for i, state_name in enumerate(discrete_state_names) + } + + import sys + + RED = "\033[31m" # ANSI code for red + RESET = "\033[0m" # ANSI code to reset color + try: + raise ValueError( + f"\n\nThe state " + f"\n\n{pd.Series(invalid_state_dict_to).to_string()}\n\n" + f"is used as a proxy state for the state:" + f"\n\n{pd.Series(invalid_state_dict_from).to_string()}\n\n" + f"However, the proxy state is also declared invalid by " + "the sparsity condition. This is not allowed. The proxy state must be valid." + ) + except ValueError as e: + print(f"\n\n{RED}State space error:{RESET} {e}", file=sys.stderr) + sys.exit(1) # Exit without showing the traceback + + map_state_to_index_with_proxies = map_state_to_index.copy() + + map_state_to_index_with_proxies[tuple_of_states_proxied_from] = index_proxy_to + return map_state_to_index_with_proxies diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index b6021d28..3046deb0 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -2,16 +2,19 @@ from typing import Callable, Dict import jax -import jax.numpy as jnp - -from dcegm.pre_processing.batches import create_batches_and_information -from dcegm.pre_processing.exog_processes import create_exog_state_mapping -from dcegm.pre_processing.model_functions import process_model_functions -from dcegm.pre_processing.state_space import ( - check_options_and_set_defaults, - create_array_with_smallest_int_dtype, - create_discrete_state_space_and_choice_objects, + +from dcegm.pre_processing.batches.batch_creation import create_batches_and_information +from dcegm.pre_processing.check_options import check_options_and_set_defaults +from dcegm.pre_processing.model_functions import ( + process_model_functions, + process_sparsity_condition, +) +from dcegm.pre_processing.model_structure.exogenous_processes import ( + create_exog_state_mapping, ) +from dcegm.pre_processing.model_structure.model_structure import create_model_structure +from dcegm.pre_processing.model_structure.state_space import create_state_space +from dcegm.pre_processing.shared import create_array_with_smallest_int_dtype def setup_model( @@ -20,6 +23,7 @@ def setup_model( utility_functions_final_period: Dict[str, Callable], budget_constraint: Callable, state_space_functions: Dict[str, Callable] = None, + debug_output: str = None, ): """Set up the model for dcegm. @@ -46,6 +50,10 @@ def setup_model( budget_constraint (Callable): User supplied budget constraint. """ + debug_output = process_debug_string(debug_output, state_space_functions, options) + if debug_output is not None: + return debug_output + options = check_options_and_set_defaults(options) model_funcs = process_model_functions( @@ -56,7 +64,7 @@ def setup_model( budget_constraint=budget_constraint, ) - model_structure = create_discrete_state_space_and_choice_objects( + model_structure = create_model_structure( options=options, model_funcs=model_funcs, ) @@ -66,13 +74,17 @@ def setup_model( model_structure["exog_states_names"], ) + print("State, state-choice and child state mapping created.\n") + print("Start creating batches for the model.") + batch_info = create_batches_and_information( model_structure=model_structure, - options=options, + state_space_options=options["state_space"], ) # Delete large array which is not needed model_structure.pop("map_state_choice_to_child_states") + print("Model setup complete.\n") return { "options": options, "model_funcs": model_funcs, @@ -96,7 +108,6 @@ def setup_and_save_model( than recreating the model from scratch. """ - model = setup_model( options=options, state_space_functions=state_space_functions, @@ -142,3 +153,16 @@ def load_and_setup_model( ) return model + + +def process_debug_string(debug_output, state_space_functions, options): + if debug_output is not None: + if debug_output == "state_space_df": + sparsity_condition = process_sparsity_condition( + state_space_functions, options + ) + return create_state_space( + options["state_space"], sparsity_condition, debugging=True + ) + else: + raise ValueError("The requested debug output is not implemented.") diff --git a/src/dcegm/pre_processing/shared.py b/src/dcegm/pre_processing/shared.py index af085d4d..01cbd0eb 100644 --- a/src/dcegm/pre_processing/shared.py +++ b/src/dcegm/pre_processing/shared.py @@ -2,6 +2,9 @@ import inspect from functools import partial +import numpy as np +from jax import numpy as jnp + def determine_function_arguments_and_partial_options( func, options, continuous_state_name=None @@ -34,3 +37,22 @@ def partial_options_and_update_signature(func, signature, options): signature = signature - {"options"} return func, signature + + +def create_array_with_smallest_int_dtype(arr): + """Return array with the smallest unsigned integer dtype.""" + if isinstance(arr, (np.ndarray, jnp.ndarray)) and np.issubdtype( + arr.dtype, np.integer + ): + return arr.astype(get_smallest_int_type(arr.max())) + + return arr + + +def get_smallest_int_type(n_values): + """Return the smallest unsigned integer type that can hold n_values.""" + uint_types = [np.uint8, np.uint16, np.uint32, np.uint64] + + for dtype in uint_types: + if np.iinfo(dtype).max >= n_values: + return dtype diff --git a/src/dcegm/pre_processing/state_space.py b/src/dcegm/pre_processing/state_space.py deleted file mode 100644 index fd2dc866..00000000 --- a/src/dcegm/pre_processing/state_space.py +++ /dev/null @@ -1,630 +0,0 @@ -"""Functions for creating internal state space objects.""" - -import jax -import jax.numpy as jnp -import numpy as np - -from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options - - -def create_discrete_state_space_and_choice_objects( - options, - model_funcs, -): - """Create dictionary of discrete state and state-choice objects for each period. - - Args: - options (Dict[str, int]): Options dictionary. - - Returns: - dict of np.ndarray: Dictionary containing period-specific - state and state-choice objects, with the following keys: - - "state_choice_mat" (np.ndarray) - - "idx_state_of_state_choice" (np.ndarray) - - "reshape_state_choice_vec_to_mat" (callable) - - "transform_between_state_and_state_choice_vec" (callable) - - """ - - ( - state_space, - state_space_dict, - map_state_to_index, - states_names_without_exog, - exog_states_names, - exog_state_space, - ) = create_state_space(options) - - ( - state_choice_space, - map_state_choice_to_index, - map_state_choice_to_parent_state, - map_state_choice_to_child_states, - ) = create_state_choice_space( - state_space_options=options["state_space"], - state_space=state_space, - states_names_without_exog=states_names_without_exog, - exog_state_names=exog_states_names, - exog_state_space=exog_state_space, - map_state_to_index=map_state_to_index, - get_state_specific_choice_set=model_funcs["get_state_specific_choice_set"], - get_next_period_state=model_funcs["get_next_period_state"], - ) - - state_choice_space_dict = { - key: state_choice_space[:, i] - for i, key in enumerate( - states_names_without_exog + exog_states_names + ["choice"] - ) - } - - test_state_space_objects( - state_space_options=options["state_space"], - state_choice_space=state_choice_space, - map_state_choice_to_child_states=map_state_choice_to_child_states, - discrete_states_names=states_names_without_exog + exog_states_names, - ) - - model_structure = { - "state_space": state_space, - "choice_range": jnp.asarray(options["state_space"]["choices"]), - "state_space_dict": state_space_dict, - "map_state_to_index": map_state_to_index, - "exog_state_space": exog_state_space, - "states_names_without_exog": states_names_without_exog, - "exog_states_names": exog_states_names, - "discrete_states_names": states_names_without_exog + exog_states_names, - "state_choice_space": state_choice_space, - "state_choice_space_dict": state_choice_space_dict, - "map_state_choice_to_index": map_state_choice_to_index, - "map_state_choice_to_parent_state": map_state_choice_to_parent_state, - "map_state_choice_to_child_states": map_state_choice_to_child_states, - } - return jax.tree.map(create_array_with_smallest_int_dtype, model_structure) - - -def test_state_space_objects( - state_space_options, - state_choice_space, - map_state_choice_to_child_states, - discrete_states_names, -): - """Test state space objects for consistency.""" - n_periods = state_space_options["n_periods"] - state_choices_idxs_wo_last = np.where(state_choice_space[:, 0] < n_periods - 1)[0] - - # Check if all feasible state choice combinations have a valid child state - child_states = map_state_choice_to_child_states[state_choices_idxs_wo_last, :] - - # Get dtype and max int for state space indexer - state_space_indexer_dtype = map_state_choice_to_child_states.dtype - invalid_state_space_idx = np.iinfo(state_space_indexer_dtype).max - - if np.any(child_states == invalid_state_space_idx): - # Get row axis of child states that are invalid - invalid_child_states = np.unique( - np.where(child_states == invalid_state_space_idx)[0] - ) - invalid_state_choices_example = state_choice_space[invalid_child_states[0]] - example_dict = { - key: invalid_state_choices_example[i] - for i, key in enumerate(discrete_states_names) - } - example_dict["choice"] = invalid_state_choices_example[-1] - raise ValueError( - f"\n\n\n\n Some state-choice combinations have invalid child " - f"states. " - f"\n \n An example of a combination of state and choice with " - f"invalid child states is: \n \n" - f"{example_dict} \n \n" - ) - - -def create_state_space(options): - """Create state space object and indexer. - - We need to add the convention for the state space objects. - - Args: - options (dict): Options dictionary. - - Returns: - tuple: - - - state_vars (list): List of state variables. - - state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1) - which serves as a collection of all possible states. By convention, - the first column must contain the period and the last column the - exogenous processes. Any other state variables are in between. - E.g. if the two state variables are period and lagged choice and all choices - are admissible in each period, the shape of the state space array is - (n_periods * n_choices, 3). - - map_state_to_index (np.ndarray): Indexer array that maps states to indexes. - The shape of this object is quite complicated. For each state variable it - has the number of possible states as rows, i.e. - (n_poss_states_state_var_1, n_poss_states_state_var_2, ....). - - """ - - state_space_options = options["state_space"] - model_params = options["model_params"] - - n_periods = state_space_options["n_periods"] - n_choices = len(state_space_options["choices"]) - - ( - add_endog_state_func, - endog_states_names, - n_endog_states, - sparsity_func, - ) = process_endog_state_specifications( - state_space_options=state_space_options, model_params=model_params - ) - - ( - exog_states_names, - exog_state_space_raw, - ) = process_exog_model_specifications(state_space_options=state_space_options) - states_names_without_exog = ["period", "lagged_choice"] + endog_states_names - - state_space_wo_exog_list = [] - - for period in range(n_periods): - for endog_state_id in range(n_endog_states): - for lagged_choice in range(n_choices): - # Select the endogenous state combination - endog_states = add_endog_state_func(endog_state_id) - - # Create the state vector without the exogenous processes - state_without_exog = [period, lagged_choice] + endog_states - - # Transform to dictionary to call sparsity function from user - state_dict_without_exog = { - states_names_without_exog[i]: state_value - for i, state_value in enumerate(state_without_exog) - } - - # Check if the state is valid by calling the sparsity function - is_state_valid = sparsity_func(**state_dict_without_exog) - if not is_state_valid: - continue - else: - state_space_wo_exog_list += [state_without_exog] - - n_exog_states = exog_state_space_raw.shape[0] - - state_space_wo_exog = np.array(state_space_wo_exog_list) - state_space_wo_exog_full = np.repeat(state_space_wo_exog, n_exog_states, axis=0) - exog_state_space_full = np.tile( - exog_state_space_raw, (state_space_wo_exog.shape[0], 1) - ) - state_space_raw = np.concatenate( - (state_space_wo_exog_full, exog_state_space_full), axis=1 - ) - - state_space = create_array_with_smallest_int_dtype(state_space_raw) - map_state_to_index = create_indexer_for_space(state_space) - - state_space_dict = { - key: create_array_with_smallest_int_dtype(state_space[:, i]) - for i, key in enumerate(states_names_without_exog + exog_states_names) - } - - exog_state_space = create_array_with_smallest_int_dtype(exog_state_space_raw) - - return ( - state_space, - state_space_dict, - map_state_to_index, - states_names_without_exog, - exog_states_names, - exog_state_space, - ) - - -def create_state_choice_space( - state_space_options, - state_space, - exog_state_space, - states_names_without_exog, - exog_state_names, - map_state_to_index, - get_state_specific_choice_set, - get_next_period_state, -): - """Create state choice space of all feasible state-choice combinations. - - Also conditional on any realization of exogenous processes. - - Args: - state_space (np.ndarray): 2d array of shape (n_states, n_state_vars + 1) - which serves as a collection of all possible states. By convention, - the first column must contain the period and the last column the - exogenous state. Any other state variables are in between. - E.g. if the two state variables are period and lagged choice and all choices - are admissible in each period, the shape of the state space array is - (n_periods * n_choices, 3). - map_state_to_state_space_index (np.ndarray): Indexer array that maps - a period-specific state vector to the respective index positions in the - state space. - The shape of this object is quite complicated. For each state variable it - has the number of potential states as rows, i.e. - (n_potential_states_state_var_1, n_potential_states_state_var_2, ....). - get_state_specific_choice_set (Callable): User-supplied function that returns - the set of feasible choices for a given state. - - Returns: - tuple: - - - state_choice_space(np.ndarray): 2d array of shape - (n_feasible_state_choice_combs, n_state_and_exog_variables + 1) containing - the space of all feasible state-choice combinations. By convention, - the second to last column contains the exogenous process. - The last column always contains the choice to be made at the end of the - period (which is not a state variable). - - map_state_choice_vec_to_parent_state (np.ndarray): 1d array of shape - (n_states * n_feasible_choices,) that maps from any vector of state-choice - combinations to the respective parent state. - - reshape_state_choice_vec_to_mat (np.ndarray): 2d array of shape - (n_states, n_feasible_choices). For each parent state, this array can be - used to reshape the vector of feasible state-choice combinations - to a matrix of lagged and current choice combinations of - shape (n_choices, n_choices). - - transform_between_state_and_state_choice_space (np.ndarray): 2d boolean - array of shape (n_states, n_states * n_feasible_choices) indicating which - state belongs to which state-choice combination in the entire state and - state choice space. The array is used to - (i) contract state-choice level arrays to the state level by summing - over state-choice combinations. - (ii) to expand state level arrays to the state-choice level. - - """ - n_states, n_state_and_exog_variables = state_space.shape - n_exog_states, n_exog_vars = exog_state_space.shape - n_choices = len(state_space_options["choices"]) - discrete_states_names = states_names_without_exog + exog_state_names - n_periods = state_space_options["n_periods"] - - dtype_exog_state_space = get_smallest_int_type(n_exog_states) - - # Get dtype and maxint for choices - dtype_choices = get_smallest_int_type(n_choices) - # Get dtype and max int for state space - state_space_dtype = state_space.dtype - - if np.iinfo(state_space_dtype).max > np.iinfo(dtype_choices).max: - state_choice_space_dtype = state_space_dtype - else: - state_choice_space_dtype = dtype_choices - - state_choice_space_raw = np.zeros( - (n_states * n_choices, n_state_and_exog_variables + 1), - dtype=state_choice_space_dtype, - ) - - state_space_indexer_dtype = map_state_to_index.dtype - invalid_indexer_idx = np.iinfo(state_space_indexer_dtype).max - - map_state_choice_to_parent_state = np.zeros( - (n_states * n_choices), dtype=state_space_indexer_dtype - ) - - map_state_choice_to_child_states = np.full( - (n_states * n_choices, n_exog_states), - fill_value=invalid_indexer_idx, - dtype=state_space_indexer_dtype, - ) - - exog_states_tuple = tuple(exog_state_space[:, i] for i in range(n_exog_vars)) - - idx = 0 - for state_vec in state_space: - state_idx = map_state_to_index[tuple(state_vec)] - - # Full state dictionary - state_dict = {key: state_vec[i] for i, key in enumerate(discrete_states_names)} - - feasible_choice_set = get_state_specific_choice_set( - **state_dict, - ) - - for choice in feasible_choice_set: - state_choice_space_raw[idx, :-1] = state_vec - state_choice_space_raw[idx, -1] = choice - - map_state_choice_to_parent_state[idx] = state_idx - - if state_vec[0] < n_periods - 1: - - # Current state without exog - state_dict_without_exog = { - key: state_dict[key] for key in states_names_without_exog - } - - endog_state_update = get_next_period_state( - **state_dict_without_exog, choice=choice - ) - - state_dict_without_exog.update(endog_state_update) - - states_next_tuple = ( - tuple( - np.full( - n_exog_states, - fill_value=state_dict_without_exog[key], - dtype=dtype_exog_state_space, - ) - for key in states_names_without_exog - ) - + exog_states_tuple - ) - - try: - child_idxs = map_state_to_index[states_next_tuple] - except: - raise IndexError( - f"\n\n The state \n\n{endog_state_update}\n\n is reached as a " - f"child state from an existing state, but does not exist for " - f"some values of the exogenous processes. Please check if it " - f"should not be reached or should exist by adapting the " - f"sparsity condition and/or the set of possible state values." - ) - - map_state_choice_to_child_states[idx, :] = child_idxs - - idx += 1 - - state_choice_space = state_choice_space_raw[:idx] - map_state_choice_to_index = create_indexer_for_space(state_choice_space) - - return ( - state_choice_space, - map_state_choice_to_index, - map_state_choice_to_parent_state[:idx], - map_state_choice_to_child_states[:idx, :], - ) - - -def process_exog_model_specifications(state_space_options): - if "exogenous_processes" in state_space_options: - exog_state_names = list(state_space_options["exogenous_processes"].keys()) - dict_of_only_states = { - key: state_space_options["exogenous_processes"][key]["states"] - for key in exog_state_names - } - - exog_state_space = span_subspace_and_read_information( - subdict_of_space=dict_of_only_states, - states_names=exog_state_names, - ) - else: - exog_state_names = ["dummy_exog"] - exog_state_space = np.array([[0]], dtype=np.uint8) - - return exog_state_names, exog_state_space - - -def span_subspace_and_read_information(subdict_of_space, states_names): - """Span subspace and read information from dictionary.""" - # Retrieve all state arrays from the dictionary - states = [np.array(subdict_of_space[name]) for name in states_names] - - # Use np.meshgrid to get all combinations, then reshape and stack them - grids = np.meshgrid(*states, indexing="ij") - space = np.column_stack([grid.ravel() for grid in grids]) - - return space - - -def process_endog_state_specifications(state_space_options, model_params): - """Get number of endog states which we loop over when creating the state space.""" - - # if "endogenous_states" in state_space_options: - # if ( - # "endogenous_states" in state_space_options - # and isinstance(state_space_options["endogenous_states"], dict) - # and state_space_options["endogenous_states"] - # ): - if state_space_options.get("endogenous_states"): - - endog_state_keys = state_space_options["endogenous_states"].keys() - - if "sparsity_condition" in state_space_options["endogenous_states"].keys(): - endog_states_names = list(set(endog_state_keys) - {"sparsity_condition"}) - sparsity_cond_specified = True - else: - sparsity_cond_specified = False - endog_states_names = list(endog_state_keys) - - endog_state_space = span_subspace_and_read_information( - subdict_of_space=state_space_options["endogenous_states"], - states_names=endog_states_names, - ) - n_endog_states = endog_state_space.shape[0] - - else: - endog_states_names = [] - n_endog_states = 1 - - endog_state_space = None - sparsity_cond_specified = False - - sparsity_func = select_sparsity_function( - sparsity_cond_specified=sparsity_cond_specified, - state_space_options=state_space_options, - model_params=model_params, - ) - - endog_states_add_func = create_endog_state_add_function(endog_state_space) - - return ( - endog_states_add_func, - endog_states_names, - n_endog_states, - sparsity_func, - ) - - -def select_sparsity_function( - sparsity_cond_specified, state_space_options, model_params -): - if sparsity_cond_specified: - sparsity_func = determine_function_arguments_and_partial_options( - func=state_space_options["endogenous_states"]["sparsity_condition"], - options=model_params, - ) - else: - - def sparsity_func(**kwargs): - return True - - return sparsity_func - - -def create_endog_state_add_function(endog_state_space): - if endog_state_space is None: - - def add_endog_states(id_endog_state): - return [] - - else: - - def add_endog_states(id_endog_state): - return list(endog_state_space[id_endog_state]) - - return add_endog_states - - -def create_indexer_for_space(space): - """Create indexer for space.""" - - # Indexer has always unsigned data type with integers starting at zero - # Leave one additional value for the invalid number - data_type = get_smallest_int_type(space.shape[0] + 1) - max_value = np.iinfo(data_type).max - - max_var_values = np.max(space, axis=0) - - map_vars_to_index = np.full( - max_var_values + 1, fill_value=max_value, dtype=data_type - ) - index_tuple = tuple(space[:, i] for i in range(space.shape[1])) - - map_vars_to_index[index_tuple] = np.arange(space.shape[0], dtype=data_type) - - return map_vars_to_index - - -def check_options_and_set_defaults(options): - """Check if options are valid and set defaults.""" - - if not isinstance(options, dict): - raise ValueError("Options must be a dictionary.") - - if "state_space" not in options: - raise ValueError("Options must contain a state space dictionary.") - - if not isinstance(options["state_space"], dict): - raise ValueError("State space must be a dictionary.") - - if "n_periods" not in options["state_space"]: - raise ValueError("State space must contain the number of periods.") - - if not isinstance(options["state_space"]["n_periods"], int): - raise ValueError("Number of periods must be an integer.") - - if "choices" not in options["state_space"]: - print("Choices not given. Assume only single choice with value 0") - options["state_space"]["choices"] = np.array([0], dtype=np.uint8) - - if "model_params" not in options: - raise ValueError("Options must contain a model parameters dictionary.") - - if not isinstance(options["model_params"], dict): - raise ValueError("Model parameters must be a dictionary.") - - if "n_choices" not in options["model_params"]: - options["model_params"]["n_choices"] = len(options["state_space"]["choices"]) - - n_savings_grid_points = len(options["state_space"]["continuous_states"]["wealth"]) - options["n_wealth_grid"] = n_savings_grid_points - - if "tuning_params" not in options: - options["tuning_params"] = {} - - options["tuning_params"]["extra_wealth_grid_factor"] = ( - options["tuning_params"]["extra_wealth_grid_factor"] - if "extra_wealth_grid_factor" in options["tuning_params"] - else 0.2 - ) - options["tuning_params"]["n_constrained_points_to_add"] = ( - options["tuning_params"]["n_constrained_points_to_add"] - if "n_constrained_points_to_add" in options["tuning_params"] - else n_savings_grid_points // 10 - ) - - if ( - n_savings_grid_points - * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) - < n_savings_grid_points - + options["tuning_params"]["n_constrained_points_to_add"] - ): - raise ValueError( - f"""\n\n - When preparing the tuning parameters for the upper - envelope, we found the following contradicting parameters: \n - The extra wealth grid factor of {options["tuning_params"]["extra_wealth_grid_factor"]} is too small - to cover the {options["tuning_params"]["n_constrained_points_to_add"]} wealth points which are added in - the credit constrained part of the wealth grid. \n\n""" - ) - options["tuning_params"]["n_total_wealth_grid"] = int( - n_savings_grid_points - * (1 + options["tuning_params"]["extra_wealth_grid_factor"]) - ) - - exog_grids = options["state_space"]["continuous_states"].copy() - - if len(options["state_space"]["continuous_states"]) == 2: - second_continuous_state = next( - ( - {key: value} - for key, value in options["state_space"]["continuous_states"].items() - if key != "wealth" - ), - None, - ) - - second_continuous_state_name = list(second_continuous_state.keys())[0] - options["second_continuous_state_name"] = second_continuous_state_name - - options["tuning_params"]["n_second_continuous_grid"] = len( - second_continuous_state[second_continuous_state_name] - ) - - exog_grids["second_continuous"] = options["state_space"]["continuous_states"][ - second_continuous_state_name - ] - exog_grids.pop(second_continuous_state_name) - - options["exog_grids"] = exog_grids - - return options - - -def create_array_with_smallest_int_dtype(arr): - """Return array with the smallest unsigned integer dtype.""" - if isinstance(arr, (np.ndarray, jnp.ndarray)) and np.issubdtype( - arr.dtype, np.integer - ): - return arr.astype(get_smallest_int_type(arr.max())) - - return arr - - -def get_smallest_int_type(n_values): - """Return the smallest unsigned integer type that can hold n_values.""" - uint_types = [np.uint8, np.uint16, np.uint32, np.uint64] - - for dtype in uint_types: - if np.iinfo(dtype).max >= n_values: - return dtype diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index 3ee76161..9f31036a 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -149,7 +149,7 @@ def transition_to_next_period( discrete_endog_states_next_period = vmap( update_discrete_states_for_one_agent, in_axes=(None, 0, 0, None) # choice )( - compute_next_period_states["get_next_period_state"], + compute_next_period_states["next_period_endogenous_state"], discrete_states_beginning_of_period, choice, params, @@ -174,7 +174,7 @@ def transition_to_next_period( continuous_state_beginning_of_period=continuous_state_beginning_of_period, params=params, compute_continuous_state=compute_next_period_states[ - "update_continuous_state" + "next_period_continuous_state" ], ) @@ -228,7 +228,7 @@ def update_discrete_states_for_one_agent(update_func, state, choice, params): return update_func(**state, choice=choice, params=params) -def update_continuous_state_for_one_agent( +def next_period_continuous_state_for_one_agent( update_func, discrete_states, continuous_state, choice, params ): diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index 5a82f950..4d6632f1 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -78,8 +78,8 @@ def simulate_all_periods( model_funcs_sim = model_sim["model_funcs"] compute_next_period_states = { - "get_next_period_state": model_funcs_sim["get_next_period_state"], - "update_continuous_state": model_funcs_sim["update_continuous_state"], + "next_period_endogenous_state": model_funcs_sim["next_period_endogenous_state"], + "next_period_continuous_state": model_funcs_sim["next_period_continuous_state"], } simulate_body = partial( diff --git a/src/dcegm/solve.py b/src/dcegm/solve.py index 43996382..0ee15972 100644 --- a/src/dcegm/solve.py +++ b/src/dcegm/solve.py @@ -12,13 +12,13 @@ from dcegm.final_periods import solve_last_two_periods from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.numerical_integration import quadrature_legendre -from dcegm.pre_processing.params import process_params +from dcegm.pre_processing.check_params import process_params from dcegm.pre_processing.setup_model import setup_model from dcegm.solve_single_period import solve_single_period def solve_dcegm( - params: pd.DataFrame, + params: Dict, options: Dict, utility_functions: Dict[str, Callable], utility_functions_final_period: Dict[str, Callable], @@ -251,7 +251,7 @@ def backward_induction( income_shock_weights=income_shock_weights, exog_grids=exog_grids, model_funcs=model_funcs, - batch_info=batch_info, + last_two_period_batch_info=batch_info["last_two_period_info"], value_solved=value_solved, policy_solved=policy_solved, endog_grid_solved=endog_grid_solved, @@ -275,52 +275,55 @@ def partial_single_period(carry, xs): taste_shock_scale=taste_shock_scale, ) - carry_start = ( - value_solved, - policy_solved, - endog_grid_solved, - ) + for id_segment in range(batch_info["n_segments"]): + segment_info = batch_info[f"batches_info_segment_{id_segment}"] - final_carry, _ = jax.lax.scan( - f=partial_single_period, - init=carry_start, - xs=( - batch_info["batches_state_choice_idx"], - batch_info["child_state_choices_to_aggr_choice"], - batch_info["child_states_to_integrate_exog"], - batch_info["child_state_choice_idxs_to_interp"], - batch_info["child_states_idxs"], - batch_info["state_choices"], - batch_info["state_choices_childs"], - ), - ) + carry_start = ( + value_solved, + policy_solved, + endog_grid_solved, + ) - if not batch_info["batches_cover_all"]: - last_batch_info = batch_info["last_batch_info"] - extra_final_carry, () = partial_single_period( - carry=final_carry, + final_carry, _ = jax.lax.scan( + f=partial_single_period, + init=carry_start, xs=( - last_batch_info["state_choice_idx"], - last_batch_info["child_state_choices_to_aggr_choice"], - last_batch_info["child_states_to_integrate_exog"], - last_batch_info["child_state_choice_idxs_to_interp"], - last_batch_info["child_states_idxs"], - last_batch_info["state_choices"], - last_batch_info["state_choices_childs"], + segment_info["batches_state_choice_idx"], + segment_info["child_state_choices_to_aggr_choice"], + segment_info["child_states_to_integrate_exog"], + segment_info["child_state_choice_idxs_to_interp"], + segment_info["child_states_idxs"], + segment_info["state_choices"], + segment_info["state_choices_childs"], ), ) - ( - value_solved, - policy_solved, - endog_grid_solved, - ) = extra_final_carry - else: - ( - value_solved, - policy_solved, - endog_grid_solved, - ) = final_carry + if not segment_info["batches_cover_all"]: + last_batch_info = segment_info["last_batch_info"] + extra_final_carry, () = partial_single_period( + carry=final_carry, + xs=( + last_batch_info["state_choice_idx"], + last_batch_info["child_state_choices_to_aggr_choice"], + last_batch_info["child_states_to_integrate_exog"], + last_batch_info["child_state_choice_idxs_to_interp"], + last_batch_info["child_states_idxs"], + last_batch_info["state_choices"], + last_batch_info["state_choices_childs"], + ), + ) + + ( + value_solved, + policy_solved, + endog_grid_solved, + ) = extra_final_carry + else: + ( + value_solved, + policy_solved, + endog_grid_solved, + ) = final_carry return ( value_solved, diff --git a/src/toy_models/cons_ret_model_dcegm_paper/state_space_objects.py b/src/toy_models/cons_ret_model_dcegm_paper/state_space_objects.py index 9f71b874..cdbbfafb 100644 --- a/src/toy_models/cons_ret_model_dcegm_paper/state_space_objects.py +++ b/src/toy_models/cons_ret_model_dcegm_paper/state_space_objects.py @@ -13,7 +13,7 @@ def create_state_space_function_dict(): """ return { - "get_state_specific_choice_set": get_state_specific_feasible_choice_set, + "state_specific_choice_set": get_state_specific_feasible_choice_set, } diff --git a/src/toy_models/cons_ret_model_with_cont_exp/state_space_objects.py b/src/toy_models/cons_ret_model_with_cont_exp/state_space_objects.py index aa521564..d3387613 100644 --- a/src/toy_models/cons_ret_model_with_cont_exp/state_space_objects.py +++ b/src/toy_models/cons_ret_model_with_cont_exp/state_space_objects.py @@ -11,12 +11,12 @@ def create_state_space_function_dict(): """ return { - "get_state_specific_choice_set": get_state_specific_feasible_choice_set, - "update_continuous_state": get_next_period_experience, + "state_specific_choice_set": get_state_specific_feasible_choice_set, + "next_period_experience": next_period_experience, } -def get_next_period_experience(period, lagged_choice, experience, options): +def next_period_experience(period, lagged_choice, experience, options): max_experience_period = period + options["max_init_experience"] return (1 / max_experience_period) * ( diff --git a/src/toy_models/cons_ret_model_with_exp/state_space_objects.py b/src/toy_models/cons_ret_model_with_exp/state_space_objects.py index c506a9c0..f274291b 100644 --- a/src/toy_models/cons_ret_model_with_exp/state_space_objects.py +++ b/src/toy_models/cons_ret_model_with_exp/state_space_objects.py @@ -11,12 +11,13 @@ def create_state_space_function_dict(): """ return { - "get_state_specific_choice_set": get_state_specific_feasible_choice_set, - "get_next_period_state": get_next_period_state, + "state_specific_choice_set": get_state_specific_feasible_choice_set, + "next_period_endogenous_state": next_period_endogenous_state, + "sparsity_condition": sparsity_condition, } -def get_next_period_state(period, choice, experience): +def next_period_endogenous_state(period, choice, experience): """Update state with experience.""" next_state = {} @@ -36,13 +37,19 @@ def sparsity_condition( max_exp_period = period + options["max_init_experience"] max_total_experience = options["n_periods"] + options["max_init_experience"] + # Experience must be smaller than the maximum experience in a period if max_exp_period < experience: return False + # Experience must be smaller than the maximum total experience elif max_total_experience <= experience: return False + # If experience is the maximum experience in a period, you must have been working last period elif (experience == max_exp_period) & (lagged_choice == 1): return False - elif (lagged_choice == 0) & (experience == 0): + # As retirement is absorbing, if you have been working last period + # your experience must be at least as big as the period as you + # had to been working all periods before + elif (lagged_choice == 0) & (experience < period): return False else: return True diff --git a/src/toy_models/load_example_model.py b/src/toy_models/load_example_model.py index 61cdfe15..819da8c4 100644 --- a/src/toy_models/load_example_model.py +++ b/src/toy_models/load_example_model.py @@ -19,7 +19,6 @@ def load_example_models(model_name): "utility_functions": crm_paper.create_utility_function_dict(), "final_period_utility_functions": crm_paper.create_final_period_utility_function_dict(), "budget_constraint": crm_exp.budget_constraint_exp, - "sparsity_condition": crm_exp.sparsity_condition, } elif model_name == "with_cont_exp": diff --git a/tests/conftest.py b/tests/conftest.py index 6001d7eb..6f100b17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,20 +58,20 @@ def pytest_sessionfinish(session, exitstatus): # noqa: ARG001 @pytest.fixture(scope="session") -def load_example_model(): - def load_options_and_params(model): +def load_replication_params_and_specs(): + def load_options_and_params(model_name): """Return parameters and options of an example model.""" - params = pd.read_csv( - REPLICATION_TEST_RESOURCES_DIR / f"{model}" / "params.csv", - index_col=["category", "name"], + params = yaml.safe_load( + ( + REPLICATION_TEST_RESOURCES_DIR / f"{model_name}" / "params.yaml" + ).read_text() ) - params = ( - params.reset_index()[["name", "value"]].set_index("name")["value"].to_dict() + model_specs = yaml.safe_load( + ( + REPLICATION_TEST_RESOURCES_DIR / f"{model_name}" / "options.yaml" + ).read_text() ) - options = yaml.safe_load( - (REPLICATION_TEST_RESOURCES_DIR / f"{model}" / "options.yaml").read_text() - ) - return params, options + return params, model_specs return load_options_and_params diff --git a/tests/resources/replication_tests/deaton/params.csv b/tests/resources/replication_tests/deaton/params.csv deleted file mode 100644 index 54006547..00000000 --- a/tests/resources/replication_tests/deaton/params.csv +++ /dev/null @@ -1,14 +0,0 @@ -category,name,value,comment -beta,beta,0.95,discount factor -delta,delta,0,disutility of work -utility_function,rho,1,CRRA coefficient -wage,constant,0.75,age-independent labor income -wage,exp,0.04,return to experience -wage,exp_squared,-0.0004,return to experience squared -shocks,sigma,0.25,shock on labor income sigma parameter/standard deviation -shocks,lambda,2.2204e-16,taste shock (scale) parameter -assets,interest_rate,0.05,interest rate on capital -assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation) -assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation) -assets,max_wealth,75,maximum level of wealth -assets,consumption_floor,0.0,consumption floor/retirement safety net (only relevant in the dc-egm retirement model) diff --git a/tests/resources/replication_tests/deaton/params.yaml b/tests/resources/replication_tests/deaton/params.yaml new file mode 100644 index 00000000..1a237b0a --- /dev/null +++ b/tests/resources/replication_tests/deaton/params.yaml @@ -0,0 +1,15 @@ +--- +beta: 0.95 # discount factor +delta: 0 # disutility of work +rho: 1 # CRRA coefficient +constant: 0.75 # age-independent labor income +exp: 0.04 # return to experience +exp_squared: -0.0004 # return to experience squared +sigma: 0.25 # shock on labor income sigma parameter/standard deviation +lambda: 2.2204e-16 # taste shock (scale) parameter +interest_rate: 0.05 # interest rate on capital +initial_wealth_low: 0 # lowest level of initial wealth (relevant for simulation) +initial_wealth_high: 30 # highest level of initial wealth (relevant for simulation) +max_wealth: 75 # maximum level of wealth +# consumption floor/retirement safety net (only relevant in the dc-egm retirement model) +consumption_floor: 0.0 diff --git a/tests/resources/replication_tests/retirement_no_taste_shocks/params.csv b/tests/resources/replication_tests/retirement_no_taste_shocks/params.csv deleted file mode 100644 index 1d90dabd..00000000 --- a/tests/resources/replication_tests/retirement_no_taste_shocks/params.csv +++ /dev/null @@ -1,14 +0,0 @@ -category,name,value,comment -beta,beta,0.95,discount factor -delta,delta,0.35,disutility of work -utility_function,rho,1.95,CRRA coefficient -wage,constant,0.75,age-independent labor income -wage,exp,0.04,return to experience -wage,exp_squared,-0.0002,return to experience squared -shocks,sigma,0.00,shock on labor income sigma parameter/standard deviation -shocks,lambda,2.2204e-16,taste shock (scale) parameter -assets,interest_rate,0.05,interest rate on capital -assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation) -assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation) -assets,max_wealth,50,maximum level of wealth -assets,consumption_floor,0.001,consumption floor/retirement safety net (only relevant in the dc-egm retirement model) diff --git a/tests/resources/replication_tests/retirement_no_taste_shocks/params.yaml b/tests/resources/replication_tests/retirement_no_taste_shocks/params.yaml new file mode 100644 index 00000000..3f2c16a1 --- /dev/null +++ b/tests/resources/replication_tests/retirement_no_taste_shocks/params.yaml @@ -0,0 +1,15 @@ +--- +beta: 0.95 # discount factor +delta: 0.35 # disutility of work +rho: 1.95 # CRRA coefficient +constant: 0.75 # age-independent labor income +exp: 0.04 # return to experience +exp_squared: -0.0002 # return to experience squared +sigma: 0.00 # shock on labor income sigma parameter/standard deviation +lambda: 2.2204e-16 # taste shock (scale) parameter +interest_rate: 0.05 # interest rate on capital +initial_wealth_low: 0 # lowest level of initial wealth (relevant for simulation) +initial_wealth_high: 30 # highest level of initial wealth (relevant for simulation) +max_wealth: 50 # maximum level of wealth +# consumption floor/retirement safety net (only relevant in the dc-egm retirement model) +consumption_floor: 0.001 diff --git a/tests/resources/replication_tests/retirement_taste_shocks/params.csv b/tests/resources/replication_tests/retirement_taste_shocks/params.csv deleted file mode 100644 index 897e8856..00000000 --- a/tests/resources/replication_tests/retirement_taste_shocks/params.csv +++ /dev/null @@ -1,14 +0,0 @@ -category,name,value,comment -beta,beta,0.9523809523809523,discount factor -delta,delta,0.35,disutility of work -utility_function,rho,1.95,CRRA coefficient -wage,constant,0.75,age-independent labor income -wage,exp,0.04,return to experience -wage,exp_squared,-0.0002,return to experience squared -shocks,sigma,0.35,shock on labor income sigma parameter/standard deviation -shocks,lambda,0.2,taste shock (scale) parameter -assets,interest_rate,0.05,interest rate on capital -assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation) -assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation) -assets,max_wealth,50,maximum level of wealth -assets,consumption_floor,0.001,consumption floor/retirement safety net (only relevant in the dc-egm retirement model) diff --git a/tests/resources/replication_tests/retirement_taste_shocks/params.yaml b/tests/resources/replication_tests/retirement_taste_shocks/params.yaml new file mode 100644 index 00000000..0c5d78cd --- /dev/null +++ b/tests/resources/replication_tests/retirement_taste_shocks/params.yaml @@ -0,0 +1,15 @@ +--- +beta: 0.9523809523809523 # discount factor +delta: 0.35 # disutility of work +rho: 1.95 # CRRA coefficient +constant: 0.75 # age-independent labor income +exp: 0.04 # return to experience +exp_squared: -0.0002 # return to experience squared +sigma: 0.35 # shock on labor income sigma parameter/standard deviation +lambda: 0.2 # taste shock (scale) parameter +interest_rate: 0.05 # interest rate on capital +initial_wealth_low: 0 # lowest level of initial wealth (relevant for simulation) +initial_wealth_high: 30 # highest level of initial wealth (relevant for simulation) +max_wealth: 50 # maximum level of wealth +# consumption floor/retirement safety net (only relevant in the dc-egm retirement model) +consumption_floor: 0.001 diff --git a/tests/sandbox/jax_timeit_large_toy_model.ipynb b/tests/sandbox/jax_timeit_large_toy_model.ipynb index 1be115b9..ccf151b3 100644 --- a/tests/sandbox/jax_timeit_large_toy_model.ipynb +++ b/tests/sandbox/jax_timeit_large_toy_model.ipynb @@ -150,8 +150,8 @@ "state_space_functions = {\n", " # \"create_state_space\": _create_state_space_custom,\n", " \"create_state_space\": create_state_space,\n", - " \"get_state_specific_choice_set\": get_state_specific_feasible_choice_set,\n", - " \"get_next_period_state\": update_state,\n", + " \"state_specific_choice_set\": get_state_specific_feasible_choice_set,\n", + " \"next_period_endogenous_state\": update_state,\n", "}" ] }, diff --git a/tests/test_biased_sim.py b/tests/test_biased_sim.py index 0779013a..c6c6dd8c 100644 --- a/tests/test_biased_sim.py +++ b/tests/test_biased_sim.py @@ -68,8 +68,8 @@ def state_space_options(): return state_space_options -def test_sim_and_sol_model(state_space_options, load_example_model): - params, model_params = load_example_model("retirement_taste_shocks") +def test_sim_and_sol_model(state_space_options, load_replication_params_and_specs): + params, model_specs = load_replication_params_and_specs("retirement_taste_shocks") params["married_util"] = 0.5 model_funcs = load_example_models("dcegm_paper") @@ -78,7 +78,7 @@ def test_sim_and_sol_model(state_space_options, load_example_model): options_sol = { "state_space": state_space_options["solution"], - "model_params": model_params, + "model_params": model_specs, } model_sol = setup_model( @@ -94,7 +94,7 @@ def test_sim_and_sol_model(state_space_options, load_example_model): options_sim = { "state_space": state_space_options["simulation"], - "model_params": model_params, + "model_params": model_specs, } marriage_trans_mat = jnp.array([[0.3, 0.7], [0.1, 0.9]]) options_sim["model_params"]["marriage_trans_mat"] = marriage_trans_mat diff --git a/tests/test_changing_choice_set.py b/tests/test_changing_choice_set.py index fbb74b0d..3c0a576a 100644 --- a/tests/test_changing_choice_set.py +++ b/tests/test_changing_choice_set.py @@ -47,12 +47,19 @@ def sparsity_condition(period, lagged_choice, experience, options): if period == 0 and lagged_choice != 0: return False # Starting from second we check if choice was in last periods full choice set - if period > 0 and lagged_choice not in choice_set(period - 1, 1): + elif (period > 0) and lagged_choice not in choice_set(period - 1, 1): return False # Filter states with too high experience - if (experience > period) or (experience > options["max_experience"]): + elif (experience > period) or (experience > options["max_experience"]): return False - return True + # If experience is 0 you can not have been working last period + elif (experience == 0) and (lagged_choice == 1): + return False + # If experience is equal to period you must have been working last period (periods larger than 0) + elif (experience == period) and (period > 0) and (lagged_choice != 1): + return False + else: + return True @pytest.fixture @@ -98,7 +105,6 @@ def test_model(): "choices": np.arange(3), "endogenous_states": { "experience": np.arange(5), - "sparsity_condition": sparsity_condition, }, "continuous_states": { "wealth": np.linspace(0, 500, 100), @@ -136,8 +142,9 @@ def next_period_state(period, choice, experience): def state_space_functions(): """Return dict with state space functions.""" out = { - "get_state_specific_choice_set": choice_set, - "get_next_period_state": next_period_state, + "state_specific_choice_set": choice_set, + "next_period_endogenous_state": next_period_state, + "sparsity_condition": sparsity_condition, } return out diff --git a/tests/test_discrete_versus_continuous_experience.py b/tests/test_discrete_versus_continuous_experience.py index e41d335d..32ebacac 100644 --- a/tests/test_discrete_versus_continuous_experience.py +++ b/tests/test_discrete_versus_continuous_experience.py @@ -65,7 +65,6 @@ def test_setup(): ), "endogenous_states": { "experience": np.arange(N_PERIODS + MAX_INIT_EXPERIENCE), - "sparsity_condition": model_funcs_discr_exp["sparsity_condition"], }, "continuous_states": { "wealth": jnp.linspace( @@ -185,17 +184,16 @@ def test_replication_discrete_versus_continuous_experience( state_choice_cont_dict["dummy_exog"], state_choice_cont_dict["choice"], ] - state_specific_choice_set = model_disc["model_funcs"][ - "get_state_specific_choice_set" - ](**state_choice_disc_dict) + state_specific_choice_set = model_disc["model_funcs"]["state_specific_choice_set"]( + **state_choice_disc_dict + ) choice_valid = choice in state_specific_choice_set - sparsity_condition = load_example_models("with_exp")["sparsity_condition"] + sparsity_condition = model_disc["model_funcs"]["sparsity_condition"] state_valid = sparsity_condition( - period, - experience, - lagged_choice, - model_disc["options"]["model_params"], + period=period, + experience=experience, + lagged_choice=lagged_choice, ) if state_valid & choice_valid: diff --git a/tests/test_exog_processes.py b/tests/test_exog_processes.py index d035b9ce..01293cc0 100644 --- a/tests/test_exog_processes.py +++ b/tests/test_exog_processes.py @@ -7,12 +7,13 @@ import pytest from numpy.testing import assert_almost_equal as aaae -from dcegm.pre_processing.exog_processes import create_exog_state_mapping +from dcegm.pre_processing.check_options import check_options_and_set_defaults from dcegm.pre_processing.model_functions import process_model_functions -from dcegm.pre_processing.state_space import ( - check_options_and_set_defaults, - create_discrete_state_space_and_choice_objects, +from dcegm.pre_processing.model_structure.exogenous_processes import ( + create_exog_state_mapping, ) +from dcegm.pre_processing.model_structure.model_structure import create_model_structure +from dcegm.pre_processing.setup_model import setup_model from toy_models.cons_ret_model_dcegm_paper.budget_constraint import budget_constraint from toy_models.cons_ret_model_dcegm_paper.state_space_objects import ( create_state_space_function_dict, @@ -148,17 +149,16 @@ def test_exog_processes( } options = check_options_and_set_defaults(options) - model_funcs = process_model_functions( + + model = setup_model( options, state_space_functions=create_state_space_function_dict(), utility_functions=create_utility_function_dict(), utility_functions_final_period=create_final_period_utility_function_dict(), budget_constraint=budget_constraint, ) - model_structure = create_discrete_state_space_and_choice_objects( - options=options, - model_funcs=model_funcs, - ) + model_funcs = model["model_funcs"] + model_structure = model["model_structure"] exog_state_mapping = create_exog_state_mapping( model_structure["exog_state_space"].astype(np.int16), diff --git a/tests/test_law_of_motion.py b/tests/test_law_of_motion.py index 2cd2bf28..716128db 100644 --- a/tests/test_law_of_motion.py +++ b/tests/test_law_of_motion.py @@ -11,7 +11,7 @@ from scipy.stats import norm from dcegm.law_of_motion import calculate_continuous_state -from dcegm.pre_processing.params import process_params +from dcegm.pre_processing.check_params import process_params from toy_models.cons_ret_model_dcegm_paper.budget_constraint import budget_constraint # ===================================================================================== @@ -75,7 +75,7 @@ def _transform_lagged_choice_to_working_hours(lagged_choice): return not_working * 0 + part_time * 2000 + full_time * 3000 -def _update_continuous_state(period, lagged_choice, continuous_state, params): +def _next_period_continuous_state(period, lagged_choice, continuous_state, params): working_hours = _transform_lagged_choice_to_working_hours(lagged_choice) @@ -100,9 +100,14 @@ def _update_continuous_state(period, lagged_choice, continuous_state, params): "model, period, labor_choice, max_wealth, n_grid_points", TEST_CASES ) def test_get_beginning_of_period_wealth( - model, period, labor_choice, max_wealth, n_grid_points, load_example_model + model, + period, + labor_choice, + max_wealth, + n_grid_points, + load_replication_params_and_specs, ): - params, options = load_example_model(f"{model}") + params, options = load_replication_params_and_specs(f"{model}") params["part_time"] = -1 params = process_params(params) @@ -153,13 +158,13 @@ def test_get_beginning_of_period_wealth( "model, max_wealth, n_grid_points", TEST_CASES_SECOND_CONTINUOUS ) def test_wealth_and_second_continuous_state( - model, max_wealth, n_grid_points, load_example_model + model, max_wealth, n_grid_points, load_replication_params_and_specs ): # parametrize over number of experience points n_exp_points = 10 - params, options = load_example_model(f"{model}") + params, options = load_replication_params_and_specs(f"{model}") options["working_hours_max"] = 3000 params["part_time"] = -1 @@ -178,7 +183,7 @@ def test_wealth_and_second_continuous_state( } update_experience_vectorized = vmap( - lambda period, lagged_choice: _update_continuous_state( + lambda period, lagged_choice: _next_period_continuous_state( period, lagged_choice, experience_grid, options ) ) @@ -187,7 +192,7 @@ def test_wealth_and_second_continuous_state( ) exp_next = calculate_continuous_state( - child_state_dict, experience_grid, params, _update_continuous_state + child_state_dict, experience_grid, params, _next_period_continuous_state ) aaae(exp_next, experience_next) diff --git a/tests/test_pre_processing.py b/tests/test_pre_processing.py index 2263c5d4..b201c2bf 100644 --- a/tests/test_pre_processing.py +++ b/tests/test_pre_processing.py @@ -3,15 +3,15 @@ import pytest from jax import vmap +from dcegm.pre_processing.check_options import check_options_and_set_defaults +from dcegm.pre_processing.check_params import process_params from dcegm.pre_processing.model_functions import process_model_functions -from dcegm.pre_processing.params import process_params from dcegm.pre_processing.setup_model import ( load_and_setup_model, setup_and_save_model, setup_model, ) from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options -from dcegm.pre_processing.state_space import check_options_and_set_defaults from toy_models.cons_ret_model_dcegm_paper.budget_constraint import budget_constraint from toy_models.cons_ret_model_dcegm_paper.state_space_objects import ( create_state_space_function_dict, @@ -45,8 +45,8 @@ def util_wrap(state_dict, params, util_func): return util_func(**state_dict, params=params) -def test_wrap_function(load_example_model): - params, _raw_options = load_example_model("deaton") +def test_wrap_function(load_replication_params_and_specs): + params, _raw_options = load_replication_params_and_specs("deaton") options = {} options["model_params"] = _raw_options @@ -101,9 +101,9 @@ def test_wrap_function(load_example_model): ) def test_missing_parameter( model_name, - load_example_model, + load_replication_params_and_specs, ): - params, _ = load_example_model(f"{model_name}") + params, _ = load_replication_params_and_specs(f"{model_name}") params.pop("interest_rate") params.pop("sigma") @@ -129,10 +129,10 @@ def test_missing_parameter( ) def test_load_and_save_model( model_name, - load_example_model, + load_replication_params_and_specs, ): options = {} - _params, _raw_options = load_example_model(f"{model_name}") + _params, _raw_options = load_replication_params_and_specs(f"{model_name}") options["model_params"] = _raw_options options["model_params"]["n_choices"] = _raw_options["n_discrete_choices"] @@ -252,7 +252,7 @@ def test_second_continuous_state(period, lagged_choice, continuous_state): params = {} state_space_functions = create_state_space_function_dict() - state_space_functions["get_next_period_experience"] = get_next_experience + state_space_functions["next_period_experience"] = get_next_experience options = check_options_and_set_defaults(options) @@ -264,9 +264,9 @@ def test_second_continuous_state(period, lagged_choice, continuous_state): budget_constraint=budget_constraint, ) - update_continuous_state = model_funcs["update_continuous_state"] + next_period_continuous_state = model_funcs["next_period_continuous_state"] - got = update_continuous_state( + got = next_period_continuous_state( period=period, lagged_choice=lagged_choice, continuous_state=continuous_state, diff --git a/tests/test_replication.py b/tests/test_replication.py index 93e10728..2a38fc1e 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -11,23 +11,15 @@ interpolate_policy_and_value_on_wealth_grid, linear_interpolation_with_extrapolation, ) -from toy_models.cons_ret_model_dcegm_paper.budget_constraint import budget_constraint -from toy_models.cons_ret_model_dcegm_paper.state_space_objects import ( - create_state_space_function_dict, -) -from toy_models.cons_ret_model_dcegm_paper.utility_functions import ( - create_final_period_utility_function_dict, - create_utility_function_dict, -) from toy_models.cons_ret_model_dcegm_paper.utility_functions_log_crra import ( utiility_log_crra, utiility_log_crra_final_consume_all, ) +from toy_models.load_example_model import load_example_models # Obtain the test directory of the package TEST_DIR = Path(__file__).parent -# Directory with additional resources for the testing harness REPLICATION_TEST_RESOURCES_DIR = TEST_DIR / "resources" / "replication_tests" @@ -39,18 +31,15 @@ "deaton", ], ) -def test_benchmark_models( - model_name, - load_example_model, -): +def test_benchmark_models(model_name, load_replication_params_and_specs): + params, model_specs = load_replication_params_and_specs(model_name) options = {} - params, _raw_options = load_example_model(f"{model_name}") - options["model_params"] = _raw_options - options["model_params"]["n_choices"] = _raw_options["n_discrete_choices"] + options["model_params"] = model_specs + options["model_params"]["n_choices"] = model_specs["n_discrete_choices"] options["state_space"] = { "n_periods": 25, - "choices": [i for i in range(_raw_options["n_discrete_choices"])], + "choices": [i for i in range(model_specs["n_discrete_choices"])], "continuous_states": { "wealth": jnp.linspace( 0, @@ -60,31 +49,30 @@ def test_benchmark_models( }, } - utility_functions = create_utility_function_dict() - utility_functions_final_period = create_final_period_utility_function_dict() + model_funcs = load_example_models("dcegm_paper") if model_name == "deaton": - state_space_functions = None - utility_functions["utility"] = utiility_log_crra - utility_functions_final_period["utility"] = utiility_log_crra_final_consume_all - else: - state_space_functions = create_state_space_function_dict() + model_funcs["state_space_functions"] = None + model_funcs["utility_functions"]["utility"] = utiility_log_crra + model_funcs["final_period_utility_functions"][ + "utility" + ] = utiility_log_crra_final_consume_all model = setup_model( options=options, - state_space_functions=state_space_functions, - utility_functions=utility_functions, - utility_functions_final_period=utility_functions_final_period, - budget_constraint=budget_constraint, + state_space_functions=model_funcs["state_space_functions"], + utility_functions=model_funcs["utility_functions"], + utility_functions_final_period=model_funcs["final_period_utility_functions"], + budget_constraint=model_funcs["budget_constraint"], ) - value, policy, endog_grid, *_ = solve_dcegm( - params, - options, - state_space_functions=state_space_functions, - utility_functions=utility_functions, - utility_functions_final_period=utility_functions_final_period, - budget_constraint=budget_constraint, + value, policy, endog_grid = solve_dcegm( + params=params, + options=options, + state_space_functions=model_funcs["state_space_functions"], + utility_functions=model_funcs["utility_functions"], + utility_functions_final_period=model_funcs["final_period_utility_functions"], + budget_constraint=model_funcs["budget_constraint"], ) policy_expected = pickle.load( diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 2f906226..6708ae05 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -16,7 +16,7 @@ simulate_single_period, ) from toy_models.cons_ret_model_with_cont_exp.state_space_objects import ( - get_next_period_experience, + next_period_experience, ) @@ -87,7 +87,7 @@ def test_simulate_lax_scan(model_setup): map_state_choice_to_index = model_structure["map_state_choice_to_index"] exog_state_mapping = model_funcs["exog_state_mapping"] - get_next_period_state = model_funcs["get_next_period_state"] + next_period_endogenous_state = model_funcs["next_period_endogenous_state"] value = model_setup["value"] policy = model_setup["policy"] @@ -120,7 +120,9 @@ def test_simulate_lax_scan(model_setup): "compute_beginning_of_period_wealth" ], exog_state_mapping=exog_state_mapping, - compute_next_period_states={"get_next_period_state": get_next_period_state}, + compute_next_period_states={ + "next_period_endogenous_state": next_period_endogenous_state + }, ) # a) lax.scan @@ -215,9 +217,9 @@ def test_simulate_second_continuous_choice(model_setup): model["options"]["state_space"]["continuous_states"]["experience"] = jnp.linspace( 0, 1, 6 ) - model["model_funcs"]["update_continuous_state"] = ( + model["model_funcs"]["next_period_continuous_state"] = ( determine_function_arguments_and_partial_options( - func=get_next_period_experience, + func=next_period_experience, options=model["options"]["model_params"], continuous_state_name="experience", ) diff --git a/tests/test_simulate_continuous_state.py b/tests/test_simulate_continuous_state.py index eba4720c..662dc513 100644 --- a/tests/test_simulate_continuous_state.py +++ b/tests/test_simulate_continuous_state.py @@ -58,7 +58,6 @@ def test_setup(): ), "endogenous_states": { "experience": np.arange(N_PERIODS + MAX_INIT_EXPERIENCE), - "sparsity_condition": model_funcs_discr_exp["sparsity_condition"], }, "continuous_states": { "wealth": jnp.linspace( diff --git a/tests/test_sparse_exog_and_batch_sep.py b/tests/test_sparse_exog_and_batch_sep.py new file mode 100644 index 00000000..d898605e --- /dev/null +++ b/tests/test_sparse_exog_and_batch_sep.py @@ -0,0 +1,178 @@ +import pickle +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytest +from numpy.testing import assert_array_almost_equal as aaae + +from dcegm.pre_processing.setup_model import setup_model +from dcegm.solve import get_solve_func_for_model, solve_dcegm +from toy_models.load_example_model import load_example_models + + +def utility_with_exog(consumption, health, partner, params): + utility_consumption = consumption ** (1 - params["rho"]) / (1 - params["rho"]) + utility_health = (1 - health) * params["health_disutil"] + utility_partner = partner * params["partner_util"] + return utility_consumption - utility_health + utility_partner + + +def health_transition(period, health, params): + prob_good_health = ( + health * params["good_to_good"] + (1 - health) * params["bad_to_good"] + ) + # After period 20 you always transition to bad + prob_good_health = jax.lax.select(period < 20, prob_good_health, 0.0) + return jnp.array([1 - prob_good_health, prob_good_health]) + + +def partner_transition(period, partner, params): + prob_married = (1 - partner) * params["single_to_married"] + partner * params[ + "married_to_married" + ] + # After period 15 you always transition to married + prob_married = jax.lax.select(period < 15, prob_married, 1.0) + return jnp.array([1 - prob_married, prob_married]) + + +def sparsity_condition(period, lagged_choice, health, education, partner): + # If period is larger than 15 you can not be single + + if period < 20: + if period > 14 and partner == 0: + return { + "period": period, + "lagged_choice": lagged_choice, + "education": education, + "health": health, + "partner": 1, + } + else: + return True + else: + if (health == 0) & (partner == 1): + return True + else: + return { + "period": period, + "lagged_choice": lagged_choice, + "education": education, + "health": 0, + "partner": 1, + } + + +def test_benchmark_models(load_replication_params_and_specs): + params, model_specs = load_replication_params_and_specs("retirement_taste_shocks") + params_update = { + "health_disutil": 0.1, + "good_to_good": 0.8, + "bad_to_good": 0.1, + "single_to_married": 0.1, + "married_to_married": 0.9, + } + params = {**params, **params_update} + + options = {} + + options["model_params"] = model_specs + options["model_params"]["n_choices"] = model_specs["n_discrete_choices"] + options["state_space"] = { + "n_periods": 25, + "choices": np.arange(2, dtype=int), + "endogenous_states": { + "education": np.arange(2, dtype=int), + }, + "exogenous_processes": { + "health": { + "states": np.arange(2, dtype=int), + "transition": health_transition, + }, + "partner": { + "states": np.arange(2, dtype=int), + "transition": partner_transition, + }, + }, + "continuous_states": { + "wealth": jnp.linspace( + 0, + options["model_params"]["max_wealth"], + options["model_params"]["n_grid_points"], + ) + }, + } + + model_funcs = load_example_models("dcegm_paper") + + model_full = setup_model( + options=options, + state_space_functions=model_funcs["state_space_functions"], + utility_functions=model_funcs["utility_functions"], + utility_functions_final_period=model_funcs["final_period_utility_functions"], + budget_constraint=model_funcs["budget_constraint"], + ) + + model_funcs_sparse = model_funcs.copy() + model_funcs_sparse["sparsity_condition"] = sparsity_condition + + model_sparse = setup_model( + options=options, + state_space_functions=model_funcs_sparse["state_space_functions"], + utility_functions=model_funcs["utility_functions"], + utility_functions_final_period=model_funcs["final_period_utility_functions"], + budget_constraint=model_funcs["budget_constraint"], + ) + + value_full, policy_full, endog_grid_full = get_solve_func_for_model(model_full)( + params + ) + value_sparse, policy_sparse, endog_grid_sparse = get_solve_func_for_model( + model_sparse + )(params) + + state_choices_sparse = model_sparse["model_structure"]["state_choice_space"] + state_choice_space_tuple_sparse = tuple( + state_choices_sparse[:, i] for i in range(state_choices_sparse.shape[1]) + ) + full_idxs = model_full["model_structure"]["map_state_choice_to_index"][ + state_choice_space_tuple_sparse + ] + + aaae(endog_grid_full[full_idxs], endog_grid_sparse) + aaae(value_full[full_idxs], value_sparse) + aaae(policy_full[full_idxs], policy_sparse) + + options_sep_once = options.copy() + options_sep_once["state_space"]["min_period_batch_segments"] = 20 + + value_sep_1, policy_sep_1, endog_grid_sep_1 = solve_dcegm( + params=params, + options=options_sep_once, + state_space_functions=model_funcs_sparse["state_space_functions"], + utility_functions=model_funcs["utility_functions"], + utility_functions_final_period=model_funcs["final_period_utility_functions"], + budget_constraint=model_funcs["budget_constraint"], + ) + + aaae(endog_grid_full[full_idxs], endog_grid_sep_1) + aaae(value_full[full_idxs], value_sep_1) + aaae(policy_full[full_idxs], policy_sep_1) + + options_sep_twice = options.copy() + options_sep_twice["state_space"]["min_period_batch_segments"] = [15, 20] + + value_sep_2, policy_sep_2, endog_grid_sep_2 = solve_dcegm( + params=params, + options=options_sep_twice, + state_space_functions=model_funcs_sparse["state_space_functions"], + utility_functions=model_funcs["utility_functions"], + utility_functions_final_period=model_funcs["final_period_utility_functions"], + budget_constraint=model_funcs["budget_constraint"], + ) + + aaae(endog_grid_full[full_idxs], endog_grid_sep_2) + aaae(value_full[full_idxs], value_sep_2) + aaae(policy_full[full_idxs], policy_sep_2) diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 25972dd1..29dd73c0 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -4,17 +4,18 @@ import numpy as np import pytest -from dcegm.pre_processing.debugging import inspect_state_space -from dcegm.pre_processing.state_space import create_state_space +from dcegm.pre_processing.model_functions import process_sparsity_condition +from dcegm.pre_processing.model_structure.state_space import create_state_space +from dcegm.pre_processing.setup_model import setup_model from toy_models.cons_ret_model_dcegm_paper.state_space_objects import ( get_state_specific_feasible_choice_set, ) @pytest.fixture() -def options(load_example_model): +def options(load_replication_params_and_specs): """Return options dictionary.""" - _, _raw_options = load_example_model("retirement_no_taste_shocks") + _, _raw_options = load_replication_params_and_specs("retirement_no_taste_shocks") _raw_options["n_choices"] = 2 options = {} @@ -212,7 +213,6 @@ def test_state_space(): "experience": np.arange(n_periods, dtype=int), "policy_state": np.arange(36, dtype=int), "retirement_age_id": np.arange(10, dtype=int), - "sparsity_condition": sparsity_condition, }, "continuous_states": {"wealth": np.linspace(0, 50, 100)}, }, @@ -226,23 +226,29 @@ def test_state_space(): }, } + state_space_functions = { + "sparsity_condition": sparsity_condition, + } + + processed_sparsity_condition = process_sparsity_condition( + options=options_sparse, state_space_functions=state_space_functions + ) + state_space_test, _ = create_state_space_test(options_sparse["model_params"]) - ( - state_space, - state_space_dict, - map_state_to_index, - states_names_without_exog, - exog_states_names, - exog_state_space, - ) = create_state_space(options=options_sparse) + dict_of_state_space_objects = create_state_space( + state_space_options=options_sparse["state_space"], + sparsity_condition=processed_sparsity_condition, + ) + + state_space = dict_of_state_space_objects["state_space"] + discrete_states_names = dict_of_state_space_objects["discrete_states_names"] # The dcegm package create the state vector in the order of the dictionary keys. # How these are ordered is not clear ex ante. state_space_sums_test = state_space_test.sum(axis=0) state_space_sums = state_space.sum(axis=0) state_space_sum_dict = { - key: state_space_sums[i] - for i, key in enumerate(states_names_without_exog + exog_states_names) + key: state_space_sums[i] for i, key in enumerate(discrete_states_names) } np.testing.assert_allclose(state_space_sum_dict["period"], state_space_sums_test[0]) @@ -263,8 +269,15 @@ def test_state_space(): ) ### Now test the inspection function. - state_space_df = inspect_state_space(options=options_sparse) - admissible_df = state_space_df[state_space_df["is_feasible"]] + state_space_df = setup_model( + options=options_sparse, + utility_functions=None, + utility_functions_final_period=None, + budget_constraint=None, + state_space_functions=state_space_functions, + debug_output="state_space_df", + ) + admissible_df = state_space_df[state_space_df["is_valid"]] - for i, column in enumerate(states_names_without_exog + exog_states_names): + for i, column in enumerate(discrete_states_names): np.testing.assert_allclose(admissible_df[column].values, state_space[:, i]) diff --git a/tests/test_two_period_continuous_experience.py b/tests/test_two_period_continuous_experience.py index dc69aa3a..0e5b94a0 100644 --- a/tests/test_two_period_continuous_experience.py +++ b/tests/test_two_period_continuous_experience.py @@ -105,7 +105,7 @@ def marginal_utility_weighted( params, ): """Return the expected marginal utility for one realization of the wage shock.""" - exp_new = get_next_period_experience( + exp_new = next_period_experience( period=1, lagged_choice=lagged_choice, experience=experience, params=params ) @@ -218,7 +218,7 @@ def calc_stochastic_income( return jnp.exp(labor_income + wage_shock) -def get_next_period_experience(period, lagged_choice, experience, params): +def next_period_experience(period, lagged_choice, experience, params): return (1 / period) * ((period - 1) * experience + (lagged_choice == 0)) @@ -263,7 +263,7 @@ def create_test_inputs(): # ================================================================================= state_space_functions = { - "update_continuous_state": get_next_period_experience, + "next_period_experience": next_period_experience, } model = setup_model( @@ -281,7 +281,7 @@ def create_test_inputs(): taste_shock_scale, exog_grids_cont, model_funcs_cont, - batch_info_cont, + last_two_period_batch_info_cont, value_solved, policy_solved, endog_grid_solved, @@ -296,13 +296,15 @@ def create_test_inputs(): value_interp_final_period, marginal_utility_final_last_period, ) = solve_final_period( - idx_state_choices_final_period=batch_info_cont[ + idx_state_choices_final_period=last_two_period_batch_info_cont[ "idx_state_choices_final_period" ], - idx_parent_states_final_period=batch_info_cont[ + idx_parent_states_final_period=last_two_period_batch_info_cont[ "idxs_parent_states_final_period" ], - state_choice_mat_final_period=batch_info_cont["state_choice_mat_final_period"], + state_choice_mat_final_period=last_two_period_batch_info_cont[ + "state_choice_mat_final_period" + ], cont_grids_next_period=cont_grids_next_period, exog_grids=exog_grids_cont, params=params, @@ -316,9 +318,15 @@ def create_test_inputs(): endog_grid, policy, value_second_last = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, - state_choice_mat=batch_info_cont["state_choice_mat_second_last_period"], - child_state_idxs=batch_info_cont["child_states_second_last_period"], - states_to_choices_child_states=batch_info_cont["state_to_choices_final_period"], + state_choice_mat=last_two_period_batch_info_cont[ + "state_choice_mat_second_last_period" + ], + child_state_idxs=last_two_period_batch_info_cont[ + "child_states_second_last_period" + ], + states_to_choices_child_states=last_two_period_batch_info_cont[ + "state_to_choices_final_period" + ], params=params, taste_shock_scale=taste_shock_scale, income_shock_weights=income_shock_weights, @@ -327,7 +335,9 @@ def create_test_inputs(): has_second_continuous_state=True, ) - idx_second_last = batch_info_cont["idx_state_choices_second_last_period"] + idx_second_last = last_two_period_batch_info_cont[ + "idx_state_choices_second_last_period" + ] value_solved = value_solved.at[idx_second_last, ...].set(value_second_last) policy_solved = policy_solved.at[idx_second_last, ...].set(policy) @@ -437,7 +447,7 @@ def test_euler_equation(wealth_idx, state_idx, create_test_inputs): def _get_solve_last_two_periods_args(model, params, has_second_continuous_state): options = model["options"] - batch_info = model["batch_info"] + batch_info_last_two_periods = model["batch_info"]["last_two_period_info"] exog_grids = options["exog_grids"] @@ -474,7 +484,7 @@ def _get_solve_last_two_periods_args(model, params, has_second_continuous_state) taste_shock_scale, exog_grids, model_funcs, - batch_info, + batch_info_last_two_periods, value_solved, policy_solved, endog_grid_solved, diff --git a/tests/test_utility_second_continuous.py b/tests/test_utility_second_continuous.py index b23987db..e31ddc14 100644 --- a/tests/test_utility_second_continuous.py +++ b/tests/test_utility_second_continuous.py @@ -226,7 +226,6 @@ def test_setup(): ), "endogenous_states": { "experience": np.arange(N_PERIODS + MAX_INIT_EXPERIENCE), - "sparsity_condition": model_funcs_discr_exp["sparsity_condition"], }, "continuous_states": { "wealth": jnp.linspace( @@ -366,17 +365,16 @@ def test_replication_discrete_versus_continuous_experience( state_choice_cont_dict["choice"], ] - state_specific_choice_set = model_disc["model_funcs"][ - "get_state_specific_choice_set" - ](**state_choice_disc_dict) + state_specific_choice_set = model_disc["model_funcs"]["state_specific_choice_set"]( + **state_choice_disc_dict + ) choice_valid = choice in state_specific_choice_set - sparsity_condition = load_example_models("with_exp")["sparsity_condition"] + sparsity_condition = model_disc["model_funcs"]["sparsity_condition"] state_valid = sparsity_condition( - period, - experience, - lagged_choice, - model_disc["options"]["model_params"], + period=period, + experience=experience, + lagged_choice=lagged_choice, ) # ================================================================================