Skip to content

Commit

Permalink
Refactor code to improve performance and
Browse files Browse the repository at this point in the history
readability and...
- Add more doctests in docstring examples
- Fix an error in the get_value method for n-d functions
- remove type hinting
- adjust docstrings
  • Loading branch information
Gui-FernandesBR committed Nov 13, 2023
1 parent f1eb70a commit 8ea0124
Showing 1 changed file with 130 additions and 44 deletions.
174 changes: 130 additions & 44 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from inspect import signature
from pathlib import Path
from typing import List, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -21,12 +20,12 @@ class Function:

def __init__(
self,
source: Union[str, List, np.ndarray, float, int, callable],
inputs: Union[str, List] = ["Scalar"],
outputs: Union[str, List] = ["Scalar"],
interpolation: str = None,
extrapolation: str = None,
title: str = None,
source,
inputs=["Scalar"],
outputs=["Scalar"],
interpolation=None,
extrapolation=None,
title=None,
):
"""Convert source into a Function, to be used more naturally.
Set inputs, outputs, domain dimension, interpolation and extrapolation
Expand Down Expand Up @@ -143,7 +142,7 @@ def set_source(self, source):
-------
self : Function
"""
inputs, outputs, interpolation, extrapolation = self._check_user_input(
_, _, _, _ = self._check_user_input(
source,
self.__inputs__,
self.__outputs__,
Expand Down Expand Up @@ -757,6 +756,57 @@ def get_value(self, *args):
Returns
-------
ans : scalar, list
Value of the Function at the specified point(s).
Examples
--------
>>> from rocketpy import Function
Testing with callable source (1 dimension):
>>> f = Function(lambda x: x**2)
>>> f.get_value(2)
4
>>> f.get_value(2.5)
6.25
>>> f.get_value([1, 2, 3])
[1, 4, 9]
>>> f.get_value([1, 2.5, 4.0])
[1, 6.25, 16.0]
Testing with callable source (2 dimensions):
>>> f2 = Function(lambda x, y: x**2 + y**2)
>>> f2.get_value(1, 2)
5
>>> f2.get_value([1, 2, 3], [1, 2, 3])
[2, 8, 18]
>>> f2.get_value([5], [5])
[50]
Testing with ndarray source (1 dimension):
>>> f3 = Function(
... [(0, 0), (1, 1), (1.5, 2.25), (2, 4), (2.5, 6.25), (3, 9), (4, 16)]
... )
>>> f3.get_value(2)
4.0
>>> f3.get_value(2.5)
6.25
>>> f3.get_value([1, 2, 3])
[1.0, 4.0, 9.0]
>>> f3.get_value([1, 2.5, 4.0])
[1.0, 6.25, 16.0]
Testing with ndarray source (2 dimensions):
>>> f4 = Function(
... [(0, 0, 0), (1, 1, 1), (1, 2, 2), (2, 4, 8), (3, 9, 27)]
... )
>>> f4.get_value(1, 1)
1.0
>>> f4.get_value(2, 4)
8.0
>>> abs(f4.get_value(1, 1.5) - 1.5) < 1e-2 # the interpolation is not perfect
True
>>> f4.get_value(3, 9)
27.0
"""
if len(args) != self.__dom_dim__:
raise ValueError(
Expand All @@ -765,15 +815,24 @@ def get_value(self, *args):

# Return value for Function of function type
if callable(self.source):
if len(args) == 1 and isinstance(args[0], (list, tuple)):
if isinstance(args[0][0], (tuple, list)):
return [self.source(*arg) for arg in args[0]]
else:
return [self.source(arg) for arg in args[0]]
elif len(args) == 1 and isinstance(args[0], np.ndarray):
return self.source(args[0])
# if the function is 1-D:
if self.__dom_dim__ == 1:
# if the args is a simple number (int or float)
if isinstance(args[0], (int, float)):
return self.source(args[0])
# if the arguments are iterable, we map and return a list
if isinstance(args[0], (list, tuple, np.ndarray)):
return list(map(self.source, args[0]))

# if the function is n-D:
else:
return self.source(*args)
# if each arg is a simple number (int or float)
if all(isinstance(arg, (int, float)) for arg in args):
return self.source(*args)
# if each arg is iterable, we map and return a list
if all(isinstance(arg, (list, tuple, np.ndarray)) for arg in args):
return [self.source(*arg) for arg in zip(*args)]

# Returns value for shepard interpolation
elif self.__interpolation__ == "shepard":
if isinstance(args[0], (list, tuple)):
Expand Down Expand Up @@ -1230,8 +1289,8 @@ def plot2D(
x = np.linspace(lower[0], upper[0], samples[0])
y = np.linspace(lower[1], upper[1], samples[1])
mesh_x, mesh_y = np.meshgrid(x, y)
mesh_x_flat, mesh_y_flat = mesh_x.flatten(), mesh_y.flatten()
mesh = [[mesh_x_flat[i], mesh_y_flat[i]] for i in range(len(mesh_x_flat))]
mesh = np.column_stack((mesh_x.flatten(), mesh_y.flatten()))

# Evaluate function at all mesh nodes and convert it to matrix
z = np.array(self.get_value(mesh[:, 0], mesh[:, 1])).reshape(mesh_x.shape)
# Plot function
Expand Down Expand Up @@ -2634,30 +2693,31 @@ def _check_user_input(
extrapolation,
):
"""
Validates and processes the user input parameters for creating or modifying a
Function object. This function ensures the inputs, outputs, interpolation,
and extrapolation parameters are compatible with the given source. It converts
the source to a numpy array if necessary, sets default values and raises
warnings or errors for incompatible or ill-defined parameters.
Validates and processes the user input parameters for creating or
modifying a Function object. This function ensures the inputs, outputs,
interpolation, and extrapolation parameters are compatible with the
given source. It converts the source to a numpy array if necessary, sets
default values and raises warnings or errors for incompatible or
ill-defined parameters.
Parameters
----------
source : list, np.ndarray, or Function
The source data or Function object. If a list or ndarray, it should contain
numeric data. If a Function, its inputs and outputs are checked against the
provided inputs and outputs.
The source data or Function object. If a list or ndarray, it should
contain numeric data. If a Function, its inputs and outputs are
checked against the provided inputs and outputs.
inputs : list of str or None
The names of the input variables. If None, defaults are generated based on
the dimensionality of the source.
The names of the input variables. If None, defaults are generated
based on the dimensionality of the source.
outputs : str or list of str
The name(s) of the output variable(s). If a list is provided, it must have
a single element.
The name(s) of the output variable(s). If a list is provided, it
must have a single element.
interpolation : str or None
The method of interpolation to be used. For multidimensional sources,
defaults to 'shepard' if not provided.
The method of interpolation to be used. For multidimensional sources
it defaults to 'shepard' if not provided.
extrapolation : str or None
The method of extrapolation to be used. For multidimensional sources,
defaults to 'natural' if not provided.
The method of extrapolation to be used. For multidimensional sources
it defaults to 'natural' if not provided.
Returns
-------
Expand All @@ -2668,14 +2728,35 @@ def _check_user_input(
Raises
------
ValueError
If the dimensionality of the source does not match the combined dimensions
of inputs and outputs. If the outputs list has more than one element.
If the dimensionality of the source does not match the combined
dimensions of inputs and outputs. If the outputs list has more than
one element.
TypeError
If the source is not a list, np.ndarray, or Function object.
Warning
If inputs or outputs do not match for a Function source, or if defaults are
used for inputs, interpolation,and extrapolation for a multidimensional
source.
If inputs or outputs do not match for a Function source, or if
defaults are used for inputs, interpolation,and extrapolation for a
multidimensional source.
Examples
--------
>>> from rocketpy import Function
>>> source = np.array([(1, 1), (2, 4), (3, 9)])
>>> inputs = "x"
>>> outputs = ["y"]
>>> interpolation = 'linear'
>>> extrapolation = 'zero'
>>> inputs, outputs, interpolation, extrapolation = Function._check_user_input(
... source, inputs, outputs, interpolation, extrapolation
... )
>>> inputs
['x']
>>> outputs
['y']
>>> interpolation
'linear'
>>> extrapolation
'zero'
"""
# check output type and dimensions
if isinstance(outputs, str):
Expand All @@ -2685,13 +2766,15 @@ def _check_user_input(

elif len(outputs) > 1:
raise ValueError(
f"Output must either be a string or have dimension 1, it currently has dimension ({len(outputs)})."
"Output must either be a string or have dimension 1, "
+ f"it currently has dimension ({len(outputs)})."
)

# check source for data type
# if list or ndarray, check for dimensions, interpolation and extrapolation
if isinstance(source, (list, np.ndarray)):
# this will also trigger an error if the source is not a list of numbers or if the array is not homogeneous
# this will also trigger an error if the source is not a list of
# numbers or if the array is not homogeneous
source = np.array(source, dtype=np.float64)

# check dimensions
Expand All @@ -2703,7 +2786,8 @@ def _check_user_input(
if inputs == ["Scalar"]:
inputs = [f"Input {i+1}" for i in range(source_dim - 1)]
warnings.warn(
f"Inputs not set, defaulting to {inputs} for multidimensional functions.",
f"Inputs not set, defaulting to {inputs} for "
+ "multidimensional functions.",
)

if interpolation not in [None, "shepard"]:
Expand All @@ -2718,14 +2802,16 @@ def _check_user_input(
if extrapolation is None:
extrapolation = "natural"
warnings.warn(
"Extrapolation not set, defaulting to 'natural' for multidimensional functions.",
"Extrapolation not set, defaulting to 'natural' "
+ "for multidimensional functions.",
)

# check input dimensions
in_out_dim = len(inputs) + len(outputs)
if source_dim != in_out_dim:
raise ValueError(
f"Source dimension ({source_dim}) does not match input and output dimension ({in_out_dim})."
"Source dimension ({source_dim}) does not match input "
+ f"and output dimension ({in_out_dim})."
)

# if function, check for inputs and outputs
Expand Down

0 comments on commit 8ea0124

Please sign in to comment.