diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index fbacd6846..c4c079414 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -2988,8 +2988,8 @@ def _check_user_input( Returns ------- tuple - A tuple containing the processed inputs, outputs, interpolation, and - extrapolation parameters. + A tuple containing the processed source, inputs, outputs, + interpolation, and extrapolation parameters. Raises ------ @@ -2997,13 +2997,6 @@ def _check_user_input( If the dimensionality of the source does not match the combined dimensions of inputs and outputs. If the outputs list has more than one element. - TypeError - If the source is not a list, np.ndarray, Function object, str or - Path. - Warning - If inputs or outputs do not match for a Function source, or if - defaults are used for inputs, interpolation,and extrapolation for a - multidimensional source. Examples -------- @@ -3072,7 +3065,27 @@ def source_function(_): return source, inputs, outputs, interpolation, extrapolation @staticmethod - def _validate_inputs_outputs(inputs, outputs): # None | st | list[str] + def _validate_inputs_outputs(inputs, outputs): + """Used to validate the inputs and outputs parameters for creating a + Function object. It sets default values if they are not provided. + + Parameters + ---------- + inputs : str, list of str, None + The name(s) of the input variable(s). If None, defaults to "Scalar". + outputs : + The name of the output variables. If None, defaults to "Scalar". + + Returns + ------- + tuple + A tuple containing the validated inputs and outputs parameters. + + Raises + ------ + ValueError + If the output has more than one element. + """ if inputs is None: inputs = ["Scalar"] if outputs is None: @@ -3091,10 +3104,41 @@ def _validate_inputs_outputs(inputs, outputs): # None | st | list[str] @staticmethod def _validate_interpolation_and_extrapolation( - inputs: list[str], interpolation: str, extrapolation: str, source: np.ndarray + inputs, interpolation, extrapolation, source ): + """Used to validate the interpolation and extrapolation methods for + creating a Function object. It sets default values for interpolation + and extrapolation if they are not provided or if they are not supported + for the given source. The inputs and outputs may be modified if the + source is multidimensional. + + Parameters + ---------- + inputs : list of strings + List of inputs, each input is a string. Example: ['x', 'y'] + interpolation : str, None + The type of interpolation to use. The default method is 'spline'. + Currently supported values are 'spline', 'linear', 'polynomial', + 'akima', and 'shepard'. + extrapolation : str, None + The type of extrapolation to use. Currently supported values are + 'constant', 'natural', and 'zero'. The default method is 'constant'. + source : np.ndarray + The source data of the Function object. This has to be a numpy + array. + + Returns + ------- + tuple + A tuple with the validated inputs, interpolation, and extrapolation + parameters (inputs, interpolation, extrapolation). + + Raises + ------ + ValueError + If the source has less than 2 dimensions. + """ source_dim = source.shape[1] - # check interpolation and extrapolation ## single dimension (1D Functions) if source_dim == 2: # possible interpolation values: linear, polynomial, akima and spline @@ -3124,7 +3168,6 @@ def _validate_interpolation_and_extrapolation( ## multiple dimensions elif source_dim > 2: - # check for inputs and outputs if inputs == ["Scalar"]: inputs = [f"Input {i+1}" for i in range(source_dim - 1)] @@ -3148,8 +3191,25 @@ def _validate_interpolation_and_extrapolation( return inputs, interpolation, extrapolation @staticmethod - def _validate_source_dimensions(inputs: list, outputs: list, source: np.ndarray): - # check input dimensions + def _validate_source_dimensions(inputs, outputs, source): + """Used to check whether the source dimensions match the inputs and + outputs. + + Parameters + ---------- + inputs : list of strings + List of inputs, each input is a string. Example: ['x', 'y'] + outputs : list of strings + List of outputs, each output is a string. Example: ['z'] + source : np.ndarray + The source data of the Function object. This has to be a numpy + array. + + Raises + ------ + ValueError + In case the source dimensions do not match the inputs and outputs. + """ source_dim = source.shape[1] in_out_dim = len(inputs) + len(outputs) if source_dim != in_out_dim: