Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT: move piecewise functions to separate file #746

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Attention: The newest changes should be on top -->

### Changed

- MNT: move piecewise functions to separate file [#746](https://github.com/RocketPy-Team/RocketPy/pull/746)
- DOC: flight comparison improvements [#755](https://github.com/RocketPy-Team/RocketPy/pull/755)

### Fixed
Expand Down
8 changes: 2 additions & 6 deletions rocketpy/mathutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from .function import (
Function,
PiecewiseFunction,
funcify_method,
reset_funcified_methods,
)
from .function import Function, funcify_method, reset_funcified_methods
from .piecewise_function import PiecewiseFunction
from .vector_matrix import Matrix, Vector
103 changes: 0 additions & 103 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3419,109 +3419,6 @@ def __validate_extrapolation(self, extrapolation):
return extrapolation


class PiecewiseFunction(Function):
"""Class for creating piecewise functions. These kind of functions are
defined by a dictionary of functions, where the keys are tuples that
represent the domain of the function. The domains must be disjoint.
"""

def __new__(
cls,
source,
inputs=None,
outputs=None,
interpolation="spline",
extrapolation=None,
datapoints=100,
):
"""
Creates a piecewise function from a dictionary of functions. The keys of
the dictionary must be tuples that represent the domain of the function.
The domains must be disjoint. The piecewise function will be evaluated
at datapoints points to create Function object.

Parameters
----------
source: dictionary
A dictionary of Function objects, where the keys are the domains.
inputs : list of strings
A list of strings that represent the inputs of the function.
outputs: list of strings
A list of strings that represent the outputs of the function.
interpolation: str
The type of interpolation to use. The default value is 'spline'.
extrapolation: str
The type of extrapolation to use. The default value is None.
datapoints: int
The number of points in which the piecewise function will be
evaluated to create a base function. The default value is 100.
"""
if inputs is None:
inputs = ["Scalar"]
if outputs is None:
outputs = ["Scalar"]
# Check if source is a dictionary
if not isinstance(source, dict):
raise TypeError("source must be a dictionary")
# Check if all keys are tuples
for key in source.keys():
if not isinstance(key, tuple):
raise TypeError("keys of source must be tuples")
# Check if all domains are disjoint
for key1 in source.keys():
for key2 in source.keys():
if key1 != key2:
if key1[0] < key2[1] and key1[1] > key2[0]:
raise ValueError("domains must be disjoint")

# Crate Function
def calc_output(func, inputs):
"""Receives a list of inputs value and a function, populates another
list with the results corresponding to the same results.

Parameters
----------
func : Function
The Function object to be
inputs : list, tuple, np.array
The array of points to applied the func to.

Examples
--------
>>> inputs = [0, 1, 2, 3, 4, 5]
>>> def func(x):
... return x*10
>>> calc_output(func, inputs)
[0, 10, 20, 30, 40, 50]

Notes
-----
In the future, consider using the built-in map function from python.
"""
output = np.zeros(len(inputs))
for j, value in enumerate(inputs):
output[j] = func.get_value_opt(value)
return output

input_data = []
output_data = []
for key in sorted(source.keys()):
i = np.linspace(key[0], key[1], datapoints)
i = i[~np.isin(i, input_data)]
input_data = np.concatenate((input_data, i))

f = Function(source[key])
output_data = np.concatenate((output_data, calc_output(f, i)))

return Function(
np.concatenate(([input_data], [output_data])).T,
inputs=inputs,
outputs=outputs,
interpolation=interpolation,
extrapolation=extrapolation,
)


def funcify_method(*args, **kwargs): # pylint: disable=too-many-statements
"""Decorator factory to wrap methods as Function objects and save them as
cached properties.
Expand Down
94 changes: 94 additions & 0 deletions rocketpy/mathutils/piecewise_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np

from rocketpy.mathutils.function import Function


class PiecewiseFunction(Function):
"""Class for creating piecewise functions. These kind of functions are
defined by a dictionary of functions, where the keys are tuples that
represent the domain of the function. The domains must be disjoint.
"""

def __new__(
cls,
source,
inputs=None,
outputs=None,
interpolation="spline",
extrapolation=None,
datapoints=100,
):
"""
Creates a piecewise function from a dictionary of functions. The keys of
the dictionary must be tuples that represent the domain of the function.
The domains must be disjoint. The piecewise function will be evaluated
at datapoints points to create Function object.

Parameters
----------
source: dictionary
A dictionary of Function objects, where the keys are the domains.
inputs : list of strings
A list of strings that represent the inputs of the function.
outputs: list of strings
A list of strings that represent the outputs of the function.
interpolation: str
The type of interpolation to use. The default value is 'spline'.
extrapolation: str
The type of extrapolation to use. The default value is None.
datapoints: int
The number of points in which the piecewise function will be
evaluated to create a base function. The default value is 100.
"""
cls.__validate__source(source)
if inputs is None:
inputs = ["Scalar"]
if outputs is None:
outputs = ["Scalar"]

input_data = np.array([])
output_data = np.array([])
for lower, upper in sorted(source.keys()):
grid = np.linspace(lower, upper, datapoints)

# since intervals are disjoint and sorted, we only need to check
# if the first point is already included
if input_data.size != 0:
if lower == input_data[-1]:
grid = np.delete(grid, 0)
input_data = np.concatenate((input_data, grid))

f = Function(source[(lower, upper)])
output_data = np.concatenate((output_data, f.get_value(grid)))

return Function(
np.concatenate(([input_data], [output_data])).T,
inputs=inputs,
outputs=outputs,
interpolation=interpolation,
extrapolation=extrapolation,
)

@staticmethod
def __validate__source(source):
"""Validates that source is dictionary with non-overlapping
intervals

Parameters
----------
source : dict
A dictionary of Function objects, where the keys are the domains.
"""
# Check if source is a dictionary
if not isinstance(source, dict):
raise TypeError("source must be a dictionary")
# Check if all keys are tuples
for key in source.keys():
if not isinstance(key, tuple):
raise TypeError("keys of source must be tuples")
# Check if all domains are disjoint
for lower1, upper1 in source.keys():
for lower2, upper2 in source.keys():
if (lower1, upper1) != (lower2, upper2):
if lower1 < upper2 and upper1 > lower2:
raise ValueError("domains must be disjoint")
3 changes: 2 additions & 1 deletion rocketpy/motors/tank_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np

from ..mathutils.function import Function, PiecewiseFunction, funcify_method
from ..mathutils.function import Function, funcify_method
from ..mathutils.piecewise_function import PiecewiseFunction
from ..plots.tank_geometry_plots import _TankGeometryPlots
from ..prints.tank_geometry_prints import _TankGeometryPrints

Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_piecewise_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from rocketpy import PiecewiseFunction


@pytest.mark.parametrize(
"source",
[
((0, 4), lambda x: x),
{"0-4": lambda x: x},
{(0, 4): lambda x: x, (3, 5): lambda x: 2 * x},
],
)
def test_invalid_source(source):
"""Test an error is raised when the source parameter is invalid"""
with pytest.raises((TypeError, ValueError)):
PiecewiseFunction(source)


@pytest.mark.parametrize(
"source",
[
{(-1, 0): lambda x: -x, (0, 1): lambda x: x},
{
(0, 1): lambda x: x,
(1, 2): lambda x: 1,
(2, 3): lambda x: 3 - x,
},
],
)
@pytest.mark.parametrize("inputs", [None, "X"])
@pytest.mark.parametrize("outputs", [None, "Y"])
def test_new(source, inputs, outputs):
"""Test if PiecewiseFunction.__new__ runs correctly"""
PiecewiseFunction(source, inputs, outputs)