Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend sparsity to exogenous processes #143

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6b5145c
Implemented check that all states are child states except first period.
MaxBlesch Nov 29, 2024
4c02404
Fixing state spaces of tests.
MaxBlesch Nov 29, 2024
f2ef015
Exog processes in sparsity.
MaxBlesch Nov 29, 2024
7ff59a1
Dummy proxy should be working.
MaxBlesch Nov 29, 2024
8c8a02f
Refactor pre processing.
MaxBlesch Nov 29, 2024
7c3eb07
Large refactor of pre processing.
MaxBlesch Nov 29, 2024
305c866
Large refactor of pre processing.
MaxBlesch Nov 29, 2024
f368f03
Some more refactor.
MaxBlesch Nov 29, 2024
99f2e72
More refactor.
MaxBlesch Nov 29, 2024
d4c35b2
Use sparsity condition throughout.
MaxBlesch Dec 2, 2024
6471d31
Error messages.
MaxBlesch Dec 2, 2024
a80b0d2
Sparsity condition checks implemented.
MaxBlesch Dec 4, 2024
a73d143
Start with refactoring.
MaxBlesch Dec 4, 2024
c84d722
Algorithm for idx
MaxBlesch Dec 4, 2024
038e603
Extracted last two period batch information.
MaxBlesch Dec 4, 2024
00a067a
Further refactor.
MaxBlesch Dec 4, 2024
1396bec
Further batch refactor.
MaxBlesch Dec 4, 2024
1cd7590
Preliminary done with refactoring batches.
MaxBlesch Dec 4, 2024
5e9a8ff
Introduced segment.
MaxBlesch Dec 4, 2024
2c6a1f2
Introduced segments. Default of none is working. Now extending.
MaxBlesch Dec 4, 2024
aec1809
Added int split.
MaxBlesch Dec 4, 2024
90ba610
Draft on period split for batches.
MaxBlesch Dec 4, 2024
2139632
Transformed params to yaml.
MaxBlesch Dec 4, 2024
9fdf2a4
Debugging.
MaxBlesch Dec 4, 2024
d245efb
Wrote test. It works.
MaxBlesch Dec 4, 2024
6acfb0f
Comment.
MaxBlesch Dec 4, 2024
916fbae
Finished new debug mode.
MaxBlesch Dec 10, 2024
5aabaa2
New debugging infrastructure.
MaxBlesch Dec 13, 2024
f4b9769
Simplified endogenous state handling.
MaxBlesch Dec 13, 2024
b9bd1e9
Nice errors starting.
MaxBlesch Dec 16, 2024
23668fa
Rename model funcs.
MaxBlesch Dec 16, 2024
59a2a0d
Debugging with enw framework.
MaxBlesch Dec 16, 2024
1513cff
Adapted all tests.
MaxBlesch Dec 16, 2024
34f71e1
Better error.,
MaxBlesch Dec 18, 2024
14d40b8
SOme interface extension to 2d case.
MaxBlesch Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/Tutorials/two_period_model_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@
"source": [
"#### State space functions\n",
"\n",
"Next we define state space functions ```create_state_space``` and ```get_state_specific_choice_set```. They can be directly imported the dcegm package, but we display them here, in order to explain how they work."
"Next we define state space functions ```create_state_space``` and ```state_specific_choice_set```. They can be directly imported the dcegm package, but we display them here, in order to explain how they work."
]
},
{
Expand Down Expand Up @@ -402,7 +402,7 @@
"metadata": {},
"outputs": [],
"source": [
"def get_state_specific_choice_set(\n",
"def state_specific_choice_set(\n",
" state: np.ndarray,\n",
" state_space: np.ndarray, # noqa: U100\n",
" indexer: np.ndarray,\n",
Expand Down Expand Up @@ -460,7 +460,7 @@
}
],
"source": [
"choice_set_5 = get_state_specific_choice_set(_state_space[4], _state_space, _indexer)\n",
"choice_set_5 = state_specific_choice_set(_state_space[4], _state_space, _indexer)\n",
"choice_set_5"
]
},
Expand All @@ -479,7 +479,7 @@
"source": [
"state_space_functions = {\n",
" \"create_state_space\": create_state_space,\n",
" \"get_state_specific_choice_set\": get_state_specific_choice_set,\n",
" \"state_specific_choice_set\": state_specific_choice_set,\n",
"}"
]
},
Expand Down
28 changes: 19 additions & 9 deletions src/dcegm/final_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def solve_last_two_periods(
income_shock_weights: jnp.ndarray,
exog_grids: Dict[str, jnp.ndarray],
model_funcs: Dict[str, Callable],
batch_info,
last_two_period_batch_info,
value_solved,
policy_solved,
endog_grid_solved,
Expand All @@ -39,7 +39,7 @@ def solve_last_two_periods(
compute_utility (callable): User supplied utility function.
compute_marginal_utility (callable): User supplied marginal utility
function.
batch_info (dict): Dictionary containing information about the batch
last_two_period_batch_info (dict): Dictionary containing information about the batch
size and the state space.
value_solved (np.ndarray): 3d array of shape
(n_states, n_grid_wealth, n_income_shocks) of the value function for
Expand All @@ -56,9 +56,15 @@ def solve_last_two_periods(
value_interp_final_period,
marginal_utility_final_last_period,
) = solve_final_period(
idx_state_choices_final_period=batch_info["idx_state_choices_final_period"],
idx_parent_states_final_period=batch_info["idxs_parent_states_final_period"],
state_choice_mat_final_period=batch_info["state_choice_mat_final_period"],
idx_state_choices_final_period=last_two_period_batch_info[
"idx_state_choices_final_period"
],
idx_parent_states_final_period=last_two_period_batch_info[
"idxs_parent_states_final_period"
],
state_choice_mat_final_period=last_two_period_batch_info[
"state_choice_mat_final_period"
],
cont_grids_next_period=cont_grids_next_period,
exog_grids=exog_grids,
params=params,
Expand All @@ -72,9 +78,13 @@ def solve_last_two_periods(
endog_grid, policy, value = solve_for_interpolated_values(
value_interpolated=value_interp_final_period,
marginal_utility_interpolated=marginal_utility_final_last_period,
state_choice_mat=batch_info["state_choice_mat_second_last_period"],
child_state_idxs=batch_info["child_states_second_last_period"],
states_to_choices_child_states=batch_info["state_to_choices_final_period"],
state_choice_mat=last_two_period_batch_info[
"state_choice_mat_second_last_period"
],
child_state_idxs=last_two_period_batch_info["child_states_second_last_period"],
states_to_choices_child_states=last_two_period_batch_info[
"state_to_choices_final_period"
],
params=params,
taste_shock_scale=taste_shock_scale,
income_shock_weights=income_shock_weights,
Expand All @@ -83,7 +93,7 @@ def solve_last_two_periods(
has_second_continuous_state=has_second_continuous_state,
)

idx_second_last = batch_info["idx_state_choices_second_last_period"]
idx_second_last = last_two_period_batch_info["idx_state_choices_second_last_period"]

value_solved = value_solved.at[idx_second_last, ...].set(value)
policy_solved = policy_solved.at[idx_second_last, ...].set(policy)
Expand Down
46 changes: 33 additions & 13 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
interp_policy_on_wealth,
interp_value_on_wealth,
)
from dcegm.interpolation.interp2d import (
interp2d_policy_and_value_on_wealth_and_regular_grid,
interp2d_policy_on_wealth_and_regular_grid,
interp2d_value_on_wealth_and_regular_grid,
)


def policy_and_value_for_state_choice_vec(
Expand Down Expand Up @@ -47,14 +52,13 @@ def policy_and_value_for_state_choice_vec(


def value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
discrete_states_names,
endog_grid_solved,
value_solved,
compute_utility,
params,
model,
state_choice_vec,
wealth,
second_continous=None,
):
"""Get policy and value for a given state and choice vector.

Expand All @@ -67,20 +71,36 @@ def value_for_state_choice_vec(
Tuple[float, float]: Policy and value for the given state and choice vector.

"""
map_state_choice_to_index = model["model_structure"]["map_state_choice_to_index"]
discrete_states_names = model["model_structure"]["discrete_states_names"]
compute_utility = model["model_funcs"]["compute_utility"]

state_choice_tuple = tuple(
state_choice_vec[st] for st in discrete_states_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]

value = interp_value_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
value=jnp.take(value_solved, state_choice_index, axis=0),
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
if second_continous is None:
value = interp_value_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
value=jnp.take(value_solved, state_choice_index, axis=0),
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
else:
value = interp2d_value_on_wealth_and_regular_grid(
regular_grid=model["options"]["exog_grids"]["second_continuous"],
wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
value_grid=jnp.take(value_solved, state_choice_index, axis=0),
regular_point_to_interp=second_continous,
wealth_point_to_interp=wealth,
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
return value


Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/law_of_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def calc_cont_grids_next_period(
discrete_states_beginning_of_period=state_space_dict,
continuous_grid=exog_grids["second_continuous"],
params=params,
compute_continuous_state=model_funcs["update_continuous_state"],
compute_continuous_state=model_funcs["next_period_continuous_state"],
)

# Extra dimension for continuous state
Expand Down
Loading
Loading