-
-
Notifications
You must be signed in to change notification settings - Fork 168
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MNT: move piecewise functions to separate file (#746)
* MNT: move piecewise functions to separate file closes #667 * improved import for linting * MNT: applying code formaters * ENH: simplifying and optimizing the function, implementing tests. * MNT: update changelog and apply changes suggested in review --------- Co-authored-by: Lucas Prates <[email protected]> Co-authored-by: Lucas de Oliveira Prates <[email protected]> Co-authored-by: Gui-FernandesBR <[email protected]>
- Loading branch information
1 parent
7a122ad
commit 2218f0f
Showing
6 changed files
with
134 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,3 @@ | ||
from .function import ( | ||
Function, | ||
PiecewiseFunction, | ||
funcify_method, | ||
reset_funcified_methods, | ||
) | ||
from .function import Function, funcify_method, reset_funcified_methods | ||
from .piecewise_function import PiecewiseFunction | ||
from .vector_matrix import Matrix, Vector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import numpy as np | ||
|
||
from rocketpy.mathutils.function import Function | ||
|
||
|
||
class PiecewiseFunction(Function): | ||
"""Class for creating piecewise functions. These kind of functions are | ||
defined by a dictionary of functions, where the keys are tuples that | ||
represent the domain of the function. The domains must be disjoint. | ||
""" | ||
|
||
def __new__( | ||
cls, | ||
source, | ||
inputs=None, | ||
outputs=None, | ||
interpolation="spline", | ||
extrapolation=None, | ||
datapoints=100, | ||
): | ||
""" | ||
Creates a piecewise function from a dictionary of functions. The keys of | ||
the dictionary must be tuples that represent the domain of the function. | ||
The domains must be disjoint. The piecewise function will be evaluated | ||
at datapoints points to create Function object. | ||
Parameters | ||
---------- | ||
source: dictionary | ||
A dictionary of Function objects, where the keys are the domains. | ||
inputs : list of strings | ||
A list of strings that represent the inputs of the function. | ||
outputs: list of strings | ||
A list of strings that represent the outputs of the function. | ||
interpolation: str | ||
The type of interpolation to use. The default value is 'spline'. | ||
extrapolation: str | ||
The type of extrapolation to use. The default value is None. | ||
datapoints: int | ||
The number of points in which the piecewise function will be | ||
evaluated to create a base function. The default value is 100. | ||
""" | ||
cls.__validate__source(source) | ||
if inputs is None: | ||
inputs = ["Scalar"] | ||
if outputs is None: | ||
outputs = ["Scalar"] | ||
|
||
input_data = np.array([]) | ||
output_data = np.array([]) | ||
for lower, upper in sorted(source.keys()): | ||
grid = np.linspace(lower, upper, datapoints) | ||
|
||
# since intervals are disjoint and sorted, we only need to check | ||
# if the first point is already included | ||
if input_data.size != 0: | ||
if lower == input_data[-1]: | ||
grid = np.delete(grid, 0) | ||
input_data = np.concatenate((input_data, grid)) | ||
|
||
f = Function(source[(lower, upper)]) | ||
output_data = np.concatenate((output_data, f.get_value(grid))) | ||
|
||
return Function( | ||
np.concatenate(([input_data], [output_data])).T, | ||
inputs=inputs, | ||
outputs=outputs, | ||
interpolation=interpolation, | ||
extrapolation=extrapolation, | ||
) | ||
|
||
@staticmethod | ||
def __validate__source(source): | ||
"""Validates that source is dictionary with non-overlapping | ||
intervals | ||
Parameters | ||
---------- | ||
source : dict | ||
A dictionary of Function objects, where the keys are the domains. | ||
""" | ||
# Check if source is a dictionary | ||
if not isinstance(source, dict): | ||
raise TypeError("source must be a dictionary") | ||
# Check if all keys are tuples | ||
for key in source.keys(): | ||
if not isinstance(key, tuple): | ||
raise TypeError("keys of source must be tuples") | ||
# Check if all domains are disjoint | ||
for lower1, upper1 in source.keys(): | ||
for lower2, upper2 in source.keys(): | ||
if (lower1, upper1) != (lower2, upper2): | ||
if lower1 < upper2 and upper1 > lower2: | ||
raise ValueError("domains must be disjoint") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
|
||
from rocketpy import PiecewiseFunction | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"source", | ||
[ | ||
((0, 4), lambda x: x), | ||
{"0-4": lambda x: x}, | ||
{(0, 4): lambda x: x, (3, 5): lambda x: 2 * x}, | ||
], | ||
) | ||
def test_invalid_source(source): | ||
"""Test an error is raised when the source parameter is invalid""" | ||
with pytest.raises((TypeError, ValueError)): | ||
PiecewiseFunction(source) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"source", | ||
[ | ||
{(-1, 0): lambda x: -x, (0, 1): lambda x: x}, | ||
{ | ||
(0, 1): lambda x: x, | ||
(1, 2): lambda x: 1, | ||
(2, 3): lambda x: 3 - x, | ||
}, | ||
], | ||
) | ||
@pytest.mark.parametrize("inputs", [None, "X"]) | ||
@pytest.mark.parametrize("outputs", [None, "Y"]) | ||
def test_new(source, inputs, outputs): | ||
"""Test if PiecewiseFunction.__new__ runs correctly""" | ||
PiecewiseFunction(source, inputs, outputs) |