Skip to content

Commit

Permalink
Adding copy_if_needed argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
ndilalla committed Aug 13, 2024
1 parent e3a13ba commit d2c9c24
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 10 deletions.
13 changes: 7 additions & 6 deletions astromodels/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from astromodels.utils.logging import setup_logger
from astromodels.utils.pretty_list import dict_to_list
from astromodels.utils.table import dict_to_table
from astromodels.utils.file_utils import copy_if_needed
from yaml.reader import ReaderError

log = setup_logger(__name__)
Expand Down Expand Up @@ -1479,7 +1480,7 @@ def __call__(self, x):

# Transform the input to an array of floats. If x is a single number, this will be an array of size 1

new_input = np.array(x, dtype=float, ndmin=1, copy=None)
new_input = np.array(x, dtype=float, ndmin=1, copy=copy_if_needed)

# Compute the function

Expand Down Expand Up @@ -1705,8 +1706,8 @@ def __call__(self, x, y, *args, **kwargs):

# Transform the input to an array of floats

new_x = np.array(x, dtype=float, ndmin=1, copy=None)
new_y = np.array(y, dtype=float, ndmin=1, copy=None)
new_x = np.array(x, dtype=float, ndmin=1, copy=copy_if_needed)
new_y = np.array(y, dtype=float, ndmin=1, copy=copy_if_needed)

# Compute the function

Expand Down Expand Up @@ -1876,9 +1877,9 @@ def __call__(self, x, y, z):
# This is either a single number or a list
# Transform the input to an array of floats

new_x = np.array(x, dtype=float, ndmin=1, copy=None)
new_y = np.array(y, dtype=float, ndmin=1, copy=None)
new_z = np.array(z, dtype=float, ndmin=1, copy=None)
new_x = np.array(x, dtype=float, ndmin=1, copy=copy_if_needed)
new_y = np.array(y, dtype=float, ndmin=1, copy=copy_if_needed)
new_z = np.array(z, dtype=float, ndmin=1, copy=copy_if_needed)

# Compute the function

Expand Down
3 changes: 2 additions & 1 deletion astromodels/functions/functions_1D/extinction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from astromodels.functions.function import Function1D, FunctionMeta
from astromodels.utils.logging import setup_logger
from astromodels.utils.file_utils import copy_if_needed

log = setup_logger(__name__)

Expand Down Expand Up @@ -89,7 +90,7 @@ def evaluate(self, x, e_bmv, rv, redshift):

if isinstance(x, astropy_units.Quantity):

_x = np.array(x.to("keV").value, ndmin=1, copy=None, dtype=float)
_x = np.array(x.to("keV").value, ndmin=1, copy=copy_if_needed, dtype=float)

_unit = astropy_units.cm**2
_y_unit = astropy_units.dimensionless_unscaled
Expand Down
3 changes: 2 additions & 1 deletion astromodels/functions/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from astromodels.functions.function import Function1D, FunctionMeta
from astromodels.utils import get_user_data_path
from astromodels.utils.logging import setup_logger
from astromodels.utils.file_utils import copy_if_needed

log = setup_logger(__name__)

Expand Down Expand Up @@ -761,7 +762,7 @@ def _interpolate(self, energies, scale, parameters_values):
# the logarithm below will fail.

energies = np.array(
energies.to("keV").value, ndmin=1, copy=None, dtype=float
energies.to("keV").value, ndmin=1, copy=copy_if_needed, dtype=float
)

# Same for the scale
Expand Down
8 changes: 7 additions & 1 deletion astromodels/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@

import os
from pathlib import Path

import numpy as np
import pkg_resources

_custom_config_path = os.environ.get("ASTROMODELS_CONFIG")

copy_if_needed: Optional[bool]

if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
copy_if_needed = None
else:
copy_if_needed = False

def _get_data_file_path(data_file: str) -> Path:
"""
Expand Down
3 changes: 2 additions & 1 deletion astromodels/xspec/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def get_models(model_dat_path):
from astromodels.functions.function import FunctionMeta, Function1D
import numpy as np
import astropy.units as u
from astromodels.utils.file_utils import copy_if_needed
from astromodels.xspec import _xspec
import six
Expand Down Expand Up @@ -486,7 +487,7 @@ def evaluate(self, x, $PARAMETERS_NAMES$):
if isinstance(x, u.Quantity):
x = np.array(x.to('keV').value, ndmin=1, copy=None, dtype=float)
x = np.array(x.to('keV').value, ndmin=1, copy=copy_if_needed, dtype=float)
quantity = True
Expand Down

0 comments on commit d2c9c24

Please sign in to comment.