Skip to content

Commit

Permalink
Merge pull request #451 from RocketPy-Team/bug/function-input-validation
Browse files Browse the repository at this point in the history
BUG: User input checks added for Function class
  • Loading branch information
phmbressan authored Nov 17, 2023
2 parents f2be825 + 848c860 commit 00d1362
Show file tree
Hide file tree
Showing 2 changed files with 340 additions and 28 deletions.
278 changes: 250 additions & 28 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import warnings
from inspect import signature
from collections.abc import Iterable
from pathlib import Path

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -74,18 +75,22 @@ def __init__(
-------
None
"""
# initialize variables to avoid errors when being called by other methods
self.get_value_opt = None
self.__polynomial_coefficients__ = None
self.__akima_coefficients__ = None
self.__spline_coefficients__ = None

# Set input and output
if inputs is None:
inputs = ["Scalar"]
if outputs is None:
outputs = ["Scalar"]

inputs, outputs, interpolation, extrapolation = self._check_user_input(
source, inputs, outputs, interpolation, extrapolation
)

# initialize variables to avoid errors when being called by other methods
self.get_value_opt = None
self.__polynomial_coefficients__ = None
self.__akima_coefficients__ = None
self.__spline_coefficients__ = None

# store variables
self.set_inputs(inputs)
self.set_outputs(outputs)
Expand Down Expand Up @@ -148,6 +153,13 @@ def set_source(self, source):
-------
self : Function
"""
_ = self._check_user_input(
source,
self.__inputs__,
self.__outputs__,
self.__interpolation__,
self.__extrapolation__,
)
# If the source is a Function
if isinstance(source, Function):
source = source.get_source()
Expand Down Expand Up @@ -198,17 +210,13 @@ def source_function(_):
# Check to see if dimensions match incoming data set
new_total_dim = len(source[0, :])
old_total_dim = self.__dom_dim__ + self.__img_dim__
d_v = self.__inputs__ == ["Scalar"] and self.__outputs__ == ["Scalar"]

# If they don't, update default values or throw error
if new_total_dim != old_total_dim:
if d_v:
# Update dimensions and inputs
self.__dom_dim__ = new_total_dim - 1
self.__inputs__ = self.__dom_dim__ * self.__inputs__
else:
# User has made a mistake inputting inputs and outputs
print("Error in input and output dimensions!")
return None
# Update dimensions and inputs
self.__dom_dim__ = new_total_dim - 1
self.__inputs__ = self.__dom_dim__ * self.__inputs__

# Do things if domDim is 1
if self.__dom_dim__ == 1:
source = source[source[:, 0].argsort()]
Expand Down Expand Up @@ -751,22 +759,87 @@ def get_value(self, *args):
Returns
-------
ans : scalar, list
Value of the Function at the specified point(s).
Examples
--------
>>> from rocketpy import Function
Testing with callable source (1 dimension):
>>> f = Function(lambda x: x**2)
>>> f.get_value(2)
4
>>> f.get_value(2.5)
6.25
>>> f.get_value([1, 2, 3])
[1, 4, 9]
>>> f.get_value([1, 2.5, 4.0])
[1, 6.25, 16.0]
Testing with callable source (2 dimensions):
>>> f2 = Function(lambda x, y: x**2 + y**2)
>>> f2.get_value(1, 2)
5
>>> f2.get_value([1, 2, 3], [1, 2, 3])
[2, 8, 18]
>>> f2.get_value([5], [5])
[50]
Testing with ndarray source (1 dimension):
>>> f3 = Function(
... [(0, 0), (1, 1), (1.5, 2.25), (2, 4), (2.5, 6.25), (3, 9), (4, 16)]
... )
>>> f3.get_value(2)
4.0
>>> f3.get_value(2.5)
6.25
>>> f3.get_value([1, 2, 3])
[1.0, 4.0, 9.0]
>>> f3.get_value([1, 2.5, 4.0])
[1.0, 6.25, 16.0]
Testing with ndarray source (2 dimensions):
>>> f4 = Function(
... [(0, 0, 0), (1, 1, 1), (1, 2, 2), (2, 4, 8), (3, 9, 27)]
... )
>>> f4.get_value(1, 1)
1.0
>>> f4.get_value(2, 4)
8.0
>>> abs(f4.get_value(1, 1.5) - 1.5) < 1e-2 # the interpolation is not perfect
True
>>> f4.get_value(3, 9)
27.0
"""
if len(args) != self.__dom_dim__:
raise ValueError(
f"This Function takes {self.__dom_dim__} arguments, {len(args)} given."
)

# Return value for Function of function type
if callable(self.source):
if len(args) == 1 and isinstance(args[0], (list, tuple)):
if isinstance(args[0][0], (tuple, list)):
return [self.source(*arg) for arg in args[0]]
else:
return [self.source(arg) for arg in args[0]]
elif len(args) == 1 and isinstance(args[0], np.ndarray):
return self.source(args[0])
# if the function is 1-D:
if self.__dom_dim__ == 1:
# if the args is a simple number (int or float)
if isinstance(args[0], (int, float)):
return self.source(args[0])
# if the arguments are iterable, we map and return a list
if isinstance(args[0], Iterable):
return list(map(self.source, args[0]))

# if the function is n-D:
else:
return self.source(*args)
# if each arg is a simple number (int or float)
if all(isinstance(arg, (int, float)) for arg in args):
return self.source(*args)
# if each arg is iterable, we map and return a list
if all(isinstance(arg, Iterable) for arg in args):
return [self.source(*arg) for arg in zip(*args)]

# Returns value for shepard interpolation
elif self.__interpolation__ == "shepard":
if isinstance(args[0], (list, tuple)):
x = list(args[0])
if all(isinstance(arg, Iterable) for arg in args):
x = list(np.column_stack(args))
else:
x = [[float(x) for x in list(args)]]
ans = x
Expand Down Expand Up @@ -1256,12 +1329,15 @@ def plot_2d(
x = np.linspace(lower[0], upper[0], samples[0])
y = np.linspace(lower[1], upper[1], samples[1])
mesh_x, mesh_y = np.meshgrid(x, y)
mesh_x_flat, mesh_y_flat = mesh_x.flatten(), mesh_y.flatten()
mesh = [[mesh_x_flat[i], mesh_y_flat[i]] for i in range(len(mesh_x_flat))]

# Evaluate function at all mesh nodes and convert it to matrix
z = np.array(self.get_value(mesh)).reshape(mesh_x.shape)
z = np.array(self.get_value(mesh_x.flatten(), mesh_y.flatten())).reshape(
mesh_x.shape
)
z_min, z_max = z.min(), z.max()
color_map = plt.cm.get_cmap(cmap)
norm = plt.Normalize(z_min, z_max)

# Plot function
if disp_type == "surface":
surf = axes.plot_surface(
Expand Down Expand Up @@ -2663,6 +2739,152 @@ def compose(self, func, extrapolate=False):
extrapolation=self.__extrapolation__,
)

@staticmethod
def _check_user_input(
source,
inputs,
outputs,
interpolation,
extrapolation,
):
"""
Validates and processes the user input parameters for creating or
modifying a Function object. This function ensures the inputs, outputs,
interpolation, and extrapolation parameters are compatible with the
given source. It converts the source to a numpy array if necessary, sets
default values and raises warnings or errors for incompatible or
ill-defined parameters.
Parameters
----------
source : list, np.ndarray, or callable
The source data or Function object. If a list or ndarray, it should
contain numeric data. If a Function, its inputs and outputs are
checked against the provided inputs and outputs.
inputs : list of str or None
The names of the input variables. If None, defaults are generated
based on the dimensionality of the source.
outputs : str or list of str
The name(s) of the output variable(s). If a list is provided, it
must have a single element.
interpolation : str or None
The method of interpolation to be used. For multidimensional sources
it defaults to 'shepard' if not provided.
extrapolation : str or None
The method of extrapolation to be used. For multidimensional sources
it defaults to 'natural' if not provided.
Returns
-------
tuple
A tuple containing the processed inputs, outputs, interpolation, and
extrapolation parameters.
Raises
------
ValueError
If the dimensionality of the source does not match the combined
dimensions of inputs and outputs. If the outputs list has more than
one element.
TypeError
If the source is not a list, np.ndarray, or Function object.
Warning
If inputs or outputs do not match for a Function source, or if
defaults are used for inputs, interpolation,and extrapolation for a
multidimensional source.
Examples
--------
>>> from rocketpy import Function
>>> source = np.array([(1, 1), (2, 4), (3, 9)])
>>> inputs = "x"
>>> outputs = ["y"]
>>> interpolation = 'linear'
>>> extrapolation = 'zero'
>>> inputs, outputs, interpolation, extrapolation = Function._check_user_input(
... source, inputs, outputs, interpolation, extrapolation
... )
>>> inputs
['x']
>>> outputs
['y']
>>> interpolation
'linear'
>>> extrapolation
'zero'
"""
# check output type and dimensions
if isinstance(outputs, str):
outputs = [outputs]
if isinstance(inputs, str):
inputs = [inputs]

elif len(outputs) > 1:
raise ValueError(
"Output must either be a string or have dimension 1, "
+ f"it currently has dimension ({len(outputs)})."
)

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, (list, np.ndarray)):
# this will also trigger an error if the source is not a list of
# numbers or if the array is not homogeneous
source = np.array(source, dtype=np.float64)

# check dimensions
source_dim = source.shape[1]

# check interpolation and extrapolation
if source_dim > 2:
# check for inputs and outputs
if inputs == ["Scalar"]:
inputs = [f"Input {i+1}" for i in range(source_dim - 1)]
warnings.warn(
f"Inputs not set, defaulting to {inputs} for "
+ "multidimensional functions.",
)

if interpolation not in [None, "shepard"]:
interpolation = "shepard"
warnings.warn(
(
"Interpolation method for multidimensional functions is set"
"to 'shepard', currently other methods are not supported."
),
)

if extrapolation is None:
extrapolation = "natural"
warnings.warn(
"Extrapolation not set, defaulting to 'natural' "
+ "for multidimensional functions.",
)

# check input dimensions
in_out_dim = len(inputs) + len(outputs)
if source_dim != in_out_dim:
raise ValueError(
"Source dimension ({source_dim}) does not match input "
+ f"and output dimension ({in_out_dim})."
)

# if function, check for inputs and outputs
if isinstance(source, Function):
# check inputs
if inputs is not None and inputs != source.get_inputs():
warnings.warn(
f"Inputs do not match source inputs, setting inputs to {inputs}.",
)

# check outputs
if outputs is not None and outputs != source.get_outputs():
warnings.warn(
f"Outputs do not match source outputs, setting outputs to {outputs}.",
)

return inputs, outputs, interpolation, extrapolation


class PiecewiseFunction(Function):
"""Class for creating piecewise functions. These kind of functions are
Expand Down
Loading

0 comments on commit 00d1362

Please sign in to comment.