From d4f892f8851cc4f0d1c4f110ee8dfbde587839ed Mon Sep 17 00:00:00 2001 From: Maximilian Blesch Date: Thu, 13 Jun 2024 18:45:09 +0200 Subject: [PATCH] Fix default values to old behavior. (#13) --- .pre-commit-config.yaml | 35 +++++++++++-------- src/upper_envelope/__init__.py | 6 ++-- .../fues_jax/check_and_scan_funcs.py | 1 + src/upper_envelope/fues_jax/fues_jax.py | 22 +++++++----- src/upper_envelope/fues_numba/fues_numba.py | 5 ++- tests/test_fues_jax.py | 11 +++--- tests/test_fues_numba.py | 9 +++-- tests/utils/fast_upper_envelope_org.py | 4 +-- tests/utils/upper_envelope_fedor.py | 25 ++++++------- 9 files changed, 64 insertions(+), 54 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 966bdae..0fe60e0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,15 +6,15 @@ repos: - id: check-useless-excludes # - id: identity # Prints all files passed to pre-commits. Debugging. - repo: https://github.com/adrienverge/yamllint.git - rev: v1.32.0 + rev: v1.35.1 hooks: - id: yamllint - repo: https://github.com/lyz-code/yamlfix - rev: 1.13.0 + rev: 1.16.0 hooks: - id: yamlfix - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: check-added-large-files args: @@ -41,24 +41,31 @@ repos: - id: python-no-log-warn - id: python-use-type-annotations - id: text-unicode-replacement-char - - repo: https://github.com/asottile/reorder-python-imports - rev: v3.10.0 + - repo: https://github.com/pycqa/isort + rev: 5.13.2 hooks: - - id: reorder-python-imports + - id: isort + name: isort args: - - --py37-plus + - --profile=black + # - repo: https://github.com/asottile/reorder-python-imports + # rev: v3.13.0 + # hooks: + # - id: reorder-python-imports + # args: + # - --py37-plus - repo: https://github.com/asottile/setup-cfg-fmt - rev: v2.4.0 + rev: v2.5.0 hooks: - id: setup-cfg-fmt - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 24.4.2 hooks: - id: black language_version: python3.10 exclude: tests/utils/fast_upper_envelope_org.py - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.282 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.7 hooks: - id: ruff # exclude: | @@ -79,12 +86,12 @@ repos: - --blank exclude: tests/utils/fast_upper_envelope_org.py - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.0 + rev: 1.8.5 hooks: - id: nbqa-black - id: nbqa-ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + rev: 0.7.17 hooks: - id: mdformat additional_dependencies: @@ -95,7 +102,7 @@ repos: - '88' files: (README\.md) - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: diff --git a/src/upper_envelope/__init__.py b/src/upper_envelope/__init__.py index 0ae40f1..16055e5 100644 --- a/src/upper_envelope/__init__.py +++ b/src/upper_envelope/__init__.py @@ -1,4 +1,2 @@ -from upper_envelope.fues_jax.fues_jax import fues_jax -from upper_envelope.fues_jax.fues_jax import fues_jax_unconstrained -from upper_envelope.fues_numba.fues_numba import fues_numba -from upper_envelope.fues_numba.fues_numba import fues_numba_unconstrained +from upper_envelope.fues_jax.fues_jax import fues_jax, fues_jax_unconstrained +from upper_envelope.fues_numba.fues_numba import fues_numba, fues_numba_unconstrained diff --git a/src/upper_envelope/fues_jax/check_and_scan_funcs.py b/src/upper_envelope/fues_jax/check_and_scan_funcs.py index e7c5e3f..77fc653 100644 --- a/src/upper_envelope/fues_jax/check_and_scan_funcs.py +++ b/src/upper_envelope/fues_jax/check_and_scan_funcs.py @@ -3,6 +3,7 @@ import jax from jax import numpy as jnp + from upper_envelope.math_funcs import calc_gradient diff --git a/src/upper_envelope/fues_jax/fues_jax.py b/src/upper_envelope/fues_jax/fues_jax.py index 927aabc..44abb7c 100644 --- a/src/upper_envelope/fues_jax/fues_jax.py +++ b/src/upper_envelope/fues_jax/fues_jax.py @@ -5,22 +5,19 @@ https://dx.doi.org/10.2139/ssrn.4181302 """ + from functools import partial -from typing import Callable -from typing import Dict -from typing import Optional -from typing import Tuple +from typing import Callable, Dict, Optional, Tuple import jax import jax.numpy as jnp import numpy as np from jax import vmap + from upper_envelope.fues_jax.check_and_scan_funcs import ( determine_cases_and_conduct_necessary_scans, ) -from upper_envelope.math_funcs import ( - calc_intersection_and_extrapolate_policy, -) +from upper_envelope.math_funcs import calc_intersection_and_extrapolate_policy @partial( @@ -109,6 +106,13 @@ def fues_jax( else n_constrained_points_to_add ) + # Set default value of final grid size to 1.2 times current if not defined + n_final_wealth_grid = ( + int(1.2 * endog_grid.shape[0]) + if n_final_wealth_grid is None + else n_final_wealth_grid + ) + # Check if a non-concave region coincides with the credit constrained region. # This happens when there is a non-monotonicity in the endogenous wealth grid # that goes below the first point (the minimal wealth, below it is optimal to @@ -205,7 +209,9 @@ def fues_jax_unconstrained( """ # Set default value of final grid size to 1.2 times current if not defined n_final_wealth_grid = ( - int(1.2 * (len(policy))) if n_final_wealth_grid is None else n_final_wealth_grid + int(1.2 * endog_grid.shape[0]) + if n_final_wealth_grid is None + else n_final_wealth_grid ) idx_sort = jnp.argsort(endog_grid) diff --git a/src/upper_envelope/fues_numba/fues_numba.py b/src/upper_envelope/fues_numba/fues_numba.py index 92ad6a7..7a0ed91 100644 --- a/src/upper_envelope/fues_numba/fues_numba.py +++ b/src/upper_envelope/fues_numba/fues_numba.py @@ -5,9 +5,8 @@ https://dx.doi.org/10.2139/ssrn.4181302 """ -from typing import Callable -from typing import Optional -from typing import Tuple + +from typing import Callable, Optional, Tuple import numpy as np from numba import njit diff --git a/tests/test_fues_jax.py b/tests/test_fues_jax.py index 764e454..869fa49 100644 --- a/tests/test_fues_jax.py +++ b/tests/test_fues_jax.py @@ -1,4 +1,5 @@ """Test the JAX implementation of the fast upper envelope scan.""" + from pathlib import Path from typing import Dict @@ -6,13 +7,15 @@ import jax.numpy as jnp import numpy as np import pytest -import upper_envelope as upenv from numpy.testing import assert_array_almost_equal as aaae -from upper_envelope.fues_jax.check_and_scan_funcs import back_and_forward_scan_wrapper -from tests.utils.interpolation import interpolate_policy_and_value_on_wealth_grid -from tests.utils.interpolation import linear_interpolation_with_extrapolation +import upper_envelope as upenv +from tests.utils.interpolation import ( + interpolate_policy_and_value_on_wealth_grid, + linear_interpolation_with_extrapolation, +) from tests.utils.upper_envelope_fedor import upper_envelope +from upper_envelope.fues_jax.check_and_scan_funcs import back_and_forward_scan_wrapper jax.config.update("jax_enable_x64", True) diff --git a/tests/test_fues_numba.py b/tests/test_fues_numba.py index 6b4ea3f..98189cf 100644 --- a/tests/test_fues_numba.py +++ b/tests/test_fues_numba.py @@ -1,14 +1,17 @@ """Test the numba implementation of the fast upper envelope scan.""" + from pathlib import Path import numpy as np import pytest -import upper_envelope as upenv from numpy.testing import assert_array_almost_equal as aaae +import upper_envelope as upenv from tests.utils.fast_upper_envelope_org import fast_upper_envelope_wrapper_org -from tests.utils.interpolation import interpolate_single_policy_and_value_on_wealth_grid -from tests.utils.interpolation import linear_interpolation_with_extrapolation +from tests.utils.interpolation import ( + interpolate_single_policy_and_value_on_wealth_grid, + linear_interpolation_with_extrapolation, +) from tests.utils.upper_envelope_fedor import upper_envelope # Obtain the test directory of the package. diff --git a/tests/utils/fast_upper_envelope_org.py b/tests/utils/fast_upper_envelope_org.py index b37098c..6db12ad 100644 --- a/tests/utils/fast_upper_envelope_org.py +++ b/tests/utils/fast_upper_envelope_org.py @@ -5,9 +5,7 @@ https://dx.doi.org/10.2139/ssrn.4181302 """ -from typing import Callable -from typing import Optional -from typing import Tuple +from typing import Callable, Optional, Tuple import numpy as np diff --git a/tests/utils/upper_envelope_fedor.py b/tests/utils/upper_envelope_fedor.py index 47b96c9..9895e76 100644 --- a/tests/utils/upper_envelope_fedor.py +++ b/tests/utils/upper_envelope_fedor.py @@ -4,15 +4,12 @@ https://github.com/fediskhakov/dcegm/blob/master/model_retirement.m """ -from typing import Callable -from typing import Dict -from typing import List -from typing import Tuple + +from typing import Callable, Dict, List, Tuple import numpy as np from scipy.optimize import brenth as root - EPS = 2.2204e-16 @@ -317,16 +314,14 @@ def compute_upper_envelope( values_all_segments = np.empty((len(segments), 1)) for segment in range(len(segments)): - values_all_segments[ - segment - ] = _linear_interpolation_with_inserting_missing_values( - x=segments[segment][0], - y=segments[segment][1], - x_new=np.array([intersect_point]), - missing_value=-np.inf, - )[ - 0 - ] + values_all_segments[segment] = ( + _linear_interpolation_with_inserting_missing_values( + x=segments[segment][0], + y=segments[segment][1], + x_new=np.array([intersect_point]), + missing_value=-np.inf, + )[0] + ) index_max_value_intersect = np.where( values_all_segments == values_all_segments.max(axis=0)