Skip to content

Commit

Permalink
Fix default values to old behavior. (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch authored Jun 13, 2024
1 parent 7ce4132 commit d4f892f
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 54 deletions.
35 changes: 21 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ repos:
- id: check-useless-excludes
# - id: identity # Prints all files passed to pre-commits. Debugging.
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.32.0
rev: v1.35.1
hooks:
- id: yamllint
- repo: https://github.com/lyz-code/yamlfix
rev: 1.13.0
rev: 1.16.0
hooks:
- id: yamlfix
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: check-added-large-files
args:
Expand All @@ -41,24 +41,31 @@ repos:
- id: python-no-log-warn
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.10.0
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: reorder-python-imports
- id: isort
name: isort
args:
- --py37-plus
- --profile=black
# - repo: https://github.com/asottile/reorder-python-imports
# rev: v3.13.0
# hooks:
# - id: reorder-python-imports
# args:
# - --py37-plus
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.4.0
rev: v2.5.0
hooks:
- id: setup-cfg-fmt
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 24.4.2
hooks:
- id: black
language_version: python3.10
exclude: tests/utils/fast_upper_envelope_org.py
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.282
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
hooks:
- id: ruff
# exclude: |
Expand All @@ -79,12 +86,12 @@ repos:
- --blank
exclude: tests/utils/fast_upper_envelope_org.py
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
rev: 1.8.5
hooks:
- id: nbqa-black
- id: nbqa-ruff
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.17
hooks:
- id: mdformat
additional_dependencies:
Expand All @@ -95,7 +102,7 @@ repos:
- '88'
files: (README\.md)
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.3.0
hooks:
- id: codespell
additional_dependencies:
Expand Down
6 changes: 2 additions & 4 deletions src/upper_envelope/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
from upper_envelope.fues_jax.fues_jax import fues_jax
from upper_envelope.fues_jax.fues_jax import fues_jax_unconstrained
from upper_envelope.fues_numba.fues_numba import fues_numba
from upper_envelope.fues_numba.fues_numba import fues_numba_unconstrained
from upper_envelope.fues_jax.fues_jax import fues_jax, fues_jax_unconstrained
from upper_envelope.fues_numba.fues_numba import fues_numba, fues_numba_unconstrained
1 change: 1 addition & 0 deletions src/upper_envelope/fues_jax/check_and_scan_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax
from jax import numpy as jnp

from upper_envelope.math_funcs import calc_gradient


Expand Down
22 changes: 14 additions & 8 deletions src/upper_envelope/fues_jax/fues_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@
https://dx.doi.org/10.2139/ssrn.4181302
"""

from functools import partial
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap

from upper_envelope.fues_jax.check_and_scan_funcs import (
determine_cases_and_conduct_necessary_scans,
)
from upper_envelope.math_funcs import (
calc_intersection_and_extrapolate_policy,
)
from upper_envelope.math_funcs import calc_intersection_and_extrapolate_policy


@partial(
Expand Down Expand Up @@ -109,6 +106,13 @@ def fues_jax(
else n_constrained_points_to_add
)

# Set default value of final grid size to 1.2 times current if not defined
n_final_wealth_grid = (
int(1.2 * endog_grid.shape[0])
if n_final_wealth_grid is None
else n_final_wealth_grid
)

# Check if a non-concave region coincides with the credit constrained region.
# This happens when there is a non-monotonicity in the endogenous wealth grid
# that goes below the first point (the minimal wealth, below it is optimal to
Expand Down Expand Up @@ -205,7 +209,9 @@ def fues_jax_unconstrained(
"""
# Set default value of final grid size to 1.2 times current if not defined
n_final_wealth_grid = (
int(1.2 * (len(policy))) if n_final_wealth_grid is None else n_final_wealth_grid
int(1.2 * endog_grid.shape[0])
if n_final_wealth_grid is None
else n_final_wealth_grid
)

idx_sort = jnp.argsort(endog_grid)
Expand Down
5 changes: 2 additions & 3 deletions src/upper_envelope/fues_numba/fues_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
https://dx.doi.org/10.2139/ssrn.4181302
"""
from typing import Callable
from typing import Optional
from typing import Tuple

from typing import Callable, Optional, Tuple

import numpy as np
from numba import njit
Expand Down
11 changes: 7 additions & 4 deletions tests/test_fues_jax.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""Test the JAX implementation of the fast upper envelope scan."""

from pathlib import Path
from typing import Dict

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import upper_envelope as upenv
from numpy.testing import assert_array_almost_equal as aaae
from upper_envelope.fues_jax.check_and_scan_funcs import back_and_forward_scan_wrapper

from tests.utils.interpolation import interpolate_policy_and_value_on_wealth_grid
from tests.utils.interpolation import linear_interpolation_with_extrapolation
import upper_envelope as upenv
from tests.utils.interpolation import (
interpolate_policy_and_value_on_wealth_grid,
linear_interpolation_with_extrapolation,
)
from tests.utils.upper_envelope_fedor import upper_envelope
from upper_envelope.fues_jax.check_and_scan_funcs import back_and_forward_scan_wrapper

jax.config.update("jax_enable_x64", True)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_fues_numba.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Test the numba implementation of the fast upper envelope scan."""

from pathlib import Path

import numpy as np
import pytest
import upper_envelope as upenv
from numpy.testing import assert_array_almost_equal as aaae

import upper_envelope as upenv
from tests.utils.fast_upper_envelope_org import fast_upper_envelope_wrapper_org
from tests.utils.interpolation import interpolate_single_policy_and_value_on_wealth_grid
from tests.utils.interpolation import linear_interpolation_with_extrapolation
from tests.utils.interpolation import (
interpolate_single_policy_and_value_on_wealth_grid,
linear_interpolation_with_extrapolation,
)
from tests.utils.upper_envelope_fedor import upper_envelope

# Obtain the test directory of the package.
Expand Down
4 changes: 1 addition & 3 deletions tests/utils/fast_upper_envelope_org.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
https://dx.doi.org/10.2139/ssrn.4181302
"""
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import Callable, Optional, Tuple

import numpy as np

Expand Down
25 changes: 10 additions & 15 deletions tests/utils/upper_envelope_fedor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
https://github.com/fediskhakov/dcegm/blob/master/model_retirement.m
"""
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple

from typing import Callable, Dict, List, Tuple

import numpy as np
from scipy.optimize import brenth as root


EPS = 2.2204e-16


Expand Down Expand Up @@ -317,16 +314,14 @@ def compute_upper_envelope(

values_all_segments = np.empty((len(segments), 1))
for segment in range(len(segments)):
values_all_segments[
segment
] = _linear_interpolation_with_inserting_missing_values(
x=segments[segment][0],
y=segments[segment][1],
x_new=np.array([intersect_point]),
missing_value=-np.inf,
)[
0
]
values_all_segments[segment] = (
_linear_interpolation_with_inserting_missing_values(
x=segments[segment][0],
y=segments[segment][1],
x_new=np.array([intersect_point]),
missing_value=-np.inf,
)[0]
)

index_max_value_intersect = np.where(
values_all_segments == values_all_segments.max(axis=0)
Expand Down

0 comments on commit d4f892f

Please sign in to comment.