Skip to content

Commit

Permalink
MNT: final touches to Function Class
Browse files Browse the repository at this point in the history
- Fix variable naming "xinitial", I also checked this is nto used in
other classes
- update interpolation and extrapolation methods usage
- remove useless raises or warnings.
  • Loading branch information
Gui-FernandesBR committed Nov 19, 2023
1 parent 7ec3b32 commit ea57768
Showing 1 changed file with 45 additions and 39 deletions.
84 changes: 45 additions & 39 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def source_function(_):
source = source[source[:, 0].argsort()]

self.x_array = source[:, 0]
self.xinitial, self.xfinal = self.x_array[0], self.x_array[-1]
self.x_initial, self.x_final = self.x_array[0], self.x_array[-1]

self.y_array = source[:, 1]
self.y_initial, self.y_final = self.y_array[0], self.y_array[-1]
Expand All @@ -241,7 +241,7 @@ def source_function(_):
# Do things if function is multivariate
else:
self.x_array = source[:, 0]
self.xinitial, self.xfinal = self.x_array[0], self.x_array[-1]
self.x_initial, self.x_final = self.x_array[0], self.x_array[-1]

self.y_array = source[:, 1]
self.y_initial, self.y_final = self.y_array[0], self.y_array[-1]
Expand All @@ -251,30 +251,6 @@ def source_function(_):

# Finally set data source as source
self.source = source

# Update extrapolation method
if (
self.__extrapolation__ is None
or self.__extrapolation__ == "natural"
):
self.set_extrapolation("natural")
else:
raise ValueError(
"Multidimensional datasets only support natural extrapolation."
)

# Set default multidimensional interpolation if it hasn't
if (
self.__interpolation__ is None
or self.__interpolation__ == "shepard"
):
self.set_interpolation("shepard")
else:
raise ValueError(
"Multidimensional datasets only support shepard interpolation."
)

# Return self
return self

@cached_property
Expand Down Expand Up @@ -363,7 +339,7 @@ def set_get_value_opt(self):
# Retrieve general info
x_data = self.x_array
y_data = self.y_array
x_min, x_max = self.xinitial, self.xfinal
x_min, x_max = self.x_initial, self.x_final
if self.__extrapolation__ == "zero":
extrapolation = 0 # Extrapolation is zero
elif self.__extrapolation__ == "natural":
Expand Down Expand Up @@ -557,6 +533,7 @@ def set_discrete(
zs = np.array(self.get_value(mesh))
self.set_source(np.concatenate(([xs], [ys], [zs])).transpose())
self.__interpolation__ = "shepard"
self.__extrapolation__ = "natural"

Check warning on line 536 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L536

Added line #L536 was not covered by tests
return self

def set_discrete_based_on_model(
Expand Down Expand Up @@ -888,7 +865,7 @@ def get_value(self, *args):
x = np.array(args[0])
x_data = self.x_array
y_data = self.y_array
x_min, x_max = self.xinitial, self.xfinal
x_min, x_max = self.x_initial, self.x_final
coeffs = self.__polynomial_coefficients__
matrix = np.zeros((len(args[0]), coeffs.shape[0]))
for i in range(coeffs.shape[0]):
Expand All @@ -909,7 +886,7 @@ def get_value(self, *args):
x_data = self.x_array
y_data = self.y_array
x_intervals = np.searchsorted(x_data, x)
x_min, x_max = self.xinitial, self.xfinal
x_min, x_max = self.x_initial, self.x_final
if self.__interpolation__ == "spline":
coeffs = self.__spline_coefficients__
for i, _ in enumerate(x):
Expand Down Expand Up @@ -1229,7 +1206,7 @@ def plot_1d(
else:
# Determine boundaries
x_data = self.x_array
x_min, x_max = self.xinitial, self.xfinal
x_min, x_max = self.x_initial, self.x_final
lower = x_min if lower is None else lower
upper = x_max if upper is None else upper
# Plot data points if force_data = True
Expand Down Expand Up @@ -1530,7 +1507,7 @@ def __interpolate_polynomial__(self):
y = self.y_array
# Check if interpolation requires large numbers
if np.amax(x) ** degree > 1e308:
print(
warnings.warn(

Check warning on line 1510 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L1510

Added line #L1510 was not covered by tests
"Polynomial interpolation of too many points can't be done."
" Once the degree is too high, numbers get too large."
" The process becomes inefficient. Using spline instead."
Expand Down Expand Up @@ -2738,10 +2715,10 @@ def compose(self, func, extrapolate=False):
if isinstance(self.source, np.ndarray) and isinstance(func.source, np.ndarray):
# Perform bounds check for composition
if not extrapolate:
if func.min < self.xinitial and func.max > self.xfinal:
if func.min < self.x_initial and func.max > self.x_final:
raise ValueError(
f"Input Function image {func.min, func.max} must be within "
f"the domain of the Function {self.xinitial, self.xfinal}."
f"the domain of the Function {self.x_initial, self.x_final}."
)

return Function(
Expand All @@ -2763,10 +2740,10 @@ def compose(self, func, extrapolate=False):
@staticmethod
def _check_user_input(
source,
inputs,
outputs,
interpolation,
extrapolation,
inputs=None,
outputs=None,
interpolation=None,
extrapolation=None,
):
"""
Validates and processes the user input parameters for creating or
Expand Down Expand Up @@ -2857,7 +2834,36 @@ def _check_user_input(
source_dim = source.shape[1]

# check interpolation and extrapolation
if source_dim > 2: # (multiple dimensions)

## single dimension
if source_dim == 2:
# possible interpolation values: llinear, polynomial, akima and spline
if interpolation is None:
interpolation = "spline"
elif interpolation.lower() not in [
"spline",
"linear",
"polynomial",
"akima",
]:
warnings.warn(

Check warning on line 2849 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L2849

Added line #L2849 was not covered by tests
"Interpolation method for single dimensional functions was "
+ f"set to 'spline', the {interpolation} method is not supported."
)
interpolation = "spline"

Check warning on line 2853 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L2853

Added line #L2853 was not covered by tests

# possible extrapolation values: constant, natural, zero
if extrapolation is None:
extrapolation = "constant"
elif extrapolation.lower() not in ["constant", "natural", "zero"]:
warnings.warn(

Check warning on line 2859 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L2859

Added line #L2859 was not covered by tests
"Extrapolation method for single dimensional functions was "
+ f"set to 'constant', the {extrapolation} method is not supported."
)
extrapolation = "constant"

Check warning on line 2863 in rocketpy/mathutils/function.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/mathutils/function.py#L2863

Added line #L2863 was not covered by tests

## multiple dimensions
if source_dim > 2:
# check for inputs and outputs
if inputs == ["Scalar"]:
inputs = [f"Input {i+1}" for i in range(source_dim - 1)]
Expand Down Expand Up @@ -2918,7 +2924,7 @@ def __new__(
outputs: list of strings
A list of strings that represent the outputs of the function.
interpolation: str
The type of interpolation to use. The default value is 'akima'.
The type of interpolation to use. The default value is 'spline'.
extrapolation: str
The type of extrapolation to use. The default value is None.
datapoints: int
Expand Down

0 comments on commit ea57768

Please sign in to comment.